mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-27 16:36:53 +08:00
Compare commits
19 Commits
release/v0
...
ListInput
| Author | SHA1 | Date | |
|---|---|---|---|
| 30b19c6872 | |||
| 2dd281d8a6 | |||
| 911e0b2acf | |||
| 46c7e8055c | |||
| 603d891eaf | |||
| 470ac36a0a | |||
| 7cb784e0f4 | |||
| 1a510f0423 | |||
| 639c8fa788 | |||
| e22f1500f9 | |||
| dac4ea3a80 | |||
| b0ec19804f | |||
| 64e1d740b8 | |||
| b22d0fb9c0 | |||
| 5236cd02e6 | |||
| cabb7342d1 | |||
| 12218db68a | |||
| 44955d783b | |||
| 1f275fcba6 |
64
comfy/ops.py
64
comfy/ops.py
@ -256,7 +256,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
|
||||
if (want_requant and len(fns) == 0 or update_weight):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
if isinstance(orig, QuantizedTensor):
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
y = orig.requantize_from_float(x, scale="recalculate", stochastic_rounding=seed)
|
||||
else:
|
||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
|
||||
if want_requant and len(fns) == 0:
|
||||
@ -1089,6 +1089,19 @@ def _load_quantized_module(module, super_load, state_dict, prefix, local_metadat
|
||||
if ts is None or bs is None:
|
||||
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||
scales = {"scale": ts, "block_scale": bs}
|
||||
elif module.quant_format == "int8_tensorwise":
|
||||
scale = pop_scale("weight_scale")
|
||||
if scale is None:
|
||||
raise ValueError(f"Missing INT8 weight scale for layer {layer_name}")
|
||||
scales = {"scale": scale}
|
||||
params_conf = layer_conf.get("params", {})
|
||||
if not isinstance(params_conf, dict):
|
||||
params_conf = {}
|
||||
if layer_conf.get("convrot", params_conf.get("convrot", False)):
|
||||
scales["convrot"] = True
|
||||
scales["convrot_groupsize"] = int(
|
||||
layer_conf.get("convrot_groupsize", params_conf.get("convrot_groupsize", 256))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
|
||||
|
||||
@ -1131,6 +1144,10 @@ def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extr
|
||||
quant_conf = {"format": module.quant_format}
|
||||
if getattr(module, '_full_precision_mm_config', False):
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
params = getattr(module.weight, "_params", None)
|
||||
if module.quant_format == "int8_tensorwise" and getattr(params, "convrot", False):
|
||||
quant_conf["convrot"] = True
|
||||
quant_conf["convrot_groupsize"] = getattr(params, "convrot_groupsize", 256)
|
||||
if extra_quant_conf:
|
||||
quant_conf.update(extra_quant_conf)
|
||||
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
|
||||
@ -1183,8 +1200,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
|
||||
def forward_comfy_cast_weights(
|
||||
self,
|
||||
input,
|
||||
compute_dtype=None,
|
||||
want_requant=False,
|
||||
weight_only_quant=False,
|
||||
):
|
||||
if weight_only_quant:
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
input=None,
|
||||
dtype=self.weight.dtype,
|
||||
device=input.device,
|
||||
bias_dtype=input.dtype,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=want_requant,
|
||||
)
|
||||
weight = weight.to(dtype=input.dtype)
|
||||
else:
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
input,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=want_requant,
|
||||
)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
@ -1203,9 +1245,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0
|
||||
)
|
||||
quantize_input = QUANT_ALGOS.get(getattr(self, 'quant_format', None), {}).get("quantize_input", True)
|
||||
|
||||
# Training path: quantized forward with compute_dtype backward via autograd function
|
||||
if (input.requires_grad and _use_quantized):
|
||||
if (input.requires_grad and _use_quantized and quantize_input):
|
||||
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
@ -1227,7 +1270,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
return output
|
||||
|
||||
# Inference path (unchanged)
|
||||
if _use_quantized:
|
||||
if _use_quantized and quantize_input:
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
@ -1241,7 +1284,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
||||
weight_only_quant = _use_quantized and not quantize_input and isinstance(self.weight, QuantizedTensor)
|
||||
output = self.forward_comfy_cast_weights(
|
||||
input,
|
||||
compute_dtype,
|
||||
want_requant=isinstance(input, QuantizedTensor),
|
||||
weight_only_quant=weight_only_quant,
|
||||
)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if reshaped_3d:
|
||||
@ -1257,8 +1306,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
# dtype is now implicit in the layout class
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
weight = self.weight.requantize_from_float(weight, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
|
||||
@ -10,6 +10,7 @@ try:
|
||||
QuantizedLayout,
|
||||
TensorCoreFP8Layout as _CKFp8Layout,
|
||||
TensorCoreNVFP4Layout as _CKNvfp4Layout,
|
||||
TensorWiseINT8Layout as _CKTensorWiseINT8Layout,
|
||||
register_layout_op,
|
||||
register_layout_class,
|
||||
get_layout_class,
|
||||
@ -47,6 +48,9 @@ except ImportError as e:
|
||||
class _CKNvfp4Layout:
|
||||
pass
|
||||
|
||||
class _CKTensorWiseINT8Layout:
|
||||
pass
|
||||
|
||||
def register_layout_class(name, cls):
|
||||
pass
|
||||
|
||||
@ -174,6 +178,7 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
|
||||
|
||||
# Backward compatibility alias - default to E4M3
|
||||
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
||||
TensorWiseINT8Layout = _CKTensorWiseINT8Layout
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
@ -184,6 +189,7 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||
register_layout_class("TensorWiseINT8Layout", _CKTensorWiseINT8Layout)
|
||||
if _CK_MXFP8_AVAILABLE:
|
||||
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||
|
||||
@ -214,6 +220,13 @@ if _CK_MXFP8_AVAILABLE:
|
||||
"group_size": 32,
|
||||
}
|
||||
|
||||
QUANT_ALGOS["int8_tensorwise"] = {
|
||||
"storage_t": torch.int8,
|
||||
"parameters": {"weight_scale"},
|
||||
"comfy_tensor_layout": "TensorWiseINT8Layout",
|
||||
"quantize_input": False,
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Re-exports for backward compatibility
|
||||
@ -226,6 +239,7 @@ __all__ = [
|
||||
"TensorCoreFP8E4M3Layout",
|
||||
"TensorCoreFP8E5M2Layout",
|
||||
"TensorCoreNVFP4Layout",
|
||||
"TensorWiseINT8Layout",
|
||||
"QUANT_ALGOS",
|
||||
"register_layout_op",
|
||||
]
|
||||
|
||||
@ -891,6 +891,14 @@ class Tracks(ComfyTypeIO):
|
||||
track_visibility: torch.Tensor
|
||||
Type = TrackDict
|
||||
|
||||
@comfytype(io_type="DICT")
|
||||
class Dict(ComfyTypeIO):
|
||||
Type = dict
|
||||
|
||||
@comfytype(io_type="ARRAY")
|
||||
class Array(ComfyTypeIO):
|
||||
Type = list
|
||||
|
||||
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
||||
class MultiType:
|
||||
Type = Any
|
||||
@ -1253,6 +1261,155 @@ class DynamicSlot(ComfyTypeI):
|
||||
out_dict[input_type][finalized_id] = value
|
||||
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||
|
||||
@comfytype(io_type="COMFY_DYNAMICGROUP_V3")
|
||||
class DynamicGroup(ComfyTypeI):
|
||||
"""A repeatable group of widget inputs (e.g. lora_name + strength stacked into N rows).
|
||||
|
||||
At execution time the node receives a ``list[dict]`` where each element is a row.
|
||||
|
||||
Example::
|
||||
|
||||
io.DynamicGroup.Input(
|
||||
"loras",
|
||||
template=[
|
||||
io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras")),
|
||||
io.Float.Input("strength", default=1.0, min=-100, max=100, step=0.01),
|
||||
],
|
||||
min=0,
|
||||
max=50,
|
||||
)
|
||||
# execute receives: loras: list[dict] = [{"lora_name": "x.safetensors", "strength": 1.0}, ...]
|
||||
"""
|
||||
|
||||
Type = list[dict[str, Any]]
|
||||
_MaxRows = 100
|
||||
|
||||
class Input(DynamicInput):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
template: list["Input"],
|
||||
min: int = 0,
|
||||
max: int = 50,
|
||||
display_name: str = None,
|
||||
optional: bool = False,
|
||||
tooltip: str = None,
|
||||
lazy: bool = None,
|
||||
extra_dict=None,
|
||||
):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
# Validate template entries: only WidgetInput subclasses, no nesting
|
||||
assert len(template) > 0, "DynamicGroup template must have at least one field."
|
||||
for t in template:
|
||||
assert isinstance(t, WidgetInput), (
|
||||
f"DynamicGroup template field '{t.id}' must be a WidgetInput subclass "
|
||||
f"(Combo, Float, Int, String, Boolean, Color). Got {type(t).__name__}."
|
||||
)
|
||||
assert not isinstance(t, DynamicInput), (
|
||||
f"DynamicGroup template field '{t.id}' must not be a DynamicInput. "
|
||||
"Nesting dynamic inputs inside DynamicGroup is not supported."
|
||||
)
|
||||
# Enforce unique field ids within template
|
||||
field_ids = [t.id for t in template]
|
||||
assert len(field_ids) == len(set(field_ids)), (
|
||||
f"DynamicGroup template field ids must be unique within a row. Got: {field_ids}"
|
||||
)
|
||||
# Reject "." in group id and template field ids: slot_id encoding uses "." as a
|
||||
# delimiter (<group_id>.<row>.<field_id>), so any "." in these names would cause
|
||||
# path.split(".") to produce the wrong number of segments during decoding.
|
||||
assert "." not in id, (
|
||||
f"DynamicGroup id must not contain '.'. Got: '{id}'"
|
||||
)
|
||||
for t in template:
|
||||
assert "." not in t.id, (
|
||||
f"DynamicGroup template field id must not contain '.'. Got: '{t.id}'"
|
||||
)
|
||||
assert min >= 0, "DynamicGroup min must be >= 0."
|
||||
assert max >= 1, "DynamicGroup max must be >= 1."
|
||||
assert max <= DynamicGroup._MaxRows, f"DynamicGroup max must be <= {DynamicGroup._MaxRows}."
|
||||
assert min <= max, "DynamicGroup min must be <= max."
|
||||
self.template = template
|
||||
self.min = min
|
||||
self.max = max
|
||||
|
||||
def get_all(self) -> list["Input"]:
|
||||
return [self] + list(self.template)
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"template": create_input_dict_v1(self.template),
|
||||
"min": self.min,
|
||||
"max": self.max,
|
||||
})
|
||||
|
||||
def validate(self):
|
||||
for t in self.template:
|
||||
t.validate()
|
||||
|
||||
@staticmethod
|
||||
def _expand_schema_for_dynamic(
|
||||
out_dict: dict[str, Any],
|
||||
live_inputs: dict[str, Any],
|
||||
value: tuple[str, dict[str, Any]],
|
||||
input_type: str,
|
||||
curr_prefix: list[str] | None,
|
||||
):
|
||||
info = value[1]
|
||||
min_rows: int = info.get("min", 0)
|
||||
max_rows: int = info.get("max", DynamicGroup._MaxRows)
|
||||
template: dict[str, Any] = info.get("template", {})
|
||||
|
||||
# Collect all template field specs across required/optional sections
|
||||
field_specs: list[tuple[str, tuple[str, dict[str, Any]], bool]] = []
|
||||
for field_required_key in ("required", "optional"):
|
||||
section = template.get(field_required_key, {})
|
||||
is_required_field = field_required_key == "required"
|
||||
for field_id, field_value in section.items():
|
||||
field_specs.append((field_id, field_value, is_required_field))
|
||||
|
||||
# Determine how many rows are currently present by scanning live_inputs
|
||||
finalized_prefix = finalize_prefix(curr_prefix)
|
||||
present_rows = 0
|
||||
for live_key in live_inputs:
|
||||
# Keys look like "<prefix>.<row>.<field_id>"
|
||||
if live_key.startswith(finalized_prefix + "."):
|
||||
remainder = live_key[len(finalized_prefix) + 1:]
|
||||
parts = remainder.split(".", 1)
|
||||
if len(parts) >= 1:
|
||||
try:
|
||||
row_idx = int(parts[0])
|
||||
present_rows = max(present_rows, row_idx + 1)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if present_rows > max_rows:
|
||||
raise ValueError(
|
||||
f"DynamicGroup input '{finalized_prefix}' received {present_rows} rows but max is {max_rows}."
|
||||
)
|
||||
row_count = max(min_rows, present_rows)
|
||||
|
||||
for row in range(row_count):
|
||||
for field_id, field_value, is_required_field in field_specs:
|
||||
slot_id = f"{finalized_prefix}.{row}.{field_id}"
|
||||
# The first `min_rows` rows are required if the field itself is required
|
||||
if row < min_rows and is_required_field:
|
||||
out_dict["required"][slot_id] = field_value
|
||||
else:
|
||||
out_dict["optional"][slot_id] = field_value
|
||||
# Register into dynamic_paths so build_nested_inputs places value at the right path
|
||||
out_dict["dynamic_paths"][slot_id] = slot_id
|
||||
|
||||
# Track the list root path so build_nested_inputs can convert the index dict to a list
|
||||
out_dict.setdefault("list_paths", set()).add(finalized_prefix)
|
||||
|
||||
# Handle the empty case (0 rows) – emit an empty-list default for the parent.
|
||||
# This must only fire when there are genuinely no rows; otherwise the parent
|
||||
# path would clobber the per-row dict built from the slot ids above.
|
||||
if row_count == 0:
|
||||
out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix
|
||||
out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_LIST
|
||||
|
||||
|
||||
@comfytype(io_type="IMAGECOMPARE")
|
||||
class ImageCompare(ComfyTypeI):
|
||||
Type = dict
|
||||
@ -1279,6 +1436,19 @@ class Color(ComfyTypeIO):
|
||||
def as_dict(self):
|
||||
return super().as_dict()
|
||||
|
||||
|
||||
@comfytype(io_type="COLORS")
|
||||
class Colors(ComfyTypeIO):
|
||||
Type = list[Color.Type]
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: list[str]=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = []
|
||||
|
||||
|
||||
@comfytype(io_type="BOUNDING_BOX")
|
||||
class BoundingBox(ComfyTypeIO):
|
||||
class BoundingBoxDict(TypedDict):
|
||||
@ -1326,6 +1496,20 @@ class Curve(ComfyTypeIO):
|
||||
return d
|
||||
|
||||
|
||||
@comfytype(io_type="BOUNDING_BOXES")
|
||||
class BoundingBoxes(ComfyTypeIO):
|
||||
class BoundingBoxWithMetadata(BoundingBox.BoundingBoxDict):
|
||||
metadata: dict
|
||||
Type = list[BoundingBoxWithMetadata]
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: list[dict]=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = []
|
||||
|
||||
|
||||
@comfytype(io_type="HISTOGRAM")
|
||||
class Histogram(ComfyTypeIO):
|
||||
"""A histogram represented as a list of bin counts."""
|
||||
@ -1383,6 +1567,8 @@ def setup_dynamic_input_funcs():
|
||||
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
|
||||
# DynamicSlot.Input
|
||||
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
|
||||
# DynamicGroup.Input
|
||||
register_dynamic_input_func(DynamicGroup.io_type, DynamicGroup._expand_schema_for_dynamic)
|
||||
|
||||
if len(DYNAMIC_INPUT_LOOKUP) == 0:
|
||||
setup_dynamic_input_funcs()
|
||||
@ -1394,6 +1580,8 @@ class V3Data(TypedDict):
|
||||
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
|
||||
dynamic_paths_default_value: dict[str, Any]
|
||||
'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.'
|
||||
list_paths: set[str]
|
||||
'Set of top-level keys whose index-keyed dict values should be converted to a sorted list[dict] after build_nested_inputs runs.'
|
||||
create_dynamic_tuple: bool
|
||||
'When True, the value of the dynamic input will be in the format (value, path_key).'
|
||||
|
||||
@ -1735,6 +1923,7 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
||||
"optional": {},
|
||||
"dynamic_paths": {},
|
||||
"dynamic_paths_default_value": {},
|
||||
"list_paths": set(),
|
||||
}
|
||||
d = d.copy()
|
||||
# ignore hidden for parsing
|
||||
@ -1750,6 +1939,10 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
||||
dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None)
|
||||
if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0:
|
||||
v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value
|
||||
# list_paths: keys whose nested dict should be post-converted to a sorted list[dict]
|
||||
list_paths = out_dict.pop("list_paths", None)
|
||||
if list_paths:
|
||||
v3_data["list_paths"] = list_paths
|
||||
return out_dict, hidden, v3_data
|
||||
|
||||
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
|
||||
@ -1785,10 +1978,12 @@ def add_to_dict_v1(i: Input, d: dict):
|
||||
|
||||
class DynamicPathsDefaultValue:
|
||||
EMPTY_DICT = "empty_dict"
|
||||
EMPTY_LIST = "empty_list"
|
||||
|
||||
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
paths = v3_data.get("dynamic_paths", None)
|
||||
default_value_dict = v3_data.get("dynamic_paths_default_value", {})
|
||||
list_paths: set[str] = v3_data.get("list_paths", set()) or set()
|
||||
if paths is None:
|
||||
return values
|
||||
values = values.copy()
|
||||
@ -1811,6 +2006,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
default_option = default_value_dict.get(key, None)
|
||||
if default_option == DynamicPathsDefaultValue.EMPTY_DICT:
|
||||
value = {}
|
||||
elif default_option == DynamicPathsDefaultValue.EMPTY_LIST:
|
||||
value = []
|
||||
if create_tuple:
|
||||
value = (value, key)
|
||||
current[p] = value
|
||||
@ -1818,6 +2015,34 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
current = current.setdefault(p, {})
|
||||
|
||||
values.update(result)
|
||||
|
||||
# Post-pass: convert index-keyed dicts to sorted lists for io.DynamicGroup fields
|
||||
for list_path in list_paths:
|
||||
parts = list_path.split(".")
|
||||
# Navigate to the parent container, then convert the leaf
|
||||
container = values
|
||||
for part in parts[:-1]:
|
||||
if not isinstance(container, dict) or part not in container:
|
||||
container = None
|
||||
break
|
||||
container = container[part]
|
||||
if container is None:
|
||||
continue
|
||||
leaf_key = parts[-1]
|
||||
leaf = container.get(leaf_key, None)
|
||||
if isinstance(leaf, dict):
|
||||
try:
|
||||
sorted_rows = [leaf[k] for k in sorted(leaf.keys(), key=int)]
|
||||
container[leaf_key] = sorted_rows
|
||||
except (ValueError, TypeError):
|
||||
# Keys are not all integers; leave as-is
|
||||
pass
|
||||
elif isinstance(leaf, list):
|
||||
# Already a list (e.g. the EMPTY_LIST default was applied above)
|
||||
pass
|
||||
elif leaf is None:
|
||||
container[leaf_key] = []
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@ -2376,11 +2601,15 @@ __all__ = [
|
||||
"AnyType",
|
||||
"MultiType",
|
||||
"Tracks",
|
||||
"Dict",
|
||||
"Array",
|
||||
"Color",
|
||||
# Dynamic Types
|
||||
"MatchType",
|
||||
"DynamicCombo",
|
||||
"DynamicSlot",
|
||||
"Autogrow",
|
||||
"DynamicGroup",
|
||||
# Other classes
|
||||
"HiddenHolder",
|
||||
"Hidden",
|
||||
@ -2394,6 +2623,8 @@ __all__ = [
|
||||
"PriceBadgeDepends",
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"BoundingBoxes",
|
||||
"Colors",
|
||||
"Curve",
|
||||
"Histogram",
|
||||
"Range",
|
||||
|
||||
23
comfy_extras/color_util.py
Normal file
23
comfy_extras/color_util.py
Normal file
@ -0,0 +1,23 @@
|
||||
def hex_to_rgb(value: str) -> tuple[int, int, int]:
|
||||
h = value.lstrip("#")
|
||||
if len(h) != 6:
|
||||
return (255, 255, 255)
|
||||
try:
|
||||
return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16))
|
||||
except ValueError:
|
||||
return (255, 255, 255)
|
||||
|
||||
|
||||
def readable_color(rgb: tuple[int, int, int]) -> tuple[int, int, int]:
|
||||
r, g, b = rgb
|
||||
lum = 0.299 * r + 0.587 * g + 0.114 * b
|
||||
if lum >= 130:
|
||||
return (r, g, b)
|
||||
t = (130 - lum) / (255 - lum)
|
||||
return (round(r + (255 - r) * t), round(g + (255 - g) * t), round(b + (255 - b) * t))
|
||||
|
||||
|
||||
def normalize_palette(colors) -> list[str]:
|
||||
if isinstance(colors, dict):
|
||||
colors = colors.values()
|
||||
return [c.upper() for c in colors if isinstance(c, str) and c]
|
||||
253
comfy_extras/nodes_bounding_boxes.py
Normal file
253
comfy_extras/nodes_bounding_boxes.py
Normal file
@ -0,0 +1,253 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageEnhance, ImageFont
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_extras.color_util import hex_to_rgb, normalize_palette, readable_color
|
||||
|
||||
_PREVIEW_LONG_EDGE = 1024
|
||||
_PREVIEW_DIM = 0.25
|
||||
|
||||
|
||||
def pixels_to_fractions(box: dict, width: int, height: int) -> dict:
|
||||
w = width or 1
|
||||
h = height or 1
|
||||
return {
|
||||
"x": box.get("x", 0) / w,
|
||||
"y": box.get("y", 0) / h,
|
||||
"w": box.get("width", 0) / w,
|
||||
"h": box.get("height", 0) / h,
|
||||
}
|
||||
|
||||
|
||||
def fractions_to_pixels(box: dict, width: int, height: int) -> dict:
|
||||
x, y = box.get("x", 0.0), box.get("y", 0.0)
|
||||
w, h = box.get("w", 0.0), box.get("h", 0.0)
|
||||
if w < 0:
|
||||
x, w = x + w, -w
|
||||
if h < 0:
|
||||
y, h = y + h, -h
|
||||
return {
|
||||
"x": round(x * width),
|
||||
"y": round(y * height),
|
||||
"width": round(w * width),
|
||||
"height": round(h * height),
|
||||
}
|
||||
|
||||
|
||||
def fractions_to_bbox_frame(boxes: list, width: int, height: int) -> list:
|
||||
pixels = [
|
||||
fractions_to_pixels(box, width, height)
|
||||
for box in boxes
|
||||
if isinstance(box, dict)
|
||||
]
|
||||
return [pixels] if pixels else []
|
||||
|
||||
|
||||
def _font(size: int):
|
||||
try:
|
||||
return ImageFont.load_default(size)
|
||||
except Exception:
|
||||
return ImageFont.load_default()
|
||||
|
||||
|
||||
def _wrap(draw, text: str, font, max_w: float) -> list[str]:
|
||||
lines = []
|
||||
for para in text.split("\n"):
|
||||
line = ""
|
||||
for word in para.split():
|
||||
test = word if not line else line + " " + word
|
||||
if line and draw.textlength(test, font=font) > max_w:
|
||||
lines.append(line)
|
||||
line = word
|
||||
else:
|
||||
line = test
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
|
||||
def _bg_from_image(image) -> Image.Image | None:
|
||||
if image is None:
|
||||
return None
|
||||
try:
|
||||
arr = (image[0].detach().cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
return Image.fromarray(arr)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def render_preview(regions, width, height, bg=None):
|
||||
if bg is not None:
|
||||
iw, ih = bg.size
|
||||
long_edge = max(iw, ih) or 1
|
||||
scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge)
|
||||
rw, rh = max(1, round(iw * scale)), max(1, round(ih * scale))
|
||||
base = bg.convert("RGB").resize((rw, rh), Image.LANCZOS)
|
||||
base = ImageEnhance.Brightness(base).enhance(_PREVIEW_DIM)
|
||||
img = base.convert("RGBA")
|
||||
else:
|
||||
long_edge = max(width, height) or 1
|
||||
scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge)
|
||||
rw, rh = max(1, round(width * scale)), max(1, round(height * scale))
|
||||
grey = round(_PREVIEW_DIM * 128)
|
||||
img = Image.new("RGBA", (rw, rh), (grey, grey, grey, 255))
|
||||
|
||||
overlay = Image.new("RGBA", (rw, rh), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(overlay)
|
||||
fs = max(10, round(rh / 64))
|
||||
font = _font(fs)
|
||||
tag_font = _font(max(9, fs - 2))
|
||||
line_h = fs + 2
|
||||
|
||||
for i, region in enumerate(regions):
|
||||
if not isinstance(region, dict):
|
||||
continue
|
||||
palette = [c for c in (region.get("palette") or []) if c]
|
||||
r, g, b = hex_to_rgb(palette[0]) if palette else (140, 140, 140)
|
||||
x1 = max(0, min(rw, round(region.get("x", 0) * rw)))
|
||||
y1 = max(0, min(rh, round(region.get("y", 0) * rh)))
|
||||
x2 = max(0, min(rw, round((region.get("x", 0) + region.get("w", 0)) * rw)))
|
||||
y2 = max(0, min(rh, round((region.get("y", 0) + region.get("h", 0)) * rh)))
|
||||
if x2 < x1:
|
||||
x1, x2 = x2, x1
|
||||
if y2 < y1:
|
||||
y1, y2 = y2, y1
|
||||
|
||||
draw.rectangle([x1, y1, x2, y2], outline=(r, g, b, 255), width=2)
|
||||
|
||||
swatches = palette[:5]
|
||||
if swatches and (x2 - x1) > 2:
|
||||
sh = max(5, fs // 2)
|
||||
seg = (x2 - x1) / len(swatches)
|
||||
for p, hexc in enumerate(swatches):
|
||||
sx = x1 + round(p * seg)
|
||||
draw.rectangle([sx, y1, x1 + round((p + 1) * seg), y1 + sh], fill=hex_to_rgb(hexc))
|
||||
|
||||
etype = "text" if region.get("type") == "text" else "obj"
|
||||
tag = str(i + 1).zfill(2)
|
||||
tw = draw.textlength(tag, font=tag_font)
|
||||
draw.rectangle([x1, y1, x1 + tw + 6, y1 + fs + 2], fill=(r, g, b, 255))
|
||||
tag_fill = (0, 0, 0, 255) if (0.299 * r + 0.587 * g + 0.114 * b) > 140 else (255, 255, 255, 255)
|
||||
draw.text((x1 + 3, y1 + 1), tag, fill=tag_fill, font=tag_font)
|
||||
|
||||
body = region.get("desc", "") or ""
|
||||
if etype == "text" and region.get("text"):
|
||||
body = '"%s"%s' % (region["text"], " — " + body if body else "")
|
||||
if body and (x2 - x1) > 8:
|
||||
ty = y1 + fs + 5
|
||||
for line in _wrap(draw, body, font, x2 - x1 - 8):
|
||||
if ty > y2:
|
||||
break
|
||||
draw.text((x1 + 4, ty), line, fill=readable_color((r, g, b)) + (255,), font=font)
|
||||
ty += line_h
|
||||
|
||||
composed = Image.alpha_composite(img, overlay).convert("RGB")
|
||||
arr = np.asarray(composed, dtype=np.float32) / 255.0
|
||||
return torch.from_numpy(arr).unsqueeze(0)
|
||||
|
||||
|
||||
def boxes_to_regions(boxes, width: int, height: int) -> list:
|
||||
regions: list = []
|
||||
if not isinstance(boxes, list):
|
||||
return regions
|
||||
for box in boxes:
|
||||
if not isinstance(box, dict):
|
||||
continue
|
||||
meta = box.get("metadata")
|
||||
meta = meta if isinstance(meta, dict) else {}
|
||||
regions.append({
|
||||
**pixels_to_fractions(box, width, height),
|
||||
"type": meta.get("type", "obj"),
|
||||
"text": meta.get("text", ""),
|
||||
"desc": meta.get("desc", ""),
|
||||
"palette": meta.get("palette", []),
|
||||
})
|
||||
return regions
|
||||
|
||||
|
||||
def _norm_bbox(region: dict) -> list[int]:
|
||||
def grid(value: float) -> int:
|
||||
return max(0, min(1000, round(value * 1000)))
|
||||
|
||||
x, y = region.get("x", 0.0), region.get("y", 0.0)
|
||||
w, h = region.get("w", 0.0), region.get("h", 0.0)
|
||||
ymin, xmin, ymax, xmax = grid(y), grid(x), grid(y + h), grid(x + w)
|
||||
if ymin > ymax:
|
||||
ymin, ymax = ymax, ymin
|
||||
if xmin > xmax:
|
||||
xmin, xmax = xmax, xmin
|
||||
return [ymin, xmin, ymax, xmax]
|
||||
|
||||
|
||||
def build_elements(regions: list) -> list:
|
||||
elements = []
|
||||
for region in regions:
|
||||
if not isinstance(region, dict):
|
||||
continue
|
||||
etype = "text" if region.get("type") == "text" else "obj"
|
||||
element = {"type": etype}
|
||||
element["bbox"] = _norm_bbox(region)
|
||||
if etype == "text":
|
||||
element["text"] = region.get("text", "")
|
||||
element["desc"] = region.get("desc", "")
|
||||
palette = normalize_palette(region.get("palette", []))
|
||||
if palette:
|
||||
element["color_palette"] = palette[:5]
|
||||
elements.append(element)
|
||||
return elements
|
||||
|
||||
|
||||
class CreateBoundingBoxes(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
editor_state = io.BoundingBoxes.Input(
|
||||
"editor_state",
|
||||
socketless=False,
|
||||
tooltip="Draw bounding boxes and set each box type, text, description, color palette. Start with background element first and foreground last.",
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="CreateBoundingBoxes",
|
||||
display_name="Create Bounding Boxes",
|
||||
category="utilities",
|
||||
description="Draw bounding boxes in a canvas. Outputs Ideogram prompt elements, pixel-space bounding boxes, and a preview image.",
|
||||
inputs=[
|
||||
io.Image.Input(
|
||||
"background",
|
||||
optional=True,
|
||||
tooltip="Optional image used as background in the canvas and preview.",
|
||||
),
|
||||
io.Int.Input("width", default=1024, min=64, max=16384, step=16,
|
||||
tooltip="Width of the canvas and the pixel grid for the bounding boxes."),
|
||||
io.Int.Input("height", default=1024, min=64, max=16384, step=16,
|
||||
tooltip="Height of the canvas and the pixel grid for the bounding boxes."),
|
||||
editor_state,
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="preview"),
|
||||
io.BoundingBox.Output(display_name="bboxes"),
|
||||
io.Array.Output(display_name="elements"),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, editor_state=None, background=None) -> io.NodeOutput:
|
||||
regions = boxes_to_regions(editor_state, width, height)
|
||||
preview = render_preview(regions, width, height, _bg_from_image(background))
|
||||
return io.NodeOutput(
|
||||
preview,
|
||||
fractions_to_bbox_frame(regions, width, height),
|
||||
build_elements(regions),
|
||||
ui={"dims": [width, height]},
|
||||
)
|
||||
|
||||
|
||||
class BoundingBoxesExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [CreateBoundingBoxes]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> BoundingBoxesExtension:
|
||||
return BoundingBoxesExtension()
|
||||
@ -1,5 +1,6 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_extras.color_util import hex_to_rgb
|
||||
|
||||
|
||||
class ColorToRGBInt(io.ComfyNode):
|
||||
@ -24,9 +25,11 @@ class ColorToRGBInt(io.ComfyNode):
|
||||
# expect format #RRGGBB
|
||||
if len(color) != 7 or color[0] != "#":
|
||||
raise ValueError("Color must be in format #RRGGBB")
|
||||
r = int(color[1:3], 16)
|
||||
g = int(color[3:5], 16)
|
||||
b = int(color[5:7], 16)
|
||||
try:
|
||||
int(color[1:], 16)
|
||||
except ValueError:
|
||||
raise ValueError("Color must be in format #RRGGBB") from None
|
||||
r, g, b = hex_to_rgb(color)
|
||||
|
||||
rgb_int = r * 256 * 256 + g * 256 + b
|
||||
return io.NodeOutput(rgb_int, color)
|
||||
|
||||
@ -1,85 +1,68 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import ctypes
|
||||
import logging
|
||||
import ctypes.util
|
||||
import importlib.util
|
||||
from typing import TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import nodes
|
||||
import comfy_angle
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from typing_extensions import override
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _check_opengl_availability():
|
||||
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
|
||||
logger.debug("_check_opengl_availability: starting")
|
||||
missing = []
|
||||
def _preload_angle():
|
||||
egl_path = comfy_angle.get_egl_path()
|
||||
gles_path = comfy_angle.get_glesv2_path()
|
||||
|
||||
# Check Python packages (using find_spec to avoid importing)
|
||||
logger.debug("_check_opengl_availability: checking for glfw package")
|
||||
if importlib.util.find_spec("glfw") is None:
|
||||
missing.append("glfw")
|
||||
if sys.platform == "win32":
|
||||
angle_dir = comfy_angle.get_lib_dir()
|
||||
os.add_dll_directory(angle_dir)
|
||||
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
|
||||
|
||||
logger.debug("_check_opengl_availability: checking for OpenGL package")
|
||||
if importlib.util.find_spec("OpenGL") is None:
|
||||
missing.append("PyOpenGL")
|
||||
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
|
||||
)
|
||||
|
||||
# On Linux without display, check if headless backends are available
|
||||
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
|
||||
if sys.platform.startswith("linux"):
|
||||
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
|
||||
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
|
||||
if not has_display:
|
||||
# Check for EGL or OSMesa libraries
|
||||
logger.debug("_check_opengl_availability: checking for EGL library")
|
||||
has_egl = ctypes.util.find_library("EGL")
|
||||
logger.debug("_check_opengl_availability: checking for OSMesa library")
|
||||
has_osmesa = ctypes.util.find_library("OSMesa")
|
||||
|
||||
# Error disabled for CI as it fails this check
|
||||
# if not has_egl and not has_osmesa:
|
||||
# raise RuntimeError(
|
||||
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
|
||||
# "See error below for installation instructions."
|
||||
# )
|
||||
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
|
||||
|
||||
logger.debug("_check_opengl_availability: completed")
|
||||
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
|
||||
ctypes.CDLL(str(egl_path), mode=mode)
|
||||
ctypes.CDLL(str(gles_path), mode=mode)
|
||||
|
||||
|
||||
# Run early check at import time
|
||||
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
|
||||
_check_opengl_availability()
|
||||
|
||||
# OpenGL modules - initialized lazily when context is created
|
||||
gl = None
|
||||
glfw = None
|
||||
EGL = None
|
||||
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
|
||||
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
|
||||
_preload_angle()
|
||||
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||
|
||||
|
||||
def _import_opengl():
|
||||
"""Import OpenGL module. Called after context is created."""
|
||||
global gl
|
||||
if gl is None:
|
||||
logger.debug("_import_opengl: importing OpenGL.GL")
|
||||
import OpenGL.GL as _gl
|
||||
gl = _gl
|
||||
logger.debug("_import_opengl: import completed")
|
||||
return gl
|
||||
import OpenGL
|
||||
OpenGL.USE_ACCELERATE = False
|
||||
|
||||
|
||||
def _patch_find_library():
|
||||
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
|
||||
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
|
||||
'libGLESv2'. Patch find_library to return the full ANGLE paths so
|
||||
PyOpenGL loads the same libraries we pre-loaded."""
|
||||
if sys.platform == "linux":
|
||||
return
|
||||
import ctypes.util
|
||||
_orig = ctypes.util.find_library
|
||||
def _patched(name):
|
||||
if name == 'EGL':
|
||||
return comfy_angle.get_egl_path()
|
||||
if name == 'GLESv2':
|
||||
return comfy_angle.get_glesv2_path()
|
||||
return _orig(name)
|
||||
ctypes.util.find_library = _patched
|
||||
|
||||
|
||||
_patch_find_library()
|
||||
|
||||
from OpenGL import EGL
|
||||
from OpenGL import GLES3 as gl
|
||||
|
||||
class SizeModeInput(TypedDict):
|
||||
size_mode: str
|
||||
width: int
|
||||
@ -102,7 +85,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
# (-1,-1)---(3,-1)
|
||||
#
|
||||
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
||||
VERTEX_SHADER = """#version 330 core
|
||||
VERTEX_SHADER = """#version 300 es
|
||||
out vec2 v_texCoord;
|
||||
void main() {
|
||||
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
||||
@ -126,14 +109,99 @@ void main() {
|
||||
"""
|
||||
|
||||
|
||||
def _convert_es_to_desktop(source: str) -> str:
|
||||
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
|
||||
# Remove any existing #version directive
|
||||
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
|
||||
# Remove precision qualifiers (not needed in desktop GLSL)
|
||||
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
|
||||
# Prepend desktop GLSL version
|
||||
return "#version 330 core\n" + source
|
||||
|
||||
def _egl_attribs(*values):
|
||||
"""Build an EGL_NONE-terminated EGLint attribute array."""
|
||||
vals = list(values) + [EGL.EGL_NONE]
|
||||
return (ctypes.c_int32 * len(vals))(*vals)
|
||||
|
||||
|
||||
# EGL platform extension constants
|
||||
EGL_PLATFORM_ANGLE_ANGLE = 0x3202
|
||||
EGL_PLATFORM_ANGLE_TYPE_ANGLE = 0x3203
|
||||
EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE = 0x3450
|
||||
EGL_MESA_PLATFORM_SURFACELESS = 0x31DD
|
||||
|
||||
|
||||
_eglGetPlatformDisplayEXT = None
|
||||
|
||||
def _get_egl_platform_display_ext(platform, native_display, attribs):
|
||||
"""Call eglGetPlatformDisplayEXT via ctypes (extension, not in PyOpenGL)."""
|
||||
global _eglGetPlatformDisplayEXT
|
||||
if _eglGetPlatformDisplayEXT is None:
|
||||
from OpenGL import platform as _plat
|
||||
egl_lib = _plat.PLATFORM.EGL
|
||||
_get_proc = egl_lib.eglGetProcAddress
|
||||
_get_proc.restype = ctypes.c_void_p
|
||||
_get_proc.argtypes = [ctypes.c_char_p]
|
||||
ptr = _get_proc(b"eglGetPlatformDisplayEXT")
|
||||
if not ptr:
|
||||
return None
|
||||
func_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p, ctypes.c_void_p)
|
||||
_eglGetPlatformDisplayEXT = func_type(ptr)
|
||||
|
||||
raw = _eglGetPlatformDisplayEXT(platform, native_display, attribs)
|
||||
if not raw:
|
||||
return None
|
||||
return ctypes.cast(raw, EGL.EGLDisplay)
|
||||
|
||||
|
||||
def _get_egl_display():
|
||||
"""Get an EGL display, trying the default first then ANGLE's Vulkan
|
||||
platform for headless environments without a display server."""
|
||||
failures = []
|
||||
|
||||
# Try the default display first (works when X11/Wayland is available)
|
||||
display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
|
||||
if display:
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
try:
|
||||
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
return display, major.value, minor.value
|
||||
except Exception as e:
|
||||
failures.append(f"default: {e}")
|
||||
|
||||
logger.info("Default EGL display unavailable, trying headless fallbacks")
|
||||
|
||||
# Headless fallback strategies, tried in order:
|
||||
headless_strategies = [
|
||||
("surfaceless", EGL_MESA_PLATFORM_SURFACELESS, None, None),
|
||||
("ANGLE Vulkan", EGL_PLATFORM_ANGLE_ANGLE, None,
|
||||
_egl_attribs(EGL_PLATFORM_ANGLE_TYPE_ANGLE, EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE)),
|
||||
]
|
||||
|
||||
for name, platform, native_display, attribs in headless_strategies:
|
||||
display = _get_egl_platform_display_ext(platform, native_display, attribs)
|
||||
if not display:
|
||||
failures.append(f"{name}: eglGetPlatformDisplayEXT returned no display")
|
||||
continue
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
try:
|
||||
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
logger.info(f"Using EGL {name} platform (headless)")
|
||||
return display, major.value, minor.value
|
||||
failures.append(f"{name}: eglInitialize returned false")
|
||||
except Exception as e:
|
||||
failures.append(f"{name}: {e}")
|
||||
continue
|
||||
|
||||
details = "\n".join(f" - {f}" for f in failures)
|
||||
raise RuntimeError(
|
||||
"Failed to initialize EGL display.\n"
|
||||
"No display server and no headless EGL platform available.\n"
|
||||
f"Tried:\n{details}\n"
|
||||
"Ensure GPU drivers are installed or set DISPLAY for a virtual framebuffer."
|
||||
)
|
||||
|
||||
|
||||
def _gl_str(name):
|
||||
"""Get an OpenGL string parameter."""
|
||||
v = gl.glGetString(name)
|
||||
if not v:
|
||||
return "Unknown"
|
||||
if isinstance(v, bytes):
|
||||
return v.decode(errors="replace")
|
||||
return ctypes.string_at(v).decode(errors="replace")
|
||||
|
||||
|
||||
def _detect_output_count(source: str) -> int:
|
||||
@ -159,163 +227,8 @@ def _detect_pass_count(source: str) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _init_glfw():
|
||||
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_glfw: starting")
|
||||
# On macOS, glfw.init() must be called from main thread or it hangs forever
|
||||
if sys.platform == "darwin":
|
||||
logger.debug("_init_glfw: skipping on macOS")
|
||||
raise RuntimeError("GLFW backend not supported on macOS")
|
||||
|
||||
logger.debug("_init_glfw: importing glfw module")
|
||||
import glfw as _glfw
|
||||
|
||||
logger.debug("_init_glfw: calling glfw.init()")
|
||||
if not _glfw.init():
|
||||
raise RuntimeError("glfw.init() failed")
|
||||
|
||||
try:
|
||||
logger.debug("_init_glfw: setting window hints")
|
||||
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
|
||||
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
|
||||
|
||||
logger.debug("_init_glfw: calling create_window()")
|
||||
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
|
||||
if not window:
|
||||
raise RuntimeError("glfw.create_window() failed")
|
||||
|
||||
logger.debug("_init_glfw: calling make_context_current()")
|
||||
_glfw.make_context_current(window)
|
||||
logger.debug("_init_glfw: completed successfully")
|
||||
return window, _glfw
|
||||
except Exception:
|
||||
logger.debug("_init_glfw: failed, terminating glfw")
|
||||
_glfw.terminate()
|
||||
raise
|
||||
|
||||
|
||||
def _init_egl():
|
||||
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_egl: starting")
|
||||
from OpenGL import EGL as _EGL
|
||||
from OpenGL.EGL import (
|
||||
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
|
||||
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
|
||||
eglTerminate, eglDestroyContext, eglDestroySurface,
|
||||
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
|
||||
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
|
||||
)
|
||||
logger.debug("_init_egl: imports completed")
|
||||
|
||||
display = None
|
||||
context = None
|
||||
surface = None
|
||||
|
||||
try:
|
||||
logger.debug("_init_egl: calling eglGetDisplay()")
|
||||
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
|
||||
if display == _EGL.EGL_NO_DISPLAY:
|
||||
raise RuntimeError("eglGetDisplay() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglInitialize()")
|
||||
major, minor = _EGL.EGLint(), _EGL.EGLint()
|
||||
if not eglInitialize(display, major, minor):
|
||||
display = None # Not initialized, don't terminate
|
||||
raise RuntimeError("eglInitialize() failed")
|
||||
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
|
||||
|
||||
config_attribs = [
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
|
||||
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
|
||||
EGL_DEPTH_SIZE, 0, EGL_NONE
|
||||
]
|
||||
configs = (_EGL.EGLConfig * 1)()
|
||||
num_configs = _EGL.EGLint()
|
||||
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
config = configs[0]
|
||||
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
|
||||
|
||||
if not eglBindAPI(EGL_OPENGL_API):
|
||||
raise RuntimeError("eglBindAPI() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreateContext()")
|
||||
context_attribs = [
|
||||
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
|
||||
EGL_NONE
|
||||
]
|
||||
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
|
||||
if context == EGL_NO_CONTEXT:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
|
||||
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
|
||||
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
|
||||
if surface == _EGL.EGL_NO_SURFACE:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglMakeCurrent()")
|
||||
if not eglMakeCurrent(display, surface, surface, context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_egl: completed successfully")
|
||||
return display, context, surface, _EGL
|
||||
|
||||
except Exception:
|
||||
logger.debug("_init_egl: failed, cleaning up")
|
||||
# Clean up any resources on failure
|
||||
if surface is not None:
|
||||
eglDestroySurface(display, surface)
|
||||
if context is not None:
|
||||
eglDestroyContext(display, context)
|
||||
if display is not None:
|
||||
eglTerminate(display)
|
||||
raise
|
||||
|
||||
|
||||
def _init_osmesa():
|
||||
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
|
||||
import ctypes
|
||||
|
||||
logger.debug("_init_osmesa: starting")
|
||||
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
||||
|
||||
logger.debug("_init_osmesa: importing OpenGL.osmesa")
|
||||
from OpenGL import GL as _gl
|
||||
from OpenGL.osmesa import (
|
||||
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
|
||||
OSMESA_RGBA,
|
||||
)
|
||||
logger.debug("_init_osmesa: imports completed")
|
||||
|
||||
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
|
||||
if not ctx:
|
||||
raise RuntimeError("OSMesaCreateContextExt() failed")
|
||||
|
||||
width, height = 64, 64
|
||||
buffer = (ctypes.c_ubyte * (width * height * 4))()
|
||||
|
||||
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
|
||||
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
|
||||
OSMesaDestroyContext(ctx)
|
||||
raise RuntimeError("OSMesaMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_osmesa: completed successfully")
|
||||
return ctx, buffer
|
||||
|
||||
|
||||
class GLContext:
|
||||
"""Manages OpenGL context and resources for shader execution.
|
||||
|
||||
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
|
||||
"""
|
||||
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
@ -327,131 +240,105 @@ class GLContext:
|
||||
|
||||
def __init__(self):
|
||||
if GLContext._initialized:
|
||||
logger.debug("GLContext.__init__: already initialized, skipping")
|
||||
return
|
||||
|
||||
logger.debug("GLContext.__init__: starting initialization")
|
||||
|
||||
global glfw, EGL
|
||||
|
||||
import time
|
||||
start = time.perf_counter()
|
||||
|
||||
self._backend = None
|
||||
self._window = None
|
||||
self._egl_display = None
|
||||
self._egl_context = None
|
||||
self._egl_surface = None
|
||||
self._osmesa_ctx = None
|
||||
self._osmesa_buffer = None
|
||||
self._display = None
|
||||
self._surface = None
|
||||
self._context = None
|
||||
self._vao = None
|
||||
|
||||
# Try backends in order: GLFW → EGL → OSMesa
|
||||
errors = []
|
||||
|
||||
logger.debug("GLContext.__init__: trying GLFW backend")
|
||||
try:
|
||||
self._window, glfw = _init_glfw()
|
||||
self._backend = "glfw"
|
||||
logger.debug("GLContext.__init__: GLFW backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
|
||||
errors.append(("GLFW", e))
|
||||
self._display, self._egl_major, self._egl_minor = _get_egl_display()
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying EGL backend")
|
||||
try:
|
||||
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
|
||||
self._backend = "egl"
|
||||
logger.debug("GLContext.__init__: EGL backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
|
||||
errors.append(("EGL", e))
|
||||
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
|
||||
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying OSMesa backend")
|
||||
try:
|
||||
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
|
||||
self._backend = "osmesa"
|
||||
logger.debug("GLContext.__init__: OSMesa backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
|
||||
errors.append(("OSMesa", e))
|
||||
config = EGL.EGLConfig()
|
||||
n_configs = ctypes.c_int32(0)
|
||||
if not EGL.eglChooseConfig(
|
||||
self._display,
|
||||
_egl_attribs(
|
||||
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
|
||||
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
|
||||
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
|
||||
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
|
||||
),
|
||||
ctypes.byref(config), 1, ctypes.byref(n_configs),
|
||||
) or n_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
|
||||
if self._backend is None:
|
||||
if sys.platform == "win32":
|
||||
platform_help = (
|
||||
"Windows: Ensure GPU drivers are installed and display is available.\n"
|
||||
" CPU-only/headless mode is not supported on Windows."
|
||||
)
|
||||
elif sys.platform == "darwin":
|
||||
platform_help = (
|
||||
"macOS: GLFW is not supported.\n"
|
||||
" Install OSMesa via Homebrew: brew install mesa\n"
|
||||
" Then: pip install PyOpenGL PyOpenGL-accelerate"
|
||||
)
|
||||
else:
|
||||
platform_help = (
|
||||
"Linux: Install one of these backends:\n"
|
||||
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
|
||||
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
|
||||
" Headless (CPU): sudo apt install libosmesa6"
|
||||
)
|
||||
|
||||
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
|
||||
raise RuntimeError(
|
||||
f"Failed to create OpenGL context.\n\n"
|
||||
f"Backend errors:\n{error_details}\n\n"
|
||||
f"{platform_help}"
|
||||
self._surface = EGL.eglCreatePbufferSurface(
|
||||
self._display, config,
|
||||
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
|
||||
)
|
||||
if not self._surface:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
# Now import OpenGL.GL (after context is current)
|
||||
logger.debug("GLContext.__init__: importing OpenGL.GL")
|
||||
_import_opengl()
|
||||
self._context = EGL.eglCreateContext(
|
||||
self._display, config, EGL.EGL_NO_CONTEXT,
|
||||
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
|
||||
)
|
||||
if not self._context:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
# Create VAO (required for core profile, but OSMesa may use compat profile)
|
||||
logger.debug("GLContext.__init__: creating VAO")
|
||||
try:
|
||||
vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(vao)
|
||||
self._vao = vao # Only store after successful bind
|
||||
logger.debug("GLContext.__init__: VAO created successfully")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
|
||||
# OSMesa with older Mesa may not support VAOs
|
||||
# Clean up if we created but couldn't bind
|
||||
if vao:
|
||||
try:
|
||||
gl.glDeleteVertexArrays(1, [vao])
|
||||
except Exception:
|
||||
pass
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
self._vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
except Exception:
|
||||
self._cleanup()
|
||||
raise
|
||||
|
||||
elapsed = (time.perf_counter() - start) * 1000
|
||||
|
||||
# Log device info
|
||||
renderer = gl.glGetString(gl.GL_RENDERER)
|
||||
vendor = gl.glGetString(gl.GL_VENDOR)
|
||||
version = gl.glGetString(gl.GL_VERSION)
|
||||
renderer = renderer.decode() if renderer else "Unknown"
|
||||
vendor = vendor.decode() if vendor else "Unknown"
|
||||
version = version.decode() if version else "Unknown"
|
||||
renderer = _gl_str(gl.GL_RENDERER)
|
||||
vendor = _gl_str(gl.GL_VENDOR)
|
||||
version = _gl_str(gl.GL_VERSION)
|
||||
|
||||
GLContext._initialized = True
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - EGL {self._egl_major}.{self._egl_minor}, {renderer} ({vendor}), GL {version}")
|
||||
|
||||
def make_current(self):
|
||||
if self._backend == "glfw":
|
||||
glfw.make_context_current(self._window)
|
||||
elif self._backend == "egl":
|
||||
from OpenGL.EGL import eglMakeCurrent
|
||||
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
|
||||
elif self._backend == "osmesa":
|
||||
from OpenGL.osmesa import OSMesaMakeCurrent
|
||||
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
|
||||
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
err = EGL.eglGetError()
|
||||
raise RuntimeError(f"eglMakeCurrent() failed (EGL error: 0x{err:04X})")
|
||||
if self._vao is not None:
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
def _cleanup(self):
|
||||
if not self._display:
|
||||
return
|
||||
try:
|
||||
if self._vao is not None:
|
||||
gl.glDeleteVertexArrays(1, [self._vao])
|
||||
self._vao = None
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._context:
|
||||
EGL.eglDestroyContext(self._display, self._context)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._surface:
|
||||
EGL.eglDestroySurface(self._display, self._surface)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglTerminate(self._display)
|
||||
except Exception:
|
||||
pass
|
||||
self._display = None
|
||||
|
||||
|
||||
def _compile_shader(source: str, shader_type: int) -> int:
|
||||
"""Compile a shader and return its ID."""
|
||||
@ -459,8 +346,10 @@ def _compile_shader(source: str, shader_type: int) -> int:
|
||||
gl.glShaderSource(shader, source)
|
||||
gl.glCompileShader(shader)
|
||||
|
||||
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetShaderInfoLog(shader).decode()
|
||||
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
|
||||
error = gl.glGetShaderInfoLog(shader)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
gl.glDeleteShader(shader)
|
||||
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
||||
|
||||
@ -484,8 +373,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
|
||||
gl.glDeleteShader(vertex_shader)
|
||||
gl.glDeleteShader(fragment_shader)
|
||||
|
||||
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetProgramInfoLog(program).decode()
|
||||
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
|
||||
error = gl.glGetProgramInfoLog(program)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
gl.glDeleteProgram(program)
|
||||
raise RuntimeError(f"Program linking failed:\n{error}")
|
||||
|
||||
@ -530,9 +421,6 @@ def _render_shader_batch(
|
||||
ctx = GLContext()
|
||||
ctx.make_current()
|
||||
|
||||
# Convert from GLSL ES to desktop GLSL 330
|
||||
fragment_source = _convert_es_to_desktop(fragment_code)
|
||||
|
||||
# Detect how many outputs the shader actually uses
|
||||
num_outputs = _detect_output_count(fragment_code)
|
||||
|
||||
@ -558,9 +446,9 @@ def _render_shader_batch(
|
||||
try:
|
||||
# Compile shaders (once for all batches)
|
||||
try:
|
||||
program = _create_program(VERTEX_SHADER, fragment_source)
|
||||
program = _create_program(VERTEX_SHADER, fragment_code)
|
||||
except RuntimeError:
|
||||
logger.error(f"Fragment shader:\n{fragment_source}")
|
||||
logger.error(f"Fragment shader:\n{fragment_code}")
|
||||
raise
|
||||
|
||||
gl.glUseProgram(program)
|
||||
@ -723,13 +611,13 @@ def _render_shader_batch(
|
||||
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
||||
|
||||
# Read back outputs for this batch
|
||||
# (glGetTexImage is synchronous, implicitly waits for rendering)
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
||||
batch_outputs = []
|
||||
for tex in output_textures:
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
|
||||
batch_outputs.append(img[::-1, :, :].copy())
|
||||
for i in range(num_outputs):
|
||||
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
|
||||
buf = np.empty((height, width, 4), dtype=np.float32)
|
||||
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
|
||||
batch_outputs.append(buf[::-1, :, :].copy())
|
||||
|
||||
# Pad with black images for unused outputs
|
||||
black_img = np.zeros((height, width, 4), dtype=np.float32)
|
||||
@ -750,18 +638,18 @@ def _render_shader_batch(
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
||||
gl.glUseProgram(0)
|
||||
|
||||
for tex in input_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in curve_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in output_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in ping_pong_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
if input_textures:
|
||||
gl.glDeleteTextures(len(input_textures), input_textures)
|
||||
if curve_textures:
|
||||
gl.glDeleteTextures(len(curve_textures), curve_textures)
|
||||
if output_textures:
|
||||
gl.glDeleteTextures(len(output_textures), output_textures)
|
||||
if ping_pong_textures:
|
||||
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
|
||||
if fbo is not None:
|
||||
gl.glDeleteFramebuffers(1, [fbo])
|
||||
for pp_fbo in ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(1, [pp_fbo])
|
||||
if ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
|
||||
if program is not None:
|
||||
gl.glDeleteProgram(program)
|
||||
|
||||
|
||||
77
comfy_extras/nodes_json_prompt.py
Normal file
77
comfy_extras/nodes_json_prompt.py
Normal file
@ -0,0 +1,77 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_extras.color_util import normalize_palette
|
||||
|
||||
|
||||
class BuildJsonPromptIdeogram(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
color_palette = io.Colors.Input(
|
||||
"color_palette",
|
||||
socketless=False,
|
||||
tooltip="Hex color codes that steer the image's dominant colors. Up to 16 entries.",
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="BuildJsonPromptIdeogram",
|
||||
display_name="Build JSON Prompt (Ideogram)",
|
||||
category="text",
|
||||
description="Build a JSON prompt for the Ideogram 4 model.",
|
||||
inputs=[
|
||||
io.Array.Input("element", tooltip="Prompt elements from the node Create Bounding Boxes."),
|
||||
io.String.Input("high_level_description", multiline=True, default="",
|
||||
tooltip="Optional description of the image in one or two sentences. Strongly recommended."),
|
||||
io.String.Input("background", multiline=True, default="",
|
||||
tooltip="Mandatory description of the image background or environment."),
|
||||
io.DynamicCombo.Input("style", options=[
|
||||
io.DynamicCombo.Option("none", []),
|
||||
io.DynamicCombo.Option("photo", [io.String.Input("photo", default="", tooltip="Camera or lens details for photographic outputs (e.g. 35mm, f/1.4, bokeh).")]),
|
||||
io.DynamicCombo.Option("art_style", [io.String.Input("art_style", default="", tooltip="Art style description (e.g. flat vector illustration, bold outlines).")]),
|
||||
]),
|
||||
io.String.Input("aesthetics", default="", tooltip="Mandatory aesthetic keywords (e.g. moody, cinematic, desaturated)."),
|
||||
io.String.Input("lighting", default="", tooltip="Mandatory lighting description (e.g. golden hour, rim light, dramatic shadows)."),
|
||||
io.String.Input("medium", default="", tooltip="Mandatory medium type (e.g. photograph, illustration, 3d_render, painting, graphic_design). When style = photo, set to photograph."),
|
||||
color_palette,
|
||||
],
|
||||
outputs=[io.Dict.Output(display_name="prompt")],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, element, style, high_level_description="", background="",
|
||||
aesthetics="", lighting="", medium="", color_palette=None) -> io.NodeOutput:
|
||||
elements = element if isinstance(element, list) else []
|
||||
kind = style.get("style", "none") if isinstance(style, dict) else "none"
|
||||
photo = style.get("photo", "") if isinstance(style, dict) else ""
|
||||
art_style = style.get("art_style", "") if isinstance(style, dict) else ""
|
||||
palette = normalize_palette(color_palette or [])
|
||||
|
||||
caption: dict = {}
|
||||
if high_level_description.strip():
|
||||
caption["high_level_description"] = high_level_description
|
||||
if kind != "none":
|
||||
style_desc: dict = {"aesthetics": aesthetics, "lighting": lighting}
|
||||
if kind == "photo":
|
||||
style_desc["photo"] = photo
|
||||
style_desc["medium"] = medium
|
||||
else:
|
||||
style_desc["medium"] = medium
|
||||
style_desc["art_style"] = art_style
|
||||
if palette:
|
||||
style_desc["color_palette"] = palette
|
||||
caption["style_description"] = style_desc
|
||||
caption["compositional_deconstruction"] = {
|
||||
"background": background,
|
||||
"elements": elements,
|
||||
}
|
||||
return io.NodeOutput(caption)
|
||||
|
||||
|
||||
class JsonPromptExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [BuildJsonPromptIdeogram]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> JsonPromptExtension:
|
||||
return JsonPromptExtension()
|
||||
@ -337,6 +337,36 @@ class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeKrea2(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "model/merging/model specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["first."] = argument
|
||||
arg_dict["tmlp."] = argument
|
||||
arg_dict["txtmlp."] = argument
|
||||
arg_dict["tproj."] = argument
|
||||
|
||||
for i in range(2):
|
||||
arg_dict["txtfusion.layerwise_blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["txtfusion.projector."] = argument
|
||||
|
||||
for i in range(2):
|
||||
arg_dict["txtfusion.refiner_blocks.{}.".format(i)] = argument
|
||||
|
||||
for i in range(28):
|
||||
arg_dict["blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["last."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
@ -353,4 +383,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
|
||||
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
|
||||
"ModelMergeQwenImage": ModelMergeQwenImage,
|
||||
"ModelMergeKrea2": ModelMergeKrea2,
|
||||
}
|
||||
|
||||
33
comfy_extras/nodes_seed.py
Normal file
33
comfy_extras/nodes_seed.py
Normal file
@ -0,0 +1,33 @@
|
||||
import sys
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class SeedNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SeedNode",
|
||||
display_name="Seed",
|
||||
search_aliases=["seed", "random"],
|
||||
category="utilities",
|
||||
inputs=[
|
||||
io.Int.Input("seed", min=0, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed),
|
||||
],
|
||||
outputs=[io.Int.Output(display_name="seed")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, seed: int) -> io.NodeOutput:
|
||||
return io.NodeOutput(seed)
|
||||
|
||||
|
||||
class SeedExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [SeedNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> SeedExtension:
|
||||
return SeedExtension()
|
||||
@ -440,6 +440,57 @@ class JsonExtractString(io.ComfyNode):
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return io.NodeOutput("")
|
||||
|
||||
|
||||
def _dump_json(value, indent):
|
||||
return json.dumps(value, ensure_ascii=False, indent=indent or None)
|
||||
|
||||
|
||||
class ConvertDictionaryToString(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ConvertDictionaryToString",
|
||||
display_name="Convert Dictionary to String",
|
||||
category="text",
|
||||
search_aliases=["json", "dict to json", "stringify", "serialize", "dict to string"],
|
||||
inputs=[
|
||||
io.Dict.Input("dictionary"),
|
||||
io.Int.Input("indent", default=2, min=0, max=8,
|
||||
tooltip="Spaces per indent level. 0 produces compact single-line string."),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, dictionary, indent=2):
|
||||
return io.NodeOutput(_dump_json(dictionary, indent))
|
||||
|
||||
|
||||
class ConvertArrayToString(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ConvertArrayToString",
|
||||
display_name="Convert Array to String",
|
||||
category="text",
|
||||
search_aliases=["json", "list to json", "stringify", "serialize", "list to string", "array to json"],
|
||||
inputs=[
|
||||
io.Array.Input("array"),
|
||||
io.Int.Input("indent", default=2, min=0, max=8,
|
||||
tooltip="Spaces per indent level. 0 produces compact single-line string."),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, array, indent=2):
|
||||
return io.NodeOutput(_dump_json(array, indent))
|
||||
|
||||
|
||||
class StringExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@ -457,6 +508,8 @@ class StringExtension(ComfyExtension):
|
||||
RegexExtract,
|
||||
RegexReplace,
|
||||
JsonExtractString,
|
||||
ConvertDictionaryToString,
|
||||
ConvertArrayToString,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> StringExtension:
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.26.2"
|
||||
__version__ = "0.26.0"
|
||||
|
||||
3
nodes.py
3
nodes.py
@ -2374,6 +2374,8 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_images.py",
|
||||
"nodes_video_model.py",
|
||||
"nodes_ideogram4.py",
|
||||
"nodes_bounding_boxes.py",
|
||||
"nodes_json_prompt.py",
|
||||
"nodes_train.py",
|
||||
"nodes_dataset.py",
|
||||
"nodes_sag.py",
|
||||
@ -2473,6 +2475,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_gaussian_splat.py",
|
||||
"nodes_triposplat.py",
|
||||
"nodes_depth_anything_3.py",
|
||||
"nodes_seed.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
18
openapi.yaml
18
openapi.yaml
@ -1692,6 +1692,12 @@ paths:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Unsupported media type
|
||||
"422":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Validation error (e.g., disallowed model_type tag)
|
||||
"500":
|
||||
content:
|
||||
application/json:
|
||||
@ -2137,6 +2143,12 @@ paths:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Source asset with given hash not found
|
||||
"422":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Validation error (e.g., disallowed model_type tag)
|
||||
"500":
|
||||
content:
|
||||
application/json:
|
||||
@ -2357,6 +2369,10 @@ paths:
|
||||
description: |
|
||||
Returns a list of model folders available in the system.
|
||||
This is an experimental endpoint that replaces the legacy /models endpoint.
|
||||
Each folder's name is the identifier to pass to /api/experiment/models/{folder}.
|
||||
Once the model_type migration is active the names are model_type folder_names
|
||||
(e.g. `ultralytics_bbox`); a folder with no folder_name mapping is returned by
|
||||
its directory path.
|
||||
operationId: getModelFolders
|
||||
responses:
|
||||
"200":
|
||||
@ -2988,7 +3004,7 @@ paths:
|
||||
format: uuid
|
||||
type: string
|
||||
- description: |
|
||||
When present, each output item in the response receives a `short_url` field containing an owner-gated durable link for that asset. Omit this parameter (the default) to receive a response identical to the no-param baseline. The value selects the link's lifetime: use `ephemeral_tool_chain` for short-lived machine-to-machine handoffs (~15 minutes); use `default` for durable human-revisitable links (30 days). Links are minted only for the authenticated request owner and are not resolvable by other users.
|
||||
When present, each output item in the response receives a `short_url` field containing a short link for that asset. Omit this parameter (the default) to receive a response identical to the no-param baseline. The value selects the link's lifetime and auth model: use `ephemeral_tool_chain` for short-lived (≤5 minute) machine-to-machine handoffs — these are public bearer links where the link ID itself is the credential, so anyone holding the link can resolve it (intended for pasting into an agent/MCP tool chain); use `default` for durable (30 day) human-revisitable links, which are owner-gated and resolvable only by the authenticated owner. Links are always minted under the authenticated request owner's identity; the auth model is selected by the server and is never settable by the caller.
|
||||
in: query
|
||||
name: short_link
|
||||
schema:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.26.2"
|
||||
version = "0.26.0"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=16.0.0
|
||||
comfy-kitchen==0.2.10
|
||||
comfy-kitchen==0.2.13
|
||||
comfy-aimdo==0.4.10
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
@ -33,5 +33,5 @@ kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
PyOpenGL
|
||||
glfw
|
||||
PyOpenGL>=3.1.8
|
||||
comfy-angle
|
||||
|
||||
204
tests-unit/comfy_api_test/io_dynamic_group_test.py
Normal file
204
tests-unit/comfy_api_test/io_dynamic_group_test.py
Normal file
@ -0,0 +1,204 @@
|
||||
"""Unit tests for io.DynamicGroup: expansion/reconstruction (0-row and N-row cases)."""
|
||||
import sys
|
||||
import types
|
||||
import pytest
|
||||
|
||||
# Stub torch (type-hint only in _io.py; real torch not available in unit-test env)
|
||||
if "torch" not in sys.modules:
|
||||
_torch_stub = types.ModuleType("torch")
|
||||
_torch_stub.Tensor = object # type: ignore[attr-defined]
|
||||
sys.modules["torch"] = _torch_stub
|
||||
|
||||
from comfy_api.latest._io import ( # noqa: E402
|
||||
DynamicGroup,
|
||||
Float,
|
||||
Int,
|
||||
String,
|
||||
Boolean,
|
||||
get_finalized_class_inputs,
|
||||
build_nested_inputs,
|
||||
create_input_dict_v1,
|
||||
setup_dynamic_input_funcs,
|
||||
)
|
||||
|
||||
# Make sure dynamic input funcs are registered (may already be done at import time)
|
||||
setup_dynamic_input_funcs()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_class_inputs(group_input: DynamicGroup.Input) -> dict:
|
||||
"""Wrap a DynamicGroup.Input into the required/optional dict structure."""
|
||||
return create_input_dict_v1([group_input])
|
||||
|
||||
|
||||
def _run(group_input: DynamicGroup.Input, live_values: dict) -> dict:
|
||||
"""End-to-end helper: expand schema + reconstruct values.
|
||||
|
||||
Mirrors the production split in execution.py:
|
||||
1. get_finalized_class_inputs (schema expansion, line 162)
|
||||
2. build_nested_inputs (value reconstruction, line 281)
|
||||
|
||||
The two steps are separate in production because the engine resolves
|
||||
linked node outputs between them, but in tests we supply values directly.
|
||||
"""
|
||||
class_inputs = _make_class_inputs(group_input)
|
||||
_, _, v3_data = get_finalized_class_inputs(class_inputs, live_values)
|
||||
return build_nested_inputs(dict(live_values), v3_data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDynamicGroupInputConstruction:
|
||||
def test_basic_construction(self):
|
||||
inp = DynamicGroup.Input(
|
||||
"loras",
|
||||
template=[
|
||||
Float.Input("strength", default=1.0),
|
||||
String.Input("name"),
|
||||
],
|
||||
min=0,
|
||||
max=10,
|
||||
)
|
||||
assert inp.id == "loras"
|
||||
assert inp.min == 0
|
||||
assert inp.max == 10
|
||||
assert len(inp.template) == 2
|
||||
|
||||
def test_get_all_includes_self_and_template(self):
|
||||
inp = DynamicGroup.Input(
|
||||
"items",
|
||||
template=[Float.Input("value")],
|
||||
)
|
||||
all_inputs = inp.get_all()
|
||||
assert all_inputs[0] is inp
|
||||
assert all_inputs[1].id == "value"
|
||||
|
||||
def test_as_dict_has_template_min_max(self):
|
||||
inp = DynamicGroup.Input(
|
||||
"items",
|
||||
template=[Float.Input("val", default=0.5)],
|
||||
min=1,
|
||||
max=5,
|
||||
)
|
||||
d = inp.as_dict()
|
||||
assert "template" in d
|
||||
assert d["min"] == 1
|
||||
assert d["max"] == 5
|
||||
|
||||
def test_duplicate_field_ids_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input(
|
||||
"bad",
|
||||
template=[Float.Input("x"), Float.Input("x")],
|
||||
)
|
||||
|
||||
def test_empty_template_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input("bad", template=[])
|
||||
|
||||
def test_min_gt_max_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input("bad", template=[Float.Input("x")], min=5, max=3)
|
||||
|
||||
def test_max_exceeds_limit_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input("bad", template=[Float.Input("x")], max=101)
|
||||
|
||||
def test_dynamic_input_in_template_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input(
|
||||
"bad",
|
||||
template=[DynamicGroup.Input("nested", template=[Float.Input("x")])],
|
||||
)
|
||||
|
||||
def test_validate_calls_through(self):
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", min=-1.0, max=1.0)])
|
||||
inp.validate() # should not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 0-row case
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestZeroRows:
|
||||
def test_empty_live_inputs_produces_empty_list(self):
|
||||
"""With min=0 and no live values, the result should be an empty list."""
|
||||
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
|
||||
assert _run(inp, {}).get("loras") == []
|
||||
|
||||
def test_min_zero_with_values(self):
|
||||
"""min=0 but 2 rows of live data."""
|
||||
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
|
||||
result = _run(inp, {"loras.0.strength": 0.8, "loras.1.strength": 0.5})
|
||||
assert result["loras"] == [{"strength": 0.8}, {"strength": 0.5}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# N-row case
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNRows:
|
||||
def test_two_rows_two_fields(self):
|
||||
"""Two rows with two fields each produce a list[dict]."""
|
||||
inp = DynamicGroup.Input(
|
||||
"loras",
|
||||
template=[String.Input("lora_name"), Float.Input("strength", default=1.0)],
|
||||
min=0, max=50,
|
||||
)
|
||||
result = _run(inp, {
|
||||
"loras.0.lora_name": "model_a.safetensors", "loras.0.strength": 0.9,
|
||||
"loras.1.lora_name": "model_b.safetensors", "loras.1.strength": 0.4,
|
||||
})
|
||||
assert result["loras"] == [
|
||||
{"lora_name": "model_a.safetensors", "strength": 0.9},
|
||||
{"lora_name": "model_b.safetensors", "strength": 0.4},
|
||||
]
|
||||
|
||||
def test_rows_are_sorted_by_index(self):
|
||||
"""Rows must be in ascending index order even if dict iteration is unordered."""
|
||||
inp = DynamicGroup.Input("items", template=[Int.Input("v", default=0)], min=0, max=10)
|
||||
result = _run(inp, {"items.0.v": 10, "items.2.v": 30, "items.1.v": 20})
|
||||
assert [row["v"] for row in result["items"]] == [10, 20, 30]
|
||||
|
||||
def test_min_rows_schema_slots(self):
|
||||
"""With min=2 and no live data, 2 slots must appear in the expanded schema."""
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
out, _, _ = get_finalized_class_inputs(_make_class_inputs(inp), {})
|
||||
all_slots = {**out.get("required", {}), **out.get("optional", {})}
|
||||
assert "items.0.val" in all_slots
|
||||
assert "items.1.val" in all_slots
|
||||
|
||||
def test_min_rows_reconstructs_when_no_values(self):
|
||||
"""min=2 with NO live values must still yield a 2-element list,
|
||||
not collapse to [] (regression: parent-path clobber)."""
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
result = _run(inp, {})
|
||||
assert len(result["items"]) == 2
|
||||
assert all("val" in row for row in result["items"])
|
||||
|
||||
def test_min_rows_reconstructs_with_partial_values(self):
|
||||
"""min=2 with only the first row's value present still yields 2 rows."""
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
result = _run(inp, {"items.0.val": 0.7})
|
||||
assert len(result["items"]) == 2
|
||||
assert result["items"][0]["val"] == 0.7
|
||||
assert result["items"][1]["val"] is None
|
||||
|
||||
def test_list_paths_in_v3_data(self):
|
||||
"""list_paths must contain the group id so build_nested_inputs knows to convert."""
|
||||
inp = DynamicGroup.Input("things", template=[Boolean.Input("flag")], min=0, max=5)
|
||||
_, _, v3_data = get_finalized_class_inputs(_make_class_inputs(inp), {})
|
||||
assert "things" in v3_data.get("list_paths", set())
|
||||
|
||||
def test_no_leftover_flat_keys(self):
|
||||
"""Flat keys must be consumed; only the reconstructed list remains."""
|
||||
inp = DynamicGroup.Input("rows", template=[Float.Input("x", default=0.0)], min=0, max=5)
|
||||
result = _run(inp, {"rows.0.x": 1.0, "rows.1.x": 2.0})
|
||||
assert "rows.0.x" not in result
|
||||
assert "rows.1.x" not in result
|
||||
assert isinstance(result["rows"], list)
|
||||
@ -228,6 +228,62 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
with self.assertRaises(KeyError):
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
def test_int8_convrot_metadata_loads_into_params(self):
|
||||
"""ConvRot metadata must reach TensorWiseINT8Layout params."""
|
||||
torch.manual_seed(123)
|
||||
layer_quant_config = {
|
||||
"layer": {
|
||||
"format": "int8_tensorwise",
|
||||
"convrot": True,
|
||||
"convrot_groupsize": 256,
|
||||
}
|
||||
}
|
||||
weight = torch.randn(16, 256, dtype=torch.bfloat16)
|
||||
bias = torch.randn(16, dtype=torch.bfloat16)
|
||||
q_weight = QuantizedTensor.from_float(
|
||||
weight,
|
||||
"TensorWiseINT8Layout",
|
||||
per_channel=True,
|
||||
convrot=True,
|
||||
convrot_groupsize=256,
|
||||
)
|
||||
state_dict = {
|
||||
"layer.weight": q_weight._qdata,
|
||||
"layer.bias": bias,
|
||||
"layer.weight_scale": q_weight._params.scale,
|
||||
}
|
||||
|
||||
state_dict, _ = comfy.utils.convert_old_quants(
|
||||
state_dict,
|
||||
metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})},
|
||||
)
|
||||
model = torch.nn.Module()
|
||||
model.layer = ops.mixed_precision_ops({}).Linear(256, 16, device="cpu", dtype=torch.bfloat16)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
self.assertIsInstance(model.layer.weight, QuantizedTensor)
|
||||
self.assertEqual(model.layer.weight._layout_cls, "TensorWiseINT8Layout")
|
||||
self.assertTrue(model.layer.weight._params.convrot)
|
||||
self.assertEqual(model.layer.weight._params.convrot_groupsize, 256)
|
||||
|
||||
input_tensor = torch.randn(4, 256, dtype=torch.bfloat16)
|
||||
loaded_out = model.layer(input_tensor)
|
||||
ref_out = torch.nn.functional.linear(input_tensor, q_weight, bias)
|
||||
self.assertTrue(torch.equal(loaded_out, ref_out))
|
||||
|
||||
fp16_input = input_tensor.to(torch.float16)
|
||||
loaded_fp16_out = model.layer(fp16_input)
|
||||
ref_fp16_out = torch.nn.functional.linear(
|
||||
fp16_input,
|
||||
q_weight.to(dtype=torch.float16),
|
||||
bias.to(dtype=torch.float16),
|
||||
)
|
||||
self.assertTrue(torch.equal(loaded_fp16_out, ref_fp16_out))
|
||||
|
||||
saved = model.state_dict()
|
||||
saved_conf = json.loads(saved["layer.comfy_quant"].numpy().tobytes())
|
||||
self.assertTrue(saved_conf["convrot"])
|
||||
self.assertEqual(saved_conf["convrot_groupsize"], 256)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user