mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-29 17:27:56 +08:00
Compare commits
21 Commits
v0.26.2
...
add-cla-wo
| Author | SHA1 | Date | |
|---|---|---|---|
| 585267fa59 | |||
| 7c18d53e2c | |||
| 79c555ce6b | |||
| f19735759e | |||
| a95e461916 | |||
| 603d891eaf | |||
| 470ac36a0a | |||
| 7cb784e0f4 | |||
| 1a510f0423 | |||
| 639c8fa788 | |||
| e22f1500f9 | |||
| dac4ea3a80 | |||
| b0ec19804f | |||
| 64e1d740b8 | |||
| b22d0fb9c0 | |||
| 5236cd02e6 | |||
| cabb7342d1 | |||
| 12218db68a | |||
| 44955d783b | |||
| 1f275fcba6 | |||
| c06a3f060b |
38
.github/workflows/ci-cursor-review.yml
vendored
Normal file
38
.github/workflows/ci-cursor-review.yml
vendored
Normal file
@ -0,0 +1,38 @@
|
||||
name: CI - Cursor Review
|
||||
|
||||
# Thin caller for the shared reusable cursor-review workflow in
|
||||
# Comfy-Org/github-workflows. The review logic (panel matrix, judge
|
||||
# consolidation, prompts, extract/post/notify scripts) lives there as the
|
||||
# single source of truth, so this repo only carries the repo-specific diff
|
||||
# excludes.
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [labeled, unlabeled]
|
||||
|
||||
concurrency:
|
||||
group: cursor-review-pr-${{ github.event.pull_request.number }}-${{ github.event.label.name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
cursor-review:
|
||||
if: github.event.label.name == 'cursor-review'
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
# SHA-pinned per zizmor `unpinned-uses: hash-pin`. Bump this SHA to pick up
|
||||
# upstream changes; keep `workflows_ref` matching so prompts/scripts load
|
||||
# from the same commit as the workflow definition.
|
||||
uses: Comfy-Org/github-workflows/.github/workflows/cursor-review.yml@047ca48febe3a6647608ed2e0c4331b491cb9d6a # github-workflows#9
|
||||
with:
|
||||
workflows_ref: 047ca48febe3a6647608ed2e0c4331b491cb9d6a
|
||||
diff_excludes: >-
|
||||
:!**/.claude/**
|
||||
:!**/dist/**
|
||||
:!**/vendor/**
|
||||
:!**/*.generated.*
|
||||
:!**/*.min.js
|
||||
:!**/*.min.css
|
||||
secrets:
|
||||
CURSOR_API_KEY: ${{ secrets.CURSOR_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
63
.github/workflows/cla.yml
vendored
Normal file
63
.github/workflows/cla.yml
vendored
Normal file
@ -0,0 +1,63 @@
|
||||
name: CLA Assistant
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
pull_request_target:
|
||||
types: [opened, synchronize, closed]
|
||||
merge_group:
|
||||
|
||||
permissions:
|
||||
actions: write
|
||||
contents: read # 'read' is enough because signatures live in a REMOTE repo
|
||||
pull-requests: write
|
||||
statuses: write
|
||||
|
||||
jobs:
|
||||
cla-assistant:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: CLA Assistant
|
||||
# Run on PR events, on "recheck" comment, or when someone posts the exact signing phrase.
|
||||
# IMPORTANT: this phrase must match `custom-pr-sign-comment` below.
|
||||
if: >
|
||||
github.event_name == 'pull_request_target' ||
|
||||
github.event.comment.body == 'recheck' ||
|
||||
github.event.comment.body == 'I have read and agree to the Contributor License Agreement'
|
||||
uses: contributor-assistant/github-action@ca4a40a7d1004f18d9960b404b97e5f30a505a08 # v2.6.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# PAT required to write to the centralized signatures repo.
|
||||
PERSONAL_ACCESS_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }}
|
||||
with:
|
||||
# Where the CLA document lives (shown to contributors)
|
||||
path-to-document: https://github.com/Comfy-Org/comfy-cla/blob/main/comfyui_icla.md
|
||||
|
||||
# Centralized signature storage
|
||||
remote-organization-name: comfy-org
|
||||
remote-repository-name: comfy-cla
|
||||
path-to-signatures: signatures/cla.json
|
||||
branch: main
|
||||
|
||||
# Allowlist bots so they don't need to sign (optional, comma-separated).
|
||||
# *[bot] is a catch-all for any GitHub App bot account.
|
||||
allowlist: action@github.com,actions-user,ampagent,claude,comfy-pr-bot,GitHub Action,github-actions,Glary Bot,Glary-Bot,*[bot]
|
||||
|
||||
# Custom PR comment messages
|
||||
custom-notsigned-prcomment: |
|
||||
🎉 Thank you for your contribution, we really appreciate it! 🎉
|
||||
|
||||
Like many open source projects, we require contributors to sign our [Contributor License Agreement (CLA)](https://github.com/Comfy-Org/comfy-cla/blob/main/comfyui_icla.md). A CLA makes the ownership of contributions explicit, so contributors and the project share a clear understanding of how the code can be used. By signing, you:
|
||||
|
||||
- Confirm that you own your contribution.
|
||||
- Keep the right to reuse your own code.
|
||||
- Grant us a copyright license to include and share it within our projects.
|
||||
|
||||
CLAs are standard practice across major open source projects including those under the Apache Software Foundation and the Linux Foundation. Ours is based on the Apache Software Foundation's CLA. Most importantly, it would enable us to relicense the project under a more permissive license in the future, giving the project and its community greater flexibility.
|
||||
|
||||
✍ **To sign, please post a new comment on this PR with exactly the following text:** ✍
|
||||
|
||||
custom-pr-sign-comment: I have read and agree to the Contributor License Agreement
|
||||
|
||||
custom-allsigned-prcomment: |
|
||||
✅ All contributors have signed the CLA. Thank you! This PR is ready to be merged.
|
||||
64
comfy/ops.py
64
comfy/ops.py
@ -256,7 +256,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
|
||||
if (want_requant and len(fns) == 0 or update_weight):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
if isinstance(orig, QuantizedTensor):
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
y = orig.requantize_from_float(x, scale="recalculate", stochastic_rounding=seed)
|
||||
else:
|
||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
|
||||
if want_requant and len(fns) == 0:
|
||||
@ -1089,6 +1089,19 @@ def _load_quantized_module(module, super_load, state_dict, prefix, local_metadat
|
||||
if ts is None or bs is None:
|
||||
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||
scales = {"scale": ts, "block_scale": bs}
|
||||
elif module.quant_format == "int8_tensorwise":
|
||||
scale = pop_scale("weight_scale")
|
||||
if scale is None:
|
||||
raise ValueError(f"Missing INT8 weight scale for layer {layer_name}")
|
||||
scales = {"scale": scale}
|
||||
params_conf = layer_conf.get("params", {})
|
||||
if not isinstance(params_conf, dict):
|
||||
params_conf = {}
|
||||
if layer_conf.get("convrot", params_conf.get("convrot", False)):
|
||||
scales["convrot"] = True
|
||||
scales["convrot_groupsize"] = int(
|
||||
layer_conf.get("convrot_groupsize", params_conf.get("convrot_groupsize", 256))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
|
||||
|
||||
@ -1131,6 +1144,10 @@ def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extr
|
||||
quant_conf = {"format": module.quant_format}
|
||||
if getattr(module, '_full_precision_mm_config', False):
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
params = getattr(module.weight, "_params", None)
|
||||
if module.quant_format == "int8_tensorwise" and getattr(params, "convrot", False):
|
||||
quant_conf["convrot"] = True
|
||||
quant_conf["convrot_groupsize"] = getattr(params, "convrot_groupsize", 256)
|
||||
if extra_quant_conf:
|
||||
quant_conf.update(extra_quant_conf)
|
||||
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
|
||||
@ -1183,8 +1200,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
|
||||
def forward_comfy_cast_weights(
|
||||
self,
|
||||
input,
|
||||
compute_dtype=None,
|
||||
want_requant=False,
|
||||
weight_only_quant=False,
|
||||
):
|
||||
if weight_only_quant:
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
input=None,
|
||||
dtype=self.weight.dtype,
|
||||
device=input.device,
|
||||
bias_dtype=input.dtype,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=True,
|
||||
)
|
||||
weight = weight.to(dtype=input.dtype)
|
||||
else:
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
input,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=want_requant,
|
||||
)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
@ -1203,9 +1245,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0
|
||||
)
|
||||
quantize_input = QUANT_ALGOS.get(getattr(self, 'quant_format', None), {}).get("quantize_input", True)
|
||||
|
||||
# Training path: quantized forward with compute_dtype backward via autograd function
|
||||
if (input.requires_grad and _use_quantized):
|
||||
if (input.requires_grad and _use_quantized and quantize_input):
|
||||
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
@ -1227,7 +1270,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
return output
|
||||
|
||||
# Inference path (unchanged)
|
||||
if _use_quantized:
|
||||
if _use_quantized and quantize_input:
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
@ -1241,7 +1284,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
||||
weight_only_quant = _use_quantized and not quantize_input and isinstance(self.weight, QuantizedTensor)
|
||||
output = self.forward_comfy_cast_weights(
|
||||
input,
|
||||
compute_dtype,
|
||||
want_requant=isinstance(input, QuantizedTensor),
|
||||
weight_only_quant=weight_only_quant,
|
||||
)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if reshaped_3d:
|
||||
@ -1257,8 +1306,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
# dtype is now implicit in the layout class
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
weight = self.weight.requantize_from_float(weight, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
|
||||
@ -10,6 +10,7 @@ try:
|
||||
QuantizedLayout,
|
||||
TensorCoreFP8Layout as _CKFp8Layout,
|
||||
TensorCoreNVFP4Layout as _CKNvfp4Layout,
|
||||
TensorWiseINT8Layout as _CKTensorWiseINT8Layout,
|
||||
register_layout_op,
|
||||
register_layout_class,
|
||||
get_layout_class,
|
||||
@ -47,6 +48,9 @@ except ImportError as e:
|
||||
class _CKNvfp4Layout:
|
||||
pass
|
||||
|
||||
class _CKTensorWiseINT8Layout:
|
||||
pass
|
||||
|
||||
def register_layout_class(name, cls):
|
||||
pass
|
||||
|
||||
@ -174,6 +178,7 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
|
||||
|
||||
# Backward compatibility alias - default to E4M3
|
||||
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
||||
TensorWiseINT8Layout = _CKTensorWiseINT8Layout
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
@ -184,6 +189,7 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||
register_layout_class("TensorWiseINT8Layout", _CKTensorWiseINT8Layout)
|
||||
if _CK_MXFP8_AVAILABLE:
|
||||
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||
|
||||
@ -214,6 +220,13 @@ if _CK_MXFP8_AVAILABLE:
|
||||
"group_size": 32,
|
||||
}
|
||||
|
||||
QUANT_ALGOS["int8_tensorwise"] = {
|
||||
"storage_t": torch.int8,
|
||||
"parameters": {"weight_scale"},
|
||||
"comfy_tensor_layout": "TensorWiseINT8Layout",
|
||||
"quantize_input": False,
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Re-exports for backward compatibility
|
||||
@ -226,6 +239,7 @@ __all__ = [
|
||||
"TensorCoreFP8E4M3Layout",
|
||||
"TensorCoreFP8E5M2Layout",
|
||||
"TensorCoreNVFP4Layout",
|
||||
"TensorWiseINT8Layout",
|
||||
"QUANT_ALGOS",
|
||||
"register_layout_op",
|
||||
]
|
||||
|
||||
@ -891,6 +891,14 @@ class Tracks(ComfyTypeIO):
|
||||
track_visibility: torch.Tensor
|
||||
Type = TrackDict
|
||||
|
||||
@comfytype(io_type="DICT")
|
||||
class Dict(ComfyTypeIO):
|
||||
Type = dict
|
||||
|
||||
@comfytype(io_type="ARRAY")
|
||||
class Array(ComfyTypeIO):
|
||||
Type = list
|
||||
|
||||
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
||||
class MultiType:
|
||||
Type = Any
|
||||
@ -1279,6 +1287,19 @@ class Color(ComfyTypeIO):
|
||||
def as_dict(self):
|
||||
return super().as_dict()
|
||||
|
||||
|
||||
@comfytype(io_type="COLORS")
|
||||
class Colors(ComfyTypeIO):
|
||||
Type = list[Color.Type]
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: list[str]=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = []
|
||||
|
||||
|
||||
@comfytype(io_type="BOUNDING_BOX")
|
||||
class BoundingBox(ComfyTypeIO):
|
||||
class BoundingBoxDict(TypedDict):
|
||||
@ -1326,6 +1347,20 @@ class Curve(ComfyTypeIO):
|
||||
return d
|
||||
|
||||
|
||||
@comfytype(io_type="BOUNDING_BOXES")
|
||||
class BoundingBoxes(ComfyTypeIO):
|
||||
class BoundingBoxWithMetadata(BoundingBox.BoundingBoxDict):
|
||||
metadata: dict
|
||||
Type = list[BoundingBoxWithMetadata]
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: list[dict]=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = []
|
||||
|
||||
|
||||
@comfytype(io_type="HISTOGRAM")
|
||||
class Histogram(ComfyTypeIO):
|
||||
"""A histogram represented as a list of bin counts."""
|
||||
@ -2376,6 +2411,8 @@ __all__ = [
|
||||
"AnyType",
|
||||
"MultiType",
|
||||
"Tracks",
|
||||
"Dict",
|
||||
"Array",
|
||||
"Color",
|
||||
# Dynamic Types
|
||||
"MatchType",
|
||||
@ -2394,6 +2431,8 @@ __all__ = [
|
||||
"PriceBadgeDepends",
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"BoundingBoxes",
|
||||
"Colors",
|
||||
"Curve",
|
||||
"Histogram",
|
||||
"Range",
|
||||
|
||||
@ -163,15 +163,31 @@ class SeedanceVirtualLibraryCreateAssetRequest(BaseModel):
|
||||
asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.")
|
||||
|
||||
|
||||
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
|
||||
# Dollars per 1K tokens, keyed by (model_id, has_video_input, resolution).
|
||||
SEEDANCE2_PRICE_PER_1K_TOKENS = {
|
||||
("dreamina-seedance-2-0-260128", False): 0.007,
|
||||
("dreamina-seedance-2-0-260128", True): 0.0043,
|
||||
("dreamina-seedance-2-0-fast-260128", False): 0.0056,
|
||||
("dreamina-seedance-2-0-fast-260128", True): 0.0033,
|
||||
("dreamina-seedance-2-0-260128", False, "480p"): 0.007,
|
||||
("dreamina-seedance-2-0-260128", True, "480p"): 0.0043,
|
||||
("dreamina-seedance-2-0-260128", False, "720p"): 0.007,
|
||||
("dreamina-seedance-2-0-260128", True, "720p"): 0.0043,
|
||||
("dreamina-seedance-2-0-260128", False, "1080p"): 0.0077,
|
||||
("dreamina-seedance-2-0-260128", True, "1080p"): 0.0047,
|
||||
("dreamina-seedance-2-0-260128", False, "4k"): 0.004,
|
||||
("dreamina-seedance-2-0-260128", True, "4k"): 0.0024,
|
||||
("dreamina-seedance-2-0-fast-260128", False, "480p"): 0.0056,
|
||||
("dreamina-seedance-2-0-fast-260128", True, "480p"): 0.0033,
|
||||
("dreamina-seedance-2-0-fast-260128", False, "720p"): 0.0056,
|
||||
("dreamina-seedance-2-0-fast-260128", True, "720p"): 0.0033,
|
||||
("dreamina-seedance-2-0-mini", False, "480p"): 0.0035,
|
||||
("dreamina-seedance-2-0-mini", True, "480p"): 0.0021,
|
||||
("dreamina-seedance-2-0-mini", False, "720p"): 0.0035,
|
||||
("dreamina-seedance-2-0-mini", True, "720p"): 0.0021,
|
||||
}
|
||||
|
||||
|
||||
def seedance2_price_per_1k_tokens(model_id: str, has_video_input: bool, resolution: str) -> float | None:
|
||||
return SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input, resolution))
|
||||
|
||||
|
||||
RECOMMENDED_PRESETS = [
|
||||
("1024x1024 (1:1)", 1024, 1024),
|
||||
("864x1152 (3:4)", 864, 1152),
|
||||
@ -266,6 +282,10 @@ SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
|
||||
"480p": {"min": 409_600, "max": 927_408},
|
||||
"720p": {"min": 409_600, "max": 927_408},
|
||||
},
|
||||
"dreamina-seedance-2-0-mini": {
|
||||
"480p": {"min": 409_600, "max": 927_408},
|
||||
"720p": {"min": 409_600, "max": 927_408},
|
||||
},
|
||||
}
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
|
||||
@ -15,7 +15,6 @@ from comfy_api_nodes.apis.bytedance import (
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4_0,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4_5,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_5_LITE,
|
||||
SEEDANCE2_PRICE_PER_1K_TOKENS,
|
||||
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS,
|
||||
VIDEO_TASKS_EXECUTION_TIME,
|
||||
GetAssetResponse,
|
||||
@ -40,6 +39,7 @@ from comfy_api_nodes.apis.bytedance import (
|
||||
TaskVideoContentUrl,
|
||||
Text2ImageTaskCreationRequest,
|
||||
Text2VideoTaskCreationRequest,
|
||||
seedance2_price_per_1k_tokens,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
@ -89,6 +89,7 @@ BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT = "/proxy/byteplus-seedance2/api/v3/cont
|
||||
SEEDANCE_MODELS = {
|
||||
"Seedance 2.0": "dreamina-seedance-2-0-260128",
|
||||
"Seedance 2.0 Fast": "dreamina-seedance-2-0-fast-260128",
|
||||
"Seedance 2.0 Mini": "dreamina-seedance-2-0-mini",
|
||||
}
|
||||
|
||||
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
|
||||
@ -141,7 +142,7 @@ SEEDANCE2_RATIO_WH = {
|
||||
"9:16": (9, 16),
|
||||
"21:9": (21, 9),
|
||||
}
|
||||
SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080}
|
||||
SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080, "4k": 2160}
|
||||
|
||||
|
||||
def _seedance2_target_dims(resolution: str, ratio: str, image: torch.Tensor) -> tuple[int, int]:
|
||||
@ -377,9 +378,9 @@ async def _seedance_virtual_library_upload_video_asset(
|
||||
return f"asset://{create_resp.asset_id}"
|
||||
|
||||
|
||||
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
|
||||
def _seedance2_price_extractor(model_id: str, has_video_input: bool, resolution: str):
|
||||
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
|
||||
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
|
||||
rate = seedance2_price_per_1k_tokens(model_id, has_video_input, resolution)
|
||||
if rate is None:
|
||||
return None
|
||||
|
||||
@ -1621,10 +1622,12 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p", "4k"])),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Mini", _seedance2_text_inputs(["480p", "720p"])),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
|
||||
"Mini for the fastest, lowest-cost generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
@ -1660,11 +1663,16 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||
$rate480 := 10044;
|
||||
$rate720 := 21600;
|
||||
$rate1080 := 48800;
|
||||
$rate4k := 195200;
|
||||
$m := widgets.model;
|
||||
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$rate := $res = "1080p" ? $rate1080 :
|
||||
$pricePer1K := $res = "4k" ? 0.00572 :
|
||||
$res = "1080p" ? 0.011011 :
|
||||
$contains($m, "mini") ? 0.005005 :
|
||||
$contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$rate := $res = "4k" ? $rate4k :
|
||||
$res = "1080p" ? $rate1080 :
|
||||
$res = "720p" ? $rate720 :
|
||||
$rate480;
|
||||
$cost := $dur * $rate * $pricePer1K / 1000;
|
||||
@ -1703,7 +1711,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
|
||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False, resolution=model["resolution"]),
|
||||
poll_interval=9,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
@ -1724,14 +1732,19 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0",
|
||||
_seedance2_text_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
|
||||
_seedance2_text_inputs(["480p", "720p", "1080p", "4k"], default_ratio="adaptive"),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0 Fast",
|
||||
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0 Mini",
|
||||
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
|
||||
"Mini for the fastest, lowest-cost generation.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"first_frame",
|
||||
@ -1791,11 +1804,16 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
$rate480 := 10044;
|
||||
$rate720 := 21600;
|
||||
$rate1080 := 48800;
|
||||
$rate4k := 195200;
|
||||
$m := widgets.model;
|
||||
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$rate := $res = "1080p" ? $rate1080 :
|
||||
$pricePer1K := $res = "4k" ? 0.00572 :
|
||||
$res = "1080p" ? 0.011011 :
|
||||
$contains($m, "mini") ? 0.005005 :
|
||||
$contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$rate := $res = "4k" ? $rate4k :
|
||||
$res = "1080p" ? $rate1080 :
|
||||
$res = "720p" ? $rate720 :
|
||||
$rate480;
|
||||
$cost := $dur * $rate * $pricePer1K / 1000;
|
||||
@ -1913,7 +1931,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
|
||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False, resolution=model["resolution"]),
|
||||
poll_interval=9,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
@ -2010,14 +2028,19 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0",
|
||||
_seedance2_reference_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
|
||||
_seedance2_reference_inputs(["480p", "720p", "1080p", "4k"], default_ratio="adaptive"),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0 Fast",
|
||||
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0 Mini",
|
||||
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
|
||||
"Mini for the fastest, lowest-cost generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
@ -2056,13 +2079,21 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
$rate480 := 10044;
|
||||
$rate720 := 21600;
|
||||
$rate1080 := 48800;
|
||||
$rate4k := 195200;
|
||||
$m := widgets.model;
|
||||
$hasVideo := $lookup(inputGroups, "model.reference_videos") > 0;
|
||||
$noVideoPricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$videoPricePer1K := $contains($m, "fast") ? 0.004719 : 0.006149;
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$rate := $res = "1080p" ? $rate1080 :
|
||||
$noVideoPricePer1K := $res = "4k" ? 0.00572 :
|
||||
$res = "1080p" ? 0.011011 :
|
||||
$contains($m, "mini") ? 0.005005 :
|
||||
$contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$videoPricePer1K := $res = "4k" ? 0.003432 :
|
||||
$res = "1080p" ? 0.006721 :
|
||||
$contains($m, "mini") ? 0.003003 :
|
||||
$contains($m, "fast") ? 0.004719 : 0.006149;
|
||||
$rate := $res = "4k" ? $rate4k :
|
||||
$res = "1080p" ? $rate1080 :
|
||||
$res = "720p" ? $rate720 :
|
||||
$rate480;
|
||||
$noVideoCost := $dur * $rate * $noVideoPricePer1K / 1000;
|
||||
@ -2258,7 +2289,9 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input),
|
||||
price_extractor=_seedance2_price_extractor(
|
||||
model_id, has_video_input=has_video_input, resolution=model["resolution"]
|
||||
),
|
||||
poll_interval=9,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
@ -30,7 +30,7 @@ from comfy_api_nodes.util import (
|
||||
|
||||
|
||||
_GROK_VIDEO_MODEL_API_IDS = {
|
||||
"grok-imagine-video-1.5": "grok-imagine-video-1.5-preview",
|
||||
"grok-imagine-video-1.5": "grok-imagine-video-1.5",
|
||||
}
|
||||
|
||||
|
||||
@ -521,8 +521,8 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["480p", "720p"],
|
||||
tooltip="The resolution of the output video.",
|
||||
options=["480p", "720p", "1080p"],
|
||||
tooltip="The resolution of the output video. 1080p is only available for grok-imagine-video-1.5.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
@ -570,11 +570,12 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
(
|
||||
$is15 := $contains(widgets.model, "1.5");
|
||||
$rate := $is15
|
||||
? (widgets.resolution = "720p" ? 0.2002 : 0.1144)
|
||||
? (widgets.resolution = "1080p" ? 0.25 : (widgets.resolution = "720p" ? 0.14 : 0.08))
|
||||
: (widgets.resolution = "720p" ? 0.07 : 0.05);
|
||||
$imgCost := $is15 ? 0.0143 : 0.002;
|
||||
$imgCost := $is15 ? 0.01 : 0.002;
|
||||
$base := $rate * widgets.duration;
|
||||
{"type":"usd","usd": inputs.image.connected ? $base + $imgCost : $base}
|
||||
$total := inputs.image.connected ? $base + $imgCost : $base;
|
||||
{"type":"usd","usd": $is15 ? $total * 1.43 : $total}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -593,6 +594,8 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
if image is None and model == "grok-imagine-video-1.5":
|
||||
raise ValueError(f"The '{model}' model requires an input image; connect one to the 'image' input.")
|
||||
if resolution == "1080p" and model != "grok-imagine-video-1.5":
|
||||
raise ValueError(f"1080p resolution is only available for grok-imagine-video-1.5, not '{model}'.")
|
||||
image_url = None
|
||||
if image is not None:
|
||||
if get_number_of_images(image) != 1:
|
||||
|
||||
@ -48,10 +48,13 @@ from comfy_api_nodes.util import (
|
||||
upload_image_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_audio_duration,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
validate_video_duration,
|
||||
)
|
||||
|
||||
|
||||
RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)")
|
||||
|
||||
|
||||
@ -1657,6 +1660,44 @@ class HappyHorseTextToVideoApi(IO.ComfyNode):
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"happyhorse-1.1-t2v",
|
||||
[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt describing the elements and visual features. "
|
||||
"Supports English and Chinese.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720P", "1080P"],
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"ratio",
|
||||
options=[
|
||||
"16:9",
|
||||
"9:16",
|
||||
"1:1",
|
||||
"4:3",
|
||||
"3:4",
|
||||
"21:9",
|
||||
"9:21",
|
||||
"5:4",
|
||||
"4:5",
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=3,
|
||||
max=15,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"happyhorse-1.0-t2v",
|
||||
[
|
||||
@ -1719,7 +1760,9 @@ class HappyHorseTextToVideoApi(IO.ComfyNode):
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$ppsTable := { "720p": 0.14, "1080p": 0.24 };
|
||||
$ppsTable := $contains(widgets.model, "1.1")
|
||||
? { "720p": 0.2002, "1080p": 0.2574 }
|
||||
: { "720p": 0.14, "1080p": 0.24 };
|
||||
$pps := $lookup($ppsTable, $res);
|
||||
{ "type": "usd", "usd": $pps * $dur }
|
||||
)
|
||||
@ -1781,6 +1824,30 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"happyhorse-1.1-i2v",
|
||||
[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt describing the elements and visual features. "
|
||||
"Supports English and Chinese.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720P", "1080P"],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=3,
|
||||
max=15,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"happyhorse-1.0-i2v",
|
||||
[
|
||||
@ -1843,7 +1910,9 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$ppsTable := { "720p": 0.14, "1080p": 0.24 };
|
||||
$ppsTable := $contains(widgets.model, "1.1")
|
||||
? { "720p": 0.2002, "1080p": 0.2574 }
|
||||
: { "720p": 0.14, "1080p": 0.24 };
|
||||
$pps := $lookup($ppsTable, $res);
|
||||
{ "type": "usd", "usd": $pps * $dur }
|
||||
)
|
||||
@ -1859,6 +1928,8 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
|
||||
seed: int,
|
||||
watermark: bool,
|
||||
):
|
||||
validate_image_dimensions(first_frame, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1), strict=False)
|
||||
media = [
|
||||
Wan27MediaItem(
|
||||
type="first_frame",
|
||||
@ -2053,6 +2124,62 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"happyhorse-1.1-r2v",
|
||||
[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt describing the video. Use identifiers such as 'character1' and "
|
||||
"'character2' to refer to the reference characters.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720P", "1080P"],
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"ratio",
|
||||
options=[
|
||||
"16:9",
|
||||
"9:16",
|
||||
"1:1",
|
||||
"4:3",
|
||||
"3:4",
|
||||
"21:9",
|
||||
"9:21",
|
||||
"5:4",
|
||||
"4:5",
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=3,
|
||||
max=15,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("reference_image"),
|
||||
names=[
|
||||
"image1",
|
||||
"image2",
|
||||
"image3",
|
||||
"image4",
|
||||
"image5",
|
||||
"image6",
|
||||
"image7",
|
||||
"image8",
|
||||
"image9",
|
||||
],
|
||||
min=1,
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"happyhorse-1.0-r2v",
|
||||
[
|
||||
@ -2133,7 +2260,9 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$ppsTable := { "720p": 0.14, "1080p": 0.24 };
|
||||
$ppsTable := $contains(widgets.model, "1.1")
|
||||
? { "720p": 0.2002, "1080p": 0.2574 }
|
||||
: { "720p": 0.14, "1080p": 0.24 };
|
||||
$pps := $lookup($ppsTable, $res);
|
||||
{ "type": "usd", "usd": $pps * $dur }
|
||||
)
|
||||
@ -2149,8 +2278,11 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
|
||||
watermark: bool,
|
||||
):
|
||||
validate_string(model["prompt"], strip_whitespace=False, min_length=1)
|
||||
media = []
|
||||
reference_images = model.get("reference_images", {})
|
||||
for key in reference_images:
|
||||
validate_image_dimensions(reference_images[key], min_width=400, min_height=400)
|
||||
validate_image_aspect_ratio(reference_images[key], (1, 2.5), (2.5, 1), strict=False)
|
||||
media = []
|
||||
for key in reference_images:
|
||||
media.append(
|
||||
Wan27MediaItem(
|
||||
@ -2159,7 +2291,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
|
||||
)
|
||||
)
|
||||
if not media:
|
||||
raise ValueError("At least one reference reference image must be provided.")
|
||||
raise ValueError("At least one reference image must be provided.")
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
|
||||
23
comfy_extras/color_util.py
Normal file
23
comfy_extras/color_util.py
Normal file
@ -0,0 +1,23 @@
|
||||
def hex_to_rgb(value: str) -> tuple[int, int, int]:
|
||||
h = value.lstrip("#")
|
||||
if len(h) != 6:
|
||||
return (255, 255, 255)
|
||||
try:
|
||||
return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16))
|
||||
except ValueError:
|
||||
return (255, 255, 255)
|
||||
|
||||
|
||||
def readable_color(rgb: tuple[int, int, int]) -> tuple[int, int, int]:
|
||||
r, g, b = rgb
|
||||
lum = 0.299 * r + 0.587 * g + 0.114 * b
|
||||
if lum >= 130:
|
||||
return (r, g, b)
|
||||
t = (130 - lum) / (255 - lum)
|
||||
return (round(r + (255 - r) * t), round(g + (255 - g) * t), round(b + (255 - b) * t))
|
||||
|
||||
|
||||
def normalize_palette(colors) -> list[str]:
|
||||
if isinstance(colors, dict):
|
||||
colors = colors.values()
|
||||
return [c.upper() for c in colors if isinstance(c, str) and c]
|
||||
253
comfy_extras/nodes_bounding_boxes.py
Normal file
253
comfy_extras/nodes_bounding_boxes.py
Normal file
@ -0,0 +1,253 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageEnhance, ImageFont
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_extras.color_util import hex_to_rgb, normalize_palette, readable_color
|
||||
|
||||
_PREVIEW_LONG_EDGE = 1024
|
||||
_PREVIEW_DIM = 0.25
|
||||
|
||||
|
||||
def pixels_to_fractions(box: dict, width: int, height: int) -> dict:
|
||||
w = width or 1
|
||||
h = height or 1
|
||||
return {
|
||||
"x": box.get("x", 0) / w,
|
||||
"y": box.get("y", 0) / h,
|
||||
"w": box.get("width", 0) / w,
|
||||
"h": box.get("height", 0) / h,
|
||||
}
|
||||
|
||||
|
||||
def fractions_to_pixels(box: dict, width: int, height: int) -> dict:
|
||||
x, y = box.get("x", 0.0), box.get("y", 0.0)
|
||||
w, h = box.get("w", 0.0), box.get("h", 0.0)
|
||||
if w < 0:
|
||||
x, w = x + w, -w
|
||||
if h < 0:
|
||||
y, h = y + h, -h
|
||||
return {
|
||||
"x": round(x * width),
|
||||
"y": round(y * height),
|
||||
"width": round(w * width),
|
||||
"height": round(h * height),
|
||||
}
|
||||
|
||||
|
||||
def fractions_to_bbox_frame(boxes: list, width: int, height: int) -> list:
|
||||
pixels = [
|
||||
fractions_to_pixels(box, width, height)
|
||||
for box in boxes
|
||||
if isinstance(box, dict)
|
||||
]
|
||||
return [pixels] if pixels else []
|
||||
|
||||
|
||||
def _font(size: int):
|
||||
try:
|
||||
return ImageFont.load_default(size)
|
||||
except Exception:
|
||||
return ImageFont.load_default()
|
||||
|
||||
|
||||
def _wrap(draw, text: str, font, max_w: float) -> list[str]:
|
||||
lines = []
|
||||
for para in text.split("\n"):
|
||||
line = ""
|
||||
for word in para.split():
|
||||
test = word if not line else line + " " + word
|
||||
if line and draw.textlength(test, font=font) > max_w:
|
||||
lines.append(line)
|
||||
line = word
|
||||
else:
|
||||
line = test
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
|
||||
def _bg_from_image(image) -> Image.Image | None:
|
||||
if image is None:
|
||||
return None
|
||||
try:
|
||||
arr = (image[0].detach().cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
return Image.fromarray(arr)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def render_preview(regions, width, height, bg=None):
|
||||
if bg is not None:
|
||||
iw, ih = bg.size
|
||||
long_edge = max(iw, ih) or 1
|
||||
scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge)
|
||||
rw, rh = max(1, round(iw * scale)), max(1, round(ih * scale))
|
||||
base = bg.convert("RGB").resize((rw, rh), Image.LANCZOS)
|
||||
base = ImageEnhance.Brightness(base).enhance(_PREVIEW_DIM)
|
||||
img = base.convert("RGBA")
|
||||
else:
|
||||
long_edge = max(width, height) or 1
|
||||
scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge)
|
||||
rw, rh = max(1, round(width * scale)), max(1, round(height * scale))
|
||||
grey = round(_PREVIEW_DIM * 128)
|
||||
img = Image.new("RGBA", (rw, rh), (grey, grey, grey, 255))
|
||||
|
||||
overlay = Image.new("RGBA", (rw, rh), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(overlay)
|
||||
fs = max(10, round(rh / 64))
|
||||
font = _font(fs)
|
||||
tag_font = _font(max(9, fs - 2))
|
||||
line_h = fs + 2
|
||||
|
||||
for i, region in enumerate(regions):
|
||||
if not isinstance(region, dict):
|
||||
continue
|
||||
palette = [c for c in (region.get("palette") or []) if c]
|
||||
r, g, b = hex_to_rgb(palette[0]) if palette else (140, 140, 140)
|
||||
x1 = max(0, min(rw, round(region.get("x", 0) * rw)))
|
||||
y1 = max(0, min(rh, round(region.get("y", 0) * rh)))
|
||||
x2 = max(0, min(rw, round((region.get("x", 0) + region.get("w", 0)) * rw)))
|
||||
y2 = max(0, min(rh, round((region.get("y", 0) + region.get("h", 0)) * rh)))
|
||||
if x2 < x1:
|
||||
x1, x2 = x2, x1
|
||||
if y2 < y1:
|
||||
y1, y2 = y2, y1
|
||||
|
||||
draw.rectangle([x1, y1, x2, y2], outline=(r, g, b, 255), width=2)
|
||||
|
||||
swatches = palette[:5]
|
||||
if swatches and (x2 - x1) > 2:
|
||||
sh = max(5, fs // 2)
|
||||
seg = (x2 - x1) / len(swatches)
|
||||
for p, hexc in enumerate(swatches):
|
||||
sx = x1 + round(p * seg)
|
||||
draw.rectangle([sx, y1, x1 + round((p + 1) * seg), y1 + sh], fill=hex_to_rgb(hexc))
|
||||
|
||||
etype = "text" if region.get("type") == "text" else "obj"
|
||||
tag = str(i + 1).zfill(2)
|
||||
tw = draw.textlength(tag, font=tag_font)
|
||||
draw.rectangle([x1, y1, x1 + tw + 6, y1 + fs + 2], fill=(r, g, b, 255))
|
||||
tag_fill = (0, 0, 0, 255) if (0.299 * r + 0.587 * g + 0.114 * b) > 140 else (255, 255, 255, 255)
|
||||
draw.text((x1 + 3, y1 + 1), tag, fill=tag_fill, font=tag_font)
|
||||
|
||||
body = region.get("desc", "") or ""
|
||||
if etype == "text" and region.get("text"):
|
||||
body = '"%s"%s' % (region["text"], " — " + body if body else "")
|
||||
if body and (x2 - x1) > 8:
|
||||
ty = y1 + fs + 5
|
||||
for line in _wrap(draw, body, font, x2 - x1 - 8):
|
||||
if ty > y2:
|
||||
break
|
||||
draw.text((x1 + 4, ty), line, fill=readable_color((r, g, b)) + (255,), font=font)
|
||||
ty += line_h
|
||||
|
||||
composed = Image.alpha_composite(img, overlay).convert("RGB")
|
||||
arr = np.asarray(composed, dtype=np.float32) / 255.0
|
||||
return torch.from_numpy(arr).unsqueeze(0)
|
||||
|
||||
|
||||
def boxes_to_regions(boxes, width: int, height: int) -> list:
|
||||
regions: list = []
|
||||
if not isinstance(boxes, list):
|
||||
return regions
|
||||
for box in boxes:
|
||||
if not isinstance(box, dict):
|
||||
continue
|
||||
meta = box.get("metadata")
|
||||
meta = meta if isinstance(meta, dict) else {}
|
||||
regions.append({
|
||||
**pixels_to_fractions(box, width, height),
|
||||
"type": meta.get("type", "obj"),
|
||||
"text": meta.get("text", ""),
|
||||
"desc": meta.get("desc", ""),
|
||||
"palette": meta.get("palette", []),
|
||||
})
|
||||
return regions
|
||||
|
||||
|
||||
def _norm_bbox(region: dict) -> list[int]:
|
||||
def grid(value: float) -> int:
|
||||
return max(0, min(1000, round(value * 1000)))
|
||||
|
||||
x, y = region.get("x", 0.0), region.get("y", 0.0)
|
||||
w, h = region.get("w", 0.0), region.get("h", 0.0)
|
||||
ymin, xmin, ymax, xmax = grid(y), grid(x), grid(y + h), grid(x + w)
|
||||
if ymin > ymax:
|
||||
ymin, ymax = ymax, ymin
|
||||
if xmin > xmax:
|
||||
xmin, xmax = xmax, xmin
|
||||
return [ymin, xmin, ymax, xmax]
|
||||
|
||||
|
||||
def build_elements(regions: list) -> list:
|
||||
elements = []
|
||||
for region in regions:
|
||||
if not isinstance(region, dict):
|
||||
continue
|
||||
etype = "text" if region.get("type") == "text" else "obj"
|
||||
element = {"type": etype}
|
||||
element["bbox"] = _norm_bbox(region)
|
||||
if etype == "text":
|
||||
element["text"] = region.get("text", "")
|
||||
element["desc"] = region.get("desc", "")
|
||||
palette = normalize_palette(region.get("palette", []))
|
||||
if palette:
|
||||
element["color_palette"] = palette[:5]
|
||||
elements.append(element)
|
||||
return elements
|
||||
|
||||
|
||||
class CreateBoundingBoxes(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
editor_state = io.BoundingBoxes.Input(
|
||||
"editor_state",
|
||||
socketless=False,
|
||||
tooltip="Draw bounding boxes and set each box type, text, description, color palette. Start with background element first and foreground last.",
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="CreateBoundingBoxes",
|
||||
display_name="Create Bounding Boxes",
|
||||
category="utilities",
|
||||
description="Draw bounding boxes in a canvas. Outputs Ideogram prompt elements, pixel-space bounding boxes, and a preview image.",
|
||||
inputs=[
|
||||
io.Image.Input(
|
||||
"background",
|
||||
optional=True,
|
||||
tooltip="Optional image used as background in the canvas and preview.",
|
||||
),
|
||||
io.Int.Input("width", default=1024, min=64, max=16384, step=16,
|
||||
tooltip="Width of the canvas and the pixel grid for the bounding boxes."),
|
||||
io.Int.Input("height", default=1024, min=64, max=16384, step=16,
|
||||
tooltip="Height of the canvas and the pixel grid for the bounding boxes."),
|
||||
editor_state,
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="preview"),
|
||||
io.BoundingBox.Output(display_name="bboxes"),
|
||||
io.Array.Output(display_name="elements"),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, editor_state=None, background=None) -> io.NodeOutput:
|
||||
regions = boxes_to_regions(editor_state, width, height)
|
||||
preview = render_preview(regions, width, height, _bg_from_image(background))
|
||||
return io.NodeOutput(
|
||||
preview,
|
||||
fractions_to_bbox_frame(regions, width, height),
|
||||
build_elements(regions),
|
||||
ui={"dims": [width, height]},
|
||||
)
|
||||
|
||||
|
||||
class BoundingBoxesExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [CreateBoundingBoxes]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> BoundingBoxesExtension:
|
||||
return BoundingBoxesExtension()
|
||||
@ -1,5 +1,6 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_extras.color_util import hex_to_rgb
|
||||
|
||||
|
||||
class ColorToRGBInt(io.ComfyNode):
|
||||
@ -24,9 +25,11 @@ class ColorToRGBInt(io.ComfyNode):
|
||||
# expect format #RRGGBB
|
||||
if len(color) != 7 or color[0] != "#":
|
||||
raise ValueError("Color must be in format #RRGGBB")
|
||||
r = int(color[1:3], 16)
|
||||
g = int(color[3:5], 16)
|
||||
b = int(color[5:7], 16)
|
||||
try:
|
||||
int(color[1:], 16)
|
||||
except ValueError:
|
||||
raise ValueError("Color must be in format #RRGGBB") from None
|
||||
r, g, b = hex_to_rgb(color)
|
||||
|
||||
rgb_int = r * 256 * 256 + g * 256 + b
|
||||
return io.NodeOutput(rgb_int, color)
|
||||
|
||||
@ -1,85 +1,68 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import ctypes
|
||||
import logging
|
||||
import ctypes.util
|
||||
import importlib.util
|
||||
from typing import TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import nodes
|
||||
import comfy_angle
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from typing_extensions import override
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _check_opengl_availability():
|
||||
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
|
||||
logger.debug("_check_opengl_availability: starting")
|
||||
missing = []
|
||||
def _preload_angle():
|
||||
egl_path = comfy_angle.get_egl_path()
|
||||
gles_path = comfy_angle.get_glesv2_path()
|
||||
|
||||
# Check Python packages (using find_spec to avoid importing)
|
||||
logger.debug("_check_opengl_availability: checking for glfw package")
|
||||
if importlib.util.find_spec("glfw") is None:
|
||||
missing.append("glfw")
|
||||
if sys.platform == "win32":
|
||||
angle_dir = comfy_angle.get_lib_dir()
|
||||
os.add_dll_directory(angle_dir)
|
||||
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
|
||||
|
||||
logger.debug("_check_opengl_availability: checking for OpenGL package")
|
||||
if importlib.util.find_spec("OpenGL") is None:
|
||||
missing.append("PyOpenGL")
|
||||
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
|
||||
)
|
||||
|
||||
# On Linux without display, check if headless backends are available
|
||||
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
|
||||
if sys.platform.startswith("linux"):
|
||||
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
|
||||
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
|
||||
if not has_display:
|
||||
# Check for EGL or OSMesa libraries
|
||||
logger.debug("_check_opengl_availability: checking for EGL library")
|
||||
has_egl = ctypes.util.find_library("EGL")
|
||||
logger.debug("_check_opengl_availability: checking for OSMesa library")
|
||||
has_osmesa = ctypes.util.find_library("OSMesa")
|
||||
|
||||
# Error disabled for CI as it fails this check
|
||||
# if not has_egl and not has_osmesa:
|
||||
# raise RuntimeError(
|
||||
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
|
||||
# "See error below for installation instructions."
|
||||
# )
|
||||
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
|
||||
|
||||
logger.debug("_check_opengl_availability: completed")
|
||||
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
|
||||
ctypes.CDLL(str(egl_path), mode=mode)
|
||||
ctypes.CDLL(str(gles_path), mode=mode)
|
||||
|
||||
|
||||
# Run early check at import time
|
||||
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
|
||||
_check_opengl_availability()
|
||||
|
||||
# OpenGL modules - initialized lazily when context is created
|
||||
gl = None
|
||||
glfw = None
|
||||
EGL = None
|
||||
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
|
||||
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
|
||||
_preload_angle()
|
||||
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||
|
||||
|
||||
def _import_opengl():
|
||||
"""Import OpenGL module. Called after context is created."""
|
||||
global gl
|
||||
if gl is None:
|
||||
logger.debug("_import_opengl: importing OpenGL.GL")
|
||||
import OpenGL.GL as _gl
|
||||
gl = _gl
|
||||
logger.debug("_import_opengl: import completed")
|
||||
return gl
|
||||
import OpenGL
|
||||
OpenGL.USE_ACCELERATE = False
|
||||
|
||||
|
||||
def _patch_find_library():
|
||||
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
|
||||
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
|
||||
'libGLESv2'. Patch find_library to return the full ANGLE paths so
|
||||
PyOpenGL loads the same libraries we pre-loaded."""
|
||||
if sys.platform == "linux":
|
||||
return
|
||||
import ctypes.util
|
||||
_orig = ctypes.util.find_library
|
||||
def _patched(name):
|
||||
if name == 'EGL':
|
||||
return comfy_angle.get_egl_path()
|
||||
if name == 'GLESv2':
|
||||
return comfy_angle.get_glesv2_path()
|
||||
return _orig(name)
|
||||
ctypes.util.find_library = _patched
|
||||
|
||||
|
||||
_patch_find_library()
|
||||
|
||||
from OpenGL import EGL
|
||||
from OpenGL import GLES3 as gl
|
||||
|
||||
class SizeModeInput(TypedDict):
|
||||
size_mode: str
|
||||
width: int
|
||||
@ -102,7 +85,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
# (-1,-1)---(3,-1)
|
||||
#
|
||||
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
||||
VERTEX_SHADER = """#version 330 core
|
||||
VERTEX_SHADER = """#version 300 es
|
||||
out vec2 v_texCoord;
|
||||
void main() {
|
||||
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
||||
@ -126,14 +109,99 @@ void main() {
|
||||
"""
|
||||
|
||||
|
||||
def _convert_es_to_desktop(source: str) -> str:
|
||||
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
|
||||
# Remove any existing #version directive
|
||||
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
|
||||
# Remove precision qualifiers (not needed in desktop GLSL)
|
||||
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
|
||||
# Prepend desktop GLSL version
|
||||
return "#version 330 core\n" + source
|
||||
|
||||
def _egl_attribs(*values):
|
||||
"""Build an EGL_NONE-terminated EGLint attribute array."""
|
||||
vals = list(values) + [EGL.EGL_NONE]
|
||||
return (ctypes.c_int32 * len(vals))(*vals)
|
||||
|
||||
|
||||
# EGL platform extension constants
|
||||
EGL_PLATFORM_ANGLE_ANGLE = 0x3202
|
||||
EGL_PLATFORM_ANGLE_TYPE_ANGLE = 0x3203
|
||||
EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE = 0x3450
|
||||
EGL_MESA_PLATFORM_SURFACELESS = 0x31DD
|
||||
|
||||
|
||||
_eglGetPlatformDisplayEXT = None
|
||||
|
||||
def _get_egl_platform_display_ext(platform, native_display, attribs):
|
||||
"""Call eglGetPlatformDisplayEXT via ctypes (extension, not in PyOpenGL)."""
|
||||
global _eglGetPlatformDisplayEXT
|
||||
if _eglGetPlatformDisplayEXT is None:
|
||||
from OpenGL import platform as _plat
|
||||
egl_lib = _plat.PLATFORM.EGL
|
||||
_get_proc = egl_lib.eglGetProcAddress
|
||||
_get_proc.restype = ctypes.c_void_p
|
||||
_get_proc.argtypes = [ctypes.c_char_p]
|
||||
ptr = _get_proc(b"eglGetPlatformDisplayEXT")
|
||||
if not ptr:
|
||||
return None
|
||||
func_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p, ctypes.c_void_p)
|
||||
_eglGetPlatformDisplayEXT = func_type(ptr)
|
||||
|
||||
raw = _eglGetPlatformDisplayEXT(platform, native_display, attribs)
|
||||
if not raw:
|
||||
return None
|
||||
return ctypes.cast(raw, EGL.EGLDisplay)
|
||||
|
||||
|
||||
def _get_egl_display():
|
||||
"""Get an EGL display, trying the default first then ANGLE's Vulkan
|
||||
platform for headless environments without a display server."""
|
||||
failures = []
|
||||
|
||||
# Try the default display first (works when X11/Wayland is available)
|
||||
display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
|
||||
if display:
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
try:
|
||||
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
return display, major.value, minor.value
|
||||
except Exception as e:
|
||||
failures.append(f"default: {e}")
|
||||
|
||||
logger.info("Default EGL display unavailable, trying headless fallbacks")
|
||||
|
||||
# Headless fallback strategies, tried in order:
|
||||
headless_strategies = [
|
||||
("surfaceless", EGL_MESA_PLATFORM_SURFACELESS, None, None),
|
||||
("ANGLE Vulkan", EGL_PLATFORM_ANGLE_ANGLE, None,
|
||||
_egl_attribs(EGL_PLATFORM_ANGLE_TYPE_ANGLE, EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE)),
|
||||
]
|
||||
|
||||
for name, platform, native_display, attribs in headless_strategies:
|
||||
display = _get_egl_platform_display_ext(platform, native_display, attribs)
|
||||
if not display:
|
||||
failures.append(f"{name}: eglGetPlatformDisplayEXT returned no display")
|
||||
continue
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
try:
|
||||
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
logger.info(f"Using EGL {name} platform (headless)")
|
||||
return display, major.value, minor.value
|
||||
failures.append(f"{name}: eglInitialize returned false")
|
||||
except Exception as e:
|
||||
failures.append(f"{name}: {e}")
|
||||
continue
|
||||
|
||||
details = "\n".join(f" - {f}" for f in failures)
|
||||
raise RuntimeError(
|
||||
"Failed to initialize EGL display.\n"
|
||||
"No display server and no headless EGL platform available.\n"
|
||||
f"Tried:\n{details}\n"
|
||||
"Ensure GPU drivers are installed or set DISPLAY for a virtual framebuffer."
|
||||
)
|
||||
|
||||
|
||||
def _gl_str(name):
|
||||
"""Get an OpenGL string parameter."""
|
||||
v = gl.glGetString(name)
|
||||
if not v:
|
||||
return "Unknown"
|
||||
if isinstance(v, bytes):
|
||||
return v.decode(errors="replace")
|
||||
return ctypes.string_at(v).decode(errors="replace")
|
||||
|
||||
|
||||
def _detect_output_count(source: str) -> int:
|
||||
@ -159,163 +227,8 @@ def _detect_pass_count(source: str) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _init_glfw():
|
||||
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_glfw: starting")
|
||||
# On macOS, glfw.init() must be called from main thread or it hangs forever
|
||||
if sys.platform == "darwin":
|
||||
logger.debug("_init_glfw: skipping on macOS")
|
||||
raise RuntimeError("GLFW backend not supported on macOS")
|
||||
|
||||
logger.debug("_init_glfw: importing glfw module")
|
||||
import glfw as _glfw
|
||||
|
||||
logger.debug("_init_glfw: calling glfw.init()")
|
||||
if not _glfw.init():
|
||||
raise RuntimeError("glfw.init() failed")
|
||||
|
||||
try:
|
||||
logger.debug("_init_glfw: setting window hints")
|
||||
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
|
||||
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
|
||||
|
||||
logger.debug("_init_glfw: calling create_window()")
|
||||
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
|
||||
if not window:
|
||||
raise RuntimeError("glfw.create_window() failed")
|
||||
|
||||
logger.debug("_init_glfw: calling make_context_current()")
|
||||
_glfw.make_context_current(window)
|
||||
logger.debug("_init_glfw: completed successfully")
|
||||
return window, _glfw
|
||||
except Exception:
|
||||
logger.debug("_init_glfw: failed, terminating glfw")
|
||||
_glfw.terminate()
|
||||
raise
|
||||
|
||||
|
||||
def _init_egl():
|
||||
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_egl: starting")
|
||||
from OpenGL import EGL as _EGL
|
||||
from OpenGL.EGL import (
|
||||
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
|
||||
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
|
||||
eglTerminate, eglDestroyContext, eglDestroySurface,
|
||||
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
|
||||
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
|
||||
)
|
||||
logger.debug("_init_egl: imports completed")
|
||||
|
||||
display = None
|
||||
context = None
|
||||
surface = None
|
||||
|
||||
try:
|
||||
logger.debug("_init_egl: calling eglGetDisplay()")
|
||||
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
|
||||
if display == _EGL.EGL_NO_DISPLAY:
|
||||
raise RuntimeError("eglGetDisplay() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglInitialize()")
|
||||
major, minor = _EGL.EGLint(), _EGL.EGLint()
|
||||
if not eglInitialize(display, major, minor):
|
||||
display = None # Not initialized, don't terminate
|
||||
raise RuntimeError("eglInitialize() failed")
|
||||
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
|
||||
|
||||
config_attribs = [
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
|
||||
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
|
||||
EGL_DEPTH_SIZE, 0, EGL_NONE
|
||||
]
|
||||
configs = (_EGL.EGLConfig * 1)()
|
||||
num_configs = _EGL.EGLint()
|
||||
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
config = configs[0]
|
||||
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
|
||||
|
||||
if not eglBindAPI(EGL_OPENGL_API):
|
||||
raise RuntimeError("eglBindAPI() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreateContext()")
|
||||
context_attribs = [
|
||||
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
|
||||
EGL_NONE
|
||||
]
|
||||
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
|
||||
if context == EGL_NO_CONTEXT:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
|
||||
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
|
||||
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
|
||||
if surface == _EGL.EGL_NO_SURFACE:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglMakeCurrent()")
|
||||
if not eglMakeCurrent(display, surface, surface, context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_egl: completed successfully")
|
||||
return display, context, surface, _EGL
|
||||
|
||||
except Exception:
|
||||
logger.debug("_init_egl: failed, cleaning up")
|
||||
# Clean up any resources on failure
|
||||
if surface is not None:
|
||||
eglDestroySurface(display, surface)
|
||||
if context is not None:
|
||||
eglDestroyContext(display, context)
|
||||
if display is not None:
|
||||
eglTerminate(display)
|
||||
raise
|
||||
|
||||
|
||||
def _init_osmesa():
|
||||
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
|
||||
import ctypes
|
||||
|
||||
logger.debug("_init_osmesa: starting")
|
||||
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
||||
|
||||
logger.debug("_init_osmesa: importing OpenGL.osmesa")
|
||||
from OpenGL import GL as _gl
|
||||
from OpenGL.osmesa import (
|
||||
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
|
||||
OSMESA_RGBA,
|
||||
)
|
||||
logger.debug("_init_osmesa: imports completed")
|
||||
|
||||
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
|
||||
if not ctx:
|
||||
raise RuntimeError("OSMesaCreateContextExt() failed")
|
||||
|
||||
width, height = 64, 64
|
||||
buffer = (ctypes.c_ubyte * (width * height * 4))()
|
||||
|
||||
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
|
||||
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
|
||||
OSMesaDestroyContext(ctx)
|
||||
raise RuntimeError("OSMesaMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_osmesa: completed successfully")
|
||||
return ctx, buffer
|
||||
|
||||
|
||||
class GLContext:
|
||||
"""Manages OpenGL context and resources for shader execution.
|
||||
|
||||
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
|
||||
"""
|
||||
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
@ -327,131 +240,105 @@ class GLContext:
|
||||
|
||||
def __init__(self):
|
||||
if GLContext._initialized:
|
||||
logger.debug("GLContext.__init__: already initialized, skipping")
|
||||
return
|
||||
|
||||
logger.debug("GLContext.__init__: starting initialization")
|
||||
|
||||
global glfw, EGL
|
||||
|
||||
import time
|
||||
start = time.perf_counter()
|
||||
|
||||
self._backend = None
|
||||
self._window = None
|
||||
self._egl_display = None
|
||||
self._egl_context = None
|
||||
self._egl_surface = None
|
||||
self._osmesa_ctx = None
|
||||
self._osmesa_buffer = None
|
||||
self._display = None
|
||||
self._surface = None
|
||||
self._context = None
|
||||
self._vao = None
|
||||
|
||||
# Try backends in order: GLFW → EGL → OSMesa
|
||||
errors = []
|
||||
|
||||
logger.debug("GLContext.__init__: trying GLFW backend")
|
||||
try:
|
||||
self._window, glfw = _init_glfw()
|
||||
self._backend = "glfw"
|
||||
logger.debug("GLContext.__init__: GLFW backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
|
||||
errors.append(("GLFW", e))
|
||||
self._display, self._egl_major, self._egl_minor = _get_egl_display()
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying EGL backend")
|
||||
try:
|
||||
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
|
||||
self._backend = "egl"
|
||||
logger.debug("GLContext.__init__: EGL backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
|
||||
errors.append(("EGL", e))
|
||||
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
|
||||
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying OSMesa backend")
|
||||
try:
|
||||
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
|
||||
self._backend = "osmesa"
|
||||
logger.debug("GLContext.__init__: OSMesa backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
|
||||
errors.append(("OSMesa", e))
|
||||
config = EGL.EGLConfig()
|
||||
n_configs = ctypes.c_int32(0)
|
||||
if not EGL.eglChooseConfig(
|
||||
self._display,
|
||||
_egl_attribs(
|
||||
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
|
||||
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
|
||||
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
|
||||
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
|
||||
),
|
||||
ctypes.byref(config), 1, ctypes.byref(n_configs),
|
||||
) or n_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
|
||||
if self._backend is None:
|
||||
if sys.platform == "win32":
|
||||
platform_help = (
|
||||
"Windows: Ensure GPU drivers are installed and display is available.\n"
|
||||
" CPU-only/headless mode is not supported on Windows."
|
||||
)
|
||||
elif sys.platform == "darwin":
|
||||
platform_help = (
|
||||
"macOS: GLFW is not supported.\n"
|
||||
" Install OSMesa via Homebrew: brew install mesa\n"
|
||||
" Then: pip install PyOpenGL PyOpenGL-accelerate"
|
||||
)
|
||||
else:
|
||||
platform_help = (
|
||||
"Linux: Install one of these backends:\n"
|
||||
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
|
||||
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
|
||||
" Headless (CPU): sudo apt install libosmesa6"
|
||||
)
|
||||
|
||||
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
|
||||
raise RuntimeError(
|
||||
f"Failed to create OpenGL context.\n\n"
|
||||
f"Backend errors:\n{error_details}\n\n"
|
||||
f"{platform_help}"
|
||||
self._surface = EGL.eglCreatePbufferSurface(
|
||||
self._display, config,
|
||||
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
|
||||
)
|
||||
if not self._surface:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
# Now import OpenGL.GL (after context is current)
|
||||
logger.debug("GLContext.__init__: importing OpenGL.GL")
|
||||
_import_opengl()
|
||||
self._context = EGL.eglCreateContext(
|
||||
self._display, config, EGL.EGL_NO_CONTEXT,
|
||||
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
|
||||
)
|
||||
if not self._context:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
# Create VAO (required for core profile, but OSMesa may use compat profile)
|
||||
logger.debug("GLContext.__init__: creating VAO")
|
||||
try:
|
||||
vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(vao)
|
||||
self._vao = vao # Only store after successful bind
|
||||
logger.debug("GLContext.__init__: VAO created successfully")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
|
||||
# OSMesa with older Mesa may not support VAOs
|
||||
# Clean up if we created but couldn't bind
|
||||
if vao:
|
||||
try:
|
||||
gl.glDeleteVertexArrays(1, [vao])
|
||||
except Exception:
|
||||
pass
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
self._vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
except Exception:
|
||||
self._cleanup()
|
||||
raise
|
||||
|
||||
elapsed = (time.perf_counter() - start) * 1000
|
||||
|
||||
# Log device info
|
||||
renderer = gl.glGetString(gl.GL_RENDERER)
|
||||
vendor = gl.glGetString(gl.GL_VENDOR)
|
||||
version = gl.glGetString(gl.GL_VERSION)
|
||||
renderer = renderer.decode() if renderer else "Unknown"
|
||||
vendor = vendor.decode() if vendor else "Unknown"
|
||||
version = version.decode() if version else "Unknown"
|
||||
renderer = _gl_str(gl.GL_RENDERER)
|
||||
vendor = _gl_str(gl.GL_VENDOR)
|
||||
version = _gl_str(gl.GL_VERSION)
|
||||
|
||||
GLContext._initialized = True
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - EGL {self._egl_major}.{self._egl_minor}, {renderer} ({vendor}), GL {version}")
|
||||
|
||||
def make_current(self):
|
||||
if self._backend == "glfw":
|
||||
glfw.make_context_current(self._window)
|
||||
elif self._backend == "egl":
|
||||
from OpenGL.EGL import eglMakeCurrent
|
||||
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
|
||||
elif self._backend == "osmesa":
|
||||
from OpenGL.osmesa import OSMesaMakeCurrent
|
||||
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
|
||||
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
err = EGL.eglGetError()
|
||||
raise RuntimeError(f"eglMakeCurrent() failed (EGL error: 0x{err:04X})")
|
||||
if self._vao is not None:
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
def _cleanup(self):
|
||||
if not self._display:
|
||||
return
|
||||
try:
|
||||
if self._vao is not None:
|
||||
gl.glDeleteVertexArrays(1, [self._vao])
|
||||
self._vao = None
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._context:
|
||||
EGL.eglDestroyContext(self._display, self._context)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._surface:
|
||||
EGL.eglDestroySurface(self._display, self._surface)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglTerminate(self._display)
|
||||
except Exception:
|
||||
pass
|
||||
self._display = None
|
||||
|
||||
|
||||
def _compile_shader(source: str, shader_type: int) -> int:
|
||||
"""Compile a shader and return its ID."""
|
||||
@ -459,8 +346,10 @@ def _compile_shader(source: str, shader_type: int) -> int:
|
||||
gl.glShaderSource(shader, source)
|
||||
gl.glCompileShader(shader)
|
||||
|
||||
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetShaderInfoLog(shader).decode()
|
||||
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
|
||||
error = gl.glGetShaderInfoLog(shader)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
gl.glDeleteShader(shader)
|
||||
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
||||
|
||||
@ -484,8 +373,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
|
||||
gl.glDeleteShader(vertex_shader)
|
||||
gl.glDeleteShader(fragment_shader)
|
||||
|
||||
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetProgramInfoLog(program).decode()
|
||||
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
|
||||
error = gl.glGetProgramInfoLog(program)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
gl.glDeleteProgram(program)
|
||||
raise RuntimeError(f"Program linking failed:\n{error}")
|
||||
|
||||
@ -530,9 +421,6 @@ def _render_shader_batch(
|
||||
ctx = GLContext()
|
||||
ctx.make_current()
|
||||
|
||||
# Convert from GLSL ES to desktop GLSL 330
|
||||
fragment_source = _convert_es_to_desktop(fragment_code)
|
||||
|
||||
# Detect how many outputs the shader actually uses
|
||||
num_outputs = _detect_output_count(fragment_code)
|
||||
|
||||
@ -558,9 +446,9 @@ def _render_shader_batch(
|
||||
try:
|
||||
# Compile shaders (once for all batches)
|
||||
try:
|
||||
program = _create_program(VERTEX_SHADER, fragment_source)
|
||||
program = _create_program(VERTEX_SHADER, fragment_code)
|
||||
except RuntimeError:
|
||||
logger.error(f"Fragment shader:\n{fragment_source}")
|
||||
logger.error(f"Fragment shader:\n{fragment_code}")
|
||||
raise
|
||||
|
||||
gl.glUseProgram(program)
|
||||
@ -723,13 +611,13 @@ def _render_shader_batch(
|
||||
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
||||
|
||||
# Read back outputs for this batch
|
||||
# (glGetTexImage is synchronous, implicitly waits for rendering)
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
||||
batch_outputs = []
|
||||
for tex in output_textures:
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
|
||||
batch_outputs.append(img[::-1, :, :].copy())
|
||||
for i in range(num_outputs):
|
||||
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
|
||||
buf = np.empty((height, width, 4), dtype=np.float32)
|
||||
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
|
||||
batch_outputs.append(buf[::-1, :, :].copy())
|
||||
|
||||
# Pad with black images for unused outputs
|
||||
black_img = np.zeros((height, width, 4), dtype=np.float32)
|
||||
@ -750,18 +638,18 @@ def _render_shader_batch(
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
||||
gl.glUseProgram(0)
|
||||
|
||||
for tex in input_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in curve_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in output_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in ping_pong_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
if input_textures:
|
||||
gl.glDeleteTextures(len(input_textures), input_textures)
|
||||
if curve_textures:
|
||||
gl.glDeleteTextures(len(curve_textures), curve_textures)
|
||||
if output_textures:
|
||||
gl.glDeleteTextures(len(output_textures), output_textures)
|
||||
if ping_pong_textures:
|
||||
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
|
||||
if fbo is not None:
|
||||
gl.glDeleteFramebuffers(1, [fbo])
|
||||
for pp_fbo in ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(1, [pp_fbo])
|
||||
if ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
|
||||
if program is not None:
|
||||
gl.glDeleteProgram(program)
|
||||
|
||||
|
||||
77
comfy_extras/nodes_json_prompt.py
Normal file
77
comfy_extras/nodes_json_prompt.py
Normal file
@ -0,0 +1,77 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_extras.color_util import normalize_palette
|
||||
|
||||
|
||||
class BuildJsonPromptIdeogram(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
color_palette = io.Colors.Input(
|
||||
"color_palette",
|
||||
socketless=False,
|
||||
tooltip="Hex color codes that steer the image's dominant colors. Up to 16 entries.",
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="BuildJsonPromptIdeogram",
|
||||
display_name="Build JSON Prompt (Ideogram)",
|
||||
category="text",
|
||||
description="Build a JSON prompt for the Ideogram 4 model.",
|
||||
inputs=[
|
||||
io.Array.Input("element", tooltip="Prompt elements from the node Create Bounding Boxes."),
|
||||
io.String.Input("high_level_description", multiline=True, default="",
|
||||
tooltip="Optional description of the image in one or two sentences. Strongly recommended."),
|
||||
io.String.Input("background", multiline=True, default="",
|
||||
tooltip="Mandatory description of the image background or environment."),
|
||||
io.DynamicCombo.Input("style", options=[
|
||||
io.DynamicCombo.Option("none", []),
|
||||
io.DynamicCombo.Option("photo", [io.String.Input("photo", default="", tooltip="Camera or lens details for photographic outputs (e.g. 35mm, f/1.4, bokeh).")]),
|
||||
io.DynamicCombo.Option("art_style", [io.String.Input("art_style", default="", tooltip="Art style description (e.g. flat vector illustration, bold outlines).")]),
|
||||
]),
|
||||
io.String.Input("aesthetics", default="", tooltip="Mandatory aesthetic keywords (e.g. moody, cinematic, desaturated)."),
|
||||
io.String.Input("lighting", default="", tooltip="Mandatory lighting description (e.g. golden hour, rim light, dramatic shadows)."),
|
||||
io.String.Input("medium", default="", tooltip="Mandatory medium type (e.g. photograph, illustration, 3d_render, painting, graphic_design). When style = photo, set to photograph."),
|
||||
color_palette,
|
||||
],
|
||||
outputs=[io.Dict.Output(display_name="prompt")],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, element, style, high_level_description="", background="",
|
||||
aesthetics="", lighting="", medium="", color_palette=None) -> io.NodeOutput:
|
||||
elements = element if isinstance(element, list) else []
|
||||
kind = style.get("style", "none") if isinstance(style, dict) else "none"
|
||||
photo = style.get("photo", "") if isinstance(style, dict) else ""
|
||||
art_style = style.get("art_style", "") if isinstance(style, dict) else ""
|
||||
palette = normalize_palette(color_palette or [])
|
||||
|
||||
caption: dict = {}
|
||||
if high_level_description.strip():
|
||||
caption["high_level_description"] = high_level_description
|
||||
if kind != "none":
|
||||
style_desc: dict = {"aesthetics": aesthetics, "lighting": lighting}
|
||||
if kind == "photo":
|
||||
style_desc["photo"] = photo
|
||||
style_desc["medium"] = medium
|
||||
else:
|
||||
style_desc["medium"] = medium
|
||||
style_desc["art_style"] = art_style
|
||||
if palette:
|
||||
style_desc["color_palette"] = palette
|
||||
caption["style_description"] = style_desc
|
||||
caption["compositional_deconstruction"] = {
|
||||
"background": background,
|
||||
"elements": elements,
|
||||
}
|
||||
return io.NodeOutput(caption)
|
||||
|
||||
|
||||
class JsonPromptExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [BuildJsonPromptIdeogram]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> JsonPromptExtension:
|
||||
return JsonPromptExtension()
|
||||
@ -337,6 +337,36 @@ class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeKrea2(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "model/merging/model specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["first."] = argument
|
||||
arg_dict["tmlp."] = argument
|
||||
arg_dict["txtmlp."] = argument
|
||||
arg_dict["tproj."] = argument
|
||||
|
||||
for i in range(2):
|
||||
arg_dict["txtfusion.layerwise_blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["txtfusion.projector."] = argument
|
||||
|
||||
for i in range(2):
|
||||
arg_dict["txtfusion.refiner_blocks.{}.".format(i)] = argument
|
||||
|
||||
for i in range(28):
|
||||
arg_dict["blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["last."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
@ -353,4 +383,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
|
||||
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
|
||||
"ModelMergeQwenImage": ModelMergeQwenImage,
|
||||
"ModelMergeKrea2": ModelMergeKrea2,
|
||||
}
|
||||
|
||||
33
comfy_extras/nodes_seed.py
Normal file
33
comfy_extras/nodes_seed.py
Normal file
@ -0,0 +1,33 @@
|
||||
import sys
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class SeedNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SeedNode",
|
||||
display_name="Seed",
|
||||
search_aliases=["seed", "random"],
|
||||
category="utilities",
|
||||
inputs=[
|
||||
io.Int.Input("seed", min=0, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed),
|
||||
],
|
||||
outputs=[io.Int.Output(display_name="seed")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, seed: int) -> io.NodeOutput:
|
||||
return io.NodeOutput(seed)
|
||||
|
||||
|
||||
class SeedExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [SeedNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> SeedExtension:
|
||||
return SeedExtension()
|
||||
@ -440,6 +440,57 @@ class JsonExtractString(io.ComfyNode):
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return io.NodeOutput("")
|
||||
|
||||
|
||||
def _dump_json(value, indent):
|
||||
return json.dumps(value, ensure_ascii=False, indent=indent or None)
|
||||
|
||||
|
||||
class ConvertDictionaryToString(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ConvertDictionaryToString",
|
||||
display_name="Convert Dictionary to String",
|
||||
category="text",
|
||||
search_aliases=["json", "dict to json", "stringify", "serialize", "dict to string"],
|
||||
inputs=[
|
||||
io.Dict.Input("dictionary"),
|
||||
io.Int.Input("indent", default=2, min=0, max=8,
|
||||
tooltip="Spaces per indent level. 0 produces compact single-line string."),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, dictionary, indent=2):
|
||||
return io.NodeOutput(_dump_json(dictionary, indent))
|
||||
|
||||
|
||||
class ConvertArrayToString(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ConvertArrayToString",
|
||||
display_name="Convert Array to String",
|
||||
category="text",
|
||||
search_aliases=["json", "list to json", "stringify", "serialize", "list to string", "array to json"],
|
||||
inputs=[
|
||||
io.Array.Input("array"),
|
||||
io.Int.Input("indent", default=2, min=0, max=8,
|
||||
tooltip="Spaces per indent level. 0 produces compact single-line string."),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, array, indent=2):
|
||||
return io.NodeOutput(_dump_json(array, indent))
|
||||
|
||||
|
||||
class StringExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@ -457,6 +508,8 @@ class StringExtension(ComfyExtension):
|
||||
RegexExtract,
|
||||
RegexReplace,
|
||||
JsonExtractString,
|
||||
ConvertDictionaryToString,
|
||||
ConvertArrayToString,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> StringExtension:
|
||||
|
||||
3
nodes.py
3
nodes.py
@ -2374,6 +2374,8 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_images.py",
|
||||
"nodes_video_model.py",
|
||||
"nodes_ideogram4.py",
|
||||
"nodes_bounding_boxes.py",
|
||||
"nodes_json_prompt.py",
|
||||
"nodes_train.py",
|
||||
"nodes_dataset.py",
|
||||
"nodes_sag.py",
|
||||
@ -2473,6 +2475,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_gaussian_splat.py",
|
||||
"nodes_triposplat.py",
|
||||
"nodes_depth_anything_3.py",
|
||||
"nodes_seed.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
18
openapi.yaml
18
openapi.yaml
@ -1692,6 +1692,12 @@ paths:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Unsupported media type
|
||||
"422":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Validation error (e.g., disallowed model_type tag)
|
||||
"500":
|
||||
content:
|
||||
application/json:
|
||||
@ -2137,6 +2143,12 @@ paths:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Source asset with given hash not found
|
||||
"422":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Validation error (e.g., disallowed model_type tag)
|
||||
"500":
|
||||
content:
|
||||
application/json:
|
||||
@ -2357,6 +2369,10 @@ paths:
|
||||
description: |
|
||||
Returns a list of model folders available in the system.
|
||||
This is an experimental endpoint that replaces the legacy /models endpoint.
|
||||
Each folder's name is the identifier to pass to /api/experiment/models/{folder}.
|
||||
Once the model_type migration is active the names are model_type folder_names
|
||||
(e.g. `ultralytics_bbox`); a folder with no folder_name mapping is returned by
|
||||
its directory path.
|
||||
operationId: getModelFolders
|
||||
responses:
|
||||
"200":
|
||||
@ -2988,7 +3004,7 @@ paths:
|
||||
format: uuid
|
||||
type: string
|
||||
- description: |
|
||||
When present, each output item in the response receives a `short_url` field containing an owner-gated durable link for that asset. Omit this parameter (the default) to receive a response identical to the no-param baseline. The value selects the link's lifetime: use `ephemeral_tool_chain` for short-lived machine-to-machine handoffs (~15 minutes); use `default` for durable human-revisitable links (30 days). Links are minted only for the authenticated request owner and are not resolvable by other users.
|
||||
When present, each output item in the response receives a `short_url` field containing a short link for that asset. Omit this parameter (the default) to receive a response identical to the no-param baseline. The value selects the link's lifetime and auth model: use `ephemeral_tool_chain` for short-lived (≤5 minute) machine-to-machine handoffs — these are public bearer links where the link ID itself is the credential, so anyone holding the link can resolve it (intended for pasting into an agent/MCP tool chain); use `default` for durable (30 day) human-revisitable links, which are owner-gated and resolvable only by the authenticated owner. Links are always minted under the authenticated request owner's identity; the auth model is selected by the server and is never settable by the caller.
|
||||
in: query
|
||||
name: short_link
|
||||
schema:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.45.19
|
||||
comfyui-workflow-templates==0.10.3
|
||||
comfyui-workflow-templates==0.10.7
|
||||
comfyui-embedded-docs==0.5.5
|
||||
torch
|
||||
torchsde
|
||||
@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=16.0.0
|
||||
comfy-kitchen==0.2.10
|
||||
comfy-kitchen==0.2.14
|
||||
comfy-aimdo==0.4.10
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
@ -33,5 +33,5 @@ kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
PyOpenGL
|
||||
glfw
|
||||
PyOpenGL>=3.1.8
|
||||
comfy-angle
|
||||
|
||||
@ -228,6 +228,62 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
with self.assertRaises(KeyError):
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
def test_int8_convrot_metadata_loads_into_params(self):
|
||||
"""ConvRot metadata must reach TensorWiseINT8Layout params."""
|
||||
torch.manual_seed(123)
|
||||
layer_quant_config = {
|
||||
"layer": {
|
||||
"format": "int8_tensorwise",
|
||||
"convrot": True,
|
||||
"convrot_groupsize": 256,
|
||||
}
|
||||
}
|
||||
weight = torch.randn(16, 256, dtype=torch.bfloat16)
|
||||
bias = torch.randn(16, dtype=torch.bfloat16)
|
||||
q_weight = QuantizedTensor.from_float(
|
||||
weight,
|
||||
"TensorWiseINT8Layout",
|
||||
per_channel=True,
|
||||
convrot=True,
|
||||
convrot_groupsize=256,
|
||||
)
|
||||
state_dict = {
|
||||
"layer.weight": q_weight._qdata,
|
||||
"layer.bias": bias,
|
||||
"layer.weight_scale": q_weight._params.scale,
|
||||
}
|
||||
|
||||
state_dict, _ = comfy.utils.convert_old_quants(
|
||||
state_dict,
|
||||
metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})},
|
||||
)
|
||||
model = torch.nn.Module()
|
||||
model.layer = ops.mixed_precision_ops({}).Linear(256, 16, device="cpu", dtype=torch.bfloat16)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
self.assertIsInstance(model.layer.weight, QuantizedTensor)
|
||||
self.assertEqual(model.layer.weight._layout_cls, "TensorWiseINT8Layout")
|
||||
self.assertTrue(model.layer.weight._params.convrot)
|
||||
self.assertEqual(model.layer.weight._params.convrot_groupsize, 256)
|
||||
|
||||
input_tensor = torch.randn(4, 256, dtype=torch.bfloat16)
|
||||
loaded_out = model.layer(input_tensor)
|
||||
ref_out = torch.nn.functional.linear(input_tensor, q_weight, bias)
|
||||
self.assertTrue(torch.equal(loaded_out, ref_out))
|
||||
|
||||
fp16_input = input_tensor.to(torch.float16)
|
||||
loaded_fp16_out = model.layer(fp16_input)
|
||||
ref_fp16_out = torch.nn.functional.linear(
|
||||
fp16_input,
|
||||
q_weight.to(dtype=torch.float16),
|
||||
bias.to(dtype=torch.float16),
|
||||
)
|
||||
self.assertTrue(torch.equal(loaded_fp16_out, ref_fp16_out))
|
||||
|
||||
saved = model.state_dict()
|
||||
saved_conf = json.loads(saved["layer.comfy_quant"].numpy().tobytes())
|
||||
self.assertTrue(saved_conf["convrot"])
|
||||
self.assertEqual(saved_conf["convrot_groupsize"], 256)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user