Compare commits

..

1 Commits

Author SHA1 Message Date
f1c07a72c4 convert model_merging and video_model nodes to V3 schema 2026-03-11 12:25:59 +02:00
9 changed files with 738 additions and 994 deletions

View File

@ -270,15 +270,10 @@ try:
except:
OOM_EXCEPTION = Exception
try:
ACCELERATOR_ERROR = torch.AcceleratorError
except AttributeError:
ACCELERATOR_ERROR = RuntimeError
def is_oom(e):
if isinstance(e, OOM_EXCEPTION):
return True
if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2:
discard_cuda_async_error()
return True
return False
@ -1280,7 +1275,7 @@ def discard_cuda_async_error():
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
synchronize()
except RuntimeError:
except torch.AcceleratorError:
#Dump it! We already know about it from the synchronous return
pass

View File

@ -1,68 +0,0 @@
from pydantic import BaseModel, Field
class RevePostprocessingOperation(BaseModel):
process: str = Field(..., description="The postprocessing operation: upscale or remove_background.")
upscale_factor: int | None = Field(
None,
description="Upscale factor (2, 3, or 4). Only used when process is upscale.",
ge=2,
le=4,
)
class ReveImageCreateRequest(BaseModel):
prompt: str = Field(...)
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageEditRequest(BaseModel):
edit_instruction: str = Field(...)
reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageRemixRequest(BaseModel):
prompt: str = Field(...)
reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageResponse(BaseModel):
image: str | None = Field(None, description="The base64 encoded image data.")
request_id: str | None = Field(None, description="A unique id for the request.")
credits_used: float | None = Field(None, description="The number of credits used for this request.")
version: str | None = Field(None, description="The specific model version used.")
content_violation: bool | None = Field(
None, description="Indicates whether the generated image violates the content policy."
)

View File

@ -1,395 +0,0 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.reve import (
ReveImageCreateRequest,
ReveImageEditRequest,
ReveImageRemixRequest,
RevePostprocessingOperation,
)
from comfy_api_nodes.util import (
ApiEndpoint,
bytesio_to_image_tensor,
sync_op_raw,
tensor_to_base64_string,
validate_string,
)
def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None:
ops = []
if upscale["upscale"] == "enabled":
ops.append(
RevePostprocessingOperation(
process="upscale",
upscale_factor=upscale["upscale_factor"],
)
)
if remove_background:
ops.append(RevePostprocessingOperation(process="remove_background"))
return ops or None
def _postprocessing_inputs():
return [
IO.DynamicCombo.Input(
"upscale",
options=[
IO.DynamicCombo.Option("disabled", []),
IO.DynamicCombo.Option(
"enabled",
[
IO.Int.Input(
"upscale_factor",
default=2,
min=2,
max=4,
step=1,
tooltip="Upscale factor (2x, 3x, or 4x).",
),
],
),
],
tooltip="Upscale the generated image. May add additional cost.",
),
IO.Boolean.Input(
"remove_background",
default=False,
tooltip="Remove the background from the generated image. May add additional cost.",
),
]
def _reve_price_extractor(headers: dict) -> float | None:
credits_used = headers.get("x-reve-credits-used")
if credits_used is not None:
return float(credits_used) / 524.48
return None
def _reve_response_header_validator(headers: dict) -> None:
error_code = headers.get("x-reve-error-code")
if error_code:
raise ValueError(f"Reve API error: {error_code}")
if headers.get("x-reve-content-violation", "").lower() == "true":
raise ValueError("The generated image was flagged for content policy violation.")
def _model_inputs(versions: list[str], aspect_ratios: list[str]):
return [
IO.DynamicCombo.Option(
version,
[
IO.Combo.Input(
"aspect_ratio",
options=aspect_ratios,
tooltip="Aspect ratio of the output image.",
),
IO.Int.Input(
"test_time_scaling",
default=1,
min=1,
max=5,
step=1,
tooltip="Higher values produce better images but cost more credits.",
advanced=True,
),
],
)
for version in versions
]
class ReveImageCreateNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageCreateNode",
display_name="Reve Image Create",
category="api node/image/Reve",
description="Generate images from text descriptions using Reve.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. Maximum 2560 characters.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-create@20250915"],
aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for generation.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2560)
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/create",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageCreateRequest(
prompt=prompt,
aspect_ratio=model["aspect_ratio"],
version=model["model"],
test_time_scaling=model["test_time_scaling"],
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveImageEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageEditNode",
display_name="Reve Image Edit",
category="api node/image/Reve",
description="Edit images using natural language instructions with Reve.",
inputs=[
IO.Image.Input("image", tooltip="The image to edit."),
IO.String.Input(
"edit_instruction",
multiline=True,
default="",
tooltip="Text description of how to edit the image. Maximum 2560 characters.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-edit@20250915", "reve-edit-fast@20251030"],
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for editing.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
),
expr="""
(
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
edit_instruction: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(edit_instruction, min_length=1, max_length=2560)
tts = model["test_time_scaling"]
ar = model["aspect_ratio"]
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/edit",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageEditRequest(
edit_instruction=edit_instruction,
reference_image=tensor_to_base64_string(image),
aspect_ratio=ar if ar != "auto" else None,
version=model["model"],
test_time_scaling=tts if tts and tts > 1 else None,
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveImageRemixNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageRemixNode",
display_name="Reve Image Remix",
category="api node/image/Reve",
description="Combine reference images with text prompts to create new images using Reve.",
inputs=[
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="image_",
min=1,
max=6,
),
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. "
"May include XML img tags to reference specific images by index, "
"e.g. <img>0</img>, <img>1</img>, etc.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-remix@20250915", "reve-remix-fast@20251030"],
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for remixing.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
),
expr="""
(
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
)
@classmethod
async def execute(
cls,
reference_images: IO.Autogrow.Type,
prompt: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2560)
if not reference_images:
raise ValueError("At least one reference image is required.")
ref_base64_list = []
for key in reference_images:
ref_base64_list.append(tensor_to_base64_string(reference_images[key]))
if len(ref_base64_list) > 6:
raise ValueError("Maximum 6 reference images are allowed.")
tts = model["test_time_scaling"]
ar = model["aspect_ratio"]
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/remix",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageRemixRequest(
prompt=prompt,
reference_images=ref_base64_list,
aspect_ratio=ar if ar != "auto" else None,
version=model["model"],
test_time_scaling=tts if tts and tts > 1 else None,
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ReveImageCreateNode,
ReveImageEditNode,
ReveImageRemixNode,
]
async def comfy_entrypoint() -> ReveExtension:
return ReveExtension()

View File

@ -67,7 +67,6 @@ class _RequestConfig:
progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
is_rate_limited: Callable[[int, Any], bool] | None = None
response_header_validator: Callable[[dict[str, str]], None] | None = None
@dataclass
@ -203,13 +202,11 @@ async def sync_op_raw(
monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None,
response_header_validator: Callable[[dict[str, str]], None] | None = None,
) -> dict[str, Any] | bytes:
"""
Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
- If as_binary=True: returns bytes.
- response_header_validator: optional callback receiving response headers dict
"""
if isinstance(data, BaseModel):
data = data.model_dump(exclude_none=True)
@ -235,7 +232,6 @@ async def sync_op_raw(
price_extractor=price_extractor,
max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited,
response_header_validator=response_header_validator,
)
return await _request_base(cfg, expect_binary=as_binary)
@ -773,12 +769,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
)
bytes_payload = bytes(buff)
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
if cfg.price_extractor:
with contextlib.suppress(Exception):
extracted_price = cfg.price_extractor(resp_headers)
if cfg.response_header_validator:
cfg.response_header_validator(resp_headers)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response(
@ -786,7 +776,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=resp_headers,
response_headers=dict(resp.headers),
response_content=bytes_payload,
)
return bytes_payload

View File

@ -1,32 +1,32 @@
from comfy import model_management
from comfy_api.latest import ComfyExtension, IO
from typing_extensions import override
import math
class LTXVLatentUpsampler(IO.ComfyNode):
class LTXVLatentUpsampler:
"""
Upsamples a video latent by a factor of 2.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="LTXVLatentUpsampler",
category="latent/video",
is_experimental=True,
inputs=[
IO.Latent.Input("samples"),
IO.LatentUpscaleModel.Input("upscale_model"),
IO.Vae.Input("vae"),
],
outputs=[
IO.Latent.Output(),
],
)
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT",),
"upscale_model": ("LATENT_UPSCALE_MODEL",),
"vae": ("VAE",),
}
}
@classmethod
def execute(cls, samples, upscale_model, vae) -> IO.NodeOutput:
RETURN_TYPES = ("LATENT",)
FUNCTION = "upsample_latent"
CATEGORY = "latent/video"
EXPERIMENTAL = True
def upsample_latent(
self,
samples: dict,
upscale_model,
vae,
) -> tuple:
"""
Upsample the input latent using the provided model.
@ -34,6 +34,7 @@ class LTXVLatentUpsampler(IO.ComfyNode):
samples (dict): Input latent samples
upscale_model (LatentUpsampler): Loaded upscale model
vae: VAE model for normalization
auto_tiling (bool): Whether to automatically tile the input for processing
Returns:
tuple: Tuple containing the upsampled latent
@ -66,16 +67,9 @@ class LTXVLatentUpsampler(IO.ComfyNode):
return_dict = samples.copy()
return_dict["samples"] = upsampled_latents
return_dict.pop("noise_mask", None)
return IO.NodeOutput(return_dict)
upsample_latent = execute # TODO: remove
return (return_dict,)
class LTXVLatentUpsamplerExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [LTXVLatentUpsampler]
async def comfy_entrypoint() -> LTXVLatentUpsamplerExtension:
return LTXVLatentUpsamplerExtension()
NODE_CLASS_MAPPINGS = {
"LTXVLatentUpsampler": LTXVLatentUpsampler,
}

View File

@ -10,146 +10,198 @@ import json
import os
from comfy.cli_args import args
from comfy_api.latest import io, ComfyExtension
from typing_extensions import override
class ModelMergeSimple:
class ModelMergeSimple(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="ModelMergeSimple",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2, ratio):
@classmethod
def execute(cls, model1, model2, ratio) -> io.NodeOutput:
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
return io.NodeOutput(m)
class ModelSubtract:
merge = execute # TODO: remove
class ModelSubtract(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="ModelMergeSubtract",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2, multiplier):
@classmethod
def execute(cls, model1, model2, multiplier) -> io.NodeOutput:
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, - multiplier, multiplier)
return (m, )
return io.NodeOutput(m)
class ModelAdd:
merge = execute # TODO: remove
class ModelAdd(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="ModelMergeAdd",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2):
@classmethod
def execute(cls, model1, model2) -> io.NodeOutput:
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, 1.0, 1.0)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
class CLIPMergeSimple:
class CLIPMergeSimple(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="CLIPMergeSimple",
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip1"),
io.Clip.Input("clip2"),
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2, ratio):
@classmethod
def execute(cls, clip1, clip2, ratio) -> io.NodeOutput:
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
class CLIPSubtract:
SEARCH_ALIASES = ["clip difference", "text encoder subtract"]
class CLIPSubtract(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="CLIPMergeSubtract",
search_aliases=["clip difference", "text encoder subtract"],
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip1"),
io.Clip.Input("clip2"),
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2, multiplier):
@classmethod
def execute(cls, clip1, clip2, multiplier) -> io.NodeOutput:
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, - multiplier, multiplier)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
class CLIPAdd:
SEARCH_ALIASES = ["combine clip"]
class CLIPAdd(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="CLIPMergeAdd",
search_aliases=["combine clip"],
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip1"),
io.Clip.Input("clip2"),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2):
@classmethod
def execute(cls, clip1, clip2) -> io.NodeOutput:
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, 1.0, 1.0)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
class ModelMergeBlocks:
class ModelMergeBlocks(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="ModelMergeBlocks",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("input", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("middle", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("out", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2, **kwargs):
@classmethod
def execute(cls, model1, model2, **kwargs) -> io.NodeOutput:
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
default_ratio = next(iter(kwargs.values()))
@ -165,7 +217,10 @@ class ModelMergeBlocks:
last_arg_size = len(arg)
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
@ -226,59 +281,65 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
class CheckpointSave:
SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"]
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
class CheckpointSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CheckpointSave",
display_name="Save Checkpoint",
search_aliases=["save model", "export checkpoint", "merge save"],
category="advanced/model_merging",
inputs=[
io.Model.Input("model"),
io.Clip.Input("clip"),
io.Vae.Input("vae"),
io.String.Input("filename_prefix", default="checkpoints/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
def execute(cls, model, clip, vae, filename_prefix) -> io.NodeOutput:
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
return io.NodeOutput()
CATEGORY = "advanced/model_merging"
save = execute # TODO: remove
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
class CLIPSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
class CLIPSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPSave",
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip"),
io.String.Input("filename_prefix", default="clip/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip": ("CLIP",),
"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
def execute(cls, clip, filename_prefix) -> io.NodeOutput:
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
if cls.hidden.prompt is not None:
prompt_info = json.dumps(cls.hidden.prompt)
metadata = {}
if not args.disable_metadata:
metadata["format"] = "pt"
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
clip_sd = clip.get_sd()
output_dir = folder_paths.get_output_directory()
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
current_clip_sd = {}
@ -295,7 +356,7 @@ class CLIPSave:
replace_prefix[prefix] = ""
replace_prefix["transformer."] = ""
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir)
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, output_dir)
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
@ -303,76 +364,88 @@ class CLIPSave:
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
return {}
return io.NodeOutput()
class VAESave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
save = execute # TODO: remove
class VAESave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VAESave",
category="advanced/model_merging",
inputs=[
io.Vae.Input("vae"),
io.String.Input("filename_prefix", default="vae/ComfyUI_vae"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
def execute(cls, vae, filename_prefix) -> io.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
if cls.hidden.prompt is not None:
prompt_info = json.dumps(cls.hidden.prompt)
metadata = {}
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
return {}
return io.NodeOutput()
class ModelSave:
SEARCH_ALIASES = ["export model", "checkpoint save"]
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
save = execute # TODO: remove
class ModelSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ModelSave",
search_aliases=["export model", "checkpoint save"],
category="advanced/model_merging",
inputs=[
io.Model.Input("model"),
io.String.Input("filename_prefix", default="diffusion_models/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
def execute(cls, model, filename_prefix) -> io.NodeOutput:
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
return io.NodeOutput()
CATEGORY = "advanced/model_merging"
save = execute # TODO: remove
def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple,
"ModelMergeBlocks": ModelMergeBlocks,
"ModelMergeSubtract": ModelSubtract,
"ModelMergeAdd": ModelAdd,
"CheckpointSave": CheckpointSave,
"CLIPMergeSimple": CLIPMergeSimple,
"CLIPMergeSubtract": CLIPSubtract,
"CLIPMergeAdd": CLIPAdd,
"CLIPSave": CLIPSave,
"VAESave": VAESave,
"ModelSave": ModelSave,
}
class ModelMergingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ModelMergeSimple,
ModelMergeBlocks,
ModelSubtract,
ModelAdd,
CheckpointSave,
CLIPMergeSimple,
CLIPSubtract,
CLIPAdd,
CLIPSave,
VAESave,
ModelSave,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointSave": "Save Checkpoint",
}
async def comfy_entrypoint() -> ModelMergingExtension:
return ModelMergingExtension()

View File

@ -1,356 +1,455 @@
import comfy_extras.nodes_model_merging
from comfy_api.latest import io, ComfyExtension
from typing_extensions import override
class ModelMergeSD1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument
inputs.append(io.Float.Input("time_embed.", **argument))
inputs.append(io.Float.Input("label_emb.", **argument))
for i in range(12):
arg_dict["input_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("input_blocks.{}.".format(i), **argument))
for i in range(3):
arg_dict["middle_block.{}.".format(i)] = argument
inputs.append(io.Float.Input("middle_block.{}.".format(i), **argument))
for i in range(12):
arg_dict["output_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("output_blocks.{}.".format(i), **argument))
arg_dict["out."] = argument
inputs.append(io.Float.Input("out.", **argument))
return {"required": arg_dict}
return io.Schema(
node_id="ModelMergeSD1",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
class ModelMergeSD2(ModelMergeSD1):
# SD1 and SD2 have the same blocks
@classmethod
def define_schema(cls):
schema = ModelMergeSD1.define_schema()
schema.node_id = "ModelMergeSD2"
return schema
class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument
inputs.append(io.Float.Input("time_embed.", **argument))
inputs.append(io.Float.Input("label_emb.", **argument))
for i in range(9):
arg_dict["input_blocks.{}".format(i)] = argument
inputs.append(io.Float.Input("input_blocks.{}".format(i), **argument))
for i in range(3):
arg_dict["middle_block.{}".format(i)] = argument
inputs.append(io.Float.Input("middle_block.{}".format(i), **argument))
for i in range(9):
arg_dict["output_blocks.{}".format(i)] = argument
inputs.append(io.Float.Input("output_blocks.{}".format(i), **argument))
arg_dict["out."] = argument
inputs.append(io.Float.Input("out.", **argument))
return io.Schema(
node_id="ModelMergeSDXL",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["pos_embed."] = argument
arg_dict["x_embedder."] = argument
arg_dict["context_embedder."] = argument
arg_dict["y_embedder."] = argument
arg_dict["t_embedder."] = argument
inputs.append(io.Float.Input("pos_embed.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("context_embedder.", **argument))
inputs.append(io.Float.Input("y_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
for i in range(24):
arg_dict["joint_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("joint_blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return {"required": arg_dict}
return io.Schema(
node_id="ModelMergeSD3_2B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
class ModelMergeAuraflow(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["init_x_linear."] = argument
arg_dict["positional_encoding"] = argument
arg_dict["cond_seq_linear."] = argument
arg_dict["register_tokens"] = argument
arg_dict["t_embedder."] = argument
inputs.append(io.Float.Input("init_x_linear.", **argument))
inputs.append(io.Float.Input("positional_encoding", **argument))
inputs.append(io.Float.Input("cond_seq_linear.", **argument))
inputs.append(io.Float.Input("register_tokens", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
for i in range(4):
arg_dict["double_layers.{}.".format(i)] = argument
inputs.append(io.Float.Input("double_layers.{}.".format(i), **argument))
for i in range(32):
arg_dict["single_layers.{}.".format(i)] = argument
inputs.append(io.Float.Input("single_layers.{}.".format(i), **argument))
arg_dict["modF."] = argument
arg_dict["final_linear."] = argument
inputs.append(io.Float.Input("modF.", **argument))
inputs.append(io.Float.Input("final_linear.", **argument))
return io.Schema(
node_id="ModelMergeAuraflow",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["img_in."] = argument
arg_dict["time_in."] = argument
arg_dict["guidance_in"] = argument
arg_dict["vector_in."] = argument
arg_dict["txt_in."] = argument
inputs.append(io.Float.Input("img_in.", **argument))
inputs.append(io.Float.Input("time_in.", **argument))
inputs.append(io.Float.Input("guidance_in", **argument))
inputs.append(io.Float.Input("vector_in.", **argument))
inputs.append(io.Float.Input("txt_in.", **argument))
for i in range(19):
arg_dict["double_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("double_blocks.{}.".format(i), **argument))
for i in range(38):
arg_dict["single_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("single_blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeFlux1",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeSD35_Large(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["pos_embed."] = argument
arg_dict["x_embedder."] = argument
arg_dict["context_embedder."] = argument
arg_dict["y_embedder."] = argument
arg_dict["t_embedder."] = argument
inputs.append(io.Float.Input("pos_embed.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("context_embedder.", **argument))
inputs.append(io.Float.Input("y_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
for i in range(38):
arg_dict["joint_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("joint_blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeSD35_Large",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeMochiPreview(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["pos_frequencies."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t5_y_embedder."] = argument
arg_dict["t5_yproj."] = argument
inputs.append(io.Float.Input("pos_frequencies.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("t5_y_embedder.", **argument))
inputs.append(io.Float.Input("t5_yproj.", **argument))
for i in range(48):
arg_dict["blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeMochiPreview",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["patchify_proj."] = argument
arg_dict["adaln_single."] = argument
arg_dict["caption_projection."] = argument
inputs.append(io.Float.Input("patchify_proj.", **argument))
inputs.append(io.Float.Input("adaln_single.", **argument))
inputs.append(io.Float.Input("caption_projection.", **argument))
for i in range(28):
arg_dict["transformer_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("transformer_blocks.{}.".format(i), **argument))
arg_dict["scale_shift_table"] = argument
arg_dict["proj_out."] = argument
inputs.append(io.Float.Input("scale_shift_table", **argument))
inputs.append(io.Float.Input("proj_out.", **argument))
return io.Schema(
node_id="ModelMergeLTXV",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeCosmos7B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["extra_pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["affline_norm."] = argument
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
inputs.append(io.Float.Input("pos_embedder.", **argument))
inputs.append(io.Float.Input("extra_pos_embedder.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("affline_norm.", **argument))
for i in range(28):
arg_dict["blocks.block{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.block{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeCosmos7B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["extra_pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["affline_norm."] = argument
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
inputs.append(io.Float.Input("pos_embedder.", **argument))
inputs.append(io.Float.Input("extra_pos_embedder.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("affline_norm.", **argument))
for i in range(36):
arg_dict["blocks.block{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.block{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeCosmos14B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb."
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["patch_embedding."] = argument
arg_dict["time_embedding."] = argument
arg_dict["time_projection."] = argument
arg_dict["text_embedding."] = argument
arg_dict["img_emb."] = argument
inputs.append(io.Float.Input("patch_embedding.", **argument))
inputs.append(io.Float.Input("time_embedding.", **argument))
inputs.append(io.Float.Input("time_projection.", **argument))
inputs.append(io.Float.Input("text_embedding.", **argument))
inputs.append(io.Float.Input("img_emb.", **argument))
for i in range(40):
arg_dict["blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
arg_dict["head."] = argument
inputs.append(io.Float.Input("head.", **argument))
return io.Schema(
node_id="ModelMergeWAN2_1",
category="advanced/model_merging/model_specific",
description="1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb.",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeCosmosPredict2_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t_embedding_norm."] = argument
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
inputs.append(io.Float.Input("pos_embedder.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("t_embedding_norm.", **argument))
for i in range(28):
arg_dict["blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeCosmosPredict2_2B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t_embedding_norm."] = argument
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
inputs.append(io.Float.Input("pos_embedder.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("t_embedding_norm.", **argument))
for i in range(36):
arg_dict["blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeCosmosPredict2_14B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["pos_embeds."] = argument
arg_dict["img_in."] = argument
arg_dict["txt_norm."] = argument
arg_dict["txt_in."] = argument
arg_dict["time_text_embed."] = argument
inputs.append(io.Float.Input("pos_embeds.", **argument))
inputs.append(io.Float.Input("img_in.", **argument))
inputs.append(io.Float.Input("txt_norm.", **argument))
inputs.append(io.Float.Input("txt_in.", **argument))
inputs.append(io.Float.Input("time_text_embed.", **argument))
for i in range(60):
arg_dict["transformer_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("transformer_blocks.{}.".format(i), **argument))
arg_dict["proj_out."] = argument
inputs.append(io.Float.Input("proj_out.", **argument))
return {"required": arg_dict}
return io.Schema(
node_id="ModelMergeQwenImage",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
"ModelMergeSDXL": ModelMergeSDXL,
"ModelMergeSD3_2B": ModelMergeSD3_2B,
"ModelMergeAuraflow": ModelMergeAuraflow,
"ModelMergeFlux1": ModelMergeFlux1,
"ModelMergeSD35_Large": ModelMergeSD35_Large,
"ModelMergeMochiPreview": ModelMergeMochiPreview,
"ModelMergeLTXV": ModelMergeLTXV,
"ModelMergeCosmos7B": ModelMergeCosmos7B,
"ModelMergeCosmos14B": ModelMergeCosmos14B,
"ModelMergeWAN2_1": ModelMergeWAN2_1,
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
"ModelMergeQwenImage": ModelMergeQwenImage,
}
class ModelMergingModelSpecificExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ModelMergeSD1,
ModelMergeSD2,
ModelMergeSDXL,
ModelMergeSD3_2B,
ModelMergeAuraflow,
ModelMergeFlux1,
ModelMergeSD35_Large,
ModelMergeMochiPreview,
ModelMergeLTXV,
ModelMergeCosmos7B,
ModelMergeCosmos14B,
ModelMergeWAN2_1,
ModelMergeCosmosPredict2_2B,
ModelMergeCosmosPredict2_14B,
ModelMergeQwenImage,
]
async def comfy_entrypoint() -> ModelMergingModelSpecificExtension:
return ModelMergingModelSpecificExtension()

View File

@ -6,44 +6,62 @@ import folder_paths
import comfy_extras.nodes_model_merging
import node_helpers
from comfy_api.latest import io, ComfyExtension
from typing_extensions import override
class ImageOnlyCheckpointLoader:
class ImageOnlyCheckpointLoader(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint"
def define_schema(cls):
return io.Schema(
node_id="ImageOnlyCheckpointLoader",
display_name="Image Only Checkpoint Loader (img2vid model)",
category="loaders/video_models",
inputs=[
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("checkpoints")),
],
outputs=[
io.Model.Output(),
io.ClipVision.Output(),
io.Vae.Output(),
],
)
CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
@classmethod
def execute(cls, ckpt_name, output_vae=True, output_clip=True) -> io.NodeOutput:
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2])
return io.NodeOutput(out[0], out[3], out[2])
load_checkpoint = execute # TODO: remove
class SVD_img2vid_Conditioning:
class SVD_img2vid_Conditioning(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023, "advanced": True}),
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01, "advanced": True})
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
def define_schema(cls):
return io.Schema(
node_id="SVD_img2vid_Conditioning",
category="conditioning/video_models",
inputs=[
io.ClipVision.Input("clip_vision"),
io.Image.Input("init_image"),
io.Vae.Input("vae"),
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("video_frames", default=14, min=1, max=4096),
io.Int.Input("motion_bucket_id", default=127, min=1, max=1023, advanced=True),
io.Int.Input("fps", default=6, min=1, max=1024),
io.Float.Input("augmentation_level", default=0.0, min=0.0, max=10.0, step=0.01, advanced=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
@classmethod
def execute(cls, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level) -> io.NodeOutput:
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
@ -54,20 +72,28 @@ class SVD_img2vid_Conditioning:
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
return io.NodeOutput(positive, negative, {"samples":latent})
class VideoLinearCFGGuidance:
encode = execute # TODO: remove
class VideoLinearCFGGuidance(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01, "advanced": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="VideoLinearCFGGuidance",
category="sampling/video_models",
inputs=[
io.Model.Input("model"),
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01, advanced=True),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
@classmethod
def execute(cls, model, min_cfg) -> io.NodeOutput:
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
@ -78,20 +104,28 @@ class VideoLinearCFGGuidance:
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
return io.NodeOutput(m)
class VideoTriangleCFGGuidance:
patch = execute # TODO: remove
class VideoTriangleCFGGuidance(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01, "advanced": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="VideoTriangleCFGGuidance",
category="sampling/video_models",
inputs=[
io.Model.Input("model"),
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01, advanced=True),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
@classmethod
def execute(cls, model, min_cfg) -> io.NodeOutput:
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
@ -105,57 +139,79 @@ class VideoTriangleCFGGuidance:
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
return io.NodeOutput(m)
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
CATEGORY = "advanced/model_merging"
patch = execute # TODO: remove
class ImageOnlyCheckpointSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageOnlyCheckpointSave",
search_aliases=["save model", "export checkpoint", "merge save"],
category="advanced/model_merging",
inputs=[
io.Model.Input("model"),
io.ClipVision.Input("clip_vision"),
io.Vae.Input("vae"),
io.String.Input("filename_prefix", default="checkpoints/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip_vision": ("CLIP_VISION",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
def execute(cls, model, clip_vision, vae, filename_prefix) -> io.NodeOutput:
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
return io.NodeOutput()
def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
save = execute # TODO: remove
class ConditioningSetAreaPercentageVideo:
class ConditioningSetAreaPercentageVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"temporal": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"z": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
def define_schema(cls):
return io.Schema(
node_id="ConditioningSetAreaPercentageVideo",
category="conditioning",
inputs=[
io.Conditioning.Input("conditioning"),
io.Float.Input("width", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("height", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("temporal", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("x", default=0.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("y", default=0.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("z", default=0.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "conditioning"
def append(self, conditioning, width, height, temporal, x, y, z, strength):
@classmethod
def execute(cls, conditioning, width, height, temporal, x, y, z, strength) -> io.NodeOutput:
c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", temporal, height, width, z, y, x),
"strength": strength,
"set_area_to_bounds": False})
return (c, )
return io.NodeOutput(c)
append = execute # TODO: remove
NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
"ConditioningSetAreaPercentageVideo": ConditioningSetAreaPercentageVideo,
}
class VideoModelExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ImageOnlyCheckpointLoader,
SVD_img2vid_Conditioning,
VideoLinearCFGGuidance,
VideoTriangleCFGGuidance,
ImageOnlyCheckpointSave,
ConditioningSetAreaPercentageVideo,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
}
async def comfy_entrypoint() -> VideoModelExtension:
return VideoModelExtension()

View File

@ -23,7 +23,7 @@ SQLAlchemy
filelock
av>=14.2.0
comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.10
comfy-aimdo>=0.2.9
requests
simpleeval>=1.0.0
blake3