Compare commits

..

2 Commits

8 changed files with 16 additions and 285 deletions

View File

@ -1089,19 +1089,6 @@ def _load_quantized_module(module, super_load, state_dict, prefix, local_metadat
if ts is None or bs is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
scales = {"scale": ts, "block_scale": bs}
elif module.quant_format == "int8_tensorwise":
scale = pop_scale("weight_scale")
if scale is None:
raise ValueError(f"Missing INT8 weight scale for layer {layer_name}")
scales = {"scale": scale}
params_conf = layer_conf.get("params", {})
if not isinstance(params_conf, dict):
params_conf = {}
if layer_conf.get("convrot", params_conf.get("convrot", False)):
scales["convrot"] = True
scales["convrot_groupsize"] = int(
layer_conf.get("convrot_groupsize", params_conf.get("convrot_groupsize", 256))
)
else:
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
@ -1144,10 +1131,6 @@ def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extr
quant_conf = {"format": module.quant_format}
if getattr(module, '_full_precision_mm_config', False):
quant_conf["full_precision_matrix_mult"] = True
params = getattr(module.weight, "_params", None)
if module.quant_format == "int8_tensorwise" and getattr(params, "convrot", False):
quant_conf["convrot"] = True
quant_conf["convrot_groupsize"] = getattr(params, "convrot_groupsize", 256)
if extra_quant_conf:
quant_conf.update(extra_quant_conf)
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
@ -1200,33 +1183,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(
self,
input,
compute_dtype=None,
want_requant=False,
weight_only_quant=False,
):
if weight_only_quant:
weight, bias, offload_stream = cast_bias_weight(
self,
input=None,
dtype=self.weight.dtype,
device=input.device,
bias_dtype=input.dtype,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=want_requant,
)
weight = weight.to(dtype=input.dtype)
else:
weight, bias, offload_stream = cast_bias_weight(
self,
input,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=want_requant,
)
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
@ -1245,10 +1203,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0
)
quantize_input = QUANT_ALGOS.get(getattr(self, 'quant_format', None), {}).get("quantize_input", True)
# Training path: quantized forward with compute_dtype backward via autograd function
if (input.requires_grad and _use_quantized and quantize_input):
if (input.requires_grad and _use_quantized):
weight, bias, offload_stream = cast_bias_weight(
self,
@ -1270,7 +1227,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
return output
# Inference path (unchanged)
if _use_quantized and quantize_input:
if _use_quantized:
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
@ -1284,13 +1241,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
weight_only_quant = _use_quantized and not quantize_input and isinstance(self.weight, QuantizedTensor)
output = self.forward_comfy_cast_weights(
input,
compute_dtype,
want_requant=isinstance(input, QuantizedTensor),
weight_only_quant=weight_only_quant,
)
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
# Reshape output back to 3D if input was 3D
if reshaped_3d:

View File

@ -10,7 +10,6 @@ try:
QuantizedLayout,
TensorCoreFP8Layout as _CKFp8Layout,
TensorCoreNVFP4Layout as _CKNvfp4Layout,
TensorWiseINT8Layout as _CKTensorWiseINT8Layout,
register_layout_op,
register_layout_class,
get_layout_class,
@ -48,9 +47,6 @@ except ImportError as e:
class _CKNvfp4Layout:
pass
class _CKTensorWiseINT8Layout:
pass
def register_layout_class(name, cls):
pass
@ -178,7 +174,6 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
# Backward compatibility alias - default to E4M3
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
TensorWiseINT8Layout = _CKTensorWiseINT8Layout
# ==============================================================================
@ -189,7 +184,6 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
register_layout_class("TensorWiseINT8Layout", _CKTensorWiseINT8Layout)
if _CK_MXFP8_AVAILABLE:
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
@ -220,13 +214,6 @@ if _CK_MXFP8_AVAILABLE:
"group_size": 32,
}
QUANT_ALGOS["int8_tensorwise"] = {
"storage_t": torch.int8,
"parameters": {"weight_scale"},
"comfy_tensor_layout": "TensorWiseINT8Layout",
"quantize_input": False,
}
# ==============================================================================
# Re-exports for backward compatibility
@ -239,7 +226,6 @@ __all__ = [
"TensorCoreFP8E4M3Layout",
"TensorCoreFP8E5M2Layout",
"TensorCoreNVFP4Layout",
"TensorWiseINT8Layout",
"QUANT_ALGOS",
"register_layout_op",
]

View File

@ -177,10 +177,6 @@ SEEDANCE2_PRICE_PER_1K_TOKENS = {
("dreamina-seedance-2-0-fast-260128", True, "480p"): 0.0033,
("dreamina-seedance-2-0-fast-260128", False, "720p"): 0.0056,
("dreamina-seedance-2-0-fast-260128", True, "720p"): 0.0033,
("dreamina-seedance-2-0-mini", False, "480p"): 0.0035,
("dreamina-seedance-2-0-mini", True, "480p"): 0.0021,
("dreamina-seedance-2-0-mini", False, "720p"): 0.0035,
("dreamina-seedance-2-0-mini", True, "720p"): 0.0021,
}
@ -282,10 +278,6 @@ SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
"480p": {"min": 409_600, "max": 927_408},
"720p": {"min": 409_600, "max": 927_408},
},
"dreamina-seedance-2-0-mini": {
"480p": {"min": 409_600, "max": 927_408},
"720p": {"min": 409_600, "max": 927_408},
},
}
# The time in this dictionary are given for 10 seconds duration.

View File

@ -89,7 +89,6 @@ BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT = "/proxy/byteplus-seedance2/api/v3/cont
SEEDANCE_MODELS = {
"Seedance 2.0": "dreamina-seedance-2-0-260128",
"Seedance 2.0 Fast": "dreamina-seedance-2-0-fast-260128",
"Seedance 2.0 Mini": "dreamina-seedance-2-0-mini",
}
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
@ -1624,10 +1623,8 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p", "4k"])),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
IO.DynamicCombo.Option("Seedance 2.0 Mini", _seedance2_text_inputs(["480p", "720p"])),
],
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
"Mini for the fastest, lowest-cost generation.",
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),
IO.Int.Input(
"seed",
@ -1669,7 +1666,6 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
$dur := $lookup(widgets, "model.duration");
$pricePer1K := $res = "4k" ? 0.00572 :
$res = "1080p" ? 0.011011 :
$contains($m, "mini") ? 0.005005 :
$contains($m, "fast") ? 0.008008 : 0.01001;
$rate := $res = "4k" ? $rate4k :
$res = "1080p" ? $rate1080 :
@ -1738,13 +1734,8 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
"Seedance 2.0 Fast",
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Mini",
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
),
],
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
"Mini for the fastest, lowest-cost generation.",
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),
IO.Image.Input(
"first_frame",
@ -1810,7 +1801,6 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
$dur := $lookup(widgets, "model.duration");
$pricePer1K := $res = "4k" ? 0.00572 :
$res = "1080p" ? 0.011011 :
$contains($m, "mini") ? 0.005005 :
$contains($m, "fast") ? 0.008008 : 0.01001;
$rate := $res = "4k" ? $rate4k :
$res = "1080p" ? $rate1080 :
@ -2034,13 +2024,8 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
"Seedance 2.0 Fast",
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Mini",
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
),
],
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
"Mini for the fastest, lowest-cost generation.",
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),
IO.Int.Input(
"seed",
@ -2086,11 +2071,9 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
$dur := $lookup(widgets, "model.duration");
$noVideoPricePer1K := $res = "4k" ? 0.00572 :
$res = "1080p" ? 0.011011 :
$contains($m, "mini") ? 0.005005 :
$contains($m, "fast") ? 0.008008 : 0.01001;
$videoPricePer1K := $res = "4k" ? 0.003432 :
$res = "1080p" ? 0.006721 :
$contains($m, "mini") ? 0.003003 :
$contains($m, "fast") ? 0.004719 : 0.006149;
$rate := $res = "4k" ? $rate4k :
$res = "1080p" ? $rate1080 :

View File

@ -3,7 +3,6 @@ from typing import Type, Literal
import nodes
import asyncio
import inspect
import traceback
from comfy_execution.graph_utils import is_link, ExecutionBlocker
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
@ -264,25 +263,7 @@ class ExecutionList(TopologicalSort):
}
return None, error_details, ex
try:
self.staged_node_id = self.ux_friendly_pick_node(available)
except Exception as ex:
# Backstop: the ordering heuristics in ux_friendly_pick_node are
# defensive, but should anything else there fail, surface it as an
# execution error instead of letting it kill the prompt worker
# thread. Blame an available node (best effort).
blamed_node = self.dynprompt.get_display_node_id(available[0])
exception_type = type(ex).__qualname__
if type(ex).__module__ != "builtins":
exception_type = type(ex).__module__ + "." + exception_type
error_details = {
"node_id": blamed_node,
"exception_message": str(ex),
"exception_type": exception_type,
"traceback": traceback.format_tb(ex.__traceback__),
"current_inputs": []
}
return None, error_details, ex
self.staged_node_id = self.ux_friendly_pick_node(available)
return self.staged_node_id, None, None
def ux_friendly_pick_node(self, node_list):
@ -290,28 +271,19 @@ class ExecutionList(TopologicalSort):
# Technically this has no effect on the overall length of execution, but it feels better as a user
# for a PreviewImage to display a result as soon as it can
# Some other heuristics could probably be used here to improve the UX further.
# These node-ordering heuristics only affect *order*, never correctness.
# A malformed node (e.g. a FUNCTION typo, or a node whose schema-derived
# attributes raise) must not crash scheduling: failing a heuristic just
# means "not prioritized". The node then proceeds to normal execution,
# where the real error is raised and reported against the correct node.
def is_output(node_id):
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
try:
return hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True
except Exception:
return False
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
return True
return False
# If an available node is async, do that first.
# This will execute the asynchronous function earlier, reducing the overall time.
def is_async(node_id):
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
try:
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
except Exception:
return False
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
for node_id in node_list:
if is_output(node_id) or is_async(node_id):

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.45.19
comfyui-workflow-templates==0.10.7
comfyui-workflow-templates==0.10.2
comfyui-embedded-docs==0.5.5
torch
torchsde
@ -22,7 +22,7 @@ alembic
SQLAlchemy>=2.0.0
filelock
av>=16.0.0
comfy-kitchen==0.2.12
comfy-kitchen==0.2.10
comfy-aimdo==0.4.10
requests
simpleeval>=1.0.0

View File

@ -228,62 +228,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
with self.assertRaises(KeyError):
model.load_state_dict(state_dict, strict=False)
def test_int8_convrot_metadata_loads_into_params(self):
"""ConvRot metadata must reach TensorWiseINT8Layout params."""
torch.manual_seed(123)
layer_quant_config = {
"layer": {
"format": "int8_tensorwise",
"convrot": True,
"convrot_groupsize": 256,
}
}
weight = torch.randn(16, 256, dtype=torch.bfloat16)
bias = torch.randn(16, dtype=torch.bfloat16)
q_weight = QuantizedTensor.from_float(
weight,
"TensorWiseINT8Layout",
per_channel=True,
convrot=True,
convrot_groupsize=256,
)
state_dict = {
"layer.weight": q_weight._qdata,
"layer.bias": bias,
"layer.weight_scale": q_weight._params.scale,
}
state_dict, _ = comfy.utils.convert_old_quants(
state_dict,
metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})},
)
model = torch.nn.Module()
model.layer = ops.mixed_precision_ops({}).Linear(256, 16, device="cpu", dtype=torch.bfloat16)
model.load_state_dict(state_dict, strict=False)
self.assertIsInstance(model.layer.weight, QuantizedTensor)
self.assertEqual(model.layer.weight._layout_cls, "TensorWiseINT8Layout")
self.assertTrue(model.layer.weight._params.convrot)
self.assertEqual(model.layer.weight._params.convrot_groupsize, 256)
input_tensor = torch.randn(4, 256, dtype=torch.bfloat16)
loaded_out = model.layer(input_tensor)
ref_out = torch.nn.functional.linear(input_tensor, q_weight, bias)
self.assertTrue(torch.equal(loaded_out, ref_out))
fp16_input = input_tensor.to(torch.float16)
loaded_fp16_out = model.layer(fp16_input)
ref_fp16_out = torch.nn.functional.linear(
fp16_input,
q_weight.to(dtype=torch.float16),
bias.to(dtype=torch.float16),
)
self.assertTrue(torch.equal(loaded_fp16_out, ref_fp16_out))
saved = model.state_dict()
saved_conf = json.loads(saved["layer.comfy_quant"].numpy().tobytes())
self.assertTrue(saved_conf["convrot"])
self.assertEqual(saved_conf["convrot_groupsize"], 256)
if __name__ == "__main__":
unittest.main()

View File

@ -1,97 +0,0 @@
"""Regression tests for scheduler resilience to malformed nodes.
A node whose FUNCTION points at a method that does not exist (e.g. a typo in a
custom node) used to raise inside the scheduling heuristic, escaping the prompt
worker's error handling and silently killing the worker thread. Scheduling must
instead either proceed (so the error surfaces through normal execution) or report
the failure as an execution error.
"""
import asyncio
import nodes
from comfy_execution.graph import DynamicPrompt, ExecutionList
class _MalformedV1Node:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "invert" # the actual method below is misspelled
OUTPUT_NODE = True
CATEGORY = "Test"
def invvert(self):
return (None,)
class _RaisingDescriptor:
def __get__(self, obj, owner):
raise RuntimeError("schema error")
class _SchemaRaisesNode:
"""A node whose schema-derived attribute access raises, as a broken V3 node would."""
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "run"
OUTPUT_NODE = _RaisingDescriptor()
CATEGORY = "Test"
def run(self):
return (None,)
class _FakeOutputCache:
def all_node_ids(self):
return set()
async def get(self, node_id):
return None
def _make_execution_list(class_type, class_def):
nodes.NODE_CLASS_MAPPINGS[class_type] = class_def
prompt = {"1": {"class_type": class_type, "inputs": {}}}
execution_list = ExecutionList(DynamicPrompt(prompt), _FakeOutputCache())
execution_list.add_node("1")
return execution_list
def test_malformed_function_does_not_crash_scheduler():
"""A FUNCTION-typo node schedules without raising; the error surfaces later."""
execution_list = _make_execution_list("MalformedV1Node", _MalformedV1Node)
node_id, error, ex = asyncio.run(execution_list.stage_node_execution())
assert ex is None
assert error is None
assert node_id == "1"
def test_schema_attribute_error_does_not_crash_scheduler():
"""A node whose attribute access raises during heuristics still schedules."""
execution_list = _make_execution_list("SchemaRaisesNode", _SchemaRaisesNode)
node_id, error, ex = asyncio.run(execution_list.stage_node_execution())
assert ex is None
assert error is None
assert node_id == "1"
def test_pick_node_failure_is_reported_not_raised():
"""An unexpected scheduling error is returned as an error, not raised."""
execution_list = _make_execution_list("MalformedV1Node", _MalformedV1Node)
def raise_on_pick(_available):
raise RuntimeError("boom")
execution_list.ux_friendly_pick_node = raise_on_pick
node_id, error, ex = asyncio.run(execution_list.stage_node_execution())
assert node_id is None
assert isinstance(ex, RuntimeError)
assert error["node_id"] == "1"
assert error["exception_type"] == "RuntimeError"
assert error["exception_message"] == "boom"
assert error["traceback"]