Compare commits

..

19 Commits

Author SHA1 Message Date
30b19c6872 Fix a bug if . in the teamplate field or id. 2026-06-27 10:18:40 +02:00
2dd281d8a6 Rename and refactor io.List to io.DynamicGroup 2026-06-27 10:18:40 +02:00
911e0b2acf Throw an error if frontend send more row than in max in List. 2026-06-27 10:18:40 +02:00
46c7e8055c Add io.List input 2026-06-27 10:18:39 +02:00
603d891eaf Update GLSL node to use ANGLE library (CORE-162) (#13195) 2026-06-27 08:40:31 +08:00
470ac36a0a Fix int8 loras causing lower quality requant with wrong settings. (#14650)
* Update comfy-kitchen

* Support requantizing with same settings as orig quant.
2026-06-26 16:41:29 -07:00
7cb784e0f4 Faster int8. (#14641) 2026-06-25 15:25:47 -07:00
1a510f0423 Support int8 models. (#14636) 2026-06-25 11:23:58 -07:00
639c8fa788 chore: update workflow templates to v0.10.7 (#14632) 2026-06-25 23:05:34 +08:00
e22f1500f9 [Partner Nodes] feat(ByteDance): add support for SeeDance-2.0-Mini video model (#14626)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-06-25 17:57:04 +03:00
dac4ea3a80 feat: Bounding boxes canvas and Ideogram JSON prompt (#14537) 2026-06-25 22:34:09 +08:00
b0ec19804f chore(openapi): sync shared API contract from cloud@4118910 (#14619) 2026-06-25 13:54:53 +08:00
64e1d740b8 Add advanced krea 2 model merging node. (#14621) 2026-06-24 20:37:30 -07:00
b22d0fb9c0 feat: Add Support For Simple Seed (CORE-295) (#14616) 2026-06-25 09:39:10 +08:00
5236cd02e6 [Partner Nodes] feat(ByteDance): add 4K resolution support for SeeDance 2.0 (#14614)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-06-24 17:57:46 +03:00
cabb7342d1 [Partner Nodes] feat(Grok): add 1080p resolution to Grok Image node (#14612)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-06-24 16:28:56 +03:00
12218db68a Update the template to bring the HH1.1 templates back (#14613) 2026-06-24 21:01:25 +08:00
44955d783b [Partner Nodes] feat(Alibaba): add support for HappyHorse 1.1 model (#14611)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-06-24 13:37:28 +03:00
1f275fcba6 chore(openapi): sync shared API contract from cloud@363764b (#14607) 2026-06-24 18:22:59 +08:00
18 changed files with 1300 additions and 367 deletions

View File

@ -256,7 +256,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
if (want_requant and len(fns) == 0 or update_weight):
seed = comfy.utils.string_to_seed(s.seed_key)
if isinstance(orig, QuantizedTensor):
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
y = orig.requantize_from_float(x, scale="recalculate", stochastic_rounding=seed)
else:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
if want_requant and len(fns) == 0:
@ -1089,6 +1089,19 @@ def _load_quantized_module(module, super_load, state_dict, prefix, local_metadat
if ts is None or bs is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
scales = {"scale": ts, "block_scale": bs}
elif module.quant_format == "int8_tensorwise":
scale = pop_scale("weight_scale")
if scale is None:
raise ValueError(f"Missing INT8 weight scale for layer {layer_name}")
scales = {"scale": scale}
params_conf = layer_conf.get("params", {})
if not isinstance(params_conf, dict):
params_conf = {}
if layer_conf.get("convrot", params_conf.get("convrot", False)):
scales["convrot"] = True
scales["convrot_groupsize"] = int(
layer_conf.get("convrot_groupsize", params_conf.get("convrot_groupsize", 256))
)
else:
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
@ -1131,6 +1144,10 @@ def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extr
quant_conf = {"format": module.quant_format}
if getattr(module, '_full_precision_mm_config', False):
quant_conf["full_precision_matrix_mult"] = True
params = getattr(module.weight, "_params", None)
if module.quant_format == "int8_tensorwise" and getattr(params, "convrot", False):
quant_conf["convrot"] = True
quant_conf["convrot_groupsize"] = getattr(params, "convrot_groupsize", 256)
if extra_quant_conf:
quant_conf.update(extra_quant_conf)
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
@ -1183,8 +1200,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
def forward_comfy_cast_weights(
self,
input,
compute_dtype=None,
want_requant=False,
weight_only_quant=False,
):
if weight_only_quant:
weight, bias, offload_stream = cast_bias_weight(
self,
input=None,
dtype=self.weight.dtype,
device=input.device,
bias_dtype=input.dtype,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=want_requant,
)
weight = weight.to(dtype=input.dtype)
else:
weight, bias, offload_stream = cast_bias_weight(
self,
input,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=want_requant,
)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
@ -1203,9 +1245,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0
)
quantize_input = QUANT_ALGOS.get(getattr(self, 'quant_format', None), {}).get("quantize_input", True)
# Training path: quantized forward with compute_dtype backward via autograd function
if (input.requires_grad and _use_quantized):
if (input.requires_grad and _use_quantized and quantize_input):
weight, bias, offload_stream = cast_bias_weight(
self,
@ -1227,7 +1270,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
return output
# Inference path (unchanged)
if _use_quantized:
if _use_quantized and quantize_input:
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
@ -1241,7 +1284,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
weight_only_quant = _use_quantized and not quantize_input and isinstance(self.weight, QuantizedTensor)
output = self.forward_comfy_cast_weights(
input,
compute_dtype,
want_requant=isinstance(input, QuantizedTensor),
weight_only_quant=weight_only_quant,
)
# Reshape output back to 3D if input was 3D
if reshaped_3d:
@ -1257,8 +1306,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
# dtype is now implicit in the layout class
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
weight = self.weight.requantize_from_float(weight, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
else:
weight = weight.to(self.weight.dtype)
if return_weight:

View File

@ -10,6 +10,7 @@ try:
QuantizedLayout,
TensorCoreFP8Layout as _CKFp8Layout,
TensorCoreNVFP4Layout as _CKNvfp4Layout,
TensorWiseINT8Layout as _CKTensorWiseINT8Layout,
register_layout_op,
register_layout_class,
get_layout_class,
@ -47,6 +48,9 @@ except ImportError as e:
class _CKNvfp4Layout:
pass
class _CKTensorWiseINT8Layout:
pass
def register_layout_class(name, cls):
pass
@ -174,6 +178,7 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
# Backward compatibility alias - default to E4M3
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
TensorWiseINT8Layout = _CKTensorWiseINT8Layout
# ==============================================================================
@ -184,6 +189,7 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
register_layout_class("TensorWiseINT8Layout", _CKTensorWiseINT8Layout)
if _CK_MXFP8_AVAILABLE:
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
@ -214,6 +220,13 @@ if _CK_MXFP8_AVAILABLE:
"group_size": 32,
}
QUANT_ALGOS["int8_tensorwise"] = {
"storage_t": torch.int8,
"parameters": {"weight_scale"},
"comfy_tensor_layout": "TensorWiseINT8Layout",
"quantize_input": False,
}
# ==============================================================================
# Re-exports for backward compatibility
@ -226,6 +239,7 @@ __all__ = [
"TensorCoreFP8E4M3Layout",
"TensorCoreFP8E5M2Layout",
"TensorCoreNVFP4Layout",
"TensorWiseINT8Layout",
"QUANT_ALGOS",
"register_layout_op",
]

View File

@ -891,6 +891,14 @@ class Tracks(ComfyTypeIO):
track_visibility: torch.Tensor
Type = TrackDict
@comfytype(io_type="DICT")
class Dict(ComfyTypeIO):
Type = dict
@comfytype(io_type="ARRAY")
class Array(ComfyTypeIO):
Type = list
@comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType:
Type = Any
@ -1253,6 +1261,155 @@ class DynamicSlot(ComfyTypeI):
out_dict[input_type][finalized_id] = value
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
@comfytype(io_type="COMFY_DYNAMICGROUP_V3")
class DynamicGroup(ComfyTypeI):
"""A repeatable group of widget inputs (e.g. lora_name + strength stacked into N rows).
At execution time the node receives a ``list[dict]`` where each element is a row.
Example::
io.DynamicGroup.Input(
"loras",
template=[
io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras")),
io.Float.Input("strength", default=1.0, min=-100, max=100, step=0.01),
],
min=0,
max=50,
)
# execute receives: loras: list[dict] = [{"lora_name": "x.safetensors", "strength": 1.0}, ...]
"""
Type = list[dict[str, Any]]
_MaxRows = 100
class Input(DynamicInput):
def __init__(
self,
id: str,
template: list["Input"],
min: int = 0,
max: int = 50,
display_name: str = None,
optional: bool = False,
tooltip: str = None,
lazy: bool = None,
extra_dict=None,
):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
# Validate template entries: only WidgetInput subclasses, no nesting
assert len(template) > 0, "DynamicGroup template must have at least one field."
for t in template:
assert isinstance(t, WidgetInput), (
f"DynamicGroup template field '{t.id}' must be a WidgetInput subclass "
f"(Combo, Float, Int, String, Boolean, Color). Got {type(t).__name__}."
)
assert not isinstance(t, DynamicInput), (
f"DynamicGroup template field '{t.id}' must not be a DynamicInput. "
"Nesting dynamic inputs inside DynamicGroup is not supported."
)
# Enforce unique field ids within template
field_ids = [t.id for t in template]
assert len(field_ids) == len(set(field_ids)), (
f"DynamicGroup template field ids must be unique within a row. Got: {field_ids}"
)
# Reject "." in group id and template field ids: slot_id encoding uses "." as a
# delimiter (<group_id>.<row>.<field_id>), so any "." in these names would cause
# path.split(".") to produce the wrong number of segments during decoding.
assert "." not in id, (
f"DynamicGroup id must not contain '.'. Got: '{id}'"
)
for t in template:
assert "." not in t.id, (
f"DynamicGroup template field id must not contain '.'. Got: '{t.id}'"
)
assert min >= 0, "DynamicGroup min must be >= 0."
assert max >= 1, "DynamicGroup max must be >= 1."
assert max <= DynamicGroup._MaxRows, f"DynamicGroup max must be <= {DynamicGroup._MaxRows}."
assert min <= max, "DynamicGroup min must be <= max."
self.template = template
self.min = min
self.max = max
def get_all(self) -> list["Input"]:
return [self] + list(self.template)
def as_dict(self):
return super().as_dict() | prune_dict({
"template": create_input_dict_v1(self.template),
"min": self.min,
"max": self.max,
})
def validate(self):
for t in self.template:
t.validate()
@staticmethod
def _expand_schema_for_dynamic(
out_dict: dict[str, Any],
live_inputs: dict[str, Any],
value: tuple[str, dict[str, Any]],
input_type: str,
curr_prefix: list[str] | None,
):
info = value[1]
min_rows: int = info.get("min", 0)
max_rows: int = info.get("max", DynamicGroup._MaxRows)
template: dict[str, Any] = info.get("template", {})
# Collect all template field specs across required/optional sections
field_specs: list[tuple[str, tuple[str, dict[str, Any]], bool]] = []
for field_required_key in ("required", "optional"):
section = template.get(field_required_key, {})
is_required_field = field_required_key == "required"
for field_id, field_value in section.items():
field_specs.append((field_id, field_value, is_required_field))
# Determine how many rows are currently present by scanning live_inputs
finalized_prefix = finalize_prefix(curr_prefix)
present_rows = 0
for live_key in live_inputs:
# Keys look like "<prefix>.<row>.<field_id>"
if live_key.startswith(finalized_prefix + "."):
remainder = live_key[len(finalized_prefix) + 1:]
parts = remainder.split(".", 1)
if len(parts) >= 1:
try:
row_idx = int(parts[0])
present_rows = max(present_rows, row_idx + 1)
except ValueError:
pass
if present_rows > max_rows:
raise ValueError(
f"DynamicGroup input '{finalized_prefix}' received {present_rows} rows but max is {max_rows}."
)
row_count = max(min_rows, present_rows)
for row in range(row_count):
for field_id, field_value, is_required_field in field_specs:
slot_id = f"{finalized_prefix}.{row}.{field_id}"
# The first `min_rows` rows are required if the field itself is required
if row < min_rows and is_required_field:
out_dict["required"][slot_id] = field_value
else:
out_dict["optional"][slot_id] = field_value
# Register into dynamic_paths so build_nested_inputs places value at the right path
out_dict["dynamic_paths"][slot_id] = slot_id
# Track the list root path so build_nested_inputs can convert the index dict to a list
out_dict.setdefault("list_paths", set()).add(finalized_prefix)
# Handle the empty case (0 rows) emit an empty-list default for the parent.
# This must only fire when there are genuinely no rows; otherwise the parent
# path would clobber the per-row dict built from the slot ids above.
if row_count == 0:
out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix
out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_LIST
@comfytype(io_type="IMAGECOMPARE")
class ImageCompare(ComfyTypeI):
Type = dict
@ -1279,6 +1436,19 @@ class Color(ComfyTypeIO):
def as_dict(self):
return super().as_dict()
@comfytype(io_type="COLORS")
class Colors(ComfyTypeIO):
Type = list[Color.Type]
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: list[str]=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = []
@comfytype(io_type="BOUNDING_BOX")
class BoundingBox(ComfyTypeIO):
class BoundingBoxDict(TypedDict):
@ -1326,6 +1496,20 @@ class Curve(ComfyTypeIO):
return d
@comfytype(io_type="BOUNDING_BOXES")
class BoundingBoxes(ComfyTypeIO):
class BoundingBoxWithMetadata(BoundingBox.BoundingBoxDict):
metadata: dict
Type = list[BoundingBoxWithMetadata]
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: list[dict]=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = []
@comfytype(io_type="HISTOGRAM")
class Histogram(ComfyTypeIO):
"""A histogram represented as a list of bin counts."""
@ -1383,6 +1567,8 @@ def setup_dynamic_input_funcs():
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
# DynamicSlot.Input
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
# DynamicGroup.Input
register_dynamic_input_func(DynamicGroup.io_type, DynamicGroup._expand_schema_for_dynamic)
if len(DYNAMIC_INPUT_LOOKUP) == 0:
setup_dynamic_input_funcs()
@ -1394,6 +1580,8 @@ class V3Data(TypedDict):
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
dynamic_paths_default_value: dict[str, Any]
'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.'
list_paths: set[str]
'Set of top-level keys whose index-keyed dict values should be converted to a sorted list[dict] after build_nested_inputs runs.'
create_dynamic_tuple: bool
'When True, the value of the dynamic input will be in the format (value, path_key).'
@ -1735,6 +1923,7 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
"optional": {},
"dynamic_paths": {},
"dynamic_paths_default_value": {},
"list_paths": set(),
}
d = d.copy()
# ignore hidden for parsing
@ -1750,6 +1939,10 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None)
if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0:
v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value
# list_paths: keys whose nested dict should be post-converted to a sorted list[dict]
list_paths = out_dict.pop("list_paths", None)
if list_paths:
v3_data["list_paths"] = list_paths
return out_dict, hidden, v3_data
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
@ -1785,10 +1978,12 @@ def add_to_dict_v1(i: Input, d: dict):
class DynamicPathsDefaultValue:
EMPTY_DICT = "empty_dict"
EMPTY_LIST = "empty_list"
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
paths = v3_data.get("dynamic_paths", None)
default_value_dict = v3_data.get("dynamic_paths_default_value", {})
list_paths: set[str] = v3_data.get("list_paths", set()) or set()
if paths is None:
return values
values = values.copy()
@ -1811,6 +2006,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
default_option = default_value_dict.get(key, None)
if default_option == DynamicPathsDefaultValue.EMPTY_DICT:
value = {}
elif default_option == DynamicPathsDefaultValue.EMPTY_LIST:
value = []
if create_tuple:
value = (value, key)
current[p] = value
@ -1818,6 +2015,34 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
current = current.setdefault(p, {})
values.update(result)
# Post-pass: convert index-keyed dicts to sorted lists for io.DynamicGroup fields
for list_path in list_paths:
parts = list_path.split(".")
# Navigate to the parent container, then convert the leaf
container = values
for part in parts[:-1]:
if not isinstance(container, dict) or part not in container:
container = None
break
container = container[part]
if container is None:
continue
leaf_key = parts[-1]
leaf = container.get(leaf_key, None)
if isinstance(leaf, dict):
try:
sorted_rows = [leaf[k] for k in sorted(leaf.keys(), key=int)]
container[leaf_key] = sorted_rows
except (ValueError, TypeError):
# Keys are not all integers; leave as-is
pass
elif isinstance(leaf, list):
# Already a list (e.g. the EMPTY_LIST default was applied above)
pass
elif leaf is None:
container[leaf_key] = []
return values
@ -2376,11 +2601,15 @@ __all__ = [
"AnyType",
"MultiType",
"Tracks",
"Dict",
"Array",
"Color",
# Dynamic Types
"MatchType",
"DynamicCombo",
"DynamicSlot",
"Autogrow",
"DynamicGroup",
# Other classes
"HiddenHolder",
"Hidden",
@ -2394,6 +2623,8 @@ __all__ = [
"PriceBadgeDepends",
"PriceBadge",
"BoundingBox",
"BoundingBoxes",
"Colors",
"Curve",
"Histogram",
"Range",

View File

@ -0,0 +1,23 @@
def hex_to_rgb(value: str) -> tuple[int, int, int]:
h = value.lstrip("#")
if len(h) != 6:
return (255, 255, 255)
try:
return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16))
except ValueError:
return (255, 255, 255)
def readable_color(rgb: tuple[int, int, int]) -> tuple[int, int, int]:
r, g, b = rgb
lum = 0.299 * r + 0.587 * g + 0.114 * b
if lum >= 130:
return (r, g, b)
t = (130 - lum) / (255 - lum)
return (round(r + (255 - r) * t), round(g + (255 - g) * t), round(b + (255 - b) * t))
def normalize_palette(colors) -> list[str]:
if isinstance(colors, dict):
colors = colors.values()
return [c.upper() for c in colors if isinstance(c, str) and c]

View File

@ -0,0 +1,253 @@
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageEnhance, ImageFont
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.color_util import hex_to_rgb, normalize_palette, readable_color
_PREVIEW_LONG_EDGE = 1024
_PREVIEW_DIM = 0.25
def pixels_to_fractions(box: dict, width: int, height: int) -> dict:
w = width or 1
h = height or 1
return {
"x": box.get("x", 0) / w,
"y": box.get("y", 0) / h,
"w": box.get("width", 0) / w,
"h": box.get("height", 0) / h,
}
def fractions_to_pixels(box: dict, width: int, height: int) -> dict:
x, y = box.get("x", 0.0), box.get("y", 0.0)
w, h = box.get("w", 0.0), box.get("h", 0.0)
if w < 0:
x, w = x + w, -w
if h < 0:
y, h = y + h, -h
return {
"x": round(x * width),
"y": round(y * height),
"width": round(w * width),
"height": round(h * height),
}
def fractions_to_bbox_frame(boxes: list, width: int, height: int) -> list:
pixels = [
fractions_to_pixels(box, width, height)
for box in boxes
if isinstance(box, dict)
]
return [pixels] if pixels else []
def _font(size: int):
try:
return ImageFont.load_default(size)
except Exception:
return ImageFont.load_default()
def _wrap(draw, text: str, font, max_w: float) -> list[str]:
lines = []
for para in text.split("\n"):
line = ""
for word in para.split():
test = word if not line else line + " " + word
if line and draw.textlength(test, font=font) > max_w:
lines.append(line)
line = word
else:
line = test
lines.append(line)
return lines
def _bg_from_image(image) -> Image.Image | None:
if image is None:
return None
try:
arr = (image[0].detach().cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
return Image.fromarray(arr)
except Exception:
return None
def render_preview(regions, width, height, bg=None):
if bg is not None:
iw, ih = bg.size
long_edge = max(iw, ih) or 1
scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge)
rw, rh = max(1, round(iw * scale)), max(1, round(ih * scale))
base = bg.convert("RGB").resize((rw, rh), Image.LANCZOS)
base = ImageEnhance.Brightness(base).enhance(_PREVIEW_DIM)
img = base.convert("RGBA")
else:
long_edge = max(width, height) or 1
scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge)
rw, rh = max(1, round(width * scale)), max(1, round(height * scale))
grey = round(_PREVIEW_DIM * 128)
img = Image.new("RGBA", (rw, rh), (grey, grey, grey, 255))
overlay = Image.new("RGBA", (rw, rh), (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
fs = max(10, round(rh / 64))
font = _font(fs)
tag_font = _font(max(9, fs - 2))
line_h = fs + 2
for i, region in enumerate(regions):
if not isinstance(region, dict):
continue
palette = [c for c in (region.get("palette") or []) if c]
r, g, b = hex_to_rgb(palette[0]) if palette else (140, 140, 140)
x1 = max(0, min(rw, round(region.get("x", 0) * rw)))
y1 = max(0, min(rh, round(region.get("y", 0) * rh)))
x2 = max(0, min(rw, round((region.get("x", 0) + region.get("w", 0)) * rw)))
y2 = max(0, min(rh, round((region.get("y", 0) + region.get("h", 0)) * rh)))
if x2 < x1:
x1, x2 = x2, x1
if y2 < y1:
y1, y2 = y2, y1
draw.rectangle([x1, y1, x2, y2], outline=(r, g, b, 255), width=2)
swatches = palette[:5]
if swatches and (x2 - x1) > 2:
sh = max(5, fs // 2)
seg = (x2 - x1) / len(swatches)
for p, hexc in enumerate(swatches):
sx = x1 + round(p * seg)
draw.rectangle([sx, y1, x1 + round((p + 1) * seg), y1 + sh], fill=hex_to_rgb(hexc))
etype = "text" if region.get("type") == "text" else "obj"
tag = str(i + 1).zfill(2)
tw = draw.textlength(tag, font=tag_font)
draw.rectangle([x1, y1, x1 + tw + 6, y1 + fs + 2], fill=(r, g, b, 255))
tag_fill = (0, 0, 0, 255) if (0.299 * r + 0.587 * g + 0.114 * b) > 140 else (255, 255, 255, 255)
draw.text((x1 + 3, y1 + 1), tag, fill=tag_fill, font=tag_font)
body = region.get("desc", "") or ""
if etype == "text" and region.get("text"):
body = '"%s"%s' % (region["text"], "" + body if body else "")
if body and (x2 - x1) > 8:
ty = y1 + fs + 5
for line in _wrap(draw, body, font, x2 - x1 - 8):
if ty > y2:
break
draw.text((x1 + 4, ty), line, fill=readable_color((r, g, b)) + (255,), font=font)
ty += line_h
composed = Image.alpha_composite(img, overlay).convert("RGB")
arr = np.asarray(composed, dtype=np.float32) / 255.0
return torch.from_numpy(arr).unsqueeze(0)
def boxes_to_regions(boxes, width: int, height: int) -> list:
regions: list = []
if not isinstance(boxes, list):
return regions
for box in boxes:
if not isinstance(box, dict):
continue
meta = box.get("metadata")
meta = meta if isinstance(meta, dict) else {}
regions.append({
**pixels_to_fractions(box, width, height),
"type": meta.get("type", "obj"),
"text": meta.get("text", ""),
"desc": meta.get("desc", ""),
"palette": meta.get("palette", []),
})
return regions
def _norm_bbox(region: dict) -> list[int]:
def grid(value: float) -> int:
return max(0, min(1000, round(value * 1000)))
x, y = region.get("x", 0.0), region.get("y", 0.0)
w, h = region.get("w", 0.0), region.get("h", 0.0)
ymin, xmin, ymax, xmax = grid(y), grid(x), grid(y + h), grid(x + w)
if ymin > ymax:
ymin, ymax = ymax, ymin
if xmin > xmax:
xmin, xmax = xmax, xmin
return [ymin, xmin, ymax, xmax]
def build_elements(regions: list) -> list:
elements = []
for region in regions:
if not isinstance(region, dict):
continue
etype = "text" if region.get("type") == "text" else "obj"
element = {"type": etype}
element["bbox"] = _norm_bbox(region)
if etype == "text":
element["text"] = region.get("text", "")
element["desc"] = region.get("desc", "")
palette = normalize_palette(region.get("palette", []))
if palette:
element["color_palette"] = palette[:5]
elements.append(element)
return elements
class CreateBoundingBoxes(io.ComfyNode):
@classmethod
def define_schema(cls):
editor_state = io.BoundingBoxes.Input(
"editor_state",
socketless=False,
tooltip="Draw bounding boxes and set each box type, text, description, color palette. Start with background element first and foreground last.",
)
return io.Schema(
node_id="CreateBoundingBoxes",
display_name="Create Bounding Boxes",
category="utilities",
description="Draw bounding boxes in a canvas. Outputs Ideogram prompt elements, pixel-space bounding boxes, and a preview image.",
inputs=[
io.Image.Input(
"background",
optional=True,
tooltip="Optional image used as background in the canvas and preview.",
),
io.Int.Input("width", default=1024, min=64, max=16384, step=16,
tooltip="Width of the canvas and the pixel grid for the bounding boxes."),
io.Int.Input("height", default=1024, min=64, max=16384, step=16,
tooltip="Height of the canvas and the pixel grid for the bounding boxes."),
editor_state,
],
outputs=[
io.Image.Output(display_name="preview"),
io.BoundingBox.Output(display_name="bboxes"),
io.Array.Output(display_name="elements"),
],
is_experimental=True,
)
@classmethod
def execute(cls, width, height, editor_state=None, background=None) -> io.NodeOutput:
regions = boxes_to_regions(editor_state, width, height)
preview = render_preview(regions, width, height, _bg_from_image(background))
return io.NodeOutput(
preview,
fractions_to_bbox_frame(regions, width, height),
build_elements(regions),
ui={"dims": [width, height]},
)
class BoundingBoxesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [CreateBoundingBoxes]
async def comfy_entrypoint() -> BoundingBoxesExtension:
return BoundingBoxesExtension()

View File

@ -1,5 +1,6 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.color_util import hex_to_rgb
class ColorToRGBInt(io.ComfyNode):
@ -24,9 +25,11 @@ class ColorToRGBInt(io.ComfyNode):
# expect format #RRGGBB
if len(color) != 7 or color[0] != "#":
raise ValueError("Color must be in format #RRGGBB")
r = int(color[1:3], 16)
g = int(color[3:5], 16)
b = int(color[5:7], 16)
try:
int(color[1:], 16)
except ValueError:
raise ValueError("Color must be in format #RRGGBB") from None
r, g, b = hex_to_rgb(color)
rgb_int = r * 256 * 256 + g * 256 + b
return io.NodeOutput(rgb_int, color)

View File

@ -1,85 +1,68 @@
import os
import sys
import re
import ctypes
import logging
import ctypes.util
import importlib.util
from typing import TypedDict
import numpy as np
import torch
import nodes
import comfy_angle
from comfy_api.latest import ComfyExtension, io, ui
from typing_extensions import override
from utils.install_util import get_missing_requirements_message
logger = logging.getLogger(__name__)
def _check_opengl_availability():
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
logger.debug("_check_opengl_availability: starting")
missing = []
def _preload_angle():
egl_path = comfy_angle.get_egl_path()
gles_path = comfy_angle.get_glesv2_path()
# Check Python packages (using find_spec to avoid importing)
logger.debug("_check_opengl_availability: checking for glfw package")
if importlib.util.find_spec("glfw") is None:
missing.append("glfw")
if sys.platform == "win32":
angle_dir = comfy_angle.get_lib_dir()
os.add_dll_directory(angle_dir)
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
logger.debug("_check_opengl_availability: checking for OpenGL package")
if importlib.util.find_spec("OpenGL") is None:
missing.append("PyOpenGL")
if missing:
raise RuntimeError(
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
)
# On Linux without display, check if headless backends are available
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
if sys.platform.startswith("linux"):
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
if not has_display:
# Check for EGL or OSMesa libraries
logger.debug("_check_opengl_availability: checking for EGL library")
has_egl = ctypes.util.find_library("EGL")
logger.debug("_check_opengl_availability: checking for OSMesa library")
has_osmesa = ctypes.util.find_library("OSMesa")
# Error disabled for CI as it fails this check
# if not has_egl and not has_osmesa:
# raise RuntimeError(
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
# "See error below for installation instructions."
# )
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
logger.debug("_check_opengl_availability: completed")
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
ctypes.CDLL(str(egl_path), mode=mode)
ctypes.CDLL(str(gles_path), mode=mode)
# Run early check at import time
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
_check_opengl_availability()
# OpenGL modules - initialized lazily when context is created
gl = None
glfw = None
EGL = None
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
_preload_angle()
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
def _import_opengl():
"""Import OpenGL module. Called after context is created."""
global gl
if gl is None:
logger.debug("_import_opengl: importing OpenGL.GL")
import OpenGL.GL as _gl
gl = _gl
logger.debug("_import_opengl: import completed")
return gl
import OpenGL
OpenGL.USE_ACCELERATE = False
def _patch_find_library():
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
'libGLESv2'. Patch find_library to return the full ANGLE paths so
PyOpenGL loads the same libraries we pre-loaded."""
if sys.platform == "linux":
return
import ctypes.util
_orig = ctypes.util.find_library
def _patched(name):
if name == 'EGL':
return comfy_angle.get_egl_path()
if name == 'GLESv2':
return comfy_angle.get_glesv2_path()
return _orig(name)
ctypes.util.find_library = _patched
_patch_find_library()
from OpenGL import EGL
from OpenGL import GLES3 as gl
class SizeModeInput(TypedDict):
size_mode: str
width: int
@ -102,7 +85,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
# (-1,-1)---(3,-1)
#
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
VERTEX_SHADER = """#version 330 core
VERTEX_SHADER = """#version 300 es
out vec2 v_texCoord;
void main() {
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
@ -126,14 +109,99 @@ void main() {
"""
def _convert_es_to_desktop(source: str) -> str:
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
# Remove any existing #version directive
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
# Remove precision qualifiers (not needed in desktop GLSL)
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
# Prepend desktop GLSL version
return "#version 330 core\n" + source
def _egl_attribs(*values):
"""Build an EGL_NONE-terminated EGLint attribute array."""
vals = list(values) + [EGL.EGL_NONE]
return (ctypes.c_int32 * len(vals))(*vals)
# EGL platform extension constants
EGL_PLATFORM_ANGLE_ANGLE = 0x3202
EGL_PLATFORM_ANGLE_TYPE_ANGLE = 0x3203
EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE = 0x3450
EGL_MESA_PLATFORM_SURFACELESS = 0x31DD
_eglGetPlatformDisplayEXT = None
def _get_egl_platform_display_ext(platform, native_display, attribs):
"""Call eglGetPlatformDisplayEXT via ctypes (extension, not in PyOpenGL)."""
global _eglGetPlatformDisplayEXT
if _eglGetPlatformDisplayEXT is None:
from OpenGL import platform as _plat
egl_lib = _plat.PLATFORM.EGL
_get_proc = egl_lib.eglGetProcAddress
_get_proc.restype = ctypes.c_void_p
_get_proc.argtypes = [ctypes.c_char_p]
ptr = _get_proc(b"eglGetPlatformDisplayEXT")
if not ptr:
return None
func_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p, ctypes.c_void_p)
_eglGetPlatformDisplayEXT = func_type(ptr)
raw = _eglGetPlatformDisplayEXT(platform, native_display, attribs)
if not raw:
return None
return ctypes.cast(raw, EGL.EGLDisplay)
def _get_egl_display():
"""Get an EGL display, trying the default first then ANGLE's Vulkan
platform for headless environments without a display server."""
failures = []
# Try the default display first (works when X11/Wayland is available)
display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
if display:
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
try:
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
return display, major.value, minor.value
except Exception as e:
failures.append(f"default: {e}")
logger.info("Default EGL display unavailable, trying headless fallbacks")
# Headless fallback strategies, tried in order:
headless_strategies = [
("surfaceless", EGL_MESA_PLATFORM_SURFACELESS, None, None),
("ANGLE Vulkan", EGL_PLATFORM_ANGLE_ANGLE, None,
_egl_attribs(EGL_PLATFORM_ANGLE_TYPE_ANGLE, EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE)),
]
for name, platform, native_display, attribs in headless_strategies:
display = _get_egl_platform_display_ext(platform, native_display, attribs)
if not display:
failures.append(f"{name}: eglGetPlatformDisplayEXT returned no display")
continue
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
try:
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
logger.info(f"Using EGL {name} platform (headless)")
return display, major.value, minor.value
failures.append(f"{name}: eglInitialize returned false")
except Exception as e:
failures.append(f"{name}: {e}")
continue
details = "\n".join(f" - {f}" for f in failures)
raise RuntimeError(
"Failed to initialize EGL display.\n"
"No display server and no headless EGL platform available.\n"
f"Tried:\n{details}\n"
"Ensure GPU drivers are installed or set DISPLAY for a virtual framebuffer."
)
def _gl_str(name):
"""Get an OpenGL string parameter."""
v = gl.glGetString(name)
if not v:
return "Unknown"
if isinstance(v, bytes):
return v.decode(errors="replace")
return ctypes.string_at(v).decode(errors="replace")
def _detect_output_count(source: str) -> int:
@ -159,163 +227,8 @@ def _detect_pass_count(source: str) -> int:
return 1
def _init_glfw():
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
logger.debug("_init_glfw: starting")
# On macOS, glfw.init() must be called from main thread or it hangs forever
if sys.platform == "darwin":
logger.debug("_init_glfw: skipping on macOS")
raise RuntimeError("GLFW backend not supported on macOS")
logger.debug("_init_glfw: importing glfw module")
import glfw as _glfw
logger.debug("_init_glfw: calling glfw.init()")
if not _glfw.init():
raise RuntimeError("glfw.init() failed")
try:
logger.debug("_init_glfw: setting window hints")
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
logger.debug("_init_glfw: calling create_window()")
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
if not window:
raise RuntimeError("glfw.create_window() failed")
logger.debug("_init_glfw: calling make_context_current()")
_glfw.make_context_current(window)
logger.debug("_init_glfw: completed successfully")
return window, _glfw
except Exception:
logger.debug("_init_glfw: failed, terminating glfw")
_glfw.terminate()
raise
def _init_egl():
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
logger.debug("_init_egl: starting")
from OpenGL import EGL as _EGL
from OpenGL.EGL import (
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
eglTerminate, eglDestroyContext, eglDestroySurface,
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
)
logger.debug("_init_egl: imports completed")
display = None
context = None
surface = None
try:
logger.debug("_init_egl: calling eglGetDisplay()")
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
if display == _EGL.EGL_NO_DISPLAY:
raise RuntimeError("eglGetDisplay() failed")
logger.debug("_init_egl: calling eglInitialize()")
major, minor = _EGL.EGLint(), _EGL.EGLint()
if not eglInitialize(display, major, minor):
display = None # Not initialized, don't terminate
raise RuntimeError("eglInitialize() failed")
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
config_attribs = [
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
EGL_DEPTH_SIZE, 0, EGL_NONE
]
configs = (_EGL.EGLConfig * 1)()
num_configs = _EGL.EGLint()
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
config = configs[0]
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
if not eglBindAPI(EGL_OPENGL_API):
raise RuntimeError("eglBindAPI() failed")
logger.debug("_init_egl: calling eglCreateContext()")
context_attribs = [
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
EGL_NONE
]
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
if context == EGL_NO_CONTEXT:
raise RuntimeError("eglCreateContext() failed")
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
if surface == _EGL.EGL_NO_SURFACE:
raise RuntimeError("eglCreatePbufferSurface() failed")
logger.debug("_init_egl: calling eglMakeCurrent()")
if not eglMakeCurrent(display, surface, surface, context):
raise RuntimeError("eglMakeCurrent() failed")
logger.debug("_init_egl: completed successfully")
return display, context, surface, _EGL
except Exception:
logger.debug("_init_egl: failed, cleaning up")
# Clean up any resources on failure
if surface is not None:
eglDestroySurface(display, surface)
if context is not None:
eglDestroyContext(display, context)
if display is not None:
eglTerminate(display)
raise
def _init_osmesa():
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
import ctypes
logger.debug("_init_osmesa: starting")
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
logger.debug("_init_osmesa: importing OpenGL.osmesa")
from OpenGL import GL as _gl
from OpenGL.osmesa import (
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
OSMESA_RGBA,
)
logger.debug("_init_osmesa: imports completed")
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
if not ctx:
raise RuntimeError("OSMesaCreateContextExt() failed")
width, height = 64, 64
buffer = (ctypes.c_ubyte * (width * height * 4))()
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
OSMesaDestroyContext(ctx)
raise RuntimeError("OSMesaMakeCurrent() failed")
logger.debug("_init_osmesa: completed successfully")
return ctx, buffer
class GLContext:
"""Manages OpenGL context and resources for shader execution.
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
"""
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
_instance = None
_initialized = False
@ -327,131 +240,105 @@ class GLContext:
def __init__(self):
if GLContext._initialized:
logger.debug("GLContext.__init__: already initialized, skipping")
return
logger.debug("GLContext.__init__: starting initialization")
global glfw, EGL
import time
start = time.perf_counter()
self._backend = None
self._window = None
self._egl_display = None
self._egl_context = None
self._egl_surface = None
self._osmesa_ctx = None
self._osmesa_buffer = None
self._display = None
self._surface = None
self._context = None
self._vao = None
# Try backends in order: GLFW → EGL → OSMesa
errors = []
logger.debug("GLContext.__init__: trying GLFW backend")
try:
self._window, glfw = _init_glfw()
self._backend = "glfw"
logger.debug("GLContext.__init__: GLFW backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
errors.append(("GLFW", e))
self._display, self._egl_major, self._egl_minor = _get_egl_display()
if self._backend is None:
logger.debug("GLContext.__init__: trying EGL backend")
try:
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
self._backend = "egl"
logger.debug("GLContext.__init__: EGL backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
errors.append(("EGL", e))
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
if self._backend is None:
logger.debug("GLContext.__init__: trying OSMesa backend")
try:
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
self._backend = "osmesa"
logger.debug("GLContext.__init__: OSMesa backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
errors.append(("OSMesa", e))
config = EGL.EGLConfig()
n_configs = ctypes.c_int32(0)
if not EGL.eglChooseConfig(
self._display,
_egl_attribs(
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
),
ctypes.byref(config), 1, ctypes.byref(n_configs),
) or n_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
if self._backend is None:
if sys.platform == "win32":
platform_help = (
"Windows: Ensure GPU drivers are installed and display is available.\n"
" CPU-only/headless mode is not supported on Windows."
)
elif sys.platform == "darwin":
platform_help = (
"macOS: GLFW is not supported.\n"
" Install OSMesa via Homebrew: brew install mesa\n"
" Then: pip install PyOpenGL PyOpenGL-accelerate"
)
else:
platform_help = (
"Linux: Install one of these backends:\n"
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
" Headless (CPU): sudo apt install libosmesa6"
)
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
raise RuntimeError(
f"Failed to create OpenGL context.\n\n"
f"Backend errors:\n{error_details}\n\n"
f"{platform_help}"
self._surface = EGL.eglCreatePbufferSurface(
self._display, config,
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
)
if not self._surface:
raise RuntimeError("eglCreatePbufferSurface() failed")
# Now import OpenGL.GL (after context is current)
logger.debug("GLContext.__init__: importing OpenGL.GL")
_import_opengl()
self._context = EGL.eglCreateContext(
self._display, config, EGL.EGL_NO_CONTEXT,
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
)
if not self._context:
raise RuntimeError("eglCreateContext() failed")
# Create VAO (required for core profile, but OSMesa may use compat profile)
logger.debug("GLContext.__init__: creating VAO")
try:
vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(vao)
self._vao = vao # Only store after successful bind
logger.debug("GLContext.__init__: VAO created successfully")
except Exception as e:
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
# OSMesa with older Mesa may not support VAOs
# Clean up if we created but couldn't bind
if vao:
try:
gl.glDeleteVertexArrays(1, [vao])
except Exception:
pass
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
raise RuntimeError("eglMakeCurrent() failed")
self._vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(self._vao)
except Exception:
self._cleanup()
raise
elapsed = (time.perf_counter() - start) * 1000
# Log device info
renderer = gl.glGetString(gl.GL_RENDERER)
vendor = gl.glGetString(gl.GL_VENDOR)
version = gl.glGetString(gl.GL_VERSION)
renderer = renderer.decode() if renderer else "Unknown"
vendor = vendor.decode() if vendor else "Unknown"
version = version.decode() if version else "Unknown"
renderer = _gl_str(gl.GL_RENDERER)
vendor = _gl_str(gl.GL_VENDOR)
version = _gl_str(gl.GL_VERSION)
GLContext._initialized = True
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - EGL {self._egl_major}.{self._egl_minor}, {renderer} ({vendor}), GL {version}")
def make_current(self):
if self._backend == "glfw":
glfw.make_context_current(self._window)
elif self._backend == "egl":
from OpenGL.EGL import eglMakeCurrent
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
elif self._backend == "osmesa":
from OpenGL.osmesa import OSMesaMakeCurrent
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
err = EGL.eglGetError()
raise RuntimeError(f"eglMakeCurrent() failed (EGL error: 0x{err:04X})")
if self._vao is not None:
gl.glBindVertexArray(self._vao)
def _cleanup(self):
if not self._display:
return
try:
if self._vao is not None:
gl.glDeleteVertexArrays(1, [self._vao])
self._vao = None
except Exception:
pass
try:
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
except Exception:
pass
try:
if self._context:
EGL.eglDestroyContext(self._display, self._context)
except Exception:
pass
try:
if self._surface:
EGL.eglDestroySurface(self._display, self._surface)
except Exception:
pass
try:
EGL.eglTerminate(self._display)
except Exception:
pass
self._display = None
def _compile_shader(source: str, shader_type: int) -> int:
"""Compile a shader and return its ID."""
@ -459,8 +346,10 @@ def _compile_shader(source: str, shader_type: int) -> int:
gl.glShaderSource(shader, source)
gl.glCompileShader(shader)
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
error = gl.glGetShaderInfoLog(shader).decode()
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
error = gl.glGetShaderInfoLog(shader)
if isinstance(error, bytes):
error = error.decode(errors="replace")
gl.glDeleteShader(shader)
raise RuntimeError(f"Shader compilation failed:\n{error}")
@ -484,8 +373,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
gl.glDeleteShader(vertex_shader)
gl.glDeleteShader(fragment_shader)
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
error = gl.glGetProgramInfoLog(program).decode()
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
error = gl.glGetProgramInfoLog(program)
if isinstance(error, bytes):
error = error.decode(errors="replace")
gl.glDeleteProgram(program)
raise RuntimeError(f"Program linking failed:\n{error}")
@ -530,9 +421,6 @@ def _render_shader_batch(
ctx = GLContext()
ctx.make_current()
# Convert from GLSL ES to desktop GLSL 330
fragment_source = _convert_es_to_desktop(fragment_code)
# Detect how many outputs the shader actually uses
num_outputs = _detect_output_count(fragment_code)
@ -558,9 +446,9 @@ def _render_shader_batch(
try:
# Compile shaders (once for all batches)
try:
program = _create_program(VERTEX_SHADER, fragment_source)
program = _create_program(VERTEX_SHADER, fragment_code)
except RuntimeError:
logger.error(f"Fragment shader:\n{fragment_source}")
logger.error(f"Fragment shader:\n{fragment_code}")
raise
gl.glUseProgram(program)
@ -723,13 +611,13 @@ def _render_shader_batch(
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
# Read back outputs for this batch
# (glGetTexImage is synchronous, implicitly waits for rendering)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
batch_outputs = []
for tex in output_textures:
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
batch_outputs.append(img[::-1, :, :].copy())
for i in range(num_outputs):
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
buf = np.empty((height, width, 4), dtype=np.float32)
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
batch_outputs.append(buf[::-1, :, :].copy())
# Pad with black images for unused outputs
black_img = np.zeros((height, width, 4), dtype=np.float32)
@ -750,18 +638,18 @@ def _render_shader_batch(
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
gl.glUseProgram(0)
for tex in input_textures:
gl.glDeleteTextures(int(tex))
for tex in curve_textures:
gl.glDeleteTextures(int(tex))
for tex in output_textures:
gl.glDeleteTextures(int(tex))
for tex in ping_pong_textures:
gl.glDeleteTextures(int(tex))
if input_textures:
gl.glDeleteTextures(len(input_textures), input_textures)
if curve_textures:
gl.glDeleteTextures(len(curve_textures), curve_textures)
if output_textures:
gl.glDeleteTextures(len(output_textures), output_textures)
if ping_pong_textures:
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
if fbo is not None:
gl.glDeleteFramebuffers(1, [fbo])
for pp_fbo in ping_pong_fbos:
gl.glDeleteFramebuffers(1, [pp_fbo])
if ping_pong_fbos:
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
if program is not None:
gl.glDeleteProgram(program)

View File

@ -0,0 +1,77 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.color_util import normalize_palette
class BuildJsonPromptIdeogram(io.ComfyNode):
@classmethod
def define_schema(cls):
color_palette = io.Colors.Input(
"color_palette",
socketless=False,
tooltip="Hex color codes that steer the image's dominant colors. Up to 16 entries.",
)
return io.Schema(
node_id="BuildJsonPromptIdeogram",
display_name="Build JSON Prompt (Ideogram)",
category="text",
description="Build a JSON prompt for the Ideogram 4 model.",
inputs=[
io.Array.Input("element", tooltip="Prompt elements from the node Create Bounding Boxes."),
io.String.Input("high_level_description", multiline=True, default="",
tooltip="Optional description of the image in one or two sentences. Strongly recommended."),
io.String.Input("background", multiline=True, default="",
tooltip="Mandatory description of the image background or environment."),
io.DynamicCombo.Input("style", options=[
io.DynamicCombo.Option("none", []),
io.DynamicCombo.Option("photo", [io.String.Input("photo", default="", tooltip="Camera or lens details for photographic outputs (e.g. 35mm, f/1.4, bokeh).")]),
io.DynamicCombo.Option("art_style", [io.String.Input("art_style", default="", tooltip="Art style description (e.g. flat vector illustration, bold outlines).")]),
]),
io.String.Input("aesthetics", default="", tooltip="Mandatory aesthetic keywords (e.g. moody, cinematic, desaturated)."),
io.String.Input("lighting", default="", tooltip="Mandatory lighting description (e.g. golden hour, rim light, dramatic shadows)."),
io.String.Input("medium", default="", tooltip="Mandatory medium type (e.g. photograph, illustration, 3d_render, painting, graphic_design). When style = photo, set to photograph."),
color_palette,
],
outputs=[io.Dict.Output(display_name="prompt")],
is_experimental=True,
)
@classmethod
def execute(cls, element, style, high_level_description="", background="",
aesthetics="", lighting="", medium="", color_palette=None) -> io.NodeOutput:
elements = element if isinstance(element, list) else []
kind = style.get("style", "none") if isinstance(style, dict) else "none"
photo = style.get("photo", "") if isinstance(style, dict) else ""
art_style = style.get("art_style", "") if isinstance(style, dict) else ""
palette = normalize_palette(color_palette or [])
caption: dict = {}
if high_level_description.strip():
caption["high_level_description"] = high_level_description
if kind != "none":
style_desc: dict = {"aesthetics": aesthetics, "lighting": lighting}
if kind == "photo":
style_desc["photo"] = photo
style_desc["medium"] = medium
else:
style_desc["medium"] = medium
style_desc["art_style"] = art_style
if palette:
style_desc["color_palette"] = palette
caption["style_description"] = style_desc
caption["compositional_deconstruction"] = {
"background": background,
"elements": elements,
}
return io.NodeOutput(caption)
class JsonPromptExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [BuildJsonPromptIdeogram]
async def comfy_entrypoint() -> JsonPromptExtension:
return JsonPromptExtension()

View File

@ -337,6 +337,36 @@ class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict}
class ModelMergeKrea2(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "model/merging/model specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["first."] = argument
arg_dict["tmlp."] = argument
arg_dict["txtmlp."] = argument
arg_dict["tproj."] = argument
for i in range(2):
arg_dict["txtfusion.layerwise_blocks.{}.".format(i)] = argument
arg_dict["txtfusion.projector."] = argument
for i in range(2):
arg_dict["txtfusion.refiner_blocks.{}.".format(i)] = argument
for i in range(28):
arg_dict["blocks.{}.".format(i)] = argument
arg_dict["last."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
@ -353,4 +383,5 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
"ModelMergeQwenImage": ModelMergeQwenImage,
"ModelMergeKrea2": ModelMergeKrea2,
}

View File

@ -0,0 +1,33 @@
import sys
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class SeedNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedNode",
display_name="Seed",
search_aliases=["seed", "random"],
category="utilities",
inputs=[
io.Int.Input("seed", min=0, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed),
],
outputs=[io.Int.Output(display_name="seed")],
)
@classmethod
def execute(cls, seed: int) -> io.NodeOutput:
return io.NodeOutput(seed)
class SeedExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [SeedNode]
async def comfy_entrypoint() -> SeedExtension:
return SeedExtension()

View File

@ -440,6 +440,57 @@ class JsonExtractString(io.ComfyNode):
except (json.JSONDecodeError, TypeError):
return io.NodeOutput("")
def _dump_json(value, indent):
return json.dumps(value, ensure_ascii=False, indent=indent or None)
class ConvertDictionaryToString(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ConvertDictionaryToString",
display_name="Convert Dictionary to String",
category="text",
search_aliases=["json", "dict to json", "stringify", "serialize", "dict to string"],
inputs=[
io.Dict.Input("dictionary"),
io.Int.Input("indent", default=2, min=0, max=8,
tooltip="Spaces per indent level. 0 produces compact single-line string."),
],
outputs=[
io.String.Output(),
],
)
@classmethod
def execute(cls, dictionary, indent=2):
return io.NodeOutput(_dump_json(dictionary, indent))
class ConvertArrayToString(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ConvertArrayToString",
display_name="Convert Array to String",
category="text",
search_aliases=["json", "list to json", "stringify", "serialize", "list to string", "array to json"],
inputs=[
io.Array.Input("array"),
io.Int.Input("indent", default=2, min=0, max=8,
tooltip="Spaces per indent level. 0 produces compact single-line string."),
],
outputs=[
io.String.Output(),
],
)
@classmethod
def execute(cls, array, indent=2):
return io.NodeOutput(_dump_json(array, indent))
class StringExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -457,6 +508,8 @@ class StringExtension(ComfyExtension):
RegexExtract,
RegexReplace,
JsonExtractString,
ConvertDictionaryToString,
ConvertArrayToString,
]
async def comfy_entrypoint() -> StringExtension:

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.26.2"
__version__ = "0.26.0"

View File

@ -2374,6 +2374,8 @@ async def init_builtin_extra_nodes():
"nodes_images.py",
"nodes_video_model.py",
"nodes_ideogram4.py",
"nodes_bounding_boxes.py",
"nodes_json_prompt.py",
"nodes_train.py",
"nodes_dataset.py",
"nodes_sag.py",
@ -2473,6 +2475,7 @@ async def init_builtin_extra_nodes():
"nodes_gaussian_splat.py",
"nodes_triposplat.py",
"nodes_depth_anything_3.py",
"nodes_seed.py",
]
import_failed = []

View File

@ -1692,6 +1692,12 @@ paths:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Unsupported media type
"422":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Validation error (e.g., disallowed model_type tag)
"500":
content:
application/json:
@ -2137,6 +2143,12 @@ paths:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Source asset with given hash not found
"422":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Validation error (e.g., disallowed model_type tag)
"500":
content:
application/json:
@ -2357,6 +2369,10 @@ paths:
description: |
Returns a list of model folders available in the system.
This is an experimental endpoint that replaces the legacy /models endpoint.
Each folder's name is the identifier to pass to /api/experiment/models/{folder}.
Once the model_type migration is active the names are model_type folder_names
(e.g. `ultralytics_bbox`); a folder with no folder_name mapping is returned by
its directory path.
operationId: getModelFolders
responses:
"200":
@ -2988,7 +3004,7 @@ paths:
format: uuid
type: string
- description: |
When present, each output item in the response receives a `short_url` field containing an owner-gated durable link for that asset. Omit this parameter (the default) to receive a response identical to the no-param baseline. The value selects the link's lifetime: use `ephemeral_tool_chain` for short-lived machine-to-machine handoffs (~15 minutes); use `default` for durable human-revisitable links (30 days). Links are minted only for the authenticated request owner and are not resolvable by other users.
When present, each output item in the response receives a `short_url` field containing a short link for that asset. Omit this parameter (the default) to receive a response identical to the no-param baseline. The value selects the link's lifetime and auth model: use `ephemeral_tool_chain` for short-lived (≤5 minute) machine-to-machine handoffs — these are public bearer links where the link ID itself is the credential, so anyone holding the link can resolve it (intended for pasting into an agent/MCP tool chain); use `default` for durable (30 day) human-revisitable links, which are owner-gated and resolvable only by the authenticated owner. Links are always minted under the authenticated request owner's identity; the auth model is selected by the server and is never settable by the caller.
in: query
name: short_link
schema:

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.26.2"
version = "0.26.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -22,7 +22,7 @@ alembic
SQLAlchemy>=2.0.0
filelock
av>=16.0.0
comfy-kitchen==0.2.10
comfy-kitchen==0.2.13
comfy-aimdo==0.4.10
requests
simpleeval>=1.0.0
@ -33,5 +33,5 @@ kornia>=0.7.1
spandrel
pydantic~=2.0
pydantic-settings~=2.0
PyOpenGL
glfw
PyOpenGL>=3.1.8
comfy-angle

View File

@ -0,0 +1,204 @@
"""Unit tests for io.DynamicGroup: expansion/reconstruction (0-row and N-row cases)."""
import sys
import types
import pytest
# Stub torch (type-hint only in _io.py; real torch not available in unit-test env)
if "torch" not in sys.modules:
_torch_stub = types.ModuleType("torch")
_torch_stub.Tensor = object # type: ignore[attr-defined]
sys.modules["torch"] = _torch_stub
from comfy_api.latest._io import ( # noqa: E402
DynamicGroup,
Float,
Int,
String,
Boolean,
get_finalized_class_inputs,
build_nested_inputs,
create_input_dict_v1,
setup_dynamic_input_funcs,
)
# Make sure dynamic input funcs are registered (may already be done at import time)
setup_dynamic_input_funcs()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_class_inputs(group_input: DynamicGroup.Input) -> dict:
"""Wrap a DynamicGroup.Input into the required/optional dict structure."""
return create_input_dict_v1([group_input])
def _run(group_input: DynamicGroup.Input, live_values: dict) -> dict:
"""End-to-end helper: expand schema + reconstruct values.
Mirrors the production split in execution.py:
1. get_finalized_class_inputs (schema expansion, line 162)
2. build_nested_inputs (value reconstruction, line 281)
The two steps are separate in production because the engine resolves
linked node outputs between them, but in tests we supply values directly.
"""
class_inputs = _make_class_inputs(group_input)
_, _, v3_data = get_finalized_class_inputs(class_inputs, live_values)
return build_nested_inputs(dict(live_values), v3_data)
# ---------------------------------------------------------------------------
# Schema construction
# ---------------------------------------------------------------------------
class TestDynamicGroupInputConstruction:
def test_basic_construction(self):
inp = DynamicGroup.Input(
"loras",
template=[
Float.Input("strength", default=1.0),
String.Input("name"),
],
min=0,
max=10,
)
assert inp.id == "loras"
assert inp.min == 0
assert inp.max == 10
assert len(inp.template) == 2
def test_get_all_includes_self_and_template(self):
inp = DynamicGroup.Input(
"items",
template=[Float.Input("value")],
)
all_inputs = inp.get_all()
assert all_inputs[0] is inp
assert all_inputs[1].id == "value"
def test_as_dict_has_template_min_max(self):
inp = DynamicGroup.Input(
"items",
template=[Float.Input("val", default=0.5)],
min=1,
max=5,
)
d = inp.as_dict()
assert "template" in d
assert d["min"] == 1
assert d["max"] == 5
def test_duplicate_field_ids_raises(self):
with pytest.raises(AssertionError):
DynamicGroup.Input(
"bad",
template=[Float.Input("x"), Float.Input("x")],
)
def test_empty_template_raises(self):
with pytest.raises(AssertionError):
DynamicGroup.Input("bad", template=[])
def test_min_gt_max_raises(self):
with pytest.raises(AssertionError):
DynamicGroup.Input("bad", template=[Float.Input("x")], min=5, max=3)
def test_max_exceeds_limit_raises(self):
with pytest.raises(AssertionError):
DynamicGroup.Input("bad", template=[Float.Input("x")], max=101)
def test_dynamic_input_in_template_raises(self):
with pytest.raises(AssertionError):
DynamicGroup.Input(
"bad",
template=[DynamicGroup.Input("nested", template=[Float.Input("x")])],
)
def test_validate_calls_through(self):
inp = DynamicGroup.Input("items", template=[Float.Input("val", min=-1.0, max=1.0)])
inp.validate() # should not raise
# ---------------------------------------------------------------------------
# 0-row case
# ---------------------------------------------------------------------------
class TestZeroRows:
def test_empty_live_inputs_produces_empty_list(self):
"""With min=0 and no live values, the result should be an empty list."""
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
assert _run(inp, {}).get("loras") == []
def test_min_zero_with_values(self):
"""min=0 but 2 rows of live data."""
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
result = _run(inp, {"loras.0.strength": 0.8, "loras.1.strength": 0.5})
assert result["loras"] == [{"strength": 0.8}, {"strength": 0.5}]
# ---------------------------------------------------------------------------
# N-row case
# ---------------------------------------------------------------------------
class TestNRows:
def test_two_rows_two_fields(self):
"""Two rows with two fields each produce a list[dict]."""
inp = DynamicGroup.Input(
"loras",
template=[String.Input("lora_name"), Float.Input("strength", default=1.0)],
min=0, max=50,
)
result = _run(inp, {
"loras.0.lora_name": "model_a.safetensors", "loras.0.strength": 0.9,
"loras.1.lora_name": "model_b.safetensors", "loras.1.strength": 0.4,
})
assert result["loras"] == [
{"lora_name": "model_a.safetensors", "strength": 0.9},
{"lora_name": "model_b.safetensors", "strength": 0.4},
]
def test_rows_are_sorted_by_index(self):
"""Rows must be in ascending index order even if dict iteration is unordered."""
inp = DynamicGroup.Input("items", template=[Int.Input("v", default=0)], min=0, max=10)
result = _run(inp, {"items.0.v": 10, "items.2.v": 30, "items.1.v": 20})
assert [row["v"] for row in result["items"]] == [10, 20, 30]
def test_min_rows_schema_slots(self):
"""With min=2 and no live data, 2 slots must appear in the expanded schema."""
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
out, _, _ = get_finalized_class_inputs(_make_class_inputs(inp), {})
all_slots = {**out.get("required", {}), **out.get("optional", {})}
assert "items.0.val" in all_slots
assert "items.1.val" in all_slots
def test_min_rows_reconstructs_when_no_values(self):
"""min=2 with NO live values must still yield a 2-element list,
not collapse to [] (regression: parent-path clobber)."""
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
result = _run(inp, {})
assert len(result["items"]) == 2
assert all("val" in row for row in result["items"])
def test_min_rows_reconstructs_with_partial_values(self):
"""min=2 with only the first row's value present still yields 2 rows."""
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
result = _run(inp, {"items.0.val": 0.7})
assert len(result["items"]) == 2
assert result["items"][0]["val"] == 0.7
assert result["items"][1]["val"] is None
def test_list_paths_in_v3_data(self):
"""list_paths must contain the group id so build_nested_inputs knows to convert."""
inp = DynamicGroup.Input("things", template=[Boolean.Input("flag")], min=0, max=5)
_, _, v3_data = get_finalized_class_inputs(_make_class_inputs(inp), {})
assert "things" in v3_data.get("list_paths", set())
def test_no_leftover_flat_keys(self):
"""Flat keys must be consumed; only the reconstructed list remains."""
inp = DynamicGroup.Input("rows", template=[Float.Input("x", default=0.0)], min=0, max=5)
result = _run(inp, {"rows.0.x": 1.0, "rows.1.x": 2.0})
assert "rows.0.x" not in result
assert "rows.1.x" not in result
assert isinstance(result["rows"], list)

View File

@ -228,6 +228,62 @@ class TestMixedPrecisionOps(unittest.TestCase):
with self.assertRaises(KeyError):
model.load_state_dict(state_dict, strict=False)
def test_int8_convrot_metadata_loads_into_params(self):
"""ConvRot metadata must reach TensorWiseINT8Layout params."""
torch.manual_seed(123)
layer_quant_config = {
"layer": {
"format": "int8_tensorwise",
"convrot": True,
"convrot_groupsize": 256,
}
}
weight = torch.randn(16, 256, dtype=torch.bfloat16)
bias = torch.randn(16, dtype=torch.bfloat16)
q_weight = QuantizedTensor.from_float(
weight,
"TensorWiseINT8Layout",
per_channel=True,
convrot=True,
convrot_groupsize=256,
)
state_dict = {
"layer.weight": q_weight._qdata,
"layer.bias": bias,
"layer.weight_scale": q_weight._params.scale,
}
state_dict, _ = comfy.utils.convert_old_quants(
state_dict,
metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})},
)
model = torch.nn.Module()
model.layer = ops.mixed_precision_ops({}).Linear(256, 16, device="cpu", dtype=torch.bfloat16)
model.load_state_dict(state_dict, strict=False)
self.assertIsInstance(model.layer.weight, QuantizedTensor)
self.assertEqual(model.layer.weight._layout_cls, "TensorWiseINT8Layout")
self.assertTrue(model.layer.weight._params.convrot)
self.assertEqual(model.layer.weight._params.convrot_groupsize, 256)
input_tensor = torch.randn(4, 256, dtype=torch.bfloat16)
loaded_out = model.layer(input_tensor)
ref_out = torch.nn.functional.linear(input_tensor, q_weight, bias)
self.assertTrue(torch.equal(loaded_out, ref_out))
fp16_input = input_tensor.to(torch.float16)
loaded_fp16_out = model.layer(fp16_input)
ref_fp16_out = torch.nn.functional.linear(
fp16_input,
q_weight.to(dtype=torch.float16),
bias.to(dtype=torch.float16),
)
self.assertTrue(torch.equal(loaded_fp16_out, ref_fp16_out))
saved = model.state_dict()
saved_conf = json.loads(saved["layer.comfy_quant"].numpy().tobytes())
self.assertTrue(saved_conf["convrot"])
self.assertEqual(saved_conf["convrot_groupsize"], 256)
if __name__ == "__main__":
unittest.main()