mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-27 08:26:55 +08:00
Compare commits
7 Commits
feat/ideog
...
fix/surfac
| Author | SHA1 | Date | |
|---|---|---|---|
| ffdc23c6dd | |||
| 91f3c0c4d9 | |||
| 7cb784e0f4 | |||
| 1a510f0423 | |||
| 639c8fa788 | |||
| e22f1500f9 | |||
| dac4ea3a80 |
59
comfy/ops.py
59
comfy/ops.py
@ -1089,6 +1089,19 @@ def _load_quantized_module(module, super_load, state_dict, prefix, local_metadat
|
||||
if ts is None or bs is None:
|
||||
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||
scales = {"scale": ts, "block_scale": bs}
|
||||
elif module.quant_format == "int8_tensorwise":
|
||||
scale = pop_scale("weight_scale")
|
||||
if scale is None:
|
||||
raise ValueError(f"Missing INT8 weight scale for layer {layer_name}")
|
||||
scales = {"scale": scale}
|
||||
params_conf = layer_conf.get("params", {})
|
||||
if not isinstance(params_conf, dict):
|
||||
params_conf = {}
|
||||
if layer_conf.get("convrot", params_conf.get("convrot", False)):
|
||||
scales["convrot"] = True
|
||||
scales["convrot_groupsize"] = int(
|
||||
layer_conf.get("convrot_groupsize", params_conf.get("convrot_groupsize", 256))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
|
||||
|
||||
@ -1131,6 +1144,10 @@ def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extr
|
||||
quant_conf = {"format": module.quant_format}
|
||||
if getattr(module, '_full_precision_mm_config', False):
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
params = getattr(module.weight, "_params", None)
|
||||
if module.quant_format == "int8_tensorwise" and getattr(params, "convrot", False):
|
||||
quant_conf["convrot"] = True
|
||||
quant_conf["convrot_groupsize"] = getattr(params, "convrot_groupsize", 256)
|
||||
if extra_quant_conf:
|
||||
quant_conf.update(extra_quant_conf)
|
||||
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
|
||||
@ -1183,8 +1200,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
|
||||
def forward_comfy_cast_weights(
|
||||
self,
|
||||
input,
|
||||
compute_dtype=None,
|
||||
want_requant=False,
|
||||
weight_only_quant=False,
|
||||
):
|
||||
if weight_only_quant:
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
input=None,
|
||||
dtype=self.weight.dtype,
|
||||
device=input.device,
|
||||
bias_dtype=input.dtype,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=want_requant,
|
||||
)
|
||||
weight = weight.to(dtype=input.dtype)
|
||||
else:
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
input,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=want_requant,
|
||||
)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
@ -1203,9 +1245,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0
|
||||
)
|
||||
quantize_input = QUANT_ALGOS.get(getattr(self, 'quant_format', None), {}).get("quantize_input", True)
|
||||
|
||||
# Training path: quantized forward with compute_dtype backward via autograd function
|
||||
if (input.requires_grad and _use_quantized):
|
||||
if (input.requires_grad and _use_quantized and quantize_input):
|
||||
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
@ -1227,7 +1270,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
return output
|
||||
|
||||
# Inference path (unchanged)
|
||||
if _use_quantized:
|
||||
if _use_quantized and quantize_input:
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
@ -1241,7 +1284,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
||||
weight_only_quant = _use_quantized and not quantize_input and isinstance(self.weight, QuantizedTensor)
|
||||
output = self.forward_comfy_cast_weights(
|
||||
input,
|
||||
compute_dtype,
|
||||
want_requant=isinstance(input, QuantizedTensor),
|
||||
weight_only_quant=weight_only_quant,
|
||||
)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if reshaped_3d:
|
||||
|
||||
@ -10,6 +10,7 @@ try:
|
||||
QuantizedLayout,
|
||||
TensorCoreFP8Layout as _CKFp8Layout,
|
||||
TensorCoreNVFP4Layout as _CKNvfp4Layout,
|
||||
TensorWiseINT8Layout as _CKTensorWiseINT8Layout,
|
||||
register_layout_op,
|
||||
register_layout_class,
|
||||
get_layout_class,
|
||||
@ -47,6 +48,9 @@ except ImportError as e:
|
||||
class _CKNvfp4Layout:
|
||||
pass
|
||||
|
||||
class _CKTensorWiseINT8Layout:
|
||||
pass
|
||||
|
||||
def register_layout_class(name, cls):
|
||||
pass
|
||||
|
||||
@ -174,6 +178,7 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
|
||||
|
||||
# Backward compatibility alias - default to E4M3
|
||||
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
||||
TensorWiseINT8Layout = _CKTensorWiseINT8Layout
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
@ -184,6 +189,7 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||
register_layout_class("TensorWiseINT8Layout", _CKTensorWiseINT8Layout)
|
||||
if _CK_MXFP8_AVAILABLE:
|
||||
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||
|
||||
@ -214,6 +220,13 @@ if _CK_MXFP8_AVAILABLE:
|
||||
"group_size": 32,
|
||||
}
|
||||
|
||||
QUANT_ALGOS["int8_tensorwise"] = {
|
||||
"storage_t": torch.int8,
|
||||
"parameters": {"weight_scale"},
|
||||
"comfy_tensor_layout": "TensorWiseINT8Layout",
|
||||
"quantize_input": False,
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Re-exports for backward compatibility
|
||||
@ -226,6 +239,7 @@ __all__ = [
|
||||
"TensorCoreFP8E4M3Layout",
|
||||
"TensorCoreFP8E5M2Layout",
|
||||
"TensorCoreNVFP4Layout",
|
||||
"TensorWiseINT8Layout",
|
||||
"QUANT_ALGOS",
|
||||
"register_layout_op",
|
||||
]
|
||||
|
||||
@ -177,6 +177,10 @@ SEEDANCE2_PRICE_PER_1K_TOKENS = {
|
||||
("dreamina-seedance-2-0-fast-260128", True, "480p"): 0.0033,
|
||||
("dreamina-seedance-2-0-fast-260128", False, "720p"): 0.0056,
|
||||
("dreamina-seedance-2-0-fast-260128", True, "720p"): 0.0033,
|
||||
("dreamina-seedance-2-0-mini", False, "480p"): 0.0035,
|
||||
("dreamina-seedance-2-0-mini", True, "480p"): 0.0021,
|
||||
("dreamina-seedance-2-0-mini", False, "720p"): 0.0035,
|
||||
("dreamina-seedance-2-0-mini", True, "720p"): 0.0021,
|
||||
}
|
||||
|
||||
|
||||
@ -278,6 +282,10 @@ SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
|
||||
"480p": {"min": 409_600, "max": 927_408},
|
||||
"720p": {"min": 409_600, "max": 927_408},
|
||||
},
|
||||
"dreamina-seedance-2-0-mini": {
|
||||
"480p": {"min": 409_600, "max": 927_408},
|
||||
"720p": {"min": 409_600, "max": 927_408},
|
||||
},
|
||||
}
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
|
||||
@ -89,6 +89,7 @@ BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT = "/proxy/byteplus-seedance2/api/v3/cont
|
||||
SEEDANCE_MODELS = {
|
||||
"Seedance 2.0": "dreamina-seedance-2-0-260128",
|
||||
"Seedance 2.0 Fast": "dreamina-seedance-2-0-fast-260128",
|
||||
"Seedance 2.0 Mini": "dreamina-seedance-2-0-mini",
|
||||
}
|
||||
|
||||
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
|
||||
@ -1623,8 +1624,10 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p", "4k"])),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Mini", _seedance2_text_inputs(["480p", "720p"])),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
|
||||
"Mini for the fastest, lowest-cost generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
@ -1666,6 +1669,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$pricePer1K := $res = "4k" ? 0.00572 :
|
||||
$res = "1080p" ? 0.011011 :
|
||||
$contains($m, "mini") ? 0.005005 :
|
||||
$contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$rate := $res = "4k" ? $rate4k :
|
||||
$res = "1080p" ? $rate1080 :
|
||||
@ -1734,8 +1738,13 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
"Seedance 2.0 Fast",
|
||||
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0 Mini",
|
||||
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
|
||||
"Mini for the fastest, lowest-cost generation.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"first_frame",
|
||||
@ -1801,6 +1810,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$pricePer1K := $res = "4k" ? 0.00572 :
|
||||
$res = "1080p" ? 0.011011 :
|
||||
$contains($m, "mini") ? 0.005005 :
|
||||
$contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$rate := $res = "4k" ? $rate4k :
|
||||
$res = "1080p" ? $rate1080 :
|
||||
@ -2024,8 +2034,13 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
"Seedance 2.0 Fast",
|
||||
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0 Mini",
|
||||
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
|
||||
"Mini for the fastest, lowest-cost generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
@ -2071,9 +2086,11 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$noVideoPricePer1K := $res = "4k" ? 0.00572 :
|
||||
$res = "1080p" ? 0.011011 :
|
||||
$contains($m, "mini") ? 0.005005 :
|
||||
$contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$videoPricePer1K := $res = "4k" ? 0.003432 :
|
||||
$res = "1080p" ? 0.006721 :
|
||||
$contains($m, "mini") ? 0.003003 :
|
||||
$contains($m, "fast") ? 0.004719 : 0.006149;
|
||||
$rate := $res = "4k" ? $rate4k :
|
||||
$res = "1080p" ? $rate1080 :
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Type, Literal
|
||||
import nodes
|
||||
import asyncio
|
||||
import inspect
|
||||
import traceback
|
||||
from comfy_execution.graph_utils import is_link, ExecutionBlocker
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||
|
||||
@ -263,7 +264,25 @@ class ExecutionList(TopologicalSort):
|
||||
}
|
||||
return None, error_details, ex
|
||||
|
||||
self.staged_node_id = self.ux_friendly_pick_node(available)
|
||||
try:
|
||||
self.staged_node_id = self.ux_friendly_pick_node(available)
|
||||
except Exception as ex:
|
||||
# Backstop: the ordering heuristics in ux_friendly_pick_node are
|
||||
# defensive, but should anything else there fail, surface it as an
|
||||
# execution error instead of letting it kill the prompt worker
|
||||
# thread. Blame an available node (best effort).
|
||||
blamed_node = self.dynprompt.get_display_node_id(available[0])
|
||||
exception_type = type(ex).__qualname__
|
||||
if type(ex).__module__ != "builtins":
|
||||
exception_type = type(ex).__module__ + "." + exception_type
|
||||
error_details = {
|
||||
"node_id": blamed_node,
|
||||
"exception_message": str(ex),
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(ex.__traceback__),
|
||||
"current_inputs": []
|
||||
}
|
||||
return None, error_details, ex
|
||||
return self.staged_node_id, None, None
|
||||
|
||||
def ux_friendly_pick_node(self, node_list):
|
||||
@ -271,19 +290,28 @@ class ExecutionList(TopologicalSort):
|
||||
# Technically this has no effect on the overall length of execution, but it feels better as a user
|
||||
# for a PreviewImage to display a result as soon as it can
|
||||
# Some other heuristics could probably be used here to improve the UX further.
|
||||
# These node-ordering heuristics only affect *order*, never correctness.
|
||||
# A malformed node (e.g. a FUNCTION typo, or a node whose schema-derived
|
||||
# attributes raise) must not crash scheduling: failing a heuristic just
|
||||
# means "not prioritized". The node then proceeds to normal execution,
|
||||
# where the real error is raised and reported against the correct node.
|
||||
def is_output(node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
return True
|
||||
return False
|
||||
try:
|
||||
return hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# If an available node is async, do that first.
|
||||
# This will execute the asynchronous function earlier, reducing the overall time.
|
||||
def is_async(node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
|
||||
try:
|
||||
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
for node_id in node_list:
|
||||
if is_output(node_id) or is_async(node_id):
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.45.19
|
||||
comfyui-workflow-templates==0.10.2
|
||||
comfyui-workflow-templates==0.10.7
|
||||
comfyui-embedded-docs==0.5.5
|
||||
torch
|
||||
torchsde
|
||||
@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=16.0.0
|
||||
comfy-kitchen==0.2.10
|
||||
comfy-kitchen==0.2.12
|
||||
comfy-aimdo==0.4.10
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
|
||||
@ -228,6 +228,62 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
with self.assertRaises(KeyError):
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
def test_int8_convrot_metadata_loads_into_params(self):
|
||||
"""ConvRot metadata must reach TensorWiseINT8Layout params."""
|
||||
torch.manual_seed(123)
|
||||
layer_quant_config = {
|
||||
"layer": {
|
||||
"format": "int8_tensorwise",
|
||||
"convrot": True,
|
||||
"convrot_groupsize": 256,
|
||||
}
|
||||
}
|
||||
weight = torch.randn(16, 256, dtype=torch.bfloat16)
|
||||
bias = torch.randn(16, dtype=torch.bfloat16)
|
||||
q_weight = QuantizedTensor.from_float(
|
||||
weight,
|
||||
"TensorWiseINT8Layout",
|
||||
per_channel=True,
|
||||
convrot=True,
|
||||
convrot_groupsize=256,
|
||||
)
|
||||
state_dict = {
|
||||
"layer.weight": q_weight._qdata,
|
||||
"layer.bias": bias,
|
||||
"layer.weight_scale": q_weight._params.scale,
|
||||
}
|
||||
|
||||
state_dict, _ = comfy.utils.convert_old_quants(
|
||||
state_dict,
|
||||
metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})},
|
||||
)
|
||||
model = torch.nn.Module()
|
||||
model.layer = ops.mixed_precision_ops({}).Linear(256, 16, device="cpu", dtype=torch.bfloat16)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
self.assertIsInstance(model.layer.weight, QuantizedTensor)
|
||||
self.assertEqual(model.layer.weight._layout_cls, "TensorWiseINT8Layout")
|
||||
self.assertTrue(model.layer.weight._params.convrot)
|
||||
self.assertEqual(model.layer.weight._params.convrot_groupsize, 256)
|
||||
|
||||
input_tensor = torch.randn(4, 256, dtype=torch.bfloat16)
|
||||
loaded_out = model.layer(input_tensor)
|
||||
ref_out = torch.nn.functional.linear(input_tensor, q_weight, bias)
|
||||
self.assertTrue(torch.equal(loaded_out, ref_out))
|
||||
|
||||
fp16_input = input_tensor.to(torch.float16)
|
||||
loaded_fp16_out = model.layer(fp16_input)
|
||||
ref_fp16_out = torch.nn.functional.linear(
|
||||
fp16_input,
|
||||
q_weight.to(dtype=torch.float16),
|
||||
bias.to(dtype=torch.float16),
|
||||
)
|
||||
self.assertTrue(torch.equal(loaded_fp16_out, ref_fp16_out))
|
||||
|
||||
saved = model.state_dict()
|
||||
saved_conf = json.loads(saved["layer.comfy_quant"].numpy().tobytes())
|
||||
self.assertTrue(saved_conf["convrot"])
|
||||
self.assertEqual(saved_conf["convrot_groupsize"], 256)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
|
||||
97
tests-unit/execution_test/scheduler_malformed_node_test.py
Normal file
97
tests-unit/execution_test/scheduler_malformed_node_test.py
Normal file
@ -0,0 +1,97 @@
|
||||
"""Regression tests for scheduler resilience to malformed nodes.
|
||||
|
||||
A node whose FUNCTION points at a method that does not exist (e.g. a typo in a
|
||||
custom node) used to raise inside the scheduling heuristic, escaping the prompt
|
||||
worker's error handling and silently killing the worker thread. Scheduling must
|
||||
instead either proceed (so the error surfaces through normal execution) or report
|
||||
the failure as an execution error.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
import nodes
|
||||
from comfy_execution.graph import DynamicPrompt, ExecutionList
|
||||
|
||||
|
||||
class _MalformedV1Node:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "invert" # the actual method below is misspelled
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "Test"
|
||||
|
||||
def invvert(self):
|
||||
return (None,)
|
||||
|
||||
|
||||
class _RaisingDescriptor:
|
||||
def __get__(self, obj, owner):
|
||||
raise RuntimeError("schema error")
|
||||
|
||||
|
||||
class _SchemaRaisesNode:
|
||||
"""A node whose schema-derived attribute access raises, as a broken V3 node would."""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "run"
|
||||
OUTPUT_NODE = _RaisingDescriptor()
|
||||
CATEGORY = "Test"
|
||||
|
||||
def run(self):
|
||||
return (None,)
|
||||
|
||||
|
||||
class _FakeOutputCache:
|
||||
def all_node_ids(self):
|
||||
return set()
|
||||
|
||||
async def get(self, node_id):
|
||||
return None
|
||||
|
||||
|
||||
def _make_execution_list(class_type, class_def):
|
||||
nodes.NODE_CLASS_MAPPINGS[class_type] = class_def
|
||||
prompt = {"1": {"class_type": class_type, "inputs": {}}}
|
||||
execution_list = ExecutionList(DynamicPrompt(prompt), _FakeOutputCache())
|
||||
execution_list.add_node("1")
|
||||
return execution_list
|
||||
|
||||
|
||||
def test_malformed_function_does_not_crash_scheduler():
|
||||
"""A FUNCTION-typo node schedules without raising; the error surfaces later."""
|
||||
execution_list = _make_execution_list("MalformedV1Node", _MalformedV1Node)
|
||||
node_id, error, ex = asyncio.run(execution_list.stage_node_execution())
|
||||
assert ex is None
|
||||
assert error is None
|
||||
assert node_id == "1"
|
||||
|
||||
|
||||
def test_schema_attribute_error_does_not_crash_scheduler():
|
||||
"""A node whose attribute access raises during heuristics still schedules."""
|
||||
execution_list = _make_execution_list("SchemaRaisesNode", _SchemaRaisesNode)
|
||||
node_id, error, ex = asyncio.run(execution_list.stage_node_execution())
|
||||
assert ex is None
|
||||
assert error is None
|
||||
assert node_id == "1"
|
||||
|
||||
|
||||
def test_pick_node_failure_is_reported_not_raised():
|
||||
"""An unexpected scheduling error is returned as an error, not raised."""
|
||||
execution_list = _make_execution_list("MalformedV1Node", _MalformedV1Node)
|
||||
|
||||
def raise_on_pick(_available):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
execution_list.ux_friendly_pick_node = raise_on_pick
|
||||
node_id, error, ex = asyncio.run(execution_list.stage_node_execution())
|
||||
assert node_id is None
|
||||
assert isinstance(ex, RuntimeError)
|
||||
assert error["node_id"] == "1"
|
||||
assert error["exception_type"] == "RuntimeError"
|
||||
assert error["exception_message"] == "boom"
|
||||
assert error["traceback"]
|
||||
Reference in New Issue
Block a user