Compare commits

..

1 Commits

Author SHA1 Message Date
f1c07a72c4 convert model_merging and video_model nodes to V3 schema 2026-03-11 12:25:59 +02:00
29 changed files with 848 additions and 2204 deletions

View File

@ -176,8 +176,8 @@ class InputTypeOptions(TypedDict):
"""COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
gradient_stops: NotRequired[list[dict]]
"""Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}."""
gradient_stops: NotRequired[list[list[float]]]
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
class HiddenInputTypeDict(TypedDict):

View File

@ -144,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
return tensor * m_mult
else:
for d in modulation_dims:
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1]
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
if m_add is not None:
tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1]
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
return tensor

View File

@ -44,22 +44,6 @@ class FluxParams:
txt_norm: bool = False
def invert_slices(slices, length):
sorted_slices = sorted(slices)
result = []
current = 0
for start, end in sorted_slices:
if current < start:
result.append((current, start))
current = max(current, end)
if current < length:
result.append((current, length))
return result
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
@ -154,7 +138,6 @@ class Flux(nn.Module):
y: Tensor,
guidance: Tensor = None,
control = None,
timestep_zero_index=None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
@ -181,6 +164,10 @@ class Flux(nn.Module):
txt = self.txt_norm(txt)
txt = self.txt_in(txt)
vec_orig = vec
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
@ -195,24 +182,6 @@ class Flux(nn.Module):
else:
pe = None
vec_orig = vec
txt_vec = vec
extra_kwargs = {}
if timestep_zero_index is not None:
modulation_dims = []
batch = vec.shape[0] // 2
vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
invert = invert_slices(timestep_zero_index, img.shape[1])
for s in invert:
modulation_dims.append((s[0], s[1], 0))
for s in timestep_zero_index:
modulation_dims.append((s[0], s[1], 1))
extra_kwargs["modulation_dims_img"] = modulation_dims
txt_vec = vec[:batch]
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
@ -226,8 +195,7 @@ class Flux(nn.Module):
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"),
**extra_kwargs)
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
@ -245,8 +213,7 @@ class Flux(nn.Module):
vec=vec,
pe=pe,
attn_mask=attn_mask,
transformer_options=transformer_options,
**extra_kwargs)
transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
@ -263,12 +230,6 @@ class Flux(nn.Module):
if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig)
extra_kwargs = {}
if timestep_zero_index is not None:
lambda a: 0 if a == 0 else a + txt.shape[1]
modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
extra_kwargs["modulation_dims"] = modulation_dims_combined
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
@ -281,8 +242,7 @@ class Flux(nn.Module):
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"),
**extra_kwargs)
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
@ -293,7 +253,7 @@ class Flux(nn.Module):
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
@ -304,11 +264,7 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...]
extra_kwargs = {}
if timestep_zero_index is not None:
extra_kwargs["modulation_dims"] = modulation_dims
img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
return img
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
@ -356,16 +312,13 @@ class Flux(nn.Module):
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x, transformer_options=transformer_options)
img_tokens = img.shape[1]
timestep_zero_index = None
if ref_latents is not None:
ref_num_tokens = []
h = 0
w = 0
index = 0
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
timestep_zero = ref_latents_method == "index_timestep_zero"
for ref in ref_latents:
if ref_latents_method in ("index", "index_timestep_zero"):
if ref_latents_method == "index":
index += self.params.ref_index_scale
h_offset = 0
w_offset = 0
@ -389,13 +342,6 @@ class Flux(nn.Module):
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
@ -403,6 +349,6 @@ class Flux(nn.Module):
for i in self.params.txt_ids_dims:
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]

View File

@ -270,15 +270,10 @@ try:
except:
OOM_EXCEPTION = Exception
try:
ACCELERATOR_ERROR = torch.AcceleratorError
except AttributeError:
ACCELERATOR_ERROR = RuntimeError
def is_oom(e):
if isinstance(e, OOM_EXCEPTION):
return True
if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2:
discard_cuda_async_error()
return True
return False
@ -1280,7 +1275,7 @@ def discard_cuda_async_error():
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
synchronize()
except RuntimeError:
except torch.AcceleratorError:
#Dump it! We already know about it from the synchronous return
pass

View File

@ -25,7 +25,6 @@ class ComfyAPI_latest(ComfyAPIBase):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
self.caching = self.Caching()
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
@ -85,36 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
class Caching(ProxiedSingleton):
"""
External cache provider API for sharing cached node outputs
across ComfyUI instances.
Example::
from comfy_api.latest import Caching
class MyCacheProvider(Caching.CacheProvider):
async def on_lookup(self, context):
... # check external storage
async def on_store(self, context, value):
... # store to external storage
Caching.register_provider(MyCacheProvider())
"""
from ._caching import CacheProvider, CacheContext, CacheValue
async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
"""Register an external cache provider. Providers are called in registration order."""
from comfy_execution.cache_provider import register_cache_provider
register_cache_provider(provider)
async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
"""Unregister a previously registered cache provider."""
from comfy_execution.cache_provider import unregister_cache_provider
unregister_cache_provider(provider)
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
@ -147,9 +116,6 @@ class Types:
VOXEL = VOXEL
File3D = File3D
Caching = ComfyAPI_latest.Caching
ComfyAPI = ComfyAPI_latest
# Create a synchronous version of the API
@ -169,7 +135,6 @@ __all__ = [
"Input",
"InputImpl",
"Types",
"Caching",
"ComfyExtension",
"io",
"IO",

View File

@ -1,42 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional
from dataclasses import dataclass
@dataclass
class CacheContext:
node_id: str
class_type: str
cache_key_hash: str # SHA256 hex digest
@dataclass
class CacheValue:
outputs: list
ui: dict = None
class CacheProvider(ABC):
"""Abstract base class for external cache providers.
Exceptions from provider methods are caught by the caller and never break execution.
"""
@abstractmethod
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
pass
@abstractmethod
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
"""Called after local store. Dispatched via asyncio.create_task."""
pass
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
"""Return False to skip external caching for this node. Default: True."""
return True
def on_prompt_start(self, prompt_id: str) -> None:
pass
def on_prompt_end(self, prompt_id: str) -> None:
pass

View File

@ -272,7 +272,7 @@ class VideoFromFile(VideoInput):
has_first_frame = False
for frame in frames:
offset_seconds = start_time - frame.pts * audio_stream.time_base
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
to_skip = int(offset_seconds * audio_stream.sample_rate)
if to_skip < frame.samples:
has_first_frame = True
break
@ -280,7 +280,7 @@ class VideoFromFile(VideoInput):
audio_frames.append(frame.to_ndarray()[..., to_skip:])
for frame in frames:
if self.__duration and frame.time > start_time + self.__duration:
if frame.time > start_time + self.__duration:
break
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:

View File

@ -297,7 +297,7 @@ class Float(ComfyTypeIO):
'''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, gradient_stops: list[dict]=None,
display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None,
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min

View File

@ -1,68 +0,0 @@
from pydantic import BaseModel, Field
class RevePostprocessingOperation(BaseModel):
process: str = Field(..., description="The postprocessing operation: upscale or remove_background.")
upscale_factor: int | None = Field(
None,
description="Upscale factor (2, 3, or 4). Only used when process is upscale.",
ge=2,
le=4,
)
class ReveImageCreateRequest(BaseModel):
prompt: str = Field(...)
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageEditRequest(BaseModel):
edit_instruction: str = Field(...)
reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageRemixRequest(BaseModel):
prompt: str = Field(...)
reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageResponse(BaseModel):
image: str | None = Field(None, description="The base64 encoded image data.")
request_id: str | None = Field(None, description="A unique id for the request.")
credits_used: float | None = Field(None, description="The number of credits used for this request.")
version: str | None = Field(None, description="The specific model version used.")
content_violation: bool | None = Field(
None, description="Indicates whether the generated image violates the content policy."
)

View File

@ -1,7 +1,3 @@
import zipfile
from io import BytesIO
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, Types
@ -21,10 +17,7 @@ from comfy_api_nodes.apis.hunyuan3d import (
)
from comfy_api_nodes.util import (
ApiEndpoint,
bytesio_to_image_tensor,
download_url_to_bytesio,
download_url_to_file_3d,
download_url_to_image_tensor,
downscale_image_tensor_by_max_side,
poll_op,
sync_op,
@ -43,68 +36,6 @@ def _is_tencent_rate_limited(status: int, body: object) -> bool:
)
class ObjZipResult:
__slots__ = ("obj", "texture", "metallic", "normal", "roughness")
def __init__(
self,
obj: Types.File3D,
texture: Input.Image | None = None,
metallic: Input.Image | None = None,
normal: Input.Image | None = None,
roughness: Input.Image | None = None,
):
self.obj = obj
self.texture = texture
self.metallic = metallic
self.normal = normal
self.roughness = roughness
async def download_and_extract_obj_zip(url: str) -> ObjZipResult:
"""The Tencent API returns OBJ results as ZIP archives containing the .obj mesh, and texture images.
When PBR is enabled, the ZIP may contain additional metallic, normal, and roughness maps
identified by their filename suffixes.
"""
data = BytesIO()
await download_url_to_bytesio(url, data)
data.seek(0)
if not zipfile.is_zipfile(data):
data.seek(0)
return ObjZipResult(obj=Types.File3D(source=data, file_format="obj"))
data.seek(0)
obj_bytes = None
textures: dict[str, Input.Image] = {}
with zipfile.ZipFile(data) as zf:
for name in zf.namelist():
lower = name.lower()
if lower.endswith(".obj"):
obj_bytes = zf.read(name)
elif any(lower.endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp")):
stem = lower.rsplit(".", 1)[0]
tensor = bytesio_to_image_tensor(BytesIO(zf.read(name)), mode="RGB")
matched_key = "texture"
for suffix, key in {
"_metallic": "metallic",
"_normal": "normal",
"_roughness": "roughness",
}.items():
if stem.endswith(suffix):
matched_key = key
break
textures[matched_key] = tensor
if obj_bytes is None:
raise ValueError("ZIP archive does not contain an OBJ file.")
return ObjZipResult(
obj=Types.File3D(source=BytesIO(obj_bytes), file_format="obj"),
texture=textures.get("texture"),
metallic=textures.get("metallic"),
normal=textures.get("normal"),
roughness=textures.get("roughness"),
)
def get_file_from_response(
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
) -> ResultFile3D | None:
@ -162,7 +93,6 @@ class TencentTextToModelNode(IO.ComfyNode):
IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.File3DGLB.Output(display_name="GLB"),
IO.File3DOBJ.Output(display_name="OBJ"),
IO.Image.Output(display_name="texture_image"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@ -221,14 +151,14 @@ class TencentTextToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
return IO.NodeOutput(
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
obj_result.obj,
obj_result.texture,
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
),
)
@ -281,10 +211,6 @@ class TencentImageToModelNode(IO.ComfyNode):
IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.File3DGLB.Output(display_name="GLB"),
IO.File3DOBJ.Output(display_name="OBJ"),
IO.Image.Output(display_name="texture_image"),
IO.Image.Output(display_name="optional_metallic"),
IO.Image.Output(display_name="optional_normal"),
IO.Image.Output(display_name="optional_roughness"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@ -378,17 +304,14 @@ class TencentImageToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
return IO.NodeOutput(
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
obj_result.obj,
obj_result.texture,
obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3),
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
),
)
@ -508,8 +431,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
],
outputs=[
IO.File3DGLB.Output(display_name="GLB"),
IO.File3DOBJ.Output(display_name="OBJ"),
IO.Image.Output(display_name="texture_image"),
IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@ -558,8 +480,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "texture_image").Url),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
)
@ -733,7 +654,7 @@ class TencentHunyuan3DExtension(ComfyExtension):
TencentTextToModelNode,
TencentImageToModelNode,
TencentModelTo3DUVNode,
Tencent3DTextureEditNode,
# Tencent3DTextureEditNode,
Tencent3DPartNode,
TencentSmartTopologyNode,
]

View File

@ -1,395 +0,0 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.reve import (
ReveImageCreateRequest,
ReveImageEditRequest,
ReveImageRemixRequest,
RevePostprocessingOperation,
)
from comfy_api_nodes.util import (
ApiEndpoint,
bytesio_to_image_tensor,
sync_op_raw,
tensor_to_base64_string,
validate_string,
)
def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None:
ops = []
if upscale["upscale"] == "enabled":
ops.append(
RevePostprocessingOperation(
process="upscale",
upscale_factor=upscale["upscale_factor"],
)
)
if remove_background:
ops.append(RevePostprocessingOperation(process="remove_background"))
return ops or None
def _postprocessing_inputs():
return [
IO.DynamicCombo.Input(
"upscale",
options=[
IO.DynamicCombo.Option("disabled", []),
IO.DynamicCombo.Option(
"enabled",
[
IO.Int.Input(
"upscale_factor",
default=2,
min=2,
max=4,
step=1,
tooltip="Upscale factor (2x, 3x, or 4x).",
),
],
),
],
tooltip="Upscale the generated image. May add additional cost.",
),
IO.Boolean.Input(
"remove_background",
default=False,
tooltip="Remove the background from the generated image. May add additional cost.",
),
]
def _reve_price_extractor(headers: dict) -> float | None:
credits_used = headers.get("x-reve-credits-used")
if credits_used is not None:
return float(credits_used) / 524.48
return None
def _reve_response_header_validator(headers: dict) -> None:
error_code = headers.get("x-reve-error-code")
if error_code:
raise ValueError(f"Reve API error: {error_code}")
if headers.get("x-reve-content-violation", "").lower() == "true":
raise ValueError("The generated image was flagged for content policy violation.")
def _model_inputs(versions: list[str], aspect_ratios: list[str]):
return [
IO.DynamicCombo.Option(
version,
[
IO.Combo.Input(
"aspect_ratio",
options=aspect_ratios,
tooltip="Aspect ratio of the output image.",
),
IO.Int.Input(
"test_time_scaling",
default=1,
min=1,
max=5,
step=1,
tooltip="Higher values produce better images but cost more credits.",
advanced=True,
),
],
)
for version in versions
]
class ReveImageCreateNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageCreateNode",
display_name="Reve Image Create",
category="api node/image/Reve",
description="Generate images from text descriptions using Reve.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. Maximum 2560 characters.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-create@20250915"],
aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for generation.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2560)
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/create",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageCreateRequest(
prompt=prompt,
aspect_ratio=model["aspect_ratio"],
version=model["model"],
test_time_scaling=model["test_time_scaling"],
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveImageEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageEditNode",
display_name="Reve Image Edit",
category="api node/image/Reve",
description="Edit images using natural language instructions with Reve.",
inputs=[
IO.Image.Input("image", tooltip="The image to edit."),
IO.String.Input(
"edit_instruction",
multiline=True,
default="",
tooltip="Text description of how to edit the image. Maximum 2560 characters.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-edit@20250915", "reve-edit-fast@20251030"],
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for editing.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
),
expr="""
(
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
edit_instruction: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(edit_instruction, min_length=1, max_length=2560)
tts = model["test_time_scaling"]
ar = model["aspect_ratio"]
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/edit",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageEditRequest(
edit_instruction=edit_instruction,
reference_image=tensor_to_base64_string(image),
aspect_ratio=ar if ar != "auto" else None,
version=model["model"],
test_time_scaling=tts if tts and tts > 1 else None,
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveImageRemixNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageRemixNode",
display_name="Reve Image Remix",
category="api node/image/Reve",
description="Combine reference images with text prompts to create new images using Reve.",
inputs=[
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="image_",
min=1,
max=6,
),
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. "
"May include XML img tags to reference specific images by index, "
"e.g. <img>0</img>, <img>1</img>, etc.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-remix@20250915", "reve-remix-fast@20251030"],
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for remixing.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
),
expr="""
(
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
)
@classmethod
async def execute(
cls,
reference_images: IO.Autogrow.Type,
prompt: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2560)
if not reference_images:
raise ValueError("At least one reference image is required.")
ref_base64_list = []
for key in reference_images:
ref_base64_list.append(tensor_to_base64_string(reference_images[key]))
if len(ref_base64_list) > 6:
raise ValueError("Maximum 6 reference images are allowed.")
tts = model["test_time_scaling"]
ar = model["aspect_ratio"]
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/remix",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageRemixRequest(
prompt=prompt,
reference_images=ref_base64_list,
aspect_ratio=ar if ar != "auto" else None,
version=model["model"],
test_time_scaling=tts if tts and tts > 1 else None,
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ReveImageCreateNode,
ReveImageEditNode,
ReveImageRemixNode,
]
async def comfy_entrypoint() -> ReveExtension:
return ReveExtension()

View File

@ -67,7 +67,6 @@ class _RequestConfig:
progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
is_rate_limited: Callable[[int, Any], bool] | None = None
response_header_validator: Callable[[dict[str, str]], None] | None = None
@dataclass
@ -203,13 +202,11 @@ async def sync_op_raw(
monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None,
response_header_validator: Callable[[dict[str, str]], None] | None = None,
) -> dict[str, Any] | bytes:
"""
Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
- If as_binary=True: returns bytes.
- response_header_validator: optional callback receiving response headers dict
"""
if isinstance(data, BaseModel):
data = data.model_dump(exclude_none=True)
@ -235,7 +232,6 @@ async def sync_op_raw(
price_extractor=price_extractor,
max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited,
response_header_validator=response_header_validator,
)
return await _request_base(cfg, expect_binary=as_binary)
@ -773,12 +769,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
)
bytes_payload = bytes(buff)
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
if cfg.price_extractor:
with contextlib.suppress(Exception):
extracted_price = cfg.price_extractor(resp_headers)
if cfg.response_header_validator:
cfg.response_header_validator(resp_headers)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response(
@ -786,7 +776,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=resp_headers,
response_headers=dict(resp.headers),
response_content=bytes_payload,
)
return bytes_payload

View File

@ -1,138 +0,0 @@
from typing import Any, Optional, Tuple, List
import hashlib
import json
import logging
import threading
# Public types — source of truth is comfy_api.latest._caching
from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported)
_logger = logging.getLogger(__name__)
_providers: List[CacheProvider] = []
_providers_lock = threading.Lock()
_providers_snapshot: Tuple[CacheProvider, ...] = ()
def register_cache_provider(provider: CacheProvider) -> None:
"""Register an external cache provider. Providers are called in registration order."""
global _providers_snapshot
with _providers_lock:
if provider in _providers:
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
return
_providers.append(provider)
_providers_snapshot = tuple(_providers)
_logger.debug(f"Registered cache provider: {provider.__class__.__name__}")
def unregister_cache_provider(provider: CacheProvider) -> None:
global _providers_snapshot
with _providers_lock:
try:
_providers.remove(provider)
_providers_snapshot = tuple(_providers)
_logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}")
except ValueError:
_logger.warning(f"Provider {provider.__class__.__name__} was not registered")
def _get_cache_providers() -> Tuple[CacheProvider, ...]:
return _providers_snapshot
def _has_cache_providers() -> bool:
return bool(_providers_snapshot)
def _clear_cache_providers() -> None:
global _providers_snapshot
with _providers_lock:
_providers.clear()
_providers_snapshot = ()
def _canonicalize(obj: Any) -> Any:
# Convert to canonical JSON-serializable form with deterministic ordering.
# Frozensets have non-deterministic iteration order between Python sessions.
# Raises ValueError for non-cacheable types (Unhashable, unknown) so that
# _serialize_cache_key returns None and external caching is skipped.
if isinstance(obj, frozenset):
return ("__frozenset__", sorted(
[_canonicalize(item) for item in obj],
key=lambda x: json.dumps(x, sort_keys=True)
))
elif isinstance(obj, set):
return ("__set__", sorted(
[_canonicalize(item) for item in obj],
key=lambda x: json.dumps(x, sort_keys=True)
))
elif isinstance(obj, tuple):
return ("__tuple__", [_canonicalize(item) for item in obj])
elif isinstance(obj, list):
return [_canonicalize(item) for item in obj]
elif isinstance(obj, dict):
return {"__dict__": sorted(
[[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
key=lambda x: json.dumps(x, sort_keys=True)
)}
elif isinstance(obj, (int, float, str, bool, type(None))):
return (type(obj).__name__, obj)
elif isinstance(obj, bytes):
return ("__bytes__", obj.hex())
else:
raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")
def _serialize_cache_key(cache_key: Any) -> Optional[str]:
# Returns deterministic SHA256 hex digest, or None on failure.
# Uses JSON (not pickle) because pickle is non-deterministic across sessions.
try:
canonical = _canonicalize(cache_key)
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
except Exception as e:
_logger.warning(f"Failed to serialize cache key: {e}")
return None
def _contains_self_unequal(obj: Any) -> bool:
# Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
# never hit locally, but serialized form would match externally. Skip these.
try:
if not (obj == obj):
return True
except Exception:
return True
if isinstance(obj, (frozenset, tuple, list, set)):
return any(_contains_self_unequal(item) for item in obj)
if isinstance(obj, dict):
return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
if hasattr(obj, 'value'):
return _contains_self_unequal(obj.value)
return False
def _estimate_value_size(value: CacheValue) -> int:
try:
import torch
except ImportError:
return 0
total = 0
def estimate(obj):
nonlocal total
if isinstance(obj, torch.Tensor):
total += obj.numel() * obj.element_size()
elif isinstance(obj, dict):
for v in obj.values():
estimate(v)
elif isinstance(obj, (list, tuple)):
for item in obj:
estimate(item)
for output in value.outputs:
estimate(output)
return total

View File

@ -1,4 +1,3 @@
import asyncio
import bisect
import gc
import itertools
@ -148,15 +147,13 @@ class CacheKeySetInputSignature(CacheKeySet):
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
class BasicCache:
def __init__(self, key_class, enable_providers=False):
def __init__(self, key_class):
self.key_class = key_class
self.initialized = False
self.enable_providers = enable_providers
self.dynprompt: DynamicPrompt
self.cache_key_set: CacheKeySet
self.cache = {}
self.subcaches = {}
self._pending_store_tasks: set = set()
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
@ -199,138 +196,18 @@ class BasicCache:
def poll(self, **kwargs):
pass
def get_local(self, node_id):
def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
def _get_immediate(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
return None
def set_local(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
async def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
await self._notify_providers_store(node_id, cache_key, value)
async def _get_immediate(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
external_result = await self._check_providers_lookup(node_id, cache_key)
if external_result is not None:
self.cache[cache_key] = external_result
return external_result
return None
async def _notify_providers_store(self, node_id, cache_key, value):
from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers,
CacheValue, _contains_self_unequal, _logger
)
if not self.enable_providers:
return
if not _has_cache_providers():
return
if not self._is_external_cacheable_value(value):
return
if _contains_self_unequal(cache_key):
return
context = self._build_context(node_id, cache_key)
if context is None:
return
cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
for provider in _get_cache_providers():
try:
if provider.should_cache(context, cache_value):
task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
self._pending_store_tasks.add(task)
task.add_done_callback(self._pending_store_tasks.discard)
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
@staticmethod
async def _safe_provider_store(provider, context, cache_value):
from comfy_execution.cache_provider import _logger
try:
await provider.on_store(context, cache_value)
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
async def _check_providers_lookup(self, node_id, cache_key):
from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers,
CacheValue, _contains_self_unequal, _logger
)
if not self.enable_providers:
return None
if not _has_cache_providers():
return None
if _contains_self_unequal(cache_key):
return None
context = self._build_context(node_id, cache_key)
if context is None:
return None
for provider in _get_cache_providers():
try:
if not provider.should_cache(context):
continue
result = await provider.on_lookup(context)
if result is not None:
if not isinstance(result, CacheValue):
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
continue
if not isinstance(result.outputs, (list, tuple)):
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
continue
from execution import CacheEntry
return CacheEntry(ui=result.ui, outputs=list(result.outputs))
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
return None
def _is_external_cacheable_value(self, value):
return hasattr(value, 'outputs') and hasattr(value, 'ui')
def _get_class_type(self, node_id):
if not self.initialized or not self.dynprompt:
return ''
try:
return self.dynprompt.get_node(node_id).get('class_type', '')
except Exception:
return ''
def _build_context(self, node_id, cache_key):
from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
try:
cache_key_hash = _serialize_cache_key(cache_key)
if cache_key_hash is None:
return None
return CacheContext(
node_id=node_id,
class_type=self._get_class_type(node_id),
cache_key_hash=cache_key_hash,
)
except Exception as e:
_logger.warning(f"Failed to build cache context for node {node_id}: {e}")
else:
return None
async def _ensure_subcache(self, node_id, children_ids):
@ -359,8 +236,8 @@ class BasicCache:
return result
class HierarchicalCache(BasicCache):
def __init__(self, key_class, enable_providers=False):
super().__init__(key_class, enable_providers=enable_providers)
def __init__(self, key_class):
super().__init__(key_class)
def _get_cache_for(self, node_id):
assert self.dynprompt is not None
@ -380,27 +257,16 @@ class HierarchicalCache(BasicCache):
return None
return cache
async def get(self, node_id):
def get(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return await cache._get_immediate(node_id)
return cache._get_immediate(node_id)
def get_local(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return BasicCache.get_local(cache, node_id)
async def set(self, node_id, value):
def set(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
await cache._set_immediate(node_id, value)
def set_local(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
BasicCache.set_local(cache, node_id, value)
cache._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
@ -421,24 +287,18 @@ class NullCache:
def poll(self, **kwargs):
pass
async def get(self, node_id):
def get(self, node_id):
return None
def get_local(self, node_id):
return None
async def set(self, node_id, value):
pass
def set_local(self, node_id, value):
def set(self, node_id, value):
pass
async def ensure_subcache_for(self, node_id, children_ids):
return self
class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100, enable_providers=False):
super().__init__(key_class, enable_providers=enable_providers)
def __init__(self, key_class, max_size=100):
super().__init__(key_class)
self.max_size = max_size
self.min_generation = 0
self.generation = 0
@ -462,18 +322,18 @@ class LRUCache(BasicCache):
del self.children[key]
self._clean_subcaches()
async def get(self, node_id):
def get(self, node_id):
self._mark_used(node_id)
return await self._get_immediate(node_id)
return self._get_immediate(node_id)
def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None:
self.used_generation[cache_key] = self.generation
async def set(self, node_id, value):
def set(self, node_id, value):
self._mark_used(node_id)
return await self._set_immediate(node_id, value)
return self._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
@ -506,20 +366,20 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache):
def __init__(self, key_class, enable_providers=False):
super().__init__(key_class, 0, enable_providers=enable_providers)
def __init__(self, key_class):
super().__init__(key_class, 0)
self.timestamps = {}
def clean_unused(self):
self._clean_subcaches()
async def set(self, node_id, value):
def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
await super().set(node_id, value)
super().set(node_id, value)
async def get(self, node_id):
def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return await super().get(node_id)
return super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():

View File

@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort):
self.execution_cache_listeners = {}
def is_cached(self, node_id):
return self.output_cache.get_local(node_id) is not None
return self.output_cache.get(node_id) is not None
def cache_link(self, from_node_id, to_node_id):
if to_node_id not in self.execution_cache:
self.execution_cache[to_node_id] = {}
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id)
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
if from_node_id not in self.execution_cache_listeners:
self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id)
@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort):
if value is None:
return None
#Write back to the main cache on touch.
self.output_cache.set_local(from_node_id, value)
self.output_cache.set(from_node_id, value)
return value
def cache_update(self, node_id, value):

View File

@ -6,7 +6,6 @@ import comfy.model_management
import torch
import math
import nodes
import comfy.ldm.flux.math
class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod
@ -232,68 +231,6 @@ class Flux2Scheduler(io.ComfyNode):
sigmas = get_schedule(steps, round(seq_len))
return io.NodeOutput(sigmas)
class KV_Attn_Input:
def __init__(self):
self.cache = {}
def __call__(self, q, k, v, extra_options, **kwargs):
reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
if len(reference_image_num_tokens) == 0:
return {}
ref_toks = sum(reference_image_num_tokens)
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
if cache_key in self.cache:
kk, vv = self.cache[cache_key]
self.set_cache = False
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone())
self.set_cache = True
return {"q": q, "k": k, "v": v}
def cleanup(self):
self.cache = {}
class FluxKVCache(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="FluxKVCache",
display_name="Flux KV Cache",
description="Enables KV Cache optimization for reference images on Flux family models.",
category="",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to use KV Cache on."),
],
outputs=[
io.Model.Output(tooltip="The patched model with KV Cache enabled."),
],
)
@classmethod
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
m = model.clone()
input_patch_obj = KV_Attn_Input()
def model_input_patch(inputs):
if len(input_patch_obj.cache) > 0:
ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
if ref_image_tokens > 0:
img = inputs["img"]
inputs["img"] = img[:, :-ref_image_tokens]
return inputs
m.set_model_attn1_patch(input_patch_obj)
m.set_model_post_input_patch(model_input_patch)
if hasattr(model.model.diffusion_model, "params"):
m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
else:
m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
return io.NodeOutput(m)
class FluxExtension(ComfyExtension):
@override
@ -306,7 +243,6 @@ class FluxExtension(ComfyExtension):
FluxKontextMultiReferenceLatentMethod,
EmptyFlux2LatentImage,
Flux2Scheduler,
FluxKVCache,
]

View File

@ -10,146 +10,198 @@ import json
import os
from comfy.cli_args import args
from comfy_api.latest import io, ComfyExtension
from typing_extensions import override
class ModelMergeSimple:
class ModelMergeSimple(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="ModelMergeSimple",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2, ratio):
@classmethod
def execute(cls, model1, model2, ratio) -> io.NodeOutput:
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
return io.NodeOutput(m)
class ModelSubtract:
merge = execute # TODO: remove
class ModelSubtract(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="ModelMergeSubtract",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2, multiplier):
@classmethod
def execute(cls, model1, model2, multiplier) -> io.NodeOutput:
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, - multiplier, multiplier)
return (m, )
return io.NodeOutput(m)
class ModelAdd:
merge = execute # TODO: remove
class ModelAdd(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="ModelMergeAdd",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2):
@classmethod
def execute(cls, model1, model2) -> io.NodeOutput:
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, 1.0, 1.0)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
class CLIPMergeSimple:
class CLIPMergeSimple(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="CLIPMergeSimple",
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip1"),
io.Clip.Input("clip2"),
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2, ratio):
@classmethod
def execute(cls, clip1, clip2, ratio) -> io.NodeOutput:
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
class CLIPSubtract:
SEARCH_ALIASES = ["clip difference", "text encoder subtract"]
class CLIPSubtract(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="CLIPMergeSubtract",
search_aliases=["clip difference", "text encoder subtract"],
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip1"),
io.Clip.Input("clip2"),
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2, multiplier):
@classmethod
def execute(cls, clip1, clip2, multiplier) -> io.NodeOutput:
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, - multiplier, multiplier)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
class CLIPAdd:
SEARCH_ALIASES = ["combine clip"]
class CLIPAdd(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="CLIPMergeAdd",
search_aliases=["combine clip"],
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip1"),
io.Clip.Input("clip2"),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2):
@classmethod
def execute(cls, clip1, clip2) -> io.NodeOutput:
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, 1.0, 1.0)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
class ModelMergeBlocks:
class ModelMergeBlocks(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="ModelMergeBlocks",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("input", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("middle", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("out", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2, **kwargs):
@classmethod
def execute(cls, model1, model2, **kwargs) -> io.NodeOutput:
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
default_ratio = next(iter(kwargs.values()))
@ -165,7 +217,10 @@ class ModelMergeBlocks:
last_arg_size = len(arg)
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
@ -226,59 +281,65 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
class CheckpointSave:
SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"]
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
class CheckpointSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CheckpointSave",
display_name="Save Checkpoint",
search_aliases=["save model", "export checkpoint", "merge save"],
category="advanced/model_merging",
inputs=[
io.Model.Input("model"),
io.Clip.Input("clip"),
io.Vae.Input("vae"),
io.String.Input("filename_prefix", default="checkpoints/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
def execute(cls, model, clip, vae, filename_prefix) -> io.NodeOutput:
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
return io.NodeOutput()
CATEGORY = "advanced/model_merging"
save = execute # TODO: remove
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
class CLIPSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
class CLIPSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPSave",
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip"),
io.String.Input("filename_prefix", default="clip/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip": ("CLIP",),
"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
def execute(cls, clip, filename_prefix) -> io.NodeOutput:
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
if cls.hidden.prompt is not None:
prompt_info = json.dumps(cls.hidden.prompt)
metadata = {}
if not args.disable_metadata:
metadata["format"] = "pt"
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
clip_sd = clip.get_sd()
output_dir = folder_paths.get_output_directory()
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
current_clip_sd = {}
@ -295,7 +356,7 @@ class CLIPSave:
replace_prefix[prefix] = ""
replace_prefix["transformer."] = ""
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir)
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, output_dir)
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
@ -303,76 +364,88 @@ class CLIPSave:
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
return {}
return io.NodeOutput()
class VAESave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
save = execute # TODO: remove
class VAESave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VAESave",
category="advanced/model_merging",
inputs=[
io.Vae.Input("vae"),
io.String.Input("filename_prefix", default="vae/ComfyUI_vae"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
def execute(cls, vae, filename_prefix) -> io.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
if cls.hidden.prompt is not None:
prompt_info = json.dumps(cls.hidden.prompt)
metadata = {}
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
return {}
return io.NodeOutput()
class ModelSave:
SEARCH_ALIASES = ["export model", "checkpoint save"]
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
save = execute # TODO: remove
class ModelSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ModelSave",
search_aliases=["export model", "checkpoint save"],
category="advanced/model_merging",
inputs=[
io.Model.Input("model"),
io.String.Input("filename_prefix", default="diffusion_models/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
def execute(cls, model, filename_prefix) -> io.NodeOutput:
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
return io.NodeOutput()
CATEGORY = "advanced/model_merging"
save = execute # TODO: remove
def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple,
"ModelMergeBlocks": ModelMergeBlocks,
"ModelMergeSubtract": ModelSubtract,
"ModelMergeAdd": ModelAdd,
"CheckpointSave": CheckpointSave,
"CLIPMergeSimple": CLIPMergeSimple,
"CLIPMergeSubtract": CLIPSubtract,
"CLIPMergeAdd": CLIPAdd,
"CLIPSave": CLIPSave,
"VAESave": VAESave,
"ModelSave": ModelSave,
}
class ModelMergingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ModelMergeSimple,
ModelMergeBlocks,
ModelSubtract,
ModelAdd,
CheckpointSave,
CLIPMergeSimple,
CLIPSubtract,
CLIPAdd,
CLIPSave,
VAESave,
ModelSave,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointSave": "Save Checkpoint",
}
async def comfy_entrypoint() -> ModelMergingExtension:
return ModelMergingExtension()

View File

@ -1,356 +1,455 @@
import comfy_extras.nodes_model_merging
from comfy_api.latest import io, ComfyExtension
from typing_extensions import override
class ModelMergeSD1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument
inputs.append(io.Float.Input("time_embed.", **argument))
inputs.append(io.Float.Input("label_emb.", **argument))
for i in range(12):
arg_dict["input_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("input_blocks.{}.".format(i), **argument))
for i in range(3):
arg_dict["middle_block.{}.".format(i)] = argument
inputs.append(io.Float.Input("middle_block.{}.".format(i), **argument))
for i in range(12):
arg_dict["output_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("output_blocks.{}.".format(i), **argument))
arg_dict["out."] = argument
inputs.append(io.Float.Input("out.", **argument))
return {"required": arg_dict}
return io.Schema(
node_id="ModelMergeSD1",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
class ModelMergeSD2(ModelMergeSD1):
# SD1 and SD2 have the same blocks
@classmethod
def define_schema(cls):
schema = ModelMergeSD1.define_schema()
schema.node_id = "ModelMergeSD2"
return schema
class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument
inputs.append(io.Float.Input("time_embed.", **argument))
inputs.append(io.Float.Input("label_emb.", **argument))
for i in range(9):
arg_dict["input_blocks.{}".format(i)] = argument
inputs.append(io.Float.Input("input_blocks.{}".format(i), **argument))
for i in range(3):
arg_dict["middle_block.{}".format(i)] = argument
inputs.append(io.Float.Input("middle_block.{}".format(i), **argument))
for i in range(9):
arg_dict["output_blocks.{}".format(i)] = argument
inputs.append(io.Float.Input("output_blocks.{}".format(i), **argument))
arg_dict["out."] = argument
inputs.append(io.Float.Input("out.", **argument))
return io.Schema(
node_id="ModelMergeSDXL",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["pos_embed."] = argument
arg_dict["x_embedder."] = argument
arg_dict["context_embedder."] = argument
arg_dict["y_embedder."] = argument
arg_dict["t_embedder."] = argument
inputs.append(io.Float.Input("pos_embed.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("context_embedder.", **argument))
inputs.append(io.Float.Input("y_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
for i in range(24):
arg_dict["joint_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("joint_blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return {"required": arg_dict}
return io.Schema(
node_id="ModelMergeSD3_2B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
class ModelMergeAuraflow(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["init_x_linear."] = argument
arg_dict["positional_encoding"] = argument
arg_dict["cond_seq_linear."] = argument
arg_dict["register_tokens"] = argument
arg_dict["t_embedder."] = argument
inputs.append(io.Float.Input("init_x_linear.", **argument))
inputs.append(io.Float.Input("positional_encoding", **argument))
inputs.append(io.Float.Input("cond_seq_linear.", **argument))
inputs.append(io.Float.Input("register_tokens", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
for i in range(4):
arg_dict["double_layers.{}.".format(i)] = argument
inputs.append(io.Float.Input("double_layers.{}.".format(i), **argument))
for i in range(32):
arg_dict["single_layers.{}.".format(i)] = argument
inputs.append(io.Float.Input("single_layers.{}.".format(i), **argument))
arg_dict["modF."] = argument
arg_dict["final_linear."] = argument
inputs.append(io.Float.Input("modF.", **argument))
inputs.append(io.Float.Input("final_linear.", **argument))
return io.Schema(
node_id="ModelMergeAuraflow",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["img_in."] = argument
arg_dict["time_in."] = argument
arg_dict["guidance_in"] = argument
arg_dict["vector_in."] = argument
arg_dict["txt_in."] = argument
inputs.append(io.Float.Input("img_in.", **argument))
inputs.append(io.Float.Input("time_in.", **argument))
inputs.append(io.Float.Input("guidance_in", **argument))
inputs.append(io.Float.Input("vector_in.", **argument))
inputs.append(io.Float.Input("txt_in.", **argument))
for i in range(19):
arg_dict["double_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("double_blocks.{}.".format(i), **argument))
for i in range(38):
arg_dict["single_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("single_blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeFlux1",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeSD35_Large(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["pos_embed."] = argument
arg_dict["x_embedder."] = argument
arg_dict["context_embedder."] = argument
arg_dict["y_embedder."] = argument
arg_dict["t_embedder."] = argument
inputs.append(io.Float.Input("pos_embed.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("context_embedder.", **argument))
inputs.append(io.Float.Input("y_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
for i in range(38):
arg_dict["joint_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("joint_blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeSD35_Large",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeMochiPreview(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["pos_frequencies."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t5_y_embedder."] = argument
arg_dict["t5_yproj."] = argument
inputs.append(io.Float.Input("pos_frequencies.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("t5_y_embedder.", **argument))
inputs.append(io.Float.Input("t5_yproj.", **argument))
for i in range(48):
arg_dict["blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeMochiPreview",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["patchify_proj."] = argument
arg_dict["adaln_single."] = argument
arg_dict["caption_projection."] = argument
inputs.append(io.Float.Input("patchify_proj.", **argument))
inputs.append(io.Float.Input("adaln_single.", **argument))
inputs.append(io.Float.Input("caption_projection.", **argument))
for i in range(28):
arg_dict["transformer_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("transformer_blocks.{}.".format(i), **argument))
arg_dict["scale_shift_table"] = argument
arg_dict["proj_out."] = argument
inputs.append(io.Float.Input("scale_shift_table", **argument))
inputs.append(io.Float.Input("proj_out.", **argument))
return io.Schema(
node_id="ModelMergeLTXV",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeCosmos7B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["extra_pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["affline_norm."] = argument
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
inputs.append(io.Float.Input("pos_embedder.", **argument))
inputs.append(io.Float.Input("extra_pos_embedder.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("affline_norm.", **argument))
for i in range(28):
arg_dict["blocks.block{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.block{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeCosmos7B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["extra_pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["affline_norm."] = argument
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
inputs.append(io.Float.Input("pos_embedder.", **argument))
inputs.append(io.Float.Input("extra_pos_embedder.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("affline_norm.", **argument))
for i in range(36):
arg_dict["blocks.block{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.block{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeCosmos14B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb."
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["patch_embedding."] = argument
arg_dict["time_embedding."] = argument
arg_dict["time_projection."] = argument
arg_dict["text_embedding."] = argument
arg_dict["img_emb."] = argument
inputs.append(io.Float.Input("patch_embedding.", **argument))
inputs.append(io.Float.Input("time_embedding.", **argument))
inputs.append(io.Float.Input("time_projection.", **argument))
inputs.append(io.Float.Input("text_embedding.", **argument))
inputs.append(io.Float.Input("img_emb.", **argument))
for i in range(40):
arg_dict["blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
arg_dict["head."] = argument
inputs.append(io.Float.Input("head.", **argument))
return io.Schema(
node_id="ModelMergeWAN2_1",
category="advanced/model_merging/model_specific",
description="1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb.",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeCosmosPredict2_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t_embedding_norm."] = argument
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
inputs.append(io.Float.Input("pos_embedder.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("t_embedding_norm.", **argument))
for i in range(28):
arg_dict["blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeCosmosPredict2_2B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t_embedding_norm."] = argument
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
inputs.append(io.Float.Input("pos_embedder.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("t_embedding_norm.", **argument))
for i in range(36):
arg_dict["blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeCosmosPredict2_14B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["pos_embeds."] = argument
arg_dict["img_in."] = argument
arg_dict["txt_norm."] = argument
arg_dict["txt_in."] = argument
arg_dict["time_text_embed."] = argument
inputs.append(io.Float.Input("pos_embeds.", **argument))
inputs.append(io.Float.Input("img_in.", **argument))
inputs.append(io.Float.Input("txt_norm.", **argument))
inputs.append(io.Float.Input("txt_in.", **argument))
inputs.append(io.Float.Input("time_text_embed.", **argument))
for i in range(60):
arg_dict["transformer_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("transformer_blocks.{}.".format(i), **argument))
arg_dict["proj_out."] = argument
inputs.append(io.Float.Input("proj_out.", **argument))
return {"required": arg_dict}
return io.Schema(
node_id="ModelMergeQwenImage",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
"ModelMergeSDXL": ModelMergeSDXL,
"ModelMergeSD3_2B": ModelMergeSD3_2B,
"ModelMergeAuraflow": ModelMergeAuraflow,
"ModelMergeFlux1": ModelMergeFlux1,
"ModelMergeSD35_Large": ModelMergeSD35_Large,
"ModelMergeMochiPreview": ModelMergeMochiPreview,
"ModelMergeLTXV": ModelMergeLTXV,
"ModelMergeCosmos7B": ModelMergeCosmos7B,
"ModelMergeCosmos14B": ModelMergeCosmos14B,
"ModelMergeWAN2_1": ModelMergeWAN2_1,
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
"ModelMergeQwenImage": ModelMergeQwenImage,
}
class ModelMergingModelSpecificExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ModelMergeSD1,
ModelMergeSD2,
ModelMergeSDXL,
ModelMergeSD3_2B,
ModelMergeAuraflow,
ModelMergeFlux1,
ModelMergeSD35_Large,
ModelMergeMochiPreview,
ModelMergeLTXV,
ModelMergeCosmos7B,
ModelMergeCosmos14B,
ModelMergeWAN2_1,
ModelMergeCosmosPredict2_2B,
ModelMergeCosmosPredict2_14B,
ModelMergeQwenImage,
]
async def comfy_entrypoint() -> ModelMergingModelSpecificExtension:
return ModelMergingModelSpecificExtension()

View File

@ -1,127 +0,0 @@
from __future__ import annotations
import hashlib
import os
import numpy as np
import torch
from PIL import Image
import folder_paths
import node_helpers
from comfy_api.latest import ComfyExtension, io, UI
from typing_extensions import override
def hex_to_rgb(hex_color: str) -> tuple[float, float, float]:
hex_color = hex_color.lstrip("#")
if len(hex_color) != 6:
return (0.0, 0.0, 0.0)
r = int(hex_color[0:2], 16) / 255.0
g = int(hex_color[2:4], 16) / 255.0
b = int(hex_color[4:6], 16) / 255.0
return (r, g, b)
class PainterNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Painter",
display_name="Painter",
category="image",
inputs=[
io.Image.Input(
"image",
optional=True,
tooltip="Optional base image to paint over",
),
io.String.Input(
"mask",
default="",
socketless=True,
extra_dict={"widgetType": "PAINTER", "image_upload": True},
),
io.Int.Input(
"width",
default=512,
min=64,
max=4096,
step=64,
socketless=True,
extra_dict={"hidden": True},
),
io.Int.Input(
"height",
default=512,
min=64,
max=4096,
step=64,
socketless=True,
extra_dict={"hidden": True},
),
io.Color.Input("bg_color", default="#000000"),
],
outputs=[
io.Image.Output("IMAGE"),
io.Mask.Output("MASK"),
],
)
@classmethod
def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput:
if image is not None:
base_image = image[:1]
h, w = base_image.shape[1], base_image.shape[2]
else:
h, w = height, width
r, g, b = hex_to_rgb(bg_color)
base_image = torch.zeros((1, h, w, 3), dtype=torch.float32)
base_image[0, :, :, 0] = r
base_image[0, :, :, 1] = g
base_image[0, :, :, 2] = b
if mask and mask.strip():
mask_path = folder_paths.get_annotated_filepath(mask)
painter_img = node_helpers.pillow(Image.open, mask_path)
painter_img = painter_img.convert("RGBA")
if painter_img.size != (w, h):
painter_img = painter_img.resize((w, h), Image.LANCZOS)
painter_np = np.array(painter_img).astype(np.float32) / 255.0
painter_rgb = painter_np[:, :, :3]
painter_alpha = painter_np[:, :, 3:4]
mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0)
base_np = base_image[0].cpu().numpy()
composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha)
out_image = torch.from_numpy(composited).unsqueeze(0)
else:
mask_tensor = torch.zeros((1, h, w), dtype=torch.float32)
out_image = base_image
return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image))
@classmethod
def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None):
if mask and mask.strip():
mask_path = folder_paths.get_annotated_filepath(mask)
if os.path.exists(mask_path):
m = hashlib.sha256()
with open(mask_path, "rb") as f:
m.update(f.read())
return m.digest().hex()
return ""
class PainterExtension(ComfyExtension):
@override
async def get_node_list(self):
return [PainterNode]
async def comfy_entrypoint():
return PainterExtension()

View File

@ -6,44 +6,62 @@ import folder_paths
import comfy_extras.nodes_model_merging
import node_helpers
from comfy_api.latest import io, ComfyExtension
from typing_extensions import override
class ImageOnlyCheckpointLoader:
class ImageOnlyCheckpointLoader(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint"
def define_schema(cls):
return io.Schema(
node_id="ImageOnlyCheckpointLoader",
display_name="Image Only Checkpoint Loader (img2vid model)",
category="loaders/video_models",
inputs=[
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("checkpoints")),
],
outputs=[
io.Model.Output(),
io.ClipVision.Output(),
io.Vae.Output(),
],
)
CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
@classmethod
def execute(cls, ckpt_name, output_vae=True, output_clip=True) -> io.NodeOutput:
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2])
return io.NodeOutput(out[0], out[3], out[2])
load_checkpoint = execute # TODO: remove
class SVD_img2vid_Conditioning:
class SVD_img2vid_Conditioning(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023, "advanced": True}),
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01, "advanced": True})
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
def define_schema(cls):
return io.Schema(
node_id="SVD_img2vid_Conditioning",
category="conditioning/video_models",
inputs=[
io.ClipVision.Input("clip_vision"),
io.Image.Input("init_image"),
io.Vae.Input("vae"),
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("video_frames", default=14, min=1, max=4096),
io.Int.Input("motion_bucket_id", default=127, min=1, max=1023, advanced=True),
io.Int.Input("fps", default=6, min=1, max=1024),
io.Float.Input("augmentation_level", default=0.0, min=0.0, max=10.0, step=0.01, advanced=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
@classmethod
def execute(cls, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level) -> io.NodeOutput:
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
@ -54,20 +72,28 @@ class SVD_img2vid_Conditioning:
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
return io.NodeOutput(positive, negative, {"samples":latent})
class VideoLinearCFGGuidance:
encode = execute # TODO: remove
class VideoLinearCFGGuidance(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01, "advanced": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="VideoLinearCFGGuidance",
category="sampling/video_models",
inputs=[
io.Model.Input("model"),
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01, advanced=True),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
@classmethod
def execute(cls, model, min_cfg) -> io.NodeOutput:
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
@ -78,20 +104,28 @@ class VideoLinearCFGGuidance:
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
return io.NodeOutput(m)
class VideoTriangleCFGGuidance:
patch = execute # TODO: remove
class VideoTriangleCFGGuidance(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01, "advanced": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="VideoTriangleCFGGuidance",
category="sampling/video_models",
inputs=[
io.Model.Input("model"),
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01, advanced=True),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
@classmethod
def execute(cls, model, min_cfg) -> io.NodeOutput:
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
@ -105,57 +139,79 @@ class VideoTriangleCFGGuidance:
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
return io.NodeOutput(m)
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
CATEGORY = "advanced/model_merging"
patch = execute # TODO: remove
class ImageOnlyCheckpointSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageOnlyCheckpointSave",
search_aliases=["save model", "export checkpoint", "merge save"],
category="advanced/model_merging",
inputs=[
io.Model.Input("model"),
io.ClipVision.Input("clip_vision"),
io.Vae.Input("vae"),
io.String.Input("filename_prefix", default="checkpoints/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip_vision": ("CLIP_VISION",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
def execute(cls, model, clip_vision, vae, filename_prefix) -> io.NodeOutput:
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
return io.NodeOutput()
def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
save = execute # TODO: remove
class ConditioningSetAreaPercentageVideo:
class ConditioningSetAreaPercentageVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"temporal": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"z": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
def define_schema(cls):
return io.Schema(
node_id="ConditioningSetAreaPercentageVideo",
category="conditioning",
inputs=[
io.Conditioning.Input("conditioning"),
io.Float.Input("width", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("height", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("temporal", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("x", default=0.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("y", default=0.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("z", default=0.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "conditioning"
def append(self, conditioning, width, height, temporal, x, y, z, strength):
@classmethod
def execute(cls, conditioning, width, height, temporal, x, y, z, strength) -> io.NodeOutput:
c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", temporal, height, width, z, y, x),
"strength": strength,
"set_area_to_bounds": False})
return (c, )
return io.NodeOutput(c)
append = execute # TODO: remove
NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
"ConditioningSetAreaPercentageVideo": ConditioningSetAreaPercentageVideo,
}
class VideoModelExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ImageOnlyCheckpointLoader,
SVD_img2vid_Conditioning,
VideoLinearCFGGuidance,
VideoTriangleCFGGuidance,
ImageOnlyCheckpointSave,
ConditioningSetAreaPercentageVideo,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
}
async def comfy_entrypoint() -> VideoModelExtension:
return VideoModelExtension()

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.17.0"
__version__ = "0.16.4"

View File

@ -40,7 +40,6 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
class ExecutionResult(Enum):
@ -127,15 +126,15 @@ class CacheSet:
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature, enable_providers=True)
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size, enable_providers=True)
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature, enable_providers=True)
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self):
@ -419,7 +418,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
cached = await caches.outputs.get(unique_id)
cached = caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None:
cached_ui = cached.ui or {}
@ -475,10 +474,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
obj = await caches.objects.get(unique_id)
obj = caches.objects.get(unique_id)
if obj is None:
obj = class_def()
await caches.objects.set(unique_id, obj)
caches.objects.set(unique_id, obj)
if issubclass(class_def, _ComfyNodeInternal):
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
@ -589,7 +588,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, cache_entry)
await caches.outputs.set(unique_id, cache_entry)
caches.outputs.set(unique_id, cache_entry)
except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")
@ -685,19 +684,6 @@ class PromptExecutor:
}
self.add_message("execution_error", mes, broadcast=False)
def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
if not _has_cache_providers():
return
for provider in _get_cache_providers():
try:
if event == "start":
provider.on_prompt_start(prompt_id)
elif event == "end":
provider.on_prompt_end(prompt_id)
except Exception as e:
_cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}")
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
@ -714,75 +700,66 @@ class PromptExecutor:
self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
self._notify_prompt_lifecycle("start", prompt_id)
with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused()
try:
with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused()
cached_nodes = []
for node_id in prompt:
if self.caches.outputs.get(node_id) is not None:
cached_nodes.append(node_id)
node_ids = list(prompt.keys())
cache_results = await asyncio.gather(
*(self.caches.outputs.get(node_id) for node_id in node_ids)
)
cached_nodes = [
node_id for node_id, result in zip(node_ids, cache_results)
if result is not None
]
comfy.model_management.cleanup_models_gc()
self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
for node_id in list(execute_outputs):
execution_list.add_node(node_id)
comfy.model_management.cleanup_models_gc()
self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
for node_id in list(execute_outputs):
execution_list.add_node(node_id)
while not execution_list.is_empty():
node_id, error, ex = await execution_list.stage_node_execution()
if error is not None:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
while not execution_list.is_empty():
node_id, error, ex = await execution_list.stage_node_execution()
if error is not None:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
elif result == ExecutionResult.PENDING:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
elif result == ExecutionResult.PENDING:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
ui_outputs = {}
meta_outputs = {}
for node_id, ui_info in ui_node_outputs.items():
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
}
self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
finally:
self._notify_prompt_lifecycle("end", prompt_id)
ui_outputs = {}
meta_outputs = {}
for node_id, ui_info in ui_node_outputs.items():
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
}
self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
async def validate_inputs(prompt_id, prompt, item, validated):

View File

@ -32,7 +32,7 @@ async def cache_control(
)
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
response.headers.setdefault("Cache-Control", "no-store")
response.headers.setdefault("Cache-Control", "no-cache")
return response
# Early return for non-image files - no cache headers needed

View File

@ -2450,7 +2450,6 @@ async def init_builtin_extra_nodes():
"nodes_nag.py",
"nodes_sdpose.py",
"nodes_math.py",
"nodes_painter.py",
]
import_failed = []

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.17.0"
version = "0.16.4"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.41.19
comfyui-workflow-templates==0.9.21
comfyui-frontend-package==1.39.19
comfyui-workflow-templates==0.9.18
comfyui-embedded-docs==0.4.3
torch
torchsde
@ -22,8 +22,8 @@ alembic
SQLAlchemy
filelock
av>=14.2.0
comfy-kitchen>=0.2.8
comfy-aimdo>=0.2.10
comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.9
requests
simpleeval>=1.0.0
blake3

View File

@ -310,7 +310,7 @@ class PromptServer():
@routes.get("/")
async def get_root(request):
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
response.headers['Cache-Control'] = 'no-store, must-revalidate'
response.headers['Cache-Control'] = 'no-cache'
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response

View File

@ -1,403 +0,0 @@
"""Tests for external cache provider API."""
import importlib.util
import pytest
from typing import Optional
def _torch_available() -> bool:
"""Check if PyTorch is available."""
return importlib.util.find_spec("torch") is not None
from comfy_execution.cache_provider import (
CacheProvider,
CacheContext,
CacheValue,
register_cache_provider,
unregister_cache_provider,
_get_cache_providers,
_has_cache_providers,
_clear_cache_providers,
_serialize_cache_key,
_contains_self_unequal,
_estimate_value_size,
_canonicalize,
)
class TestCanonicalize:
"""Test _canonicalize function for deterministic ordering."""
def test_frozenset_ordering_is_deterministic(self):
"""Frozensets should produce consistent canonical form regardless of iteration order."""
# Create two frozensets with same content
fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)])
fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)])
result1 = _canonicalize(fs1)
result2 = _canonicalize(fs2)
assert result1 == result2
def test_nested_frozenset_ordering(self):
"""Nested frozensets should also be deterministically ordered."""
inner1 = frozenset([1, 2, 3])
inner2 = frozenset([3, 2, 1])
fs1 = frozenset([("key", inner1)])
fs2 = frozenset([("key", inner2)])
result1 = _canonicalize(fs1)
result2 = _canonicalize(fs2)
assert result1 == result2
def test_dict_ordering(self):
"""Dicts should be sorted by key."""
d1 = {"z": 1, "a": 2, "m": 3}
d2 = {"a": 2, "m": 3, "z": 1}
result1 = _canonicalize(d1)
result2 = _canonicalize(d2)
assert result1 == result2
def test_tuple_preserved(self):
"""Tuples should be marked and preserved."""
t = (1, 2, 3)
result = _canonicalize(t)
assert result[0] == "__tuple__"
def test_list_preserved(self):
"""Lists should be recursively canonicalized."""
lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])]
result = _canonicalize(lst)
# First element should be canonicalized dict
assert "__dict__" in result[0]
# Second element should be canonicalized frozenset
assert result[1][0] == "__frozenset__"
def test_primitives_include_type(self):
"""Primitive types should include type name for disambiguation."""
assert _canonicalize(42) == ("int", 42)
assert _canonicalize(3.14) == ("float", 3.14)
assert _canonicalize("hello") == ("str", "hello")
assert _canonicalize(True) == ("bool", True)
assert _canonicalize(None) == ("NoneType", None)
def test_int_and_str_distinguished(self):
"""int 7 and str '7' must produce different canonical forms."""
assert _canonicalize(7) != _canonicalize("7")
def test_bytes_converted(self):
"""Bytes should be converted to hex string."""
b = b"\x00\xff"
result = _canonicalize(b)
assert result[0] == "__bytes__"
assert result[1] == "00ff"
def test_set_ordering(self):
"""Sets should be sorted like frozensets."""
s1 = {3, 1, 2}
s2 = {1, 2, 3}
result1 = _canonicalize(s1)
result2 = _canonicalize(s2)
assert result1 == result2
assert result1[0] == "__set__"
def test_unknown_type_raises(self):
"""Unknown types should raise ValueError (fail-closed)."""
class CustomObj:
pass
with pytest.raises(ValueError):
_canonicalize(CustomObj())
def test_object_with_value_attr_raises(self):
"""Objects with .value attribute (Unhashable-like) should raise ValueError."""
class FakeUnhashable:
def __init__(self):
self.value = float('nan')
with pytest.raises(ValueError):
_canonicalize(FakeUnhashable())
class TestSerializeCacheKey:
"""Test _serialize_cache_key for deterministic hashing."""
def test_same_content_same_hash(self):
"""Same content should produce same hash."""
key1 = frozenset([("node_1", frozenset([("input", "value")]))])
key2 = frozenset([("node_1", frozenset([("input", "value")]))])
hash1 = _serialize_cache_key(key1)
hash2 = _serialize_cache_key(key2)
assert hash1 == hash2
def test_different_content_different_hash(self):
"""Different content should produce different hash."""
key1 = frozenset([("node_1", "value_a")])
key2 = frozenset([("node_1", "value_b")])
hash1 = _serialize_cache_key(key1)
hash2 = _serialize_cache_key(key2)
assert hash1 != hash2
def test_returns_hex_string(self):
"""Should return hex string (SHA256 hex digest)."""
key = frozenset([("test", 123)])
result = _serialize_cache_key(key)
assert isinstance(result, str)
assert len(result) == 64 # SHA256 hex digest is 64 chars
def test_complex_nested_structure(self):
"""Complex nested structures should hash deterministically."""
# Note: frozensets can only contain hashable types, so we use
# nested frozensets of tuples to represent dict-like structures
key = frozenset([
("node_1", frozenset([
("input_a", ("tuple", "value")),
("input_b", frozenset([("nested", "dict")])),
])),
("node_2", frozenset([
("param", 42),
])),
])
# Hash twice to verify determinism
hash1 = _serialize_cache_key(key)
hash2 = _serialize_cache_key(key)
assert hash1 == hash2
def test_dict_in_cache_key(self):
"""Dicts passed directly to _serialize_cache_key should work."""
key = {"node_1": {"input": "value"}, "node_2": 42}
hash1 = _serialize_cache_key(key)
hash2 = _serialize_cache_key(key)
assert hash1 == hash2
assert isinstance(hash1, str)
assert len(hash1) == 64
def test_unknown_type_returns_none(self):
"""Non-cacheable types should return None (fail-closed)."""
class CustomObj:
pass
assert _serialize_cache_key(CustomObj()) is None
class TestContainsSelfUnequal:
"""Test _contains_self_unequal utility function."""
def test_nan_float_detected(self):
"""NaN floats should be detected (not equal to itself)."""
assert _contains_self_unequal(float('nan')) is True
def test_regular_float_not_detected(self):
"""Regular floats are equal to themselves."""
assert _contains_self_unequal(3.14) is False
assert _contains_self_unequal(0.0) is False
assert _contains_self_unequal(-1.5) is False
def test_infinity_not_detected(self):
"""Infinity is equal to itself."""
assert _contains_self_unequal(float('inf')) is False
assert _contains_self_unequal(float('-inf')) is False
def test_nan_in_list(self):
"""NaN in list should be detected."""
assert _contains_self_unequal([1, 2, float('nan'), 4]) is True
assert _contains_self_unequal([1, 2, 3, 4]) is False
def test_nan_in_tuple(self):
"""NaN in tuple should be detected."""
assert _contains_self_unequal((1, float('nan'))) is True
assert _contains_self_unequal((1, 2, 3)) is False
def test_nan_in_frozenset(self):
"""NaN in frozenset should be detected."""
assert _contains_self_unequal(frozenset([1, float('nan')])) is True
assert _contains_self_unequal(frozenset([1, 2, 3])) is False
def test_nan_in_dict_value(self):
"""NaN in dict value should be detected."""
assert _contains_self_unequal({"key": float('nan')}) is True
assert _contains_self_unequal({"key": 42}) is False
def test_nan_in_nested_structure(self):
"""NaN in deeply nested structure should be detected."""
nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
assert _contains_self_unequal(nested) is True
def test_non_numeric_types(self):
"""Non-numeric types should not be self-unequal."""
assert _contains_self_unequal("string") is False
assert _contains_self_unequal(None) is False
assert _contains_self_unequal(True) is False
def test_object_with_nan_value_attribute(self):
"""Objects wrapping NaN in .value should be detected."""
class NanWrapper:
def __init__(self):
self.value = float('nan')
assert _contains_self_unequal(NanWrapper()) is True
def test_custom_self_unequal_object(self):
"""Custom objects where not (x == x) should be detected."""
class NeverEqual:
def __eq__(self, other):
return False
assert _contains_self_unequal(NeverEqual()) is True
class TestEstimateValueSize:
"""Test _estimate_value_size utility function."""
def test_empty_outputs(self):
"""Empty outputs should have zero size."""
value = CacheValue(outputs=[])
assert _estimate_value_size(value) == 0
@pytest.mark.skipif(
not _torch_available(),
reason="PyTorch not available"
)
def test_tensor_size_estimation(self):
"""Tensor size should be estimated correctly."""
import torch
# 1000 float32 elements = 4000 bytes
tensor = torch.zeros(1000, dtype=torch.float32)
value = CacheValue(outputs=[[tensor]])
size = _estimate_value_size(value)
assert size == 4000
@pytest.mark.skipif(
not _torch_available(),
reason="PyTorch not available"
)
def test_nested_tensor_in_dict(self):
"""Tensors nested in dicts should be counted."""
import torch
tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes
value = CacheValue(outputs=[[{"samples": tensor}]])
size = _estimate_value_size(value)
assert size == 400
class TestProviderRegistry:
"""Test cache provider registration and retrieval."""
def setup_method(self):
"""Clear providers before each test."""
_clear_cache_providers()
def teardown_method(self):
"""Clear providers after each test."""
_clear_cache_providers()
def test_register_provider(self):
"""Provider should be registered successfully."""
provider = MockCacheProvider()
register_cache_provider(provider)
assert _has_cache_providers() is True
providers = _get_cache_providers()
assert len(providers) == 1
assert providers[0] is provider
def test_unregister_provider(self):
"""Provider should be unregistered successfully."""
provider = MockCacheProvider()
register_cache_provider(provider)
unregister_cache_provider(provider)
assert _has_cache_providers() is False
def test_multiple_providers(self):
"""Multiple providers can be registered."""
provider1 = MockCacheProvider()
provider2 = MockCacheProvider()
register_cache_provider(provider1)
register_cache_provider(provider2)
providers = _get_cache_providers()
assert len(providers) == 2
def test_duplicate_registration_ignored(self):
"""Registering same provider twice should be ignored."""
provider = MockCacheProvider()
register_cache_provider(provider)
register_cache_provider(provider) # Should be ignored
providers = _get_cache_providers()
assert len(providers) == 1
def test_clear_providers(self):
"""_clear_cache_providers should remove all providers."""
provider1 = MockCacheProvider()
provider2 = MockCacheProvider()
register_cache_provider(provider1)
register_cache_provider(provider2)
_clear_cache_providers()
assert _has_cache_providers() is False
assert len(_get_cache_providers()) == 0
class TestCacheContext:
"""Test CacheContext dataclass."""
def test_context_creation(self):
"""CacheContext should be created with all fields."""
context = CacheContext(
node_id="node-456",
class_type="KSampler",
cache_key_hash="a" * 64,
)
assert context.node_id == "node-456"
assert context.class_type == "KSampler"
assert context.cache_key_hash == "a" * 64
class TestCacheValue:
"""Test CacheValue dataclass."""
def test_value_creation(self):
"""CacheValue should be created with outputs."""
outputs = [[{"samples": "tensor_data"}]]
value = CacheValue(outputs=outputs)
assert value.outputs == outputs
class MockCacheProvider(CacheProvider):
"""Mock cache provider for testing."""
def __init__(self):
self.lookups = []
self.stores = []
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
self.lookups.append(context)
return None
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
self.stores.append((context, value))

View File

@ -28,31 +28,31 @@ CACHE_SCENARIOS = [
},
# JavaScript/CSS scenarios
{
"name": "js_no_store",
"name": "js_no_cache",
"path": "/script.js",
"status": 200,
"expected_cache": "no-store",
"expected_cache": "no-cache",
"should_have_header": True,
},
{
"name": "css_no_store",
"name": "css_no_cache",
"path": "/styles.css",
"status": 200,
"expected_cache": "no-store",
"expected_cache": "no-cache",
"should_have_header": True,
},
{
"name": "index_json_no_store",
"name": "index_json_no_cache",
"path": "/api/index.json",
"status": 200,
"expected_cache": "no-store",
"expected_cache": "no-cache",
"should_have_header": True,
},
{
"name": "localized_index_json_no_store",
"name": "localized_index_json_no_cache",
"path": "/templates/index.zh.json",
"status": 200,
"expected_cache": "no-store",
"expected_cache": "no-cache",
"should_have_header": True,
},
# Non-matching files