mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 16:06:53 +08:00
Compare commits
5 Commits
alexis/see
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 7cb784e0f4 | |||
| 1a510f0423 | |||
| 639c8fa788 | |||
| e22f1500f9 | |||
| dac4ea3a80 |
59
comfy/ops.py
59
comfy/ops.py
@ -1089,6 +1089,19 @@ def _load_quantized_module(module, super_load, state_dict, prefix, local_metadat
|
||||
if ts is None or bs is None:
|
||||
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||
scales = {"scale": ts, "block_scale": bs}
|
||||
elif module.quant_format == "int8_tensorwise":
|
||||
scale = pop_scale("weight_scale")
|
||||
if scale is None:
|
||||
raise ValueError(f"Missing INT8 weight scale for layer {layer_name}")
|
||||
scales = {"scale": scale}
|
||||
params_conf = layer_conf.get("params", {})
|
||||
if not isinstance(params_conf, dict):
|
||||
params_conf = {}
|
||||
if layer_conf.get("convrot", params_conf.get("convrot", False)):
|
||||
scales["convrot"] = True
|
||||
scales["convrot_groupsize"] = int(
|
||||
layer_conf.get("convrot_groupsize", params_conf.get("convrot_groupsize", 256))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
|
||||
|
||||
@ -1131,6 +1144,10 @@ def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extr
|
||||
quant_conf = {"format": module.quant_format}
|
||||
if getattr(module, '_full_precision_mm_config', False):
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
params = getattr(module.weight, "_params", None)
|
||||
if module.quant_format == "int8_tensorwise" and getattr(params, "convrot", False):
|
||||
quant_conf["convrot"] = True
|
||||
quant_conf["convrot_groupsize"] = getattr(params, "convrot_groupsize", 256)
|
||||
if extra_quant_conf:
|
||||
quant_conf.update(extra_quant_conf)
|
||||
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
|
||||
@ -1183,8 +1200,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
|
||||
def forward_comfy_cast_weights(
|
||||
self,
|
||||
input,
|
||||
compute_dtype=None,
|
||||
want_requant=False,
|
||||
weight_only_quant=False,
|
||||
):
|
||||
if weight_only_quant:
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
input=None,
|
||||
dtype=self.weight.dtype,
|
||||
device=input.device,
|
||||
bias_dtype=input.dtype,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=want_requant,
|
||||
)
|
||||
weight = weight.to(dtype=input.dtype)
|
||||
else:
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
input,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=want_requant,
|
||||
)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
@ -1203,9 +1245,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0
|
||||
)
|
||||
quantize_input = QUANT_ALGOS.get(getattr(self, 'quant_format', None), {}).get("quantize_input", True)
|
||||
|
||||
# Training path: quantized forward with compute_dtype backward via autograd function
|
||||
if (input.requires_grad and _use_quantized):
|
||||
if (input.requires_grad and _use_quantized and quantize_input):
|
||||
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
@ -1227,7 +1270,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
return output
|
||||
|
||||
# Inference path (unchanged)
|
||||
if _use_quantized:
|
||||
if _use_quantized and quantize_input:
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
@ -1241,7 +1284,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
||||
weight_only_quant = _use_quantized and not quantize_input and isinstance(self.weight, QuantizedTensor)
|
||||
output = self.forward_comfy_cast_weights(
|
||||
input,
|
||||
compute_dtype,
|
||||
want_requant=isinstance(input, QuantizedTensor),
|
||||
weight_only_quant=weight_only_quant,
|
||||
)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if reshaped_3d:
|
||||
|
||||
@ -10,6 +10,7 @@ try:
|
||||
QuantizedLayout,
|
||||
TensorCoreFP8Layout as _CKFp8Layout,
|
||||
TensorCoreNVFP4Layout as _CKNvfp4Layout,
|
||||
TensorWiseINT8Layout as _CKTensorWiseINT8Layout,
|
||||
register_layout_op,
|
||||
register_layout_class,
|
||||
get_layout_class,
|
||||
@ -47,6 +48,9 @@ except ImportError as e:
|
||||
class _CKNvfp4Layout:
|
||||
pass
|
||||
|
||||
class _CKTensorWiseINT8Layout:
|
||||
pass
|
||||
|
||||
def register_layout_class(name, cls):
|
||||
pass
|
||||
|
||||
@ -174,6 +178,7 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
|
||||
|
||||
# Backward compatibility alias - default to E4M3
|
||||
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
||||
TensorWiseINT8Layout = _CKTensorWiseINT8Layout
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
@ -184,6 +189,7 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||
register_layout_class("TensorWiseINT8Layout", _CKTensorWiseINT8Layout)
|
||||
if _CK_MXFP8_AVAILABLE:
|
||||
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||
|
||||
@ -214,6 +220,13 @@ if _CK_MXFP8_AVAILABLE:
|
||||
"group_size": 32,
|
||||
}
|
||||
|
||||
QUANT_ALGOS["int8_tensorwise"] = {
|
||||
"storage_t": torch.int8,
|
||||
"parameters": {"weight_scale"},
|
||||
"comfy_tensor_layout": "TensorWiseINT8Layout",
|
||||
"quantize_input": False,
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Re-exports for backward compatibility
|
||||
@ -226,6 +239,7 @@ __all__ = [
|
||||
"TensorCoreFP8E4M3Layout",
|
||||
"TensorCoreFP8E5M2Layout",
|
||||
"TensorCoreNVFP4Layout",
|
||||
"TensorWiseINT8Layout",
|
||||
"QUANT_ALGOS",
|
||||
"register_layout_op",
|
||||
]
|
||||
|
||||
@ -272,14 +272,13 @@ class Int(ComfyTypeIO):
|
||||
'''Integer input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None,
|
||||
display_mode: NumberDisplay=None, component: str=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.step = step
|
||||
self.control_after_generate = control_after_generate
|
||||
self.display_mode = display_mode
|
||||
self.component = component
|
||||
self.default: int
|
||||
|
||||
def as_dict(self):
|
||||
@ -289,7 +288,6 @@ class Int(ComfyTypeIO):
|
||||
"step": self.step,
|
||||
"control_after_generate": self.control_after_generate,
|
||||
"display": self.display_mode.value if self.display_mode else None,
|
||||
"component": self.component,
|
||||
})
|
||||
|
||||
@comfytype(io_type="FLOAT")
|
||||
@ -893,6 +891,14 @@ class Tracks(ComfyTypeIO):
|
||||
track_visibility: torch.Tensor
|
||||
Type = TrackDict
|
||||
|
||||
@comfytype(io_type="DICT")
|
||||
class Dict(ComfyTypeIO):
|
||||
Type = dict
|
||||
|
||||
@comfytype(io_type="ARRAY")
|
||||
class Array(ComfyTypeIO):
|
||||
Type = list
|
||||
|
||||
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
||||
class MultiType:
|
||||
Type = Any
|
||||
@ -1281,6 +1287,19 @@ class Color(ComfyTypeIO):
|
||||
def as_dict(self):
|
||||
return super().as_dict()
|
||||
|
||||
|
||||
@comfytype(io_type="COLORS")
|
||||
class Colors(ComfyTypeIO):
|
||||
Type = list[Color.Type]
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: list[str]=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = []
|
||||
|
||||
|
||||
@comfytype(io_type="BOUNDING_BOX")
|
||||
class BoundingBox(ComfyTypeIO):
|
||||
class BoundingBoxDict(TypedDict):
|
||||
@ -1328,6 +1347,20 @@ class Curve(ComfyTypeIO):
|
||||
return d
|
||||
|
||||
|
||||
@comfytype(io_type="BOUNDING_BOXES")
|
||||
class BoundingBoxes(ComfyTypeIO):
|
||||
class BoundingBoxWithMetadata(BoundingBox.BoundingBoxDict):
|
||||
metadata: dict
|
||||
Type = list[BoundingBoxWithMetadata]
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: list[dict]=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = []
|
||||
|
||||
|
||||
@comfytype(io_type="HISTOGRAM")
|
||||
class Histogram(ComfyTypeIO):
|
||||
"""A histogram represented as a list of bin counts."""
|
||||
@ -2378,6 +2411,8 @@ __all__ = [
|
||||
"AnyType",
|
||||
"MultiType",
|
||||
"Tracks",
|
||||
"Dict",
|
||||
"Array",
|
||||
"Color",
|
||||
# Dynamic Types
|
||||
"MatchType",
|
||||
@ -2396,6 +2431,8 @@ __all__ = [
|
||||
"PriceBadgeDepends",
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"BoundingBoxes",
|
||||
"Colors",
|
||||
"Curve",
|
||||
"Histogram",
|
||||
"Range",
|
||||
|
||||
@ -177,6 +177,10 @@ SEEDANCE2_PRICE_PER_1K_TOKENS = {
|
||||
("dreamina-seedance-2-0-fast-260128", True, "480p"): 0.0033,
|
||||
("dreamina-seedance-2-0-fast-260128", False, "720p"): 0.0056,
|
||||
("dreamina-seedance-2-0-fast-260128", True, "720p"): 0.0033,
|
||||
("dreamina-seedance-2-0-mini", False, "480p"): 0.0035,
|
||||
("dreamina-seedance-2-0-mini", True, "480p"): 0.0021,
|
||||
("dreamina-seedance-2-0-mini", False, "720p"): 0.0035,
|
||||
("dreamina-seedance-2-0-mini", True, "720p"): 0.0021,
|
||||
}
|
||||
|
||||
|
||||
@ -278,6 +282,10 @@ SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
|
||||
"480p": {"min": 409_600, "max": 927_408},
|
||||
"720p": {"min": 409_600, "max": 927_408},
|
||||
},
|
||||
"dreamina-seedance-2-0-mini": {
|
||||
"480p": {"min": 409_600, "max": 927_408},
|
||||
"720p": {"min": 409_600, "max": 927_408},
|
||||
},
|
||||
}
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
|
||||
@ -89,6 +89,7 @@ BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT = "/proxy/byteplus-seedance2/api/v3/cont
|
||||
SEEDANCE_MODELS = {
|
||||
"Seedance 2.0": "dreamina-seedance-2-0-260128",
|
||||
"Seedance 2.0 Fast": "dreamina-seedance-2-0-fast-260128",
|
||||
"Seedance 2.0 Mini": "dreamina-seedance-2-0-mini",
|
||||
}
|
||||
|
||||
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
|
||||
@ -1623,8 +1624,10 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p", "4k"])),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Mini", _seedance2_text_inputs(["480p", "720p"])),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
|
||||
"Mini for the fastest, lowest-cost generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
@ -1666,6 +1669,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$pricePer1K := $res = "4k" ? 0.00572 :
|
||||
$res = "1080p" ? 0.011011 :
|
||||
$contains($m, "mini") ? 0.005005 :
|
||||
$contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$rate := $res = "4k" ? $rate4k :
|
||||
$res = "1080p" ? $rate1080 :
|
||||
@ -1734,8 +1738,13 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
"Seedance 2.0 Fast",
|
||||
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0 Mini",
|
||||
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
|
||||
"Mini for the fastest, lowest-cost generation.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"first_frame",
|
||||
@ -1801,6 +1810,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$pricePer1K := $res = "4k" ? 0.00572 :
|
||||
$res = "1080p" ? 0.011011 :
|
||||
$contains($m, "mini") ? 0.005005 :
|
||||
$contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$rate := $res = "4k" ? $rate4k :
|
||||
$res = "1080p" ? $rate1080 :
|
||||
@ -2024,8 +2034,13 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
"Seedance 2.0 Fast",
|
||||
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0 Mini",
|
||||
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
|
||||
"Mini for the fastest, lowest-cost generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
@ -2071,9 +2086,11 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$noVideoPricePer1K := $res = "4k" ? 0.00572 :
|
||||
$res = "1080p" ? 0.011011 :
|
||||
$contains($m, "mini") ? 0.005005 :
|
||||
$contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$videoPricePer1K := $res = "4k" ? 0.003432 :
|
||||
$res = "1080p" ? 0.006721 :
|
||||
$contains($m, "mini") ? 0.003003 :
|
||||
$contains($m, "fast") ? 0.004719 : 0.006149;
|
||||
$rate := $res = "4k" ? $rate4k :
|
||||
$res = "1080p" ? $rate1080 :
|
||||
|
||||
23
comfy_extras/color_util.py
Normal file
23
comfy_extras/color_util.py
Normal file
@ -0,0 +1,23 @@
|
||||
def hex_to_rgb(value: str) -> tuple[int, int, int]:
|
||||
h = value.lstrip("#")
|
||||
if len(h) != 6:
|
||||
return (255, 255, 255)
|
||||
try:
|
||||
return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16))
|
||||
except ValueError:
|
||||
return (255, 255, 255)
|
||||
|
||||
|
||||
def readable_color(rgb: tuple[int, int, int]) -> tuple[int, int, int]:
|
||||
r, g, b = rgb
|
||||
lum = 0.299 * r + 0.587 * g + 0.114 * b
|
||||
if lum >= 130:
|
||||
return (r, g, b)
|
||||
t = (130 - lum) / (255 - lum)
|
||||
return (round(r + (255 - r) * t), round(g + (255 - g) * t), round(b + (255 - b) * t))
|
||||
|
||||
|
||||
def normalize_palette(colors) -> list[str]:
|
||||
if isinstance(colors, dict):
|
||||
colors = colors.values()
|
||||
return [c.upper() for c in colors if isinstance(c, str) and c]
|
||||
253
comfy_extras/nodes_bounding_boxes.py
Normal file
253
comfy_extras/nodes_bounding_boxes.py
Normal file
@ -0,0 +1,253 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageEnhance, ImageFont
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_extras.color_util import hex_to_rgb, normalize_palette, readable_color
|
||||
|
||||
_PREVIEW_LONG_EDGE = 1024
|
||||
_PREVIEW_DIM = 0.25
|
||||
|
||||
|
||||
def pixels_to_fractions(box: dict, width: int, height: int) -> dict:
|
||||
w = width or 1
|
||||
h = height or 1
|
||||
return {
|
||||
"x": box.get("x", 0) / w,
|
||||
"y": box.get("y", 0) / h,
|
||||
"w": box.get("width", 0) / w,
|
||||
"h": box.get("height", 0) / h,
|
||||
}
|
||||
|
||||
|
||||
def fractions_to_pixels(box: dict, width: int, height: int) -> dict:
|
||||
x, y = box.get("x", 0.0), box.get("y", 0.0)
|
||||
w, h = box.get("w", 0.0), box.get("h", 0.0)
|
||||
if w < 0:
|
||||
x, w = x + w, -w
|
||||
if h < 0:
|
||||
y, h = y + h, -h
|
||||
return {
|
||||
"x": round(x * width),
|
||||
"y": round(y * height),
|
||||
"width": round(w * width),
|
||||
"height": round(h * height),
|
||||
}
|
||||
|
||||
|
||||
def fractions_to_bbox_frame(boxes: list, width: int, height: int) -> list:
|
||||
pixels = [
|
||||
fractions_to_pixels(box, width, height)
|
||||
for box in boxes
|
||||
if isinstance(box, dict)
|
||||
]
|
||||
return [pixels] if pixels else []
|
||||
|
||||
|
||||
def _font(size: int):
|
||||
try:
|
||||
return ImageFont.load_default(size)
|
||||
except Exception:
|
||||
return ImageFont.load_default()
|
||||
|
||||
|
||||
def _wrap(draw, text: str, font, max_w: float) -> list[str]:
|
||||
lines = []
|
||||
for para in text.split("\n"):
|
||||
line = ""
|
||||
for word in para.split():
|
||||
test = word if not line else line + " " + word
|
||||
if line and draw.textlength(test, font=font) > max_w:
|
||||
lines.append(line)
|
||||
line = word
|
||||
else:
|
||||
line = test
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
|
||||
def _bg_from_image(image) -> Image.Image | None:
|
||||
if image is None:
|
||||
return None
|
||||
try:
|
||||
arr = (image[0].detach().cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
return Image.fromarray(arr)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def render_preview(regions, width, height, bg=None):
|
||||
if bg is not None:
|
||||
iw, ih = bg.size
|
||||
long_edge = max(iw, ih) or 1
|
||||
scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge)
|
||||
rw, rh = max(1, round(iw * scale)), max(1, round(ih * scale))
|
||||
base = bg.convert("RGB").resize((rw, rh), Image.LANCZOS)
|
||||
base = ImageEnhance.Brightness(base).enhance(_PREVIEW_DIM)
|
||||
img = base.convert("RGBA")
|
||||
else:
|
||||
long_edge = max(width, height) or 1
|
||||
scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge)
|
||||
rw, rh = max(1, round(width * scale)), max(1, round(height * scale))
|
||||
grey = round(_PREVIEW_DIM * 128)
|
||||
img = Image.new("RGBA", (rw, rh), (grey, grey, grey, 255))
|
||||
|
||||
overlay = Image.new("RGBA", (rw, rh), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(overlay)
|
||||
fs = max(10, round(rh / 64))
|
||||
font = _font(fs)
|
||||
tag_font = _font(max(9, fs - 2))
|
||||
line_h = fs + 2
|
||||
|
||||
for i, region in enumerate(regions):
|
||||
if not isinstance(region, dict):
|
||||
continue
|
||||
palette = [c for c in (region.get("palette") or []) if c]
|
||||
r, g, b = hex_to_rgb(palette[0]) if palette else (140, 140, 140)
|
||||
x1 = max(0, min(rw, round(region.get("x", 0) * rw)))
|
||||
y1 = max(0, min(rh, round(region.get("y", 0) * rh)))
|
||||
x2 = max(0, min(rw, round((region.get("x", 0) + region.get("w", 0)) * rw)))
|
||||
y2 = max(0, min(rh, round((region.get("y", 0) + region.get("h", 0)) * rh)))
|
||||
if x2 < x1:
|
||||
x1, x2 = x2, x1
|
||||
if y2 < y1:
|
||||
y1, y2 = y2, y1
|
||||
|
||||
draw.rectangle([x1, y1, x2, y2], outline=(r, g, b, 255), width=2)
|
||||
|
||||
swatches = palette[:5]
|
||||
if swatches and (x2 - x1) > 2:
|
||||
sh = max(5, fs // 2)
|
||||
seg = (x2 - x1) / len(swatches)
|
||||
for p, hexc in enumerate(swatches):
|
||||
sx = x1 + round(p * seg)
|
||||
draw.rectangle([sx, y1, x1 + round((p + 1) * seg), y1 + sh], fill=hex_to_rgb(hexc))
|
||||
|
||||
etype = "text" if region.get("type") == "text" else "obj"
|
||||
tag = str(i + 1).zfill(2)
|
||||
tw = draw.textlength(tag, font=tag_font)
|
||||
draw.rectangle([x1, y1, x1 + tw + 6, y1 + fs + 2], fill=(r, g, b, 255))
|
||||
tag_fill = (0, 0, 0, 255) if (0.299 * r + 0.587 * g + 0.114 * b) > 140 else (255, 255, 255, 255)
|
||||
draw.text((x1 + 3, y1 + 1), tag, fill=tag_fill, font=tag_font)
|
||||
|
||||
body = region.get("desc", "") or ""
|
||||
if etype == "text" and region.get("text"):
|
||||
body = '"%s"%s' % (region["text"], " — " + body if body else "")
|
||||
if body and (x2 - x1) > 8:
|
||||
ty = y1 + fs + 5
|
||||
for line in _wrap(draw, body, font, x2 - x1 - 8):
|
||||
if ty > y2:
|
||||
break
|
||||
draw.text((x1 + 4, ty), line, fill=readable_color((r, g, b)) + (255,), font=font)
|
||||
ty += line_h
|
||||
|
||||
composed = Image.alpha_composite(img, overlay).convert("RGB")
|
||||
arr = np.asarray(composed, dtype=np.float32) / 255.0
|
||||
return torch.from_numpy(arr).unsqueeze(0)
|
||||
|
||||
|
||||
def boxes_to_regions(boxes, width: int, height: int) -> list:
|
||||
regions: list = []
|
||||
if not isinstance(boxes, list):
|
||||
return regions
|
||||
for box in boxes:
|
||||
if not isinstance(box, dict):
|
||||
continue
|
||||
meta = box.get("metadata")
|
||||
meta = meta if isinstance(meta, dict) else {}
|
||||
regions.append({
|
||||
**pixels_to_fractions(box, width, height),
|
||||
"type": meta.get("type", "obj"),
|
||||
"text": meta.get("text", ""),
|
||||
"desc": meta.get("desc", ""),
|
||||
"palette": meta.get("palette", []),
|
||||
})
|
||||
return regions
|
||||
|
||||
|
||||
def _norm_bbox(region: dict) -> list[int]:
|
||||
def grid(value: float) -> int:
|
||||
return max(0, min(1000, round(value * 1000)))
|
||||
|
||||
x, y = region.get("x", 0.0), region.get("y", 0.0)
|
||||
w, h = region.get("w", 0.0), region.get("h", 0.0)
|
||||
ymin, xmin, ymax, xmax = grid(y), grid(x), grid(y + h), grid(x + w)
|
||||
if ymin > ymax:
|
||||
ymin, ymax = ymax, ymin
|
||||
if xmin > xmax:
|
||||
xmin, xmax = xmax, xmin
|
||||
return [ymin, xmin, ymax, xmax]
|
||||
|
||||
|
||||
def build_elements(regions: list) -> list:
|
||||
elements = []
|
||||
for region in regions:
|
||||
if not isinstance(region, dict):
|
||||
continue
|
||||
etype = "text" if region.get("type") == "text" else "obj"
|
||||
element = {"type": etype}
|
||||
element["bbox"] = _norm_bbox(region)
|
||||
if etype == "text":
|
||||
element["text"] = region.get("text", "")
|
||||
element["desc"] = region.get("desc", "")
|
||||
palette = normalize_palette(region.get("palette", []))
|
||||
if palette:
|
||||
element["color_palette"] = palette[:5]
|
||||
elements.append(element)
|
||||
return elements
|
||||
|
||||
|
||||
class CreateBoundingBoxes(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
editor_state = io.BoundingBoxes.Input(
|
||||
"editor_state",
|
||||
socketless=False,
|
||||
tooltip="Draw bounding boxes and set each box type, text, description, color palette. Start with background element first and foreground last.",
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="CreateBoundingBoxes",
|
||||
display_name="Create Bounding Boxes",
|
||||
category="utilities",
|
||||
description="Draw bounding boxes in a canvas. Outputs Ideogram prompt elements, pixel-space bounding boxes, and a preview image.",
|
||||
inputs=[
|
||||
io.Image.Input(
|
||||
"background",
|
||||
optional=True,
|
||||
tooltip="Optional image used as background in the canvas and preview.",
|
||||
),
|
||||
io.Int.Input("width", default=1024, min=64, max=16384, step=16,
|
||||
tooltip="Width of the canvas and the pixel grid for the bounding boxes."),
|
||||
io.Int.Input("height", default=1024, min=64, max=16384, step=16,
|
||||
tooltip="Height of the canvas and the pixel grid for the bounding boxes."),
|
||||
editor_state,
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="preview"),
|
||||
io.BoundingBox.Output(display_name="bboxes"),
|
||||
io.Array.Output(display_name="elements"),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, editor_state=None, background=None) -> io.NodeOutput:
|
||||
regions = boxes_to_regions(editor_state, width, height)
|
||||
preview = render_preview(regions, width, height, _bg_from_image(background))
|
||||
return io.NodeOutput(
|
||||
preview,
|
||||
fractions_to_bbox_frame(regions, width, height),
|
||||
build_elements(regions),
|
||||
ui={"dims": [width, height]},
|
||||
)
|
||||
|
||||
|
||||
class BoundingBoxesExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [CreateBoundingBoxes]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> BoundingBoxesExtension:
|
||||
return BoundingBoxesExtension()
|
||||
@ -1,5 +1,6 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_extras.color_util import hex_to_rgb
|
||||
|
||||
|
||||
class ColorToRGBInt(io.ComfyNode):
|
||||
@ -24,9 +25,11 @@ class ColorToRGBInt(io.ComfyNode):
|
||||
# expect format #RRGGBB
|
||||
if len(color) != 7 or color[0] != "#":
|
||||
raise ValueError("Color must be in format #RRGGBB")
|
||||
r = int(color[1:3], 16)
|
||||
g = int(color[3:5], 16)
|
||||
b = int(color[5:7], 16)
|
||||
try:
|
||||
int(color[1:], 16)
|
||||
except ValueError:
|
||||
raise ValueError("Color must be in format #RRGGBB") from None
|
||||
r, g, b = hex_to_rgb(color)
|
||||
|
||||
rgb_int = r * 256 * 256 + g * 256 + b
|
||||
return io.NodeOutput(rgb_int, color)
|
||||
|
||||
77
comfy_extras/nodes_json_prompt.py
Normal file
77
comfy_extras/nodes_json_prompt.py
Normal file
@ -0,0 +1,77 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_extras.color_util import normalize_palette
|
||||
|
||||
|
||||
class BuildJsonPromptIdeogram(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
color_palette = io.Colors.Input(
|
||||
"color_palette",
|
||||
socketless=False,
|
||||
tooltip="Hex color codes that steer the image's dominant colors. Up to 16 entries.",
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="BuildJsonPromptIdeogram",
|
||||
display_name="Build JSON Prompt (Ideogram)",
|
||||
category="text",
|
||||
description="Build a JSON prompt for the Ideogram 4 model.",
|
||||
inputs=[
|
||||
io.Array.Input("element", tooltip="Prompt elements from the node Create Bounding Boxes."),
|
||||
io.String.Input("high_level_description", multiline=True, default="",
|
||||
tooltip="Optional description of the image in one or two sentences. Strongly recommended."),
|
||||
io.String.Input("background", multiline=True, default="",
|
||||
tooltip="Mandatory description of the image background or environment."),
|
||||
io.DynamicCombo.Input("style", options=[
|
||||
io.DynamicCombo.Option("none", []),
|
||||
io.DynamicCombo.Option("photo", [io.String.Input("photo", default="", tooltip="Camera or lens details for photographic outputs (e.g. 35mm, f/1.4, bokeh).")]),
|
||||
io.DynamicCombo.Option("art_style", [io.String.Input("art_style", default="", tooltip="Art style description (e.g. flat vector illustration, bold outlines).")]),
|
||||
]),
|
||||
io.String.Input("aesthetics", default="", tooltip="Mandatory aesthetic keywords (e.g. moody, cinematic, desaturated)."),
|
||||
io.String.Input("lighting", default="", tooltip="Mandatory lighting description (e.g. golden hour, rim light, dramatic shadows)."),
|
||||
io.String.Input("medium", default="", tooltip="Mandatory medium type (e.g. photograph, illustration, 3d_render, painting, graphic_design). When style = photo, set to photograph."),
|
||||
color_palette,
|
||||
],
|
||||
outputs=[io.Dict.Output(display_name="prompt")],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, element, style, high_level_description="", background="",
|
||||
aesthetics="", lighting="", medium="", color_palette=None) -> io.NodeOutput:
|
||||
elements = element if isinstance(element, list) else []
|
||||
kind = style.get("style", "none") if isinstance(style, dict) else "none"
|
||||
photo = style.get("photo", "") if isinstance(style, dict) else ""
|
||||
art_style = style.get("art_style", "") if isinstance(style, dict) else ""
|
||||
palette = normalize_palette(color_palette or [])
|
||||
|
||||
caption: dict = {}
|
||||
if high_level_description.strip():
|
||||
caption["high_level_description"] = high_level_description
|
||||
if kind != "none":
|
||||
style_desc: dict = {"aesthetics": aesthetics, "lighting": lighting}
|
||||
if kind == "photo":
|
||||
style_desc["photo"] = photo
|
||||
style_desc["medium"] = medium
|
||||
else:
|
||||
style_desc["medium"] = medium
|
||||
style_desc["art_style"] = art_style
|
||||
if palette:
|
||||
style_desc["color_palette"] = palette
|
||||
caption["style_description"] = style_desc
|
||||
caption["compositional_deconstruction"] = {
|
||||
"background": background,
|
||||
"elements": elements,
|
||||
}
|
||||
return io.NodeOutput(caption)
|
||||
|
||||
|
||||
class JsonPromptExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [BuildJsonPromptIdeogram]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> JsonPromptExtension:
|
||||
return JsonPromptExtension()
|
||||
@ -13,7 +13,7 @@ class SeedNode(io.ComfyNode):
|
||||
search_aliases=["seed", "random"],
|
||||
category="utilities",
|
||||
inputs=[
|
||||
io.Int.Input("seed", min=0, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed, component="SetRandomInt"),
|
||||
io.Int.Input("seed", min=0, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed),
|
||||
],
|
||||
outputs=[io.Int.Output(display_name="seed")],
|
||||
)
|
||||
|
||||
@ -440,6 +440,57 @@ class JsonExtractString(io.ComfyNode):
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return io.NodeOutput("")
|
||||
|
||||
|
||||
def _dump_json(value, indent):
|
||||
return json.dumps(value, ensure_ascii=False, indent=indent or None)
|
||||
|
||||
|
||||
class ConvertDictionaryToString(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ConvertDictionaryToString",
|
||||
display_name="Convert Dictionary to String",
|
||||
category="text",
|
||||
search_aliases=["json", "dict to json", "stringify", "serialize", "dict to string"],
|
||||
inputs=[
|
||||
io.Dict.Input("dictionary"),
|
||||
io.Int.Input("indent", default=2, min=0, max=8,
|
||||
tooltip="Spaces per indent level. 0 produces compact single-line string."),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, dictionary, indent=2):
|
||||
return io.NodeOutput(_dump_json(dictionary, indent))
|
||||
|
||||
|
||||
class ConvertArrayToString(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ConvertArrayToString",
|
||||
display_name="Convert Array to String",
|
||||
category="text",
|
||||
search_aliases=["json", "list to json", "stringify", "serialize", "list to string", "array to json"],
|
||||
inputs=[
|
||||
io.Array.Input("array"),
|
||||
io.Int.Input("indent", default=2, min=0, max=8,
|
||||
tooltip="Spaces per indent level. 0 produces compact single-line string."),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, array, indent=2):
|
||||
return io.NodeOutput(_dump_json(array, indent))
|
||||
|
||||
|
||||
class StringExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@ -457,6 +508,8 @@ class StringExtension(ComfyExtension):
|
||||
RegexExtract,
|
||||
RegexReplace,
|
||||
JsonExtractString,
|
||||
ConvertDictionaryToString,
|
||||
ConvertArrayToString,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> StringExtension:
|
||||
|
||||
2
nodes.py
2
nodes.py
@ -2374,6 +2374,8 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_images.py",
|
||||
"nodes_video_model.py",
|
||||
"nodes_ideogram4.py",
|
||||
"nodes_bounding_boxes.py",
|
||||
"nodes_json_prompt.py",
|
||||
"nodes_train.py",
|
||||
"nodes_dataset.py",
|
||||
"nodes_sag.py",
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.45.19
|
||||
comfyui-workflow-templates==0.10.2
|
||||
comfyui-workflow-templates==0.10.7
|
||||
comfyui-embedded-docs==0.5.5
|
||||
torch
|
||||
torchsde
|
||||
@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=16.0.0
|
||||
comfy-kitchen==0.2.10
|
||||
comfy-kitchen==0.2.12
|
||||
comfy-aimdo==0.4.10
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
|
||||
@ -228,6 +228,62 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
with self.assertRaises(KeyError):
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
def test_int8_convrot_metadata_loads_into_params(self):
|
||||
"""ConvRot metadata must reach TensorWiseINT8Layout params."""
|
||||
torch.manual_seed(123)
|
||||
layer_quant_config = {
|
||||
"layer": {
|
||||
"format": "int8_tensorwise",
|
||||
"convrot": True,
|
||||
"convrot_groupsize": 256,
|
||||
}
|
||||
}
|
||||
weight = torch.randn(16, 256, dtype=torch.bfloat16)
|
||||
bias = torch.randn(16, dtype=torch.bfloat16)
|
||||
q_weight = QuantizedTensor.from_float(
|
||||
weight,
|
||||
"TensorWiseINT8Layout",
|
||||
per_channel=True,
|
||||
convrot=True,
|
||||
convrot_groupsize=256,
|
||||
)
|
||||
state_dict = {
|
||||
"layer.weight": q_weight._qdata,
|
||||
"layer.bias": bias,
|
||||
"layer.weight_scale": q_weight._params.scale,
|
||||
}
|
||||
|
||||
state_dict, _ = comfy.utils.convert_old_quants(
|
||||
state_dict,
|
||||
metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})},
|
||||
)
|
||||
model = torch.nn.Module()
|
||||
model.layer = ops.mixed_precision_ops({}).Linear(256, 16, device="cpu", dtype=torch.bfloat16)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
self.assertIsInstance(model.layer.weight, QuantizedTensor)
|
||||
self.assertEqual(model.layer.weight._layout_cls, "TensorWiseINT8Layout")
|
||||
self.assertTrue(model.layer.weight._params.convrot)
|
||||
self.assertEqual(model.layer.weight._params.convrot_groupsize, 256)
|
||||
|
||||
input_tensor = torch.randn(4, 256, dtype=torch.bfloat16)
|
||||
loaded_out = model.layer(input_tensor)
|
||||
ref_out = torch.nn.functional.linear(input_tensor, q_weight, bias)
|
||||
self.assertTrue(torch.equal(loaded_out, ref_out))
|
||||
|
||||
fp16_input = input_tensor.to(torch.float16)
|
||||
loaded_fp16_out = model.layer(fp16_input)
|
||||
ref_fp16_out = torch.nn.functional.linear(
|
||||
fp16_input,
|
||||
q_weight.to(dtype=torch.float16),
|
||||
bias.to(dtype=torch.float16),
|
||||
)
|
||||
self.assertTrue(torch.equal(loaded_fp16_out, ref_fp16_out))
|
||||
|
||||
saved = model.state_dict()
|
||||
saved_conf = json.loads(saved["layer.comfy_quant"].numpy().tobytes())
|
||||
self.assertTrue(saved_conf["convrot"])
|
||||
self.assertEqual(saved_conf["convrot_groupsize"], 256)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user