Compare commits

..

3 Commits

Author SHA1 Message Date
e6e75152e0 Merge branch 'master' into temp_pr 2026-05-21 13:38:11 +08:00
e715be9105 Apply suggestions from code review
Co-authored-by: Alexis Rolland <alexis@comfy.org>
2026-05-20 23:57:15 -04:00
d48a8d417b Save Image advanced node. 2026-05-20 23:57:15 -04:00
9 changed files with 497 additions and 179 deletions

View File

@ -543,7 +543,7 @@ class AudioConcat(IO.ComfyNode):
return IO.Schema(
node_id="AudioConcat",
search_aliases=["join audio", "combine audio", "append audio"],
display_name="Concatenate Audio",
display_name="Audio Concat",
description="Concatenates the audio1 to audio2 in the specified direction.",
category="audio",
inputs=[
@ -597,7 +597,7 @@ class AudioMerge(IO.ComfyNode):
return IO.Schema(
node_id="AudioMerge",
search_aliases=["mix audio", "overlay audio", "layer audio"],
display_name="Merge Audio",
display_name="Audio Merge",
description="Combine two audio tracks by overlaying their waveforms.",
category="audio",
inputs=[
@ -667,9 +667,8 @@ class AudioAdjustVolume(IO.ComfyNode):
return IO.Schema(
node_id="AudioAdjustVolume",
search_aliases=["audio gain", "loudness", "audio level"],
display_name="Adjust Audio Volume",
display_name="Audio Adjust Volume",
category="audio",
description="Adjust the volume of the audio by a specified amount in decibels (dB).",
inputs=[
IO.Audio.Input("audio"),
IO.Int.Input(

View File

@ -47,10 +47,8 @@ class LoadImageDataSetFromFolderNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LoadImageDataSetFromFolder",
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
display_name="Load Image (from Folder)",
category="image",
description="Load a dataset of images from a specified folder and return a list of images. Supported formats: PNG, JPG, JPEG, WEBP.",
display_name="Load Image Dataset from Folder",
category="dataset",
is_experimental=True,
inputs=[
io.Combo.Input(
@ -86,16 +84,14 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LoadImageTextDataSetFromFolder",
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
display_name="Load Image-Text (from Folder)",
category="image",
description="Load a dataset of pairs of images and text captions from a specified folder and return them as a list. Supported formats: PNG, JPG, JPEG, WEBP.",
display_name="Load Image and Text Dataset from Folder",
category="dataset",
is_experimental=True,
inputs=[
io.Combo.Input(
"folder",
options=folder_paths.get_input_subfolders(),
tooltip="The folder to load images and text captions from.",
tooltip="The folder to load images from.",
)
],
outputs=[
@ -210,10 +206,8 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveImageDataSetToFolder",
search_aliases=["save folder", "save to folder", "save dataset", "save images", "export dataset"],
display_name="Save Image (to Folder) (DEPRECATED)",
category="image",
description="Save a dataset of images to a specified folder. Supported formats: PNG.",
display_name="Save Image Dataset to Folder",
category="dataset",
is_experimental=True,
is_output_node=True,
is_input_list=True, # Receive images as list
@ -232,7 +226,6 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
),
],
outputs=[],
is_deprecated=True, # This node is redundant and superseded by existing Save Image nodes where the target folder can be specified in the filename_prefix
)
@classmethod
@ -253,20 +246,14 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveImageTextDataSetToFolder",
search_aliases=["save folder", "save to folder", "save dataset", "save images", "save text", "export dataset"],
display_name="Save Image-Text (to Folder)",
category="image",
description="Save a dataset of pairs of images and text captions to a specified folder. Images are saved as PNG files and captions are saved as TXT files with the same filename_prefix.",
display_name="Save Image and Text Dataset to Folder",
category="dataset",
is_experimental=True,
is_output_node=True,
is_input_list=True, # Receive both images and texts as lists
inputs=[
io.Image.Input("images", tooltip="List of images to save."),
io.String.Input("texts",
optional=True,
force_input=True,
tooltip="List of text captions to save."
),
io.String.Input("texts", tooltip="List of text captions to save."),
io.String.Input(
"folder_name",
default="dataset",
@ -283,7 +270,7 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
)
@classmethod
def execute(cls, images, folder_name, filename_prefix, texts=None):
def execute(cls, images, texts, folder_name, filename_prefix):
# Extract scalar values
folder_name = folder_name[0]
filename_prefix = filename_prefix[0]
@ -292,12 +279,11 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
saved_files = save_images_to_folder(images, output_dir, filename_prefix)
# Save captions
if texts:
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
caption_filename = filename.replace(".png", ".txt")
caption_path = os.path.join(output_dir, caption_filename)
with open(caption_path, "w", encoding="utf-8") as f:
f.write(caption)
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
caption_filename = filename.replace(".png", ".txt")
caption_path = os.path.join(output_dir, caption_filename)
with open(caption_path, "w", encoding="utf-8") as f:
f.write(caption)
logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.")
return io.NodeOutput()
@ -328,13 +314,11 @@ class ImageProcessingNode(io.ComfyNode):
Child classes should set:
node_id: Unique node identifier (required)
search_aliases: List of search aliases (optional)
display_name: Display name (optional, defaults to node_id)
description: Node description (optional)
extra_inputs: List of additional io.Input objects beyond "images" (optional)
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
is_output_list: True (list output) or False (single output) (optional, default True)
is_deprecated: True if the node is deprecated (optional, default False)
Child classes must implement ONE of:
_process(cls, image, **kwargs) -> tensor (for single-item processing)
@ -342,13 +326,12 @@ class ImageProcessingNode(io.ComfyNode):
"""
node_id = None
search_aliases = []
display_name = None
description = None
extra_inputs = []
is_group_process = None # None = auto-detect, True/False = explicit
is_output_list = None # None = auto-detect based on processing mode
is_deprecated = False
@classmethod
def _detect_processing_mode(cls):
"""Detect whether this node uses group or individual processing.
@ -419,10 +402,8 @@ class ImageProcessingNode(io.ComfyNode):
return io.Schema(
node_id=cls.node_id,
search_aliases=cls.search_aliases,
display_name=cls.display_name or cls.node_id,
category=cls.category,
description=cls.description,
category="dataset/image",
is_experimental=True,
is_input_list=is_group, # True for group, False for individual
inputs=inputs,
@ -491,13 +472,11 @@ class TextProcessingNode(io.ComfyNode):
Child classes should set:
node_id: Unique node identifier (required)
search_aliases: List of search aliases (optional)
display_name: Display name (optional, defaults to node_id)
description: Node description (optional)
extra_inputs: List of additional io.Input objects beyond "texts" (optional)
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
is_output_list: True (list output) or False (single output) (optional, default True)
is_deprecated: True if the node is deprecated (optional, default False)
Child classes must implement ONE of:
_process(cls, text, **kwargs) -> str (for single-item processing)
@ -505,13 +484,12 @@ class TextProcessingNode(io.ComfyNode):
"""
node_id = None
search_aliases = []
display_name = None
description = None
extra_inputs = []
is_group_process = None # None = auto-detect, True/False = explicit
is_output_list = None # None = auto-detect based on processing mode
is_deprecated = False
@classmethod
def _detect_processing_mode(cls):
"""Detect whether this node uses group or individual processing.
@ -649,17 +627,15 @@ class TextProcessingNode(io.ComfyNode):
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
node_id = "ResizeImagesByShorterEdge"
display_name = "Resize Images by Shorter Edge (DEPRECATED)"
category = "image/transform"
description = "Resize images so that the shorter edge matches the specified dimension while preserving aspect ratio."
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale shorter dimension
display_name = "Resize Images by Shorter Edge"
description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio."
extra_inputs = [
io.Int.Input(
"shorter_edge",
default=512,
min=1,
max=8192,
tooltip="Target dimension for the shorter edge.",
tooltip="Target length for the shorter edge.",
),
]
@ -679,17 +655,15 @@ class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
node_id = "ResizeImagesByLongerEdge"
display_name = "Resize Images by Longer Edge (DEPRECATED)"
category = "image/transform"
description = "Resize images so that the longer edge matches the specified dimension while preserving aspect ratio."
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale longer dimension
display_name = "Resize Images by Longer Edge"
description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio."
extra_inputs = [
io.Int.Input(
"longer_edge",
default=1024,
min=1,
max=8192,
tooltip="Target dimension for the longer edge.",
tooltip="Target length for the longer edge.",
),
]
@ -712,10 +686,8 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
class CenterCropImagesNode(ImageProcessingNode):
node_id = "CenterCropImages"
search_aliases=["crop", "cut", "trim"]
display_name="Crop Image (Center)"
category="image/transform"
description = "Center crop an image to the specified dimensions."
display_name = "Center Crop Images"
description = "Center crop all images to the specified dimensions."
extra_inputs = [
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
@ -734,11 +706,10 @@ class CenterCropImagesNode(ImageProcessingNode):
class RandomCropImagesNode(ImageProcessingNode):
node_id = "RandomCropImages"
search_aliases=["crop", "cut", "trim"]
display_name = "Crop Image (Random)"
category="image/transform"
description = "Randomly crop an image to the specified dimensions."
display_name = "Random Crop Images"
description = (
"Randomly crop all images to the specified dimensions (for data augmentation)."
)
extra_inputs = [
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
@ -763,9 +734,7 @@ class RandomCropImagesNode(ImageProcessingNode):
class NormalizeImagesNode(ImageProcessingNode):
node_id = "NormalizeImages"
search_aliases=["normalize", "normalize colors"]
display_name = "Normalize Image Colors"
category = "image/color"
display_name = "Normalize Images"
description = "Normalize images using mean and standard deviation."
extra_inputs = [
io.Float.Input(
@ -793,10 +762,8 @@ class NormalizeImagesNode(ImageProcessingNode):
class AdjustBrightnessNode(ImageProcessingNode):
node_id = "AdjustBrightness"
search_aliases=["brightness"]
display_name = "Adjust Brightness"
category="image/adjustments"
description = "Adjust the brightness of an image."
description = "Adjust brightness of all images."
extra_inputs = [
io.Float.Input(
"factor",
@ -814,10 +781,8 @@ class AdjustBrightnessNode(ImageProcessingNode):
class AdjustContrastNode(ImageProcessingNode):
node_id = "AdjustContrast"
search_aliases=["contrast"]
display_name = "Adjust Contrast"
category="image/adjustments"
description = "Adjust the contrast of an image."
description = "Adjust contrast of all images."
extra_inputs = [
io.Float.Input(
"factor",
@ -835,10 +800,8 @@ class AdjustContrastNode(ImageProcessingNode):
class ShuffleDatasetNode(ImageProcessingNode):
node_id = "ShuffleDataset"
search_aliases=["shuffle", "randomize", "mix"]
display_name = "Shuffle Images List"
category = "image/batch"
description = "Randomly shuffle the order of images in a list."
display_name = "Shuffle Image Dataset"
description = "Randomly shuffle the order of images in the dataset."
is_group_process = True # Requires full list to shuffle
extra_inputs = [
io.Int.Input(
@ -860,15 +823,13 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ShuffleImageTextDataset",
search_aliases=["shuffle", "randomize", "mix"],
display_name = "Shuffle Pairs of Image-Text",
category = "image/batch",
description = "Randomly shuffle the order of pairs of image-text in a list.",
display_name="Shuffle Image-Text Dataset",
category="dataset/image",
is_experimental=True,
is_input_list=True,
inputs=[
io.Image.Input("images", tooltip="List of images to shuffle."),
io.String.Input("texts", tooltip="List of texts to shuffle.", force_input=True),
io.String.Input("texts", tooltip="List of texts to shuffle."),
io.Int.Input(
"seed",
default=0,
@ -904,11 +865,8 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
class TextToLowercaseNode(TextProcessingNode):
node_id = "TextToLowercase"
search_aliases=["lowercase"]
display_name = "Convert Text to Lowercase (DEPRECATED)"
category = "text"
description = "Convert text to lowercase."
is_deprecated = True # This node is superseded by the Convert Text Case node
display_name = "Text to Lowercase"
description = "Convert all texts to lowercase."
@classmethod
def _process(cls, text):
@ -917,11 +875,8 @@ class TextToLowercaseNode(TextProcessingNode):
class TextToUppercaseNode(TextProcessingNode):
node_id = "TextToUppercase"
search_aliases=["uppercase"]
display_name = "Convert Text to Uppercase (DEPRECATED)"
category = "text"
description = "Convert text to uppercase."
is_deprecated = True # This node is superseded by the Convert Text Case node
display_name = "Text to Uppercase"
description = "Convert all texts to uppercase."
@classmethod
def _process(cls, text):
@ -930,10 +885,8 @@ class TextToUppercaseNode(TextProcessingNode):
class TruncateTextNode(TextProcessingNode):
node_id = "TruncateText"
search_aliases=["truncate", "cut", "shorten"]
display_name = "Truncate Text"
category = "text"
description = "Truncate text to a maximum length."
description = "Truncate all texts to a maximum length."
extra_inputs = [
io.Int.Input(
"max_length", default=77, min=1, max=10000, tooltip="Maximum text length."
@ -947,10 +900,8 @@ class TruncateTextNode(TextProcessingNode):
class AddTextPrefixNode(TextProcessingNode):
node_id = "AddTextPrefix"
display_name = "Add Text Prefix (DEPRECATED)"
category = "text"
display_name = "Add Text Prefix"
description = "Add a prefix to all texts."
is_deprecated = True # This node is superseded by the Concatenate Text node
extra_inputs = [
io.String.Input("prefix", default="", tooltip="Prefix to add."),
]
@ -962,10 +913,8 @@ class AddTextPrefixNode(TextProcessingNode):
class AddTextSuffixNode(TextProcessingNode):
node_id = "AddTextSuffix"
display_name = "Add Text Suffix (DEPRECATED)"
category = "text"
display_name = "Add Text Suffix"
description = "Add a suffix to all texts."
is_deprecated = True # This node is superseded by the Concatenate Text node
extra_inputs = [
io.String.Input("suffix", default="", tooltip="Suffix to add."),
]
@ -977,10 +926,8 @@ class AddTextSuffixNode(TextProcessingNode):
class ReplaceTextNode(TextProcessingNode):
node_id = "ReplaceText"
display_name = "Replace Text (DEPRECATED)"
category = "text"
display_name = "Replace Text"
description = "Replace text in all texts."
is_deprecated = True # This node is superseded by the other Replace Text node
extra_inputs = [
io.String.Input("find", default="", tooltip="Text to find."),
io.String.Input("replace", default="", tooltip="Text to replace with."),
@ -993,10 +940,8 @@ class ReplaceTextNode(TextProcessingNode):
class StripWhitespaceNode(TextProcessingNode):
node_id = "StripWhitespace"
display_name = "Strip Whitespace (DEPRECATED)"
category = "text"
display_name = "Strip Whitespace"
description = "Strip leading and trailing whitespace from all texts."
is_deprecated = True # This node is superseded by the Trim Text node
@classmethod
def _process(cls, text):
@ -1007,13 +952,11 @@ class StripWhitespaceNode(TextProcessingNode):
class ImageDeduplicationNode(ImageProcessingNode):
"""Remove duplicate or very similar images from a list using perceptual hashing."""
"""Remove duplicate or very similar images from the dataset using perceptual hashing."""
node_id = "ImageDeduplication"
search_aliases=["deduplicate", "remove duplicates", "similarity filter"]
display_name = "Deduplicate Images"
category = "image/batch"
description = "Remove duplicate or very similar images from a list."
display_name = "Image Deduplication"
description = "Remove duplicate or very similar images from the dataset."
is_group_process = True # Requires full list to compare images
extra_inputs = [
io.Float.Input(
@ -1083,9 +1026,7 @@ class ImageGridNode(ImageProcessingNode):
"""Combine multiple images into a single grid/collage."""
node_id = "ImageGrid"
search_aliases=["grid", "collage", "combine"]
display_name = "Make Image Grid"
category="image/batch"
display_name = "Image Grid"
description = "Arrange multiple images into a grid layout."
is_group_process = True # Requires full list to create grid
is_output_list = False # Outputs single grid image
@ -1161,12 +1102,9 @@ class MergeImageListsNode(ImageProcessingNode):
"""Merge multiple image lists into a single list."""
node_id = "MergeImageLists"
search_aliases=["list", "merge list", "make list"]
display_name = "Merge Image Lists (DEPRECATED)"
category = "image/batch"
display_name = "Merge Image Lists"
description = "Concatenate multiple image lists into one."
is_group_process = True # Receives images as list
is_deprecated = True # This node is superseded by the Create List node
@classmethod
def _group_process(cls, images):
@ -1181,11 +1119,9 @@ class MergeTextListsNode(TextProcessingNode):
"""Merge multiple text lists into a single list."""
node_id = "MergeTextLists"
display_name = "Merge Text Lists (DEPRECATED)"
category = "text"
display_name = "Merge Text Lists"
description = "Concatenate multiple text lists into one."
is_group_process = True # Receives texts as list
is_deprecated = True # This node is superseded by the Create List node
@classmethod
def _group_process(cls, texts):
@ -1206,10 +1142,8 @@ class ResolutionBucket(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ResolutionBucket",
search_aliases=["bucket by resolution", "group by resolution", "batch by resolution"],
display_name="Resolution Bucket",
category="training",
description="Group latents and conditionings into buckets",
category="dataset",
is_experimental=True,
is_input_list=True,
inputs=[
@ -1302,8 +1236,7 @@ class MakeTrainingDataset(io.ComfyNode):
node_id="MakeTrainingDataset",
search_aliases=["encode dataset"],
display_name="Make Training Dataset",
category="training",
description="Encode images with VAE and texts with CLIP to create a training dataset of latents and conditionings.",
category="dataset",
is_experimental=True,
is_input_list=True, # images and texts as lists
inputs=[
@ -1318,7 +1251,6 @@ class MakeTrainingDataset(io.ComfyNode):
"texts",
optional=True,
tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).",
force_input=True
),
],
outputs=[
@ -1388,10 +1320,9 @@ class SaveTrainingDataset(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveTrainingDataset",
search_aliases=["export dataset", "save dataset"],
search_aliases=["export training data"],
display_name="Save Training Dataset",
category="training",
description="Save encoded training dataset (latents + conditioning) to disk for efficient loading during training.",
category="dataset",
is_experimental=True,
is_output_node=True,
is_input_list=True, # Receive lists
@ -1493,8 +1424,7 @@ class LoadTrainingDataset(io.ComfyNode):
node_id="LoadTrainingDataset",
search_aliases=["import dataset", "training data"],
display_name="Load Training Dataset",
category="training",
description="Load encoded training dataset (latents + conditioning) from disk for use in training.",
category="dataset",
is_experimental=True,
inputs=[
io.String.Input(

View File

@ -419,17 +419,15 @@ class VoxelToMeshBasic(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="VoxelToMeshBasic",
display_name="Voxel to Mesh (Basic) (DEPRECATED)",
display_name="Voxel to Mesh (Basic)",
category="3d",
description="Converts a voxel grid to a mesh.",
is_deprecated=True, # This node is superseded by the Voxel To Mesh node
inputs=[
IO.Voxel.Input("voxel"),
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
],
outputs=[
IO.Mesh.Output(),
],
]
)
@classmethod
@ -455,10 +453,9 @@ class VoxelToMesh(IO.ComfyNode):
node_id="VoxelToMesh",
display_name="Voxel to Mesh",
category="3d",
description="Converts a voxel grid to a mesh.",
inputs=[
IO.Voxel.Input("voxel"),
IO.Combo.Input("algorithm", options=["surface net", "basic"]),
IO.Combo.Input("algorithm", options=["surface net", "basic"], advanced=True),
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
],
outputs=[

View File

@ -3,15 +3,23 @@ from __future__ import annotations
import nodes
import folder_paths
import av
import json
import os
import re
import math
import numpy as np
import struct
import torch
import zlib
import comfy.utils
from fractions import Fraction
from server import PromptServer
from comfy_api.latest import ComfyExtension, IO, UI
from comfy.cli_args import args
from typing_extensions import override
SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later.
@ -55,10 +63,9 @@ class ImageCropV2(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ImageCropV2",
search_aliases=["crop", "cut", "trim"],
search_aliases=["trim"],
display_name="Crop Image",
category="image/transform",
description = "Crop an image to the specified dimensions.",
essentials_category="Image Tools",
has_intermediate_output=True,
inputs=[
@ -835,6 +842,405 @@ class ImageMergeTileList(IO.ComfyNode):
return IO.NodeOutput(merged_image)
# ---------------------------------------------------------------------------
# Format specifications
# ---------------------------------------------------------------------------
# Maps (file_format, bit_depth, has_alpha) -> (numpy dtype scale, av pixel format,
# stream pix_fmt). Keeps the encode path declarative instead of branchy.
_FORMAT_SPECS = {
("png", "8-bit", False): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "rgb24", "stream_fmt": "rgb24"},
("png", "8-bit", True): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "rgba", "stream_fmt": "rgba"},
("png", "16-bit", False): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "rgb48le", "stream_fmt": "rgb48be"},
("png", "16-bit", True): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "rgba64le", "stream_fmt": "rgba64be"},
("exr", "32-bit float", False): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "gbrpf32le", "stream_fmt": "gbrpf32le"},
("exr", "32-bit float", True): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "gbrapf32le", "stream_fmt": "gbrapf32le"},
}
# ---------------------------------------------------------------------------
# Color transforms
# ---------------------------------------------------------------------------
def srgb_to_linear(t: torch.Tensor) -> torch.Tensor:
"""Inverse sRGB EOTF (IEC 61966-2-1). Operates on RGB channels only;
alpha (if present as the 4th channel) is passed through unchanged."""
if t.shape[-1] == 4:
rgb, alpha = t[..., :3], t[..., 3:]
return torch.cat([srgb_to_linear(rgb), alpha], dim=-1)
# Piecewise: linear toe below 0.04045, gamma curve above.
low = t / 12.92
high = ((t.clamp(min=0.0) + 0.055) / 1.055) ** 2.4
return torch.where(t <= 0.04045, low, high)
# HLG OETF constants from BT.2100 Table 5.
_HLG_A = 0.17883277
_HLG_B = 0.28466892
_HLG_C = 0.55991072928 # = 0.5 - a*ln(4*a)
def hlg_to_linear(t: torch.Tensor) -> torch.Tensor:
"""Inverse HLG OETF (BT.2100). Maps a non-linear HLG signal in [0, 1] to
*scene*-linear light in [0, 1]. Per BT.2100 Note 5a, this is the correct
transform when converting HLG to a linear scene-light representation
(rather than display-light, which would also involve the HLG OOTF).
Operates on RGB channels only; alpha is passed through unchanged."""
if t.shape[-1] == 4:
rgb, alpha = t[..., :3], t[..., 3:]
return torch.cat([hlg_to_linear(rgb), alpha], dim=-1)
# Piecewise: sqrt branch below 0.5, log branch above.
# Clamp inside the log branch so negative / out-of-range values don't blow up;
# values above 1.0 are allowed and extrapolate naturally.
low = (t ** 2) / 3.0
high = (torch.exp((t.clamp(min=_HLG_C) - _HLG_C) / _HLG_A) + _HLG_B) / 12.0
return torch.where(t <= 0.5, low, high)
# ---------------------------------------------------------------------------
# Metadata injection
# ---------------------------------------------------------------------------
_PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
def _png_chunk(chunk_type: bytes, data: bytes) -> bytes:
"""Build a single PNG chunk: length | type | data | CRC32(type+data)."""
crc = zlib.crc32(chunk_type + data) & 0xFFFFFFFF
return struct.pack(">I", len(data)) + chunk_type + data + struct.pack(">I", crc)
def _png_text_chunk(keyword: str, text: str) -> bytes:
"""tEXt chunk: latin-1 keyword + NUL + latin-1 text."""
payload = keyword.encode("latin-1") + b"\x00" + text.encode("latin-1", errors="replace")
return _png_chunk(b"tEXt", payload)
def inject_png_metadata(png_bytes: bytes, prompt: dict | None, extra_pnginfo: dict | None) -> bytes:
"""Insert ComfyUI prompt/workflow as tEXt chunks right after IHDR."""
if not png_bytes.startswith(_PNG_SIGNATURE):
return png_bytes
chunks: list[bytes] = []
if prompt is not None:
chunks.append(_png_text_chunk("prompt", json.dumps(prompt)))
if extra_pnginfo:
for key, value in extra_pnginfo.items():
chunks.append(_png_text_chunk(key, json.dumps(value)))
if not chunks:
return png_bytes
# IHDR is always the first chunk; insert ours immediately after it.
ihdr_length = struct.unpack(">I", png_bytes[8:12])[0]
ihdr_end = 8 + 8 + ihdr_length + 4 # signature + (len+type) + data + crc
return png_bytes[:ihdr_end] + b"".join(chunks) + png_bytes[ihdr_end:]
# Standard chromaticities (CIE 1931 xy) for the colorspaces this node writes.
# Each tuple is (Rx, Ry, Gx, Gy, Bx, By, Wx, Wy). All share D65 white point.
_CHROMATICITIES = {
# ITU-R BT.709 / sRGB primaries
"Rec.709": (0.6400, 0.3300, 0.3000, 0.6000, 0.1500, 0.0600, 0.3127, 0.3290),
# ITU-R BT.2020 (UHDTV / wide-gamut HDR) primaries
"Rec.2020": (0.7080, 0.2920, 0.1700, 0.7970, 0.1310, 0.0460, 0.3127, 0.3290),
}
def _pack_chromaticities(primaries: tuple) -> bytes:
"""Serialize 8 chromaticity floats into the EXR `chromaticities` payload."""
return struct.pack("<8f", *primaries)
def _exr_attribute(name: str, attr_type: str, value: bytes) -> bytes:
"""Serialize one EXR header attribute: name\\0 type\\0 size:int32 value."""
return (
name.encode("utf-8") + b"\x00"
+ attr_type.encode("utf-8") + b"\x00"
+ struct.pack("<i", len(value))
+ value
)
def inject_exr_metadata(
exr_bytes: bytes,
prompt: dict | None,
extra_pnginfo: dict | None,
colorspace: str | None = None,
) -> bytes:
"""Insert ComfyUI metadata and color-space info into an EXR header.
Color: EXR pixels are linear by convention. The standard way to describe
their RGB→XYZ relationship is the `chromaticities` attribute. We pick the
primaries that match what the user told us their input was:
colorspace="sRGB" → Rec. 709 / sRGB primaries (D65)
colorspace="HDR" → Rec. 2020 / BT.2100 primaries (D65)
Pixels are always converted to linear scene light upstream (sRGB EOTF
inverse for sRGB; HLG OETF inverse for HDR), so the file content is
scene-linear in the indicated gamut. OpenEXR has no standard transfer-
function attribute (the OpenEXR TSC has discussed adding one but it
doesn't exist), so we don't invent one — `chromaticities` plus the EXR
linear-by-convention rule fully specifies the color.
Prompt/workflow: written as plain `string` attributes using the same keys
(`prompt`, `workflow`, ...) that Comfy uses for PNG tEXt chunks, so the
same readers can pull them out symmetrically.
Implementation note: the chunk-offset table that follows the header stores
*absolute* byte offsets into the file. Inserting N bytes into the header
means every offset must be incremented by N or the file becomes unreadable.
"""
if len(exr_bytes) < 8 or exr_bytes[:4] != b"\x76\x2f\x31\x01":
return exr_bytes
new_blob = b""
if prompt is not None:
new_blob += _exr_attribute("prompt", "string", json.dumps(prompt).encode("utf-8"))
if extra_pnginfo:
for key, value in extra_pnginfo.items():
new_blob += _exr_attribute(key, "string", json.dumps(value).encode("utf-8"))
if colorspace is not None:
# Map each colorspace option to the RGB primaries the linear pixels
# are now in. "sRGB" and "linear" both produce Rec. 709 linear; "HDR"
# (HLG-encoded Rec. 2020 input) produces Rec. 2020 linear.
primaries_name = {
"sRGB": "Rec.709",
"linear": "Rec.709",
"HDR": "Rec.2020",
}.get(colorspace, "Rec.709")
new_blob += _exr_attribute(
"chromaticities",
"chromaticities",
_pack_chromaticities(_CHROMATICITIES[primaries_name]),
)
if not new_blob:
return exr_bytes
# Walk header attributes to find the terminating null byte, and pick up
# dataWindow + compression so we know how many chunks the offset table has.
pos = 8 # past magic (4) + version (4)
data_window = None
compression = 0
while pos < len(exr_bytes) and exr_bytes[pos] != 0:
name_end = exr_bytes.index(b"\x00", pos)
attr_name = exr_bytes[pos:name_end].decode("latin-1", errors="replace")
type_end = exr_bytes.index(b"\x00", name_end + 1)
attr_type = exr_bytes[name_end + 1:type_end].decode("latin-1", errors="replace")
size = struct.unpack("<i", exr_bytes[type_end + 1:type_end + 5])[0]
value_start = type_end + 5
value = exr_bytes[value_start:value_start + size]
if attr_name == "dataWindow" and attr_type == "box2i":
data_window = struct.unpack("<iiii", value) # xMin, yMin, xMax, yMax
elif attr_name == "compression" and attr_type == "compression":
compression = value[0]
pos = value_start + size
if data_window is None:
return exr_bytes # required attribute missing — don't risk corrupting
# Scanlines per chunk by compression, from the OpenEXR spec.
scanlines_per_block = {
0: 1, # NO_COMPRESSION
1: 1, # RLE
2: 1, # ZIPS
3: 16, # ZIP
4: 32, # PIZ
5: 16, # PXR24
6: 32, # B44
7: 32, # B44A
8: 256, # DWAA
9: 256, # DWAB
}.get(compression, 1)
_, y_min, _, y_max = data_window
height = y_max - y_min + 1
num_chunks = (height + scanlines_per_block - 1) // scanlines_per_block
header_end = pos # position of the terminating null byte
table_start = header_end + 1
pixel_start = table_start + num_chunks * 8
delta = len(new_blob)
old_offsets = struct.unpack(f"<{num_chunks}Q", exr_bytes[table_start:pixel_start])
new_table = struct.pack(f"<{num_chunks}Q", *(o + delta for o in old_offsets))
return (
exr_bytes[:header_end] # header attributes
+ new_blob # our new attributes
+ exr_bytes[header_end:table_start] # terminating null byte
+ new_table # shifted offset table
+ exr_bytes[pixel_start:] # pixel data, untouched
)
# ---------------------------------------------------------------------------
# Encoding
# ---------------------------------------------------------------------------
def _encode_image(
img_tensor: torch.Tensor,
file_format: str,
bit_depth: str,
colorspace: str,
) -> bytes:
"""Encode a single HxWxC tensor to PNG or EXR bytes in memory.
For EXR the input is interpreted according to `colorspace` and converted
to scene-linear (EXR's convention) before writing:
"sRGB" → input is sRGB-encoded Rec. 709; apply inverse sRGB EOTF.
"HDR" → input is HLG-encoded Rec. 2020 (BT.2100); apply inverse HLG
OETF to get scene-linear, per BT.2100 Note 5a.
"linear" → input is already scene-linear (Rec. 709 primaries); write
through unchanged. Use this for renderer/compositor output.
For PNG, colorspace selection does not modify pixels — PNG is delivered
sRGB-encoded and there is no PNG path for wide-gamut HDR in this node.
"""
height, width, num_channels = img_tensor.shape
has_alpha = num_channels == 4
spec = _FORMAT_SPECS[(file_format, bit_depth, has_alpha)]
if spec["dtype"] == np.float32:
# EXR path: preserve full range, no clamp.
if colorspace == "sRGB":
img_tensor = srgb_to_linear(img_tensor)
elif colorspace == "HDR":
img_tensor = hlg_to_linear(img_tensor)
img_np = img_tensor.cpu().numpy().astype(np.float32)
else:
# PNG path: quantize to integer range.
scaled = (img_tensor * spec["scale"]).clamp(0, spec["scale"])
img_np = scaled.to(torch.int32).cpu().numpy().astype(spec["dtype"])
# Encode directly via CodecContext. PyAV's `image2` muxer does NOT write to
# BytesIO (it expects a real file path), so we bypass the container entirely.
# For single-frame PNG/EXR the raw codec output IS the file.
codec = av.CodecContext.create(file_format, "w")
codec.width = width
codec.height = height
codec.pix_fmt = spec["stream_fmt"]
codec.time_base = Fraction(1, 1)
frame = av.VideoFrame.from_ndarray(img_np, format=spec["frame_fmt"])
if spec["frame_fmt"] != spec["stream_fmt"]:
frame = frame.reformat(format=spec["stream_fmt"])
frame.pts = 0
frame.time_base = codec.time_base
packets = list(codec.encode(frame)) + list(codec.encode(None)) # flush with None
return b"".join(bytes(p) for p in packets)
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
class SaveImageAdvanced(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveImageAdvanced",
search_aliases=["save", "save image", "export image", "output image", "write image"],
display_name="Save Image (Advanced)",
description="Saves the input images to your ComfyUI output directory.",
category="image",
essentials_category="Basics",
inputs=[
IO.Image.Input("images", tooltip="The images to save."),
IO.String.Input(
"filename_prefix",
default="ComfyUI",
tooltip=(
"The prefix for the file to save. May include formatting tokens "
"such as %date:yyyy-MM-dd% or %Empty Latent Image.width%."
),
),
IO.DynamicCombo.Input(
"format",
options=[
IO.DynamicCombo.Option("png", [
IO.Combo.Input("bit_depth", options=["8-bit", "16-bit"],
default="8-bit", advanced=True),
IO.Combo.Input("input_color_space", options=["sRGB"],
default="sRGB", advanced=True),
]),
IO.DynamicCombo.Option("exr", [
IO.Combo.Input("bit_depth", options=["32-bit float"],
default="32-bit float", advanced=True),
IO.Combo.Input(
"input_color_space",
options=["sRGB", "HDR", "linear"],
default="sRGB",
advanced=True,
tooltip=(
"Colorspace of the input tensor. The EXR is "
"always written as scene-linear in the matching "
"gamut.\n"
" 'sRGB' — input is sRGB-encoded Rec.709; "
"the inverse sRGB EOTF is applied.\n"
" 'HDR' — input is HLG-encoded Rec.2020 "
"(BT.2100); the inverse HLG OETF is applied "
"to get scene-linear light.\n"
" 'linear' — input is already scene-linear "
"(Rec.709 primaries); written through unchanged. "
"Use this for renderer/compositor output."
),
),
]),
],
tooltip="The file format in which to save the image.",
),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, images, filename_prefix: str, format: dict) -> IO.NodeOutput:
file_format = format["format"]
bit_depth = format["bit_depth"]
colorspace = format.get("input_color_space", "sRGB")
output_dir = folder_paths.get_output_directory()
full_output_folder, filename, counter, subfolder, filename_prefix = (
folder_paths.get_save_image_path(
filename_prefix, output_dir, images[0].shape[1], images[0].shape[0]
)
)
prompt = cls.hidden.prompt
extra_pnginfo = cls.hidden.extra_pnginfo
write_metadata = not args.disable_metadata
results = []
for batch_number, image in enumerate(images):
encoded = _encode_image(image, file_format, bit_depth, colorspace)
if write_metadata:
if file_format == "png":
encoded = inject_png_metadata(encoded, prompt, extra_pnginfo)
elif file_format == "exr":
encoded = inject_exr_metadata(encoded, prompt, extra_pnginfo, colorspace)
name = filename.replace("%batch_num%", str(batch_number))
file = f"{name}_{counter:05}.{file_format}"
with open(os.path.join(full_output_folder, file), "wb") as f:
f.write(encoded)
results.append({"filename": file, "subfolder": subfolder, "type": "output"})
counter += 1
return IO.NodeOutput(ui={"images": results})
class ImagesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -847,6 +1253,7 @@ class ImagesExtension(ComfyExtension):
ImageAddNoise,
SaveAnimatedWEBP,
SaveAnimatedPNG,
SaveImageAdvanced,
SaveSVGNode,
ImageStitch,
ResizeAndPadImage,

View File

@ -11,8 +11,8 @@ class LTXVAudioVAELoader(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVAudioVAELoader",
display_name="Load LTXV Audio VAE",
category="loaders",
display_name="LTXV Audio VAE Loader",
category="audio",
inputs=[
io.Combo.Input(
"ckpt_name",
@ -40,7 +40,7 @@ class LTXVAudioVAEEncode(VAEEncodeAudio):
return io.Schema(
node_id="LTXVAudioVAEEncode",
display_name="LTXV Audio VAE Encode",
category="latent/audio",
category="audio",
inputs=[
io.Audio.Input("audio", tooltip="The audio to be encoded."),
io.Vae.Input(
@ -63,7 +63,7 @@ class LTXVAudioVAEDecode(io.ComfyNode):
return io.Schema(
node_id="LTXVAudioVAEDecode",
display_name="LTXV Audio VAE Decode",
category="latent/audio",
category="audio",
inputs=[
io.Latent.Input("samples", tooltip="The latent to be decoded."),
io.Vae.Input(

View File

@ -28,7 +28,7 @@ from comfy_extras.mediapipe.face_landmarker import FaceLandmarker
from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection
FaceDetectionType = io.Custom("FACE_DETECTION_MODEL")
FaceLandmarkerType = io.Custom("FACE_LANDMARKER")
FaceLandmarksType = io.Custom("FACE_LANDMARKS")
_CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights")
@ -204,19 +204,18 @@ class LoadMediaPipeFaceLandmarker(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LoadMediaPipeFaceLandmarker",
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
display_name="Load Face Detection Model (MediaPipe)",
display_name="Load MediaPipe Face Landmarker",
category="loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("detection"),
tooltip="Face detection model from models/detection/."),
io.Combo.Input("model_name", options=folder_paths.get_filename_list("mediapipe"),
tooltip="Face Landmarker safetensors from models/mediapipe/."),
],
outputs=[FaceDetectionType.Output()],
outputs=[FaceLandmarkerType.Output()],
)
@classmethod
def execute(cls, model_name) -> io.NodeOutput:
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("detection", model_name), safe_load=True)
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("mediapipe", model_name), safe_load=True)
wrapper = FaceLandmarkerModel(sd)
return io.NodeOutput(wrapper)
@ -235,12 +234,10 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MediaPipeFaceLandmarker",
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
display_name="Detect Face Landmarks (MediaPipe)",
display_name="MediaPipe Face Landmarker",
category="image/detection",
description="Detects facial landmarks using MediaPipe model.",
inputs=[
FaceDetectionType.Input("face_detection_model"),
FaceLandmarkerType.Input("face_landmarker"),
io.Image.Input("image"),
io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short",
tooltip="Face detector range. 'short' is tuned for close-up faces "
@ -264,9 +261,9 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
)
@classmethod
def execute(cls, face_detection_model, image, detector_variant, num_faces, min_confidence,
def execute(cls, face_landmarker, image, detector_variant, num_faces, min_confidence,
missing_frame_fallback) -> io.NodeOutput:
canonical = face_detection_model.canonical_data
canonical = face_landmarker.canonical_data
img_np = _image_to_uint8(image)
B, H, W = img_np.shape[:3]
chunk = 16
@ -279,7 +276,7 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq:
for i in range(0, B, chunk):
end = min(i + chunk, B)
res.extend(face_detection_model.detect_batch(
res.extend(face_landmarker.detect_batch(
[img_np[bi] for bi in range(i, end)],
num_faces=int(num_faces),
score_thresh=float(min_confidence),
@ -309,7 +306,7 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])})
bboxes.append(per_bb)
return io.NodeOutput({"frames": frames, "image_size": (H, W),
"connection_sets": face_detection_model.connection_sets}, bboxes)
"connection_sets": face_landmarker.connection_sets}, bboxes)
# Topology keys unioned by the 'all' connections preset (contour parts + irises + nose).
@ -335,10 +332,8 @@ class MediaPipeFaceMeshVisualize(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MediaPipeFaceMeshVisualize",
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection", "visualize"],
display_name="Visualize Face Landmarks (MediaPipe)",
display_name="MediaPipe Face Mesh Visualize",
category="image/detection",
description="Draws face landmarks mesh on the input image.",
inputs=[
FaceLandmarksType.Input("face_landmarks"),
io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."),
@ -448,10 +443,8 @@ class MediaPipeFaceMask(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MediaPipeFaceMask",
search_aliases=["face", "facial", "mediapipe", "face mask", "blazeface", "face detection", "visualize"],
display_name="Draw Face Mask (MediaPipe)",
display_name="MediaPipe Face Mask",
category="image/detection",
description="Draws a mask from face landmarks.",
inputs=[
FaceLandmarksType.Input("face_landmarks"),
io.DynamicCombo.Input(

View File

@ -103,10 +103,8 @@ class MoGePanoramaInference(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MoGePanoramaInference",
search_aliases=["moge", "panorama", "depth", "geometry", "depth estimation", "geometry estimation"],
display_name="Run MoGe Panorama Inference",
display_name="MoGe Panorama Inference",
category="image/geometry_estimation",
description="Run MoGe on an equirectangular panorama by splitting it into 12 perspective views, running inference on each, and merging the results into a single depth map.",
inputs=[
MoGeModelType.Input("moge_model"),
io.Image.Input("image", tooltip="Equirectangular panorama (any aspect)."),
@ -224,9 +222,7 @@ class MoGeInference(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MoGeInference",
search_aliases=["moge", "depth", "geometry", "depth estimation", "geometry estimation"],
display_name="Run MoGe Inference",
description="Run MoGe on a single image to estimate depth and geometry.",
display_name="MoGe Inference",
category="image/geometry_estimation",
inputs=[
MoGeModelType.Input("moge_model"),
@ -281,9 +277,7 @@ class MoGeRender(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MoGeRender",
search_aliases=["moge", "render", "geometry", "depth", "normal"],
display_name="Render MoGe Geometry",
description="Render a depth map or normal map from geometry data",
display_name="MoGe Render",
category="image/geometry_estimation",
inputs=[
MoGeGeometry.Input("moge_geometry"),
@ -348,9 +342,7 @@ class MoGePointMapToMesh(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MoGePointMapToMesh",
search_aliases=["moge", "mesh", "geometry", "point map"],
display_name="Convert MoGe Point Map to Mesh",
description="Convert a MoGe point map into a 3D mesh.",
display_name="MoGe Point Map to Mesh",
category="image/geometry_estimation",
inputs=[
MoGeGeometry.Input("moge_geometry"),

View File

@ -60,7 +60,7 @@ folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geom
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
folder_names_and_paths["detection"] = ([os.path.join(models_dir, "detection")], supported_pt_extensions)
folder_names_and_paths["mediapipe"] = ([os.path.join(models_dir, "mediapipe")], supported_pt_extensions)
output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp")