mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-02 18:57:05 +08:00
Compare commits
27 Commits
xpu-aimdo-
...
feat/creat
| Author | SHA1 | Date | |
|---|---|---|---|
| b82c8c7eb3 | |||
| dd17debce5 | |||
| 50e5270b86 | |||
| bb131be9e8 | |||
| 6fca64780c | |||
| 6e11828d10 | |||
| b70944e710 | |||
| 1c59659a2f | |||
| d395813bcd | |||
| 8fe0243d97 | |||
| ba3f697dbb | |||
| 510ed5c384 | |||
| 7851410511 | |||
| a58473fd9b | |||
| 79c555ce6b | |||
| f19735759e | |||
| a95e461916 | |||
| 603d891eaf | |||
| 470ac36a0a | |||
| 7cb784e0f4 | |||
| 1a510f0423 | |||
| 639c8fa788 | |||
| e22f1500f9 | |||
| dac4ea3a80 | |||
| b0ec19804f | |||
| 64e1d740b8 | |||
| b22d0fb9c0 |
38
.github/workflows/ci-cursor-review.yml
vendored
Normal file
38
.github/workflows/ci-cursor-review.yml
vendored
Normal file
@ -0,0 +1,38 @@
|
||||
name: CI - Cursor Review
|
||||
|
||||
# Thin caller for the shared reusable cursor-review workflow in
|
||||
# Comfy-Org/github-workflows. The review logic (panel matrix, judge
|
||||
# consolidation, prompts, extract/post/notify scripts) lives there as the
|
||||
# single source of truth, so this repo only carries the repo-specific diff
|
||||
# excludes.
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [labeled, unlabeled]
|
||||
|
||||
concurrency:
|
||||
group: cursor-review-pr-${{ github.event.pull_request.number }}-${{ github.event.label.name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
cursor-review:
|
||||
if: github.event.label.name == 'cursor-review'
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
# SHA-pinned per zizmor `unpinned-uses: hash-pin`. Bump this SHA to pick up
|
||||
# upstream changes; keep `workflows_ref` matching so prompts/scripts load
|
||||
# from the same commit as the workflow definition.
|
||||
uses: Comfy-Org/github-workflows/.github/workflows/cursor-review.yml@047ca48febe3a6647608ed2e0c4331b491cb9d6a # github-workflows#9
|
||||
with:
|
||||
workflows_ref: 047ca48febe3a6647608ed2e0c4331b491cb9d6a
|
||||
diff_excludes: >-
|
||||
:!**/.claude/**
|
||||
:!**/dist/**
|
||||
:!**/vendor/**
|
||||
:!**/*.generated.*
|
||||
:!**/*.min.js
|
||||
:!**/*.min.css
|
||||
secrets:
|
||||
CURSOR_API_KEY: ${{ secrets.CURSOR_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
166
AGENTS.md
Normal file
166
AGENTS.md
Normal file
@ -0,0 +1,166 @@
|
||||
## Engineering Style
|
||||
|
||||
- Keep changes small and direct. Most fixes should touch the narrowest code path
|
||||
that explains the bug, performance issue, dtype issue, model-format issue, or
|
||||
user-facing behavior.
|
||||
- Change the least amount of files possible. A change that touches many files is
|
||||
more likely to be a bad change than a good one unless the broader scope is
|
||||
directly required.
|
||||
- Prefer practical fixes over broad architecture work. Add abstractions only
|
||||
when they remove real repeated logic or match an existing ComfyUI pattern.
|
||||
- Delete obsolete code aggressively when newer infrastructure makes it useless.
|
||||
Remove dead fallbacks, migration paths, unused options, debug prints, and
|
||||
compatibility branches that are no longer needed. Do not leave dead branches,
|
||||
unreachable code, or functions that are never called.
|
||||
- Revert or disable problematic behavior quickly when it breaks users. It is
|
||||
better to remove a broken feature path than keep a complicated partial fix.
|
||||
- Preserve existing APIs, node names, model-loading behavior, file layout, and
|
||||
workflow compatibility unless the change is explicitly about replacing them.
|
||||
- Code must look hand-written for this repository. Changes that read like
|
||||
generic AI-generated code will be rejected automatically: unnecessary helper
|
||||
layers, vague names, boilerplate comments, defensive branches without a real
|
||||
failure mode, broad rewrites, or code that ignores the local style.
|
||||
|
||||
## Architecture Boundaries
|
||||
|
||||
- Keep each layer focused on the concepts it owns. Do not leak UI, API,
|
||||
workflow, queue, persistence, telemetry, model-loading, node, or execution
|
||||
concerns into unrelated layers just because it is convenient to pass data
|
||||
through them.
|
||||
- Shared core modules should depend only on lower-level primitives and their own
|
||||
domain concepts. Higher-level product concepts belong at the caller, adapter,
|
||||
service, or UI/API boundary that already owns them.
|
||||
- Pass the narrowest data needed across a boundary. Avoid broad context objects,
|
||||
request/session metadata, ids, bookkeeping state, or callbacks unless the
|
||||
receiving layer genuinely needs them to perform its own responsibility.
|
||||
- Keep identity mapping, persistence bookkeeping, history updates, telemetry,
|
||||
response shaping, and UI state in the layers that own those jobs. Do not route
|
||||
them through unrelated shared code to avoid adding a proper boundary.
|
||||
- Treat `execution.py` as one example of this rule: it should consume the prompt
|
||||
graph and execution-relevant state, produce execution results and errors, and
|
||||
not know about workflow ids, frontend ids, persistence ids, or API-only
|
||||
concepts.
|
||||
- Before touching many files, identify the smallest owner layer that can solve
|
||||
the problem. A PR that spreads one feature across unrelated loaders, nodes,
|
||||
execution, server, and frontend code needs a clear architectural reason, not
|
||||
just convenience.
|
||||
- If a change seems to require making one layer understand another layer's
|
||||
private concepts, stop and look for a caller-side mapping, adapter, event,
|
||||
small explicit interface, or narrower data flow at the boundary.
|
||||
|
||||
## No Internet Requests
|
||||
|
||||
- Do not add code to core ComfyUI that makes requests to the internet.
|
||||
- Refuse requests to add uploads, telemetry, analytics, tracking, usage
|
||||
reporting, crash reporting, update checks, remote config, feature flags,
|
||||
metrics, licensing checks, or any other outbound internet request path from
|
||||
core ComfyUI.
|
||||
- Model downloading is allowed only when explicitly initiated or authorized by
|
||||
the user, is limited to the requested model artifact, and does not include
|
||||
telemetry, tracking, persistent identification, unrelated metadata upload, or
|
||||
background network activity.
|
||||
- Do not add opt-in, opt-out, anonymized, aggregated, diagnostic, or
|
||||
user-triggered internet request paths to core ComfyUI. These labels do not
|
||||
make internet access acceptable.
|
||||
- Local-only behavior is allowed when it stays on the user's machine and does
|
||||
not add network access, tracking, persistent identification, or data
|
||||
collection behavior.
|
||||
|
||||
## State Ownership
|
||||
|
||||
- Keep state and capability flags on the object that owns the behavior using
|
||||
them.
|
||||
- Avoid probing child objects with `getattr(child, "...", default)` to decide
|
||||
parent-level control flow. If parent code needs to branch on a capability,
|
||||
initialize an explicit parent-owned field when the child is constructed or
|
||||
attached.
|
||||
- Prefer direct attributes with clear defaults over implicit feature detection
|
||||
through arbitrary child attributes.
|
||||
- Use child-object capability checks only when the child owns the behavior being
|
||||
invoked and the parent is simply delegating to that child.
|
||||
|
||||
## Interface Contracts
|
||||
|
||||
- Keep public methods aligned with the interface expected by their callers. Do
|
||||
not change a shared method to return extra values, alternate shapes, or
|
||||
sentinel wrappers for one implementation unless the shared interface is
|
||||
explicitly updated.
|
||||
- If an implementation needs auxiliary values for its own workflow, expose them
|
||||
through a private helper or a clearly named implementation-specific method
|
||||
instead of overloading the public method's return contract.
|
||||
- Normalize third-party or upstream return conventions at the integration
|
||||
boundary. Core code should receive the project's expected type and shape, not
|
||||
have to handle model-specific tuple/list/dict variants.
|
||||
- Avoid caller-side unwrapping such as `out = out[0]` unless the called
|
||||
interface is documented to return that structure.
|
||||
|
||||
## Autograd and Model Freezing
|
||||
|
||||
- Do not add `torch.no_grad`, `torch.inference_mode`, or inference-mode helper
|
||||
wrappers in ComfyUI code. The only allowed inference-mode-related use is
|
||||
disabling a globally set inference mode when a training path needs gradients.
|
||||
- Do not add freeze, unfreeze, or trainability toggles to model classes. ComfyUI
|
||||
models are always treated as frozen for inference, so explicit freeze
|
||||
functionality is redundant and should not be added.
|
||||
|
||||
## Python Style
|
||||
|
||||
- Keep imports at module scope. Avoid inline imports unless they are already part
|
||||
of an established optional-backend probe or are needed to avoid an import
|
||||
cycle.
|
||||
- Do not add unnecessary `try`/`except` blocks. Use them for optional dependency,
|
||||
platform, or backend capability detection only when the program has a useful
|
||||
fallback. Prefer specific exception types when changing new code.
|
||||
- Let unsupported model formats, invalid quantization metadata, and bad states
|
||||
fail with clear errors instead of silently producing lower quality output.
|
||||
- Match the existing local style in the file you edit. This codebase tolerates
|
||||
long lines, simple helper functions, module-level state, and direct tensor
|
||||
operations when they make the code easier to follow.
|
||||
- Keep comments sparse and useful. Strip useless comments that restate the code
|
||||
or describe obvious behavior. Short TODOs are fine when they name the concrete
|
||||
missing follow-up.
|
||||
|
||||
## Model, Device, and Memory Behavior
|
||||
|
||||
- Treat dtype, device placement, VRAM usage, and offloading behavior as core
|
||||
correctness concerns. Check CPU, CUDA, ROCm, MPS, DirectML, XPU, NPU, and low
|
||||
VRAM implications when touching shared execution or loading code.
|
||||
- Prefer native ComfyUI formats and existing quantization/offload helpers over
|
||||
adding parallel code paths. Use `comfy.quant_ops`, `comfy.model_management`,
|
||||
`comfy.memory_management`, `comfy.pinned_memory`, `comfy_aimdo`, and
|
||||
`comfy-kitchen` helpers where they already solve the problem.
|
||||
- Avoid unnecessary casts and transfers. Preserve the intended compute dtype,
|
||||
storage dtype, bias dtype, and original tensor shape metadata.
|
||||
- When optimizing, favor small measurable changes: fewer allocations, fewer
|
||||
device transfers, less peak memory, better batching, or use of a faster
|
||||
existing backend op.
|
||||
|
||||
## Nodes and User-Facing Behavior
|
||||
|
||||
- Follow existing node conventions: `INPUT_TYPES`, `RETURN_TYPES`, `FUNCTION`,
|
||||
`CATEGORY`, and registration through the local mapping used by that file.
|
||||
- Keep node changes backward compatible by default. Add inputs with sensible
|
||||
defaults and avoid changing output types unless the request requires it.
|
||||
- The official mascot of ComfyUI is a very cute anime girl with massive fennec
|
||||
ears, a big fluffy tail, long blonde wavy hair, and blue eyes. Feel free to
|
||||
use her in ComfyUI materials, UI text, examples, tests, generated assets, or
|
||||
comments, but do not disrespect her.
|
||||
- Warning and info messages should be short and actionable. Remove noisy or
|
||||
misleading messages rather than adding more logging.
|
||||
- Documentation and README edits should be concise, factual, and tied to the
|
||||
changed behavior.
|
||||
|
||||
## Commit and Review Habits
|
||||
|
||||
- If asked to write commit messages, use short direct subjects like the existing
|
||||
history: `Fix ...`, `Add ...`, `Support ...`, `Remove ...`, `Update ...`,
|
||||
`Make ...`, `Use ...`, `Disable ...`, `Bump ...`, or `Revert ...`.
|
||||
- Keep PR descriptions short and reviewable. State the problem, the behavioral
|
||||
change, and the tests run; avoid long narrative explanations, implementation
|
||||
diaries, or exhaustive file-by-file summaries unless the reviewer explicitly
|
||||
needs that context.
|
||||
- Prefer one coherent behavioral change per commit. Dependency pins, tests, and
|
||||
the code that needs them may be in the same commit when they are inseparable.
|
||||
- In reviews, prioritize real user impact: crashes, wrong dtype/device behavior,
|
||||
memory regressions, broken model loading, workflow incompatibility, and noisy
|
||||
or misleading user-facing output.
|
||||
@ -240,6 +240,7 @@ database_default_path = os.path.abspath(
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
|
||||
parser.add_argument("--enable-asset-hashing", action="store_true", help="Compute blake3 content hashes when scanning assets. Hashing enables future asset-portability features (deduplication, cross-machine model resolution) but adds startup cost and per-output cost on large models directories. Off by default; enable to opt in.")
|
||||
parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button")
|
||||
parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.")
|
||||
|
||||
|
||||
@ -1274,148 +1274,13 @@ def force_channels_last():
|
||||
return False
|
||||
|
||||
|
||||
_INTEL_XPU_DISCRETE = None
|
||||
def is_intel_xpu_discrete():
|
||||
# Returns True only if the active Intel XPU is a discrete GPU. torch.xpu does
|
||||
# not expose the integrated-vs-discrete distinction, so we query Level Zero
|
||||
# directly via ctypes. Works on Windows (ze_loader.dll) and Linux
|
||||
# (libze_loader.so.1). Any failure or ambiguity returns False so a
|
||||
# discrete-only fast path is never enabled by mistake.
|
||||
global _INTEL_XPU_DISCRETE
|
||||
if _INTEL_XPU_DISCRETE is not None:
|
||||
return _INTEL_XPU_DISCRETE
|
||||
_INTEL_XPU_DISCRETE = False
|
||||
if not is_intel_xpu():
|
||||
return False
|
||||
|
||||
try:
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
|
||||
ZE_RESULT_SUCCESS = 0
|
||||
ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES = 0x3
|
||||
ZE_DEVICE_TYPE_GPU = 1
|
||||
ZE_DEVICE_PROPERTY_FLAG_INTEGRATED = 1 << 0
|
||||
ZE_MAX_DEVICE_NAME = 256
|
||||
|
||||
class ze_device_uuid_t(ctypes.Structure):
|
||||
_fields_ = [("id", ctypes.c_ubyte * 16)]
|
||||
|
||||
class ze_device_properties_t(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("stype", ctypes.c_uint32),
|
||||
("pNext", ctypes.c_void_p),
|
||||
("type", ctypes.c_uint32),
|
||||
("vendorId", ctypes.c_uint32),
|
||||
("deviceId", ctypes.c_uint32),
|
||||
("flags", ctypes.c_uint32),
|
||||
("subdeviceId", ctypes.c_uint32),
|
||||
("coreClockRate", ctypes.c_uint32),
|
||||
("maxMemAllocSize", ctypes.c_uint64),
|
||||
("maxHardwareContexts", ctypes.c_uint32),
|
||||
("maxCommandQueuePriority", ctypes.c_uint32),
|
||||
("numThreadsPerEU", ctypes.c_uint32),
|
||||
("physicalEUSimdWidth", ctypes.c_uint32),
|
||||
("numEUsPerSubslice", ctypes.c_uint32),
|
||||
("numSubslicesPerSlice", ctypes.c_uint32),
|
||||
("numSlices", ctypes.c_uint32),
|
||||
("timerResolution", ctypes.c_uint64),
|
||||
("timestampValidBits", ctypes.c_uint32),
|
||||
("kernelTimestampValidBits", ctypes.c_uint32),
|
||||
("uuid", ze_device_uuid_t),
|
||||
("name", ctypes.c_char * ZE_MAX_DEVICE_NAME),
|
||||
]
|
||||
|
||||
if sys.platform == "win32":
|
||||
loader_names = ["ze_loader.dll"]
|
||||
else:
|
||||
loader_names = [ctypes.util.find_library("ze_loader"), "libze_loader.so.1", "libze_loader.so"]
|
||||
|
||||
ze = None
|
||||
for name in loader_names:
|
||||
if not name:
|
||||
continue
|
||||
try:
|
||||
ze = ctypes.CDLL(name)
|
||||
break
|
||||
except OSError:
|
||||
pass
|
||||
if ze is None:
|
||||
return False
|
||||
|
||||
ze.zeInit.argtypes = [ctypes.c_uint32]
|
||||
ze.zeInit.restype = ctypes.c_uint32
|
||||
ze.zeDriverGet.argtypes = [ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(ctypes.c_void_p)]
|
||||
ze.zeDriverGet.restype = ctypes.c_uint32
|
||||
ze.zeDeviceGet.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(ctypes.c_void_p)]
|
||||
ze.zeDeviceGet.restype = ctypes.c_uint32
|
||||
ze.zeDeviceGetProperties.argtypes = [ctypes.c_void_p, ctypes.POINTER(ze_device_properties_t)]
|
||||
ze.zeDeviceGetProperties.restype = ctypes.c_uint32
|
||||
|
||||
if ze.zeInit(0) != ZE_RESULT_SUCCESS:
|
||||
return False
|
||||
|
||||
try:
|
||||
torch_device_id = int(torch.xpu.get_device_properties(torch.xpu.current_device()).device_id)
|
||||
except Exception:
|
||||
torch_device_id = None
|
||||
|
||||
driver_count = ctypes.c_uint32(0)
|
||||
if ze.zeDriverGet(ctypes.byref(driver_count), None) != ZE_RESULT_SUCCESS or driver_count.value == 0:
|
||||
return False
|
||||
allocated_drivers = driver_count.value
|
||||
drivers = (ctypes.c_void_p * allocated_drivers)()
|
||||
if ze.zeDriverGet(ctypes.byref(driver_count), drivers) != ZE_RESULT_SUCCESS:
|
||||
return False
|
||||
|
||||
gpu_devices = [] # (deviceId, is_integrated)
|
||||
for i in range(min(driver_count.value, allocated_drivers)):
|
||||
device_count = ctypes.c_uint32(0)
|
||||
if ze.zeDeviceGet(drivers[i], ctypes.byref(device_count), None) != ZE_RESULT_SUCCESS:
|
||||
return False
|
||||
if device_count.value == 0:
|
||||
continue
|
||||
allocated_devices = device_count.value
|
||||
devices = (ctypes.c_void_p * allocated_devices)()
|
||||
if ze.zeDeviceGet(drivers[i], ctypes.byref(device_count), devices) != ZE_RESULT_SUCCESS:
|
||||
return False
|
||||
for j in range(min(device_count.value, allocated_devices)):
|
||||
props = ze_device_properties_t()
|
||||
props.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES
|
||||
props.pNext = None
|
||||
if ze.zeDeviceGetProperties(devices[j], ctypes.byref(props)) != ZE_RESULT_SUCCESS:
|
||||
return False
|
||||
if props.type != ZE_DEVICE_TYPE_GPU:
|
||||
continue
|
||||
gpu_devices.append((int(props.deviceId), bool(props.flags & ZE_DEVICE_PROPERTY_FLAG_INTEGRATED)))
|
||||
|
||||
if not gpu_devices:
|
||||
return False
|
||||
|
||||
if torch_device_id is not None:
|
||||
matches = [integrated for device_id, integrated in gpu_devices if device_id == torch_device_id]
|
||||
if matches:
|
||||
# Fail closed if a duplicate PCI device id somehow mixes flags.
|
||||
_INTEL_XPU_DISCRETE = not any(matches)
|
||||
return _INTEL_XPU_DISCRETE
|
||||
|
||||
# No reliable match: only enable when every visible GPU is discrete so a
|
||||
# mixed iGPU+dGPU system never enables streams while running on the iGPU.
|
||||
_INTEL_XPU_DISCRETE = all(not integrated for _, integrated in gpu_devices)
|
||||
return _INTEL_XPU_DISCRETE
|
||||
except Exception as e:
|
||||
logging.info("Could not determine Intel XPU type via Level Zero: {}".format(e))
|
||||
_INTEL_XPU_DISCRETE = False
|
||||
return False
|
||||
|
||||
|
||||
STREAMS = {}
|
||||
NUM_STREAMS = 0
|
||||
if args.async_offload is not None:
|
||||
NUM_STREAMS = args.async_offload
|
||||
else:
|
||||
# Enable by default on Nvidia, AMD, and discrete Intel XPU
|
||||
if not args.disable_async_offload and (is_nvidia() or is_amd() or is_intel_xpu_discrete()):
|
||||
# Enable by default on Nvidia and AMD
|
||||
if is_nvidia() or is_amd():
|
||||
NUM_STREAMS = 2
|
||||
|
||||
if args.disable_async_offload:
|
||||
@ -1622,7 +1487,7 @@ PINNED_MEMORY = {}
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
MAX_PINNED_MEMORY = -1
|
||||
if not args.disable_pinned_memory:
|
||||
if is_nvidia() or is_amd() or is_intel_xpu():
|
||||
if is_nvidia() or is_amd():
|
||||
ram = get_total_memory(torch.device("cpu"))
|
||||
if WINDOWS:
|
||||
MAX_PINNED_MEMORY = ram * 0.40 # Windows limit is apparently 50%
|
||||
@ -1647,20 +1512,6 @@ def discard_cuda_async_error():
|
||||
#Dump it! We already know about it from the synchronous return
|
||||
pass
|
||||
|
||||
def host_register(ptr, size):
|
||||
# Intel XPU has no CUDA host-registration API. The pinnable buffers used by
|
||||
# the DynamicVRAM path are already Level Zero host USM (allocated through the
|
||||
# aimdo hostbuf / zeMemAllocHost), and pageable host memory is still usable
|
||||
# for transfers, so registration is a no-op success on XPU.
|
||||
if is_intel_xpu():
|
||||
return 0
|
||||
return torch.cuda.cudart().cudaHostRegister(ptr, size, 1)
|
||||
|
||||
def host_unregister(ptr):
|
||||
if is_intel_xpu():
|
||||
return 0
|
||||
return torch.cuda.cudart().cudaHostUnregister(ptr)
|
||||
|
||||
def pin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
@ -1689,7 +1540,7 @@ def pin_memory(tensor):
|
||||
if ptr == 0:
|
||||
return False
|
||||
|
||||
if host_register(ptr, size) == 0:
|
||||
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
|
||||
PINNED_MEMORY[ptr] = size
|
||||
TOTAL_PINNED_MEMORY += size
|
||||
return True
|
||||
@ -1719,7 +1570,7 @@ def unpin_memory(tensor):
|
||||
logging.warning("Size of pinned tensor changed")
|
||||
return False
|
||||
|
||||
if host_unregister(ptr) == 0:
|
||||
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
||||
size = PINNED_MEMORY.pop(ptr)
|
||||
TOTAL_PINNED_MEMORY -= size
|
||||
return True
|
||||
|
||||
@ -1961,7 +1961,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
if not module._pin_registered:
|
||||
continue
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
if comfy.model_management.host_unregister(module._pin.data_ptr()) != 0:
|
||||
if torch.cuda.cudart().cudaHostUnregister(module._pin.data_ptr()) != 0:
|
||||
comfy.model_management.discard_cuda_async_error()
|
||||
continue
|
||||
module._pin_registered = False
|
||||
|
||||
64
comfy/ops.py
64
comfy/ops.py
@ -256,7 +256,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
|
||||
if (want_requant and len(fns) == 0 or update_weight):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
if isinstance(orig, QuantizedTensor):
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
y = orig.requantize_from_float(x, scale="recalculate", stochastic_rounding=seed)
|
||||
else:
|
||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
|
||||
if want_requant and len(fns) == 0:
|
||||
@ -1089,6 +1089,19 @@ def _load_quantized_module(module, super_load, state_dict, prefix, local_metadat
|
||||
if ts is None or bs is None:
|
||||
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||
scales = {"scale": ts, "block_scale": bs}
|
||||
elif module.quant_format == "int8_tensorwise":
|
||||
scale = pop_scale("weight_scale")
|
||||
if scale is None:
|
||||
raise ValueError(f"Missing INT8 weight scale for layer {layer_name}")
|
||||
scales = {"scale": scale}
|
||||
params_conf = layer_conf.get("params", {})
|
||||
if not isinstance(params_conf, dict):
|
||||
params_conf = {}
|
||||
if layer_conf.get("convrot", params_conf.get("convrot", False)):
|
||||
scales["convrot"] = True
|
||||
scales["convrot_groupsize"] = int(
|
||||
layer_conf.get("convrot_groupsize", params_conf.get("convrot_groupsize", 256))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
|
||||
|
||||
@ -1131,6 +1144,10 @@ def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extr
|
||||
quant_conf = {"format": module.quant_format}
|
||||
if getattr(module, '_full_precision_mm_config', False):
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
params = getattr(module.weight, "_params", None)
|
||||
if module.quant_format == "int8_tensorwise" and getattr(params, "convrot", False):
|
||||
quant_conf["convrot"] = True
|
||||
quant_conf["convrot_groupsize"] = getattr(params, "convrot_groupsize", 256)
|
||||
if extra_quant_conf:
|
||||
quant_conf.update(extra_quant_conf)
|
||||
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
|
||||
@ -1183,8 +1200,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
|
||||
def forward_comfy_cast_weights(
|
||||
self,
|
||||
input,
|
||||
compute_dtype=None,
|
||||
want_requant=False,
|
||||
weight_only_quant=False,
|
||||
):
|
||||
if weight_only_quant:
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
input=None,
|
||||
dtype=self.weight.dtype,
|
||||
device=input.device,
|
||||
bias_dtype=input.dtype,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=True,
|
||||
)
|
||||
weight = weight.to(dtype=input.dtype)
|
||||
else:
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
input,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=want_requant,
|
||||
)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
@ -1203,9 +1245,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0
|
||||
)
|
||||
quantize_input = QUANT_ALGOS.get(getattr(self, 'quant_format', None), {}).get("quantize_input", True)
|
||||
|
||||
# Training path: quantized forward with compute_dtype backward via autograd function
|
||||
if (input.requires_grad and _use_quantized):
|
||||
if (input.requires_grad and _use_quantized and quantize_input):
|
||||
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
@ -1227,7 +1270,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
return output
|
||||
|
||||
# Inference path (unchanged)
|
||||
if _use_quantized:
|
||||
if _use_quantized and quantize_input:
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
@ -1241,7 +1284,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
||||
weight_only_quant = _use_quantized and not quantize_input and isinstance(self.weight, QuantizedTensor)
|
||||
output = self.forward_comfy_cast_weights(
|
||||
input,
|
||||
compute_dtype,
|
||||
want_requant=isinstance(input, QuantizedTensor),
|
||||
weight_only_quant=weight_only_quant,
|
||||
)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if reshaped_3d:
|
||||
@ -1257,8 +1306,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
# dtype is now implicit in the layout class
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
weight = self.weight.requantize_from_float(weight, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
|
||||
@ -53,7 +53,7 @@ def get_pin(module, subset="weights"):
|
||||
size = pin.nbytes
|
||||
comfy.model_management.ensure_pin_registerable(size)
|
||||
|
||||
if comfy.model_management.host_register(pin.data_ptr(), size) != 0:
|
||||
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
|
||||
comfy.model_management.discard_cuda_async_error()
|
||||
return pin
|
||||
|
||||
@ -95,10 +95,10 @@ def pin_memory(module, subset="weights", size=None):
|
||||
extended = True
|
||||
pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
|
||||
pin.untyped_storage()._comfy_hostbuf = hostbuf
|
||||
if comfy.model_management.host_register(pin.data_ptr(), size) != 0:
|
||||
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
|
||||
comfy.model_management.discard_cuda_async_error()
|
||||
comfy.model_management.free_registrations(size)
|
||||
if comfy.model_management.host_register(pin.data_ptr(), size) != 0:
|
||||
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
|
||||
comfy.model_management.discard_cuda_async_error()
|
||||
del pin
|
||||
hostbuf.truncate(offset, do_unregister=False)
|
||||
|
||||
@ -10,6 +10,7 @@ try:
|
||||
QuantizedLayout,
|
||||
TensorCoreFP8Layout as _CKFp8Layout,
|
||||
TensorCoreNVFP4Layout as _CKNvfp4Layout,
|
||||
TensorWiseINT8Layout as _CKTensorWiseINT8Layout,
|
||||
register_layout_op,
|
||||
register_layout_class,
|
||||
get_layout_class,
|
||||
@ -47,6 +48,9 @@ except ImportError as e:
|
||||
class _CKNvfp4Layout:
|
||||
pass
|
||||
|
||||
class _CKTensorWiseINT8Layout:
|
||||
pass
|
||||
|
||||
def register_layout_class(name, cls):
|
||||
pass
|
||||
|
||||
@ -174,6 +178,7 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
|
||||
|
||||
# Backward compatibility alias - default to E4M3
|
||||
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
||||
TensorWiseINT8Layout = _CKTensorWiseINT8Layout
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
@ -184,6 +189,7 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||
register_layout_class("TensorWiseINT8Layout", _CKTensorWiseINT8Layout)
|
||||
if _CK_MXFP8_AVAILABLE:
|
||||
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||
|
||||
@ -214,6 +220,13 @@ if _CK_MXFP8_AVAILABLE:
|
||||
"group_size": 32,
|
||||
}
|
||||
|
||||
QUANT_ALGOS["int8_tensorwise"] = {
|
||||
"storage_t": torch.int8,
|
||||
"parameters": {"weight_scale"},
|
||||
"comfy_tensor_layout": "TensorWiseINT8Layout",
|
||||
"quantize_input": False,
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Re-exports for backward compatibility
|
||||
@ -226,6 +239,7 @@ __all__ = [
|
||||
"TensorCoreFP8E4M3Layout",
|
||||
"TensorCoreFP8E5M2Layout",
|
||||
"TensorCoreNVFP4Layout",
|
||||
"TensorWiseINT8Layout",
|
||||
"QUANT_ALGOS",
|
||||
"register_layout_op",
|
||||
]
|
||||
|
||||
@ -891,6 +891,14 @@ class Tracks(ComfyTypeIO):
|
||||
track_visibility: torch.Tensor
|
||||
Type = TrackDict
|
||||
|
||||
@comfytype(io_type="DICT")
|
||||
class Dict(ComfyTypeIO):
|
||||
Type = dict
|
||||
|
||||
@comfytype(io_type="ARRAY")
|
||||
class Array(ComfyTypeIO):
|
||||
Type = list
|
||||
|
||||
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
||||
class MultiType:
|
||||
Type = Any
|
||||
@ -1279,6 +1287,19 @@ class Color(ComfyTypeIO):
|
||||
def as_dict(self):
|
||||
return super().as_dict()
|
||||
|
||||
|
||||
@comfytype(io_type="COLORS")
|
||||
class Colors(ComfyTypeIO):
|
||||
Type = list[Color.Type]
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: list[str]=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = []
|
||||
|
||||
|
||||
@comfytype(io_type="BOUNDING_BOX")
|
||||
class BoundingBox(ComfyTypeIO):
|
||||
class BoundingBoxDict(TypedDict):
|
||||
@ -1326,6 +1347,20 @@ class Curve(ComfyTypeIO):
|
||||
return d
|
||||
|
||||
|
||||
@comfytype(io_type="BOUNDING_BOXES")
|
||||
class BoundingBoxes(ComfyTypeIO):
|
||||
class BoundingBoxWithMetadata(BoundingBox.BoundingBoxDict):
|
||||
metadata: dict
|
||||
Type = list[BoundingBoxWithMetadata]
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: list[dict]=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = []
|
||||
|
||||
|
||||
@comfytype(io_type="HISTOGRAM")
|
||||
class Histogram(ComfyTypeIO):
|
||||
"""A histogram represented as a list of bin counts."""
|
||||
@ -2376,6 +2411,8 @@ __all__ = [
|
||||
"AnyType",
|
||||
"MultiType",
|
||||
"Tracks",
|
||||
"Dict",
|
||||
"Array",
|
||||
"Color",
|
||||
# Dynamic Types
|
||||
"MatchType",
|
||||
@ -2394,6 +2431,8 @@ __all__ = [
|
||||
"PriceBadgeDepends",
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"BoundingBoxes",
|
||||
"Colors",
|
||||
"Curve",
|
||||
"Histogram",
|
||||
"Range",
|
||||
|
||||
@ -177,6 +177,10 @@ SEEDANCE2_PRICE_PER_1K_TOKENS = {
|
||||
("dreamina-seedance-2-0-fast-260128", True, "480p"): 0.0033,
|
||||
("dreamina-seedance-2-0-fast-260128", False, "720p"): 0.0056,
|
||||
("dreamina-seedance-2-0-fast-260128", True, "720p"): 0.0033,
|
||||
("dreamina-seedance-2-0-mini", False, "480p"): 0.0035,
|
||||
("dreamina-seedance-2-0-mini", True, "480p"): 0.0021,
|
||||
("dreamina-seedance-2-0-mini", False, "720p"): 0.0035,
|
||||
("dreamina-seedance-2-0-mini", True, "720p"): 0.0021,
|
||||
}
|
||||
|
||||
|
||||
@ -278,6 +282,10 @@ SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
|
||||
"480p": {"min": 409_600, "max": 927_408},
|
||||
"720p": {"min": 409_600, "max": 927_408},
|
||||
},
|
||||
"dreamina-seedance-2-0-mini": {
|
||||
"480p": {"min": 409_600, "max": 927_408},
|
||||
"720p": {"min": 409_600, "max": 927_408},
|
||||
},
|
||||
}
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
|
||||
@ -121,6 +121,7 @@ class GeminiGenerationConfig(BaseModel):
|
||||
topK: int | None = Field(None, ge=1)
|
||||
topP: float | None = Field(None, ge=0.0, le=1.0)
|
||||
thinkingConfig: GeminiThinkingConfig | None = Field(None)
|
||||
responseModalities: list[str] | None = Field(None)
|
||||
|
||||
|
||||
class GeminiImageOutputOptions(BaseModel):
|
||||
|
||||
@ -89,6 +89,7 @@ BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT = "/proxy/byteplus-seedance2/api/v3/cont
|
||||
SEEDANCE_MODELS = {
|
||||
"Seedance 2.0": "dreamina-seedance-2-0-260128",
|
||||
"Seedance 2.0 Fast": "dreamina-seedance-2-0-fast-260128",
|
||||
"Seedance 2.0 Mini": "dreamina-seedance-2-0-mini",
|
||||
}
|
||||
|
||||
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
|
||||
@ -1623,8 +1624,10 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p", "4k"])),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Mini", _seedance2_text_inputs(["480p", "720p"])),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
|
||||
"Mini for the fastest, lowest-cost generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
@ -1666,6 +1669,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$pricePer1K := $res = "4k" ? 0.00572 :
|
||||
$res = "1080p" ? 0.011011 :
|
||||
$contains($m, "mini") ? 0.005005 :
|
||||
$contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$rate := $res = "4k" ? $rate4k :
|
||||
$res = "1080p" ? $rate1080 :
|
||||
@ -1734,8 +1738,13 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
"Seedance 2.0 Fast",
|
||||
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0 Mini",
|
||||
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
|
||||
"Mini for the fastest, lowest-cost generation.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"first_frame",
|
||||
@ -1801,6 +1810,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$pricePer1K := $res = "4k" ? 0.00572 :
|
||||
$res = "1080p" ? 0.011011 :
|
||||
$contains($m, "mini") ? 0.005005 :
|
||||
$contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$rate := $res = "4k" ? $rate4k :
|
||||
$res = "1080p" ? $rate1080 :
|
||||
@ -2024,8 +2034,13 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
"Seedance 2.0 Fast",
|
||||
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0 Mini",
|
||||
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; "
|
||||
"Mini for the fastest, lowest-cost generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
@ -2071,9 +2086,11 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$noVideoPricePer1K := $res = "4k" ? 0.00572 :
|
||||
$res = "1080p" ? 0.011011 :
|
||||
$contains($m, "mini") ? 0.005005 :
|
||||
$contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$videoPricePer1K := $res = "4k" ? 0.003432 :
|
||||
$res = "1080p" ? 0.006721 :
|
||||
$contains($m, "mini") ? 0.003003 :
|
||||
$contains($m, "fast") ? 0.004719 : 0.006149;
|
||||
$rate := $res = "4k" ? $rate4k :
|
||||
$res = "1080p" ? $rate1080 :
|
||||
|
||||
@ -13,7 +13,7 @@ import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl, Types
|
||||
from comfy_api_nodes.apis.gemini import (
|
||||
GeminiContent,
|
||||
GeminiFileData,
|
||||
@ -37,6 +37,7 @@ from comfy_api_nodes.util import (
|
||||
audio_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
get_number_of_images,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
@ -45,6 +46,7 @@ from comfy_api_nodes.util import (
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
validate_video_duration,
|
||||
video_to_base64_string,
|
||||
)
|
||||
|
||||
@ -229,10 +231,29 @@ async def get_image_from_response(response: GeminiGenerateContentResponse, thoug
|
||||
return torch.cat(image_tensors, dim=0)
|
||||
|
||||
|
||||
async def get_video_from_response(
|
||||
response: GeminiGenerateContentResponse, cls: type[IO.ComfyNode] | None = None
|
||||
) -> InputImpl.VideoFromFile:
|
||||
parts = get_parts_by_type(response, "video/*")
|
||||
for part in parts:
|
||||
if part.inlineData and part.inlineData.data:
|
||||
return InputImpl.VideoFromFile(BytesIO(base64.b64decode(part.inlineData.data)))
|
||||
if part.fileData and part.fileData.fileUri:
|
||||
return await download_url_to_video_output(part.fileData.fileUri, cls=cls)
|
||||
model_message = get_text_from_response(response).strip()
|
||||
if model_message:
|
||||
raise ValueError(f"Gemini did not generate a video. Model response: {model_message}")
|
||||
raise ValueError(
|
||||
"Gemini did not generate a video. Try rephrasing your prompt, "
|
||||
"shortening the requested duration, or reducing the number of input images/videos."
|
||||
)
|
||||
|
||||
|
||||
def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | None:
|
||||
if not response.modelVersion:
|
||||
return None
|
||||
# Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing
|
||||
output_video_tokens_price = 0.0
|
||||
if response.modelVersion == "gemini-2.5-pro":
|
||||
input_tokens_price = 1.25
|
||||
output_text_tokens_price = 10.0
|
||||
@ -249,18 +270,27 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
||||
input_tokens_price = 2
|
||||
output_text_tokens_price = 12.0
|
||||
output_image_tokens_price = 0.0
|
||||
elif response.modelVersion == "gemini-3.1-flash-lite-preview":
|
||||
elif response.modelVersion in ("gemini-3.1-flash-lite-preview", "gemini-3.1-flash-lite"):
|
||||
input_tokens_price = 0.25
|
||||
output_text_tokens_price = 1.50
|
||||
output_image_tokens_price = 0.0
|
||||
elif response.modelVersion == "gemini-3-pro-image-preview":
|
||||
elif response.modelVersion in ("gemini-3-pro-image-preview", "gemini-3-pro-image"):
|
||||
input_tokens_price = 2
|
||||
output_text_tokens_price = 12.0
|
||||
output_image_tokens_price = 120.0
|
||||
elif response.modelVersion == "gemini-3.1-flash-image-preview":
|
||||
elif response.modelVersion in ("gemini-3.1-flash-image-preview", "gemini-3.1-flash-image"):
|
||||
input_tokens_price = 0.5
|
||||
output_text_tokens_price = 3.0
|
||||
output_image_tokens_price = 60.0
|
||||
elif response.modelVersion == "gemini-3.1-flash-lite-image":
|
||||
input_tokens_price = 0.25
|
||||
output_text_tokens_price = 1.50
|
||||
output_image_tokens_price = 30.0
|
||||
elif response.modelVersion == "gemini-omni-flash-preview":
|
||||
input_tokens_price = 2.145
|
||||
output_text_tokens_price = 12.87
|
||||
output_image_tokens_price = 0.0
|
||||
output_video_tokens_price = 25.025
|
||||
else:
|
||||
return None
|
||||
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
|
||||
@ -268,6 +298,8 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
||||
for i in response.usageMetadata.candidatesTokensDetails:
|
||||
if i.modality == Modality.IMAGE:
|
||||
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
|
||||
elif i.modality == Modality.VIDEO:
|
||||
final_price += output_video_tokens_price * i.tokenCount # for Omni Flash
|
||||
else:
|
||||
final_price += output_text_tokens_price * i.tokenCount
|
||||
if response.usageMetadata.thoughtsTokenCount:
|
||||
@ -1302,7 +1334,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
def _nano_banana_2_v2_model_inputs():
|
||||
def _nano_banana_2_v2_model_inputs(resolutions: list[str]):
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
@ -1329,8 +1361,8 @@ def _nano_banana_2_v2_model_inputs():
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["1K", "2K", "4K"],
|
||||
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
|
||||
options=resolutions,
|
||||
tooltip="Target output resolution.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"thinking_level",
|
||||
@ -1376,7 +1408,11 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"Nano Banana 2 (Gemini 3.1 Flash Image)",
|
||||
_nano_banana_2_v2_model_inputs(),
|
||||
_nano_banana_2_v2_model_inputs(resolutions=["1K", "2K", "4K"]),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Nano Banana 2 Lite",
|
||||
_nano_banana_2_v2_model_inputs(resolutions=["1K"]),
|
||||
),
|
||||
],
|
||||
),
|
||||
@ -1445,9 +1481,13 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$r := $lookup(widgets, "model.resolution");
|
||||
$prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
|
||||
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||
$contains(widgets.model, "lite")
|
||||
? {"type":"usd","usd": 0.034, "format":{"suffix":"/Image","approximate":true}}
|
||||
: (
|
||||
$r := $lookup(widgets, "model.resolution");
|
||||
$prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
|
||||
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||
)
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -1468,6 +1508,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
model_choice = model["model"]
|
||||
if model_choice == "Nano Banana 2 (Gemini 3.1 Flash Image)":
|
||||
model_id = "gemini-3.1-flash-image-preview"
|
||||
elif model_choice == "Nano Banana 2 Lite":
|
||||
model_id = "gemini-3.1-flash-lite-image"
|
||||
else:
|
||||
model_id = model_choice
|
||||
|
||||
@ -1517,6 +1559,149 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
OMNI_MAX_IMAGES = 14
|
||||
OMNI_MAX_VIDEOS = 3
|
||||
|
||||
OMNI_MODELS: dict[str, str] = {
|
||||
"Omni Flash": "gemini-omni-flash-preview",
|
||||
}
|
||||
|
||||
|
||||
def _omni_flash_inputs() -> list[Input]:
|
||||
"""Per-model inputs for the Omni video DynamicCombo (prompt + reference media + sampling)."""
|
||||
return [
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Describe the video to generate. Specify the length and aspect ratio directly in the "
|
||||
'prompt, e.g. "a 6-second clip in 16:9". Length may be 3-10 seconds; the aspect ratio must be '
|
||||
"16:9 (landscape) or 9:16 (portrait). The output is 720p, 24 FPS, with audio.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, OMNI_MAX_IMAGES + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional reference image(s) to guide or animate the video. Up to {OMNI_MAX_IMAGES} images.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"videos",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Video.Input("video"),
|
||||
names=[f"video_{i}" for i in range(1, OMNI_MAX_VIDEOS + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional reference video(s) to guide or edit. Up to {OMNI_MAX_VIDEOS} videos, "
|
||||
f"each up to 10 seconds long.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.01,
|
||||
tooltip="Controls randomness. Lower is more focused/deterministic, higher is more varied.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"top_p",
|
||||
default=0.95,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class GeminiVideoOmni(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GeminiVideoOmni",
|
||||
display_name="Google Gemini Omni (Video)",
|
||||
category="partner/video/Gemini",
|
||||
essentials_category="Video Generation",
|
||||
description="Generate a video with audio from a text prompt using Google's Gemini Omni Flash model. "
|
||||
"Optionally provide reference images and/or videos to guide or edit the result. Describe the desired "
|
||||
"length (3-10s) and aspect ratio (16:9 or 9:16) directly in the prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Omni Flash", _omni_flash_inputs()),
|
||||
],
|
||||
tooltip="The Gemini video model used to generate the video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
IO.String.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr='{"type":"usd","usd":0.146,"format":{"suffix":"/second","approximate":true}}'
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, model: dict, seed: int) -> IO.NodeOutput:
|
||||
prompt = model.get("prompt") or ""
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_id = OMNI_MODELS[model["model"]]
|
||||
|
||||
images = [t for t in (model.get("images") or {}).values() if t is not None]
|
||||
videos = [v for v in (model.get("videos") or {}).values() if v is not None]
|
||||
if sum(get_number_of_images(t) for t in images) > OMNI_MAX_IMAGES:
|
||||
raise ValueError(f"The current maximum number of supported images is {OMNI_MAX_IMAGES}.")
|
||||
if len(videos) > OMNI_MAX_VIDEOS:
|
||||
raise ValueError(f"The current maximum number of supported videos is {OMNI_MAX_VIDEOS}.")
|
||||
for video in videos:
|
||||
validate_video_duration(video, max_duration=10)
|
||||
|
||||
parts: list[GeminiPart] = []
|
||||
if images or videos:
|
||||
parts.extend(await build_gemini_media_parts(cls, images, [], videos))
|
||||
parts.append(GeminiPart(text=prompt))
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"),
|
||||
data=GeminiGenerateContentRequest(
|
||||
contents=[GeminiContent(role=GeminiRole.user, parts=parts)],
|
||||
generationConfig=GeminiGenerationConfig(
|
||||
responseModalities=["TEXT", "VIDEO"],
|
||||
temperature=model.get("temperature", 1.0),
|
||||
topP=model.get("top_p", 0.95),
|
||||
),
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await get_video_from_response(response, cls=cls),
|
||||
get_text_from_response(response),
|
||||
)
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -1527,6 +1712,7 @@ class GeminiExtension(ComfyExtension):
|
||||
GeminiImage2,
|
||||
GeminiNanoBanana2,
|
||||
GeminiNanoBanana2V2,
|
||||
GeminiVideoOmni,
|
||||
GeminiInputFiles,
|
||||
]
|
||||
|
||||
|
||||
23
comfy_extras/color_util.py
Normal file
23
comfy_extras/color_util.py
Normal file
@ -0,0 +1,23 @@
|
||||
def hex_to_rgb(value: str) -> tuple[int, int, int]:
|
||||
h = value.lstrip("#")
|
||||
if len(h) != 6:
|
||||
return (255, 255, 255)
|
||||
try:
|
||||
return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16))
|
||||
except ValueError:
|
||||
return (255, 255, 255)
|
||||
|
||||
|
||||
def readable_color(rgb: tuple[int, int, int]) -> tuple[int, int, int]:
|
||||
r, g, b = rgb
|
||||
lum = 0.299 * r + 0.587 * g + 0.114 * b
|
||||
if lum >= 130:
|
||||
return (r, g, b)
|
||||
t = (130 - lum) / (255 - lum)
|
||||
return (round(r + (255 - r) * t), round(g + (255 - g) * t), round(b + (255 - b) * t))
|
||||
|
||||
|
||||
def normalize_palette(colors) -> list[str]:
|
||||
if isinstance(colors, dict):
|
||||
colors = colors.values()
|
||||
return [c.upper() for c in colors if isinstance(c, str) and c]
|
||||
302
comfy_extras/nodes_bounding_boxes.py
Normal file
302
comfy_extras/nodes_bounding_boxes.py
Normal file
@ -0,0 +1,302 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageEnhance, ImageFont
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_extras.color_util import hex_to_rgb, normalize_palette, readable_color
|
||||
|
||||
_PREVIEW_LONG_EDGE = 1024
|
||||
_PREVIEW_DIM = 0.25
|
||||
|
||||
|
||||
def pixels_to_fractions(box: dict, width: int, height: int) -> dict:
|
||||
w = width or 1
|
||||
h = height or 1
|
||||
return {
|
||||
"x": box.get("x", 0) / w,
|
||||
"y": box.get("y", 0) / h,
|
||||
"w": box.get("width", 0) / w,
|
||||
"h": box.get("height", 0) / h,
|
||||
}
|
||||
|
||||
|
||||
def fractions_to_pixels(box: dict, width: int, height: int) -> dict:
|
||||
x, y = box.get("x", 0.0), box.get("y", 0.0)
|
||||
w, h = box.get("w", 0.0), box.get("h", 0.0)
|
||||
if w < 0:
|
||||
x, w = x + w, -w
|
||||
if h < 0:
|
||||
y, h = y + h, -h
|
||||
return {
|
||||
"x": round(x * width),
|
||||
"y": round(y * height),
|
||||
"width": round(w * width),
|
||||
"height": round(h * height),
|
||||
}
|
||||
|
||||
|
||||
def fractions_to_bbox_frame(boxes: list, width: int, height: int) -> list:
|
||||
pixels = [
|
||||
fractions_to_pixels(box, width, height)
|
||||
for box in boxes
|
||||
if isinstance(box, dict)
|
||||
]
|
||||
return [pixels] if pixels else []
|
||||
|
||||
|
||||
def _font(size: int):
|
||||
try:
|
||||
return ImageFont.load_default(size)
|
||||
except Exception:
|
||||
return ImageFont.load_default()
|
||||
|
||||
|
||||
def _wrap(draw, text: str, font, max_w: float) -> list[str]:
|
||||
lines = []
|
||||
for para in text.split("\n"):
|
||||
line = ""
|
||||
for word in para.split():
|
||||
test = word if not line else line + " " + word
|
||||
if line and draw.textlength(test, font=font) > max_w:
|
||||
lines.append(line)
|
||||
line = word
|
||||
else:
|
||||
line = test
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
|
||||
def _bg_from_image(image) -> Image.Image | None:
|
||||
if image is None:
|
||||
return None
|
||||
try:
|
||||
arr = (image[0].detach().cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
return Image.fromarray(arr)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def render_preview(regions, width, height, bg=None):
|
||||
if bg is not None:
|
||||
iw, ih = bg.size
|
||||
long_edge = max(iw, ih) or 1
|
||||
scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge)
|
||||
rw, rh = max(1, round(iw * scale)), max(1, round(ih * scale))
|
||||
base = bg.convert("RGB").resize((rw, rh), Image.LANCZOS)
|
||||
base = ImageEnhance.Brightness(base).enhance(_PREVIEW_DIM)
|
||||
img = base.convert("RGBA")
|
||||
else:
|
||||
long_edge = max(width, height) or 1
|
||||
scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge)
|
||||
rw, rh = max(1, round(width * scale)), max(1, round(height * scale))
|
||||
grey = round(_PREVIEW_DIM * 128)
|
||||
img = Image.new("RGBA", (rw, rh), (grey, grey, grey, 255))
|
||||
|
||||
overlay = Image.new("RGBA", (rw, rh), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(overlay)
|
||||
fs = max(10, round(rh / 64))
|
||||
font = _font(fs)
|
||||
tag_font = _font(max(9, fs - 2))
|
||||
line_h = fs + 2
|
||||
|
||||
for i, region in enumerate(regions):
|
||||
if not isinstance(region, dict):
|
||||
continue
|
||||
palette = [c for c in (region.get("palette") or []) if c]
|
||||
r, g, b = hex_to_rgb(palette[0]) if palette else (140, 140, 140)
|
||||
x1 = max(0, min(rw, round(region.get("x", 0) * rw)))
|
||||
y1 = max(0, min(rh, round(region.get("y", 0) * rh)))
|
||||
x2 = max(0, min(rw, round((region.get("x", 0) + region.get("w", 0)) * rw)))
|
||||
y2 = max(0, min(rh, round((region.get("y", 0) + region.get("h", 0)) * rh)))
|
||||
if x2 < x1:
|
||||
x1, x2 = x2, x1
|
||||
if y2 < y1:
|
||||
y1, y2 = y2, y1
|
||||
|
||||
draw.rectangle([x1, y1, x2, y2], outline=(r, g, b, 255), width=2)
|
||||
|
||||
swatches = palette[:5]
|
||||
if swatches and (x2 - x1) > 2:
|
||||
sh = max(5, fs // 2)
|
||||
seg = (x2 - x1) / len(swatches)
|
||||
for p, hexc in enumerate(swatches):
|
||||
sx = x1 + round(p * seg)
|
||||
draw.rectangle([sx, y1, x1 + round((p + 1) * seg), y1 + sh], fill=hex_to_rgb(hexc))
|
||||
|
||||
etype = "text" if region.get("type") == "text" else "obj"
|
||||
tag = str(i + 1).zfill(2)
|
||||
tw = draw.textlength(tag, font=tag_font)
|
||||
draw.rectangle([x1, y1, x1 + tw + 6, y1 + fs + 2], fill=(r, g, b, 255))
|
||||
tag_fill = (0, 0, 0, 255) if (0.299 * r + 0.587 * g + 0.114 * b) > 140 else (255, 255, 255, 255)
|
||||
draw.text((x1 + 3, y1 + 1), tag, fill=tag_fill, font=tag_font)
|
||||
|
||||
body = region.get("desc", "") or ""
|
||||
if etype == "text" and region.get("text"):
|
||||
body = '"%s"%s' % (region["text"], " — " + body if body else "")
|
||||
if body and (x2 - x1) > 8:
|
||||
ty = y1 + fs + 5
|
||||
for line in _wrap(draw, body, font, x2 - x1 - 8):
|
||||
if ty > y2:
|
||||
break
|
||||
draw.text((x1 + 4, ty), line, fill=readable_color((r, g, b)) + (255,), font=font)
|
||||
ty += line_h
|
||||
|
||||
composed = Image.alpha_composite(img, overlay).convert("RGB")
|
||||
arr = np.asarray(composed, dtype=np.float32) / 255.0
|
||||
return torch.from_numpy(arr).unsqueeze(0)
|
||||
|
||||
|
||||
def boxes_to_regions(boxes, width: int, height: int) -> list:
|
||||
regions: list = []
|
||||
if not isinstance(boxes, list):
|
||||
return regions
|
||||
for box in boxes:
|
||||
if not isinstance(box, dict):
|
||||
continue
|
||||
meta = box.get("metadata")
|
||||
meta = meta if isinstance(meta, dict) else {}
|
||||
regions.append({
|
||||
**pixels_to_fractions(box, width, height),
|
||||
"type": meta.get("type", "obj"),
|
||||
"text": meta.get("text", ""),
|
||||
"desc": meta.get("desc", ""),
|
||||
"palette": meta.get("palette", []),
|
||||
})
|
||||
return regions
|
||||
|
||||
|
||||
def normalize_incoming_boxes(bboxes) -> list:
|
||||
if isinstance(bboxes, dict):
|
||||
frame = [bboxes]
|
||||
elif not isinstance(bboxes, list) or not bboxes:
|
||||
frame = []
|
||||
elif isinstance(bboxes[0], dict):
|
||||
frame = bboxes
|
||||
else:
|
||||
frame = bboxes[0] if isinstance(bboxes[0], list) else []
|
||||
boxes = []
|
||||
for box in frame:
|
||||
if not isinstance(box, dict):
|
||||
continue
|
||||
norm = {
|
||||
"x": box.get("x", 0),
|
||||
"y": box.get("y", 0),
|
||||
"width": box.get("width", 0),
|
||||
"height": box.get("height", 0),
|
||||
}
|
||||
meta = box.get("metadata")
|
||||
if isinstance(meta, dict):
|
||||
norm["metadata"] = meta
|
||||
boxes.append(norm)
|
||||
return boxes
|
||||
|
||||
|
||||
def _norm_bbox(region: dict) -> list[int]:
|
||||
def grid(value: float) -> int:
|
||||
return max(0, min(1000, round(value * 1000)))
|
||||
|
||||
x, y = region.get("x", 0.0), region.get("y", 0.0)
|
||||
w, h = region.get("w", 0.0), region.get("h", 0.0)
|
||||
ymin, xmin, ymax, xmax = grid(y), grid(x), grid(y + h), grid(x + w)
|
||||
if ymin > ymax:
|
||||
ymin, ymax = ymax, ymin
|
||||
if xmin > xmax:
|
||||
xmin, xmax = xmax, xmin
|
||||
return [ymin, xmin, ymax, xmax]
|
||||
|
||||
|
||||
def build_elements(regions: list) -> list:
|
||||
elements = []
|
||||
for region in regions:
|
||||
if not isinstance(region, dict):
|
||||
continue
|
||||
etype = "text" if region.get("type") == "text" else "obj"
|
||||
element = {"type": etype}
|
||||
element["bbox"] = _norm_bbox(region)
|
||||
if etype == "text":
|
||||
element["text"] = region.get("text", "")
|
||||
element["desc"] = region.get("desc", "")
|
||||
palette = normalize_palette(region.get("palette", []))
|
||||
if palette:
|
||||
element["color_palette"] = palette[:5]
|
||||
elements.append(element)
|
||||
return elements
|
||||
|
||||
|
||||
class CreateBoundingBoxes(io.ComfyNode):
|
||||
_last_incoming: dict = {}
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
editor_state = io.BoundingBoxes.Input(
|
||||
"editor_state",
|
||||
socketless=False,
|
||||
tooltip="Draw bounding boxes and set each box type, text, description, color palette. Start with background element first and foreground last.",
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="CreateBoundingBoxes",
|
||||
display_name="Create Bounding Boxes",
|
||||
category="utilities",
|
||||
description="Draw bounding boxes in a canvas. Outputs Ideogram prompt elements, pixel-space bounding boxes, and a preview image.",
|
||||
inputs=[
|
||||
io.Image.Input(
|
||||
"background",
|
||||
optional=True,
|
||||
tooltip="Optional image used as background in the canvas and preview.",
|
||||
),
|
||||
io.BoundingBox.Input(
|
||||
"bboxes",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
tooltip="Bounding boxes from an upstream node. A new upstream value seeds the canvas; edits you make on the canvas take priority and are kept until the upstream value changes again.",
|
||||
),
|
||||
io.Int.Input("width", default=1024, min=64, max=16384, step=16,
|
||||
tooltip="Width of the canvas and the pixel grid for the bounding boxes."),
|
||||
io.Int.Input("height", default=1024, min=64, max=16384, step=16,
|
||||
tooltip="Height of the canvas and the pixel grid for the bounding boxes."),
|
||||
editor_state,
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="preview"),
|
||||
io.BoundingBox.Output(display_name="bboxes"),
|
||||
io.Array.Output(display_name="elements"),
|
||||
],
|
||||
hidden=[io.Hidden.unique_id],
|
||||
is_output_node=True,
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, editor_state=None, background=None, bboxes=None) -> io.NodeOutput:
|
||||
incoming = normalize_incoming_boxes(bboxes)
|
||||
node_id = cls.hidden.unique_id
|
||||
if incoming:
|
||||
changed = cls._last_incoming.get(node_id) != incoming
|
||||
if changed:
|
||||
cls._last_incoming[node_id] = incoming
|
||||
else:
|
||||
changed = False
|
||||
cls._last_incoming.pop(node_id, None)
|
||||
source = incoming if changed else (editor_state or incoming)
|
||||
regions = boxes_to_regions(source, width, height)
|
||||
preview = render_preview(regions, width, height, _bg_from_image(background))
|
||||
ui = {"dims": [width, height]}
|
||||
if incoming:
|
||||
ui["input_bboxes"] = incoming
|
||||
return io.NodeOutput(
|
||||
preview,
|
||||
fractions_to_bbox_frame(regions, width, height),
|
||||
build_elements(regions),
|
||||
ui=ui,
|
||||
)
|
||||
|
||||
|
||||
class BoundingBoxesExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [CreateBoundingBoxes]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> BoundingBoxesExtension:
|
||||
return BoundingBoxesExtension()
|
||||
@ -1,5 +1,6 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_extras.color_util import hex_to_rgb
|
||||
|
||||
|
||||
class ColorToRGBInt(io.ComfyNode):
|
||||
@ -24,9 +25,11 @@ class ColorToRGBInt(io.ComfyNode):
|
||||
# expect format #RRGGBB
|
||||
if len(color) != 7 or color[0] != "#":
|
||||
raise ValueError("Color must be in format #RRGGBB")
|
||||
r = int(color[1:3], 16)
|
||||
g = int(color[3:5], 16)
|
||||
b = int(color[5:7], 16)
|
||||
try:
|
||||
int(color[1:], 16)
|
||||
except ValueError:
|
||||
raise ValueError("Color must be in format #RRGGBB") from None
|
||||
r, g, b = hex_to_rgb(color)
|
||||
|
||||
rgb_int = r * 256 * 256 + g * 256 + b
|
||||
return io.NodeOutput(rgb_int, color)
|
||||
|
||||
@ -8,7 +8,8 @@ class CLIPTextEncodeControlnet(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeControlnet",
|
||||
category="experimental/conditioning",
|
||||
display_name="CLIP Text Encode (Controlnet)",
|
||||
category="model/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Conditioning.Input("conditioning"),
|
||||
@ -35,11 +36,12 @@ class T5TokenizerOptions(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="T5TokenizerOptions",
|
||||
category="experimental/conditioning",
|
||||
display_name="T5 Tokenizer Options",
|
||||
category="model/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),
|
||||
io.Int.Input("min_length", default=0, min=0, max=10000, step=1, advanced=True),
|
||||
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1),
|
||||
io.Int.Input("min_length", default=0, min=0, max=10000, step=1),
|
||||
],
|
||||
outputs=[io.Clip.Output()],
|
||||
is_experimental=True,
|
||||
|
||||
@ -1070,7 +1070,7 @@ class AddNoise(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="AddNoise",
|
||||
category="experimental/custom_sampling/noise",
|
||||
category="model/sampling/noise",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
@ -1120,7 +1120,7 @@ class ManualSigmas(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ManualSigmas",
|
||||
search_aliases=["custom noise schedule", "define sigmas"],
|
||||
category="experimental/custom_sampling",
|
||||
category="model/sampling/sigmas",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.String.Input("sigmas", default="1, 0.5", multiline=False)
|
||||
|
||||
@ -1,85 +1,68 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import ctypes
|
||||
import logging
|
||||
import ctypes.util
|
||||
import importlib.util
|
||||
from typing import TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import nodes
|
||||
import comfy_angle
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from typing_extensions import override
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _check_opengl_availability():
|
||||
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
|
||||
logger.debug("_check_opengl_availability: starting")
|
||||
missing = []
|
||||
def _preload_angle():
|
||||
egl_path = comfy_angle.get_egl_path()
|
||||
gles_path = comfy_angle.get_glesv2_path()
|
||||
|
||||
# Check Python packages (using find_spec to avoid importing)
|
||||
logger.debug("_check_opengl_availability: checking for glfw package")
|
||||
if importlib.util.find_spec("glfw") is None:
|
||||
missing.append("glfw")
|
||||
if sys.platform == "win32":
|
||||
angle_dir = comfy_angle.get_lib_dir()
|
||||
os.add_dll_directory(angle_dir)
|
||||
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
|
||||
|
||||
logger.debug("_check_opengl_availability: checking for OpenGL package")
|
||||
if importlib.util.find_spec("OpenGL") is None:
|
||||
missing.append("PyOpenGL")
|
||||
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
|
||||
)
|
||||
|
||||
# On Linux without display, check if headless backends are available
|
||||
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
|
||||
if sys.platform.startswith("linux"):
|
||||
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
|
||||
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
|
||||
if not has_display:
|
||||
# Check for EGL or OSMesa libraries
|
||||
logger.debug("_check_opengl_availability: checking for EGL library")
|
||||
has_egl = ctypes.util.find_library("EGL")
|
||||
logger.debug("_check_opengl_availability: checking for OSMesa library")
|
||||
has_osmesa = ctypes.util.find_library("OSMesa")
|
||||
|
||||
# Error disabled for CI as it fails this check
|
||||
# if not has_egl and not has_osmesa:
|
||||
# raise RuntimeError(
|
||||
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
|
||||
# "See error below for installation instructions."
|
||||
# )
|
||||
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
|
||||
|
||||
logger.debug("_check_opengl_availability: completed")
|
||||
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
|
||||
ctypes.CDLL(str(egl_path), mode=mode)
|
||||
ctypes.CDLL(str(gles_path), mode=mode)
|
||||
|
||||
|
||||
# Run early check at import time
|
||||
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
|
||||
_check_opengl_availability()
|
||||
|
||||
# OpenGL modules - initialized lazily when context is created
|
||||
gl = None
|
||||
glfw = None
|
||||
EGL = None
|
||||
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
|
||||
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
|
||||
_preload_angle()
|
||||
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||
|
||||
|
||||
def _import_opengl():
|
||||
"""Import OpenGL module. Called after context is created."""
|
||||
global gl
|
||||
if gl is None:
|
||||
logger.debug("_import_opengl: importing OpenGL.GL")
|
||||
import OpenGL.GL as _gl
|
||||
gl = _gl
|
||||
logger.debug("_import_opengl: import completed")
|
||||
return gl
|
||||
import OpenGL
|
||||
OpenGL.USE_ACCELERATE = False
|
||||
|
||||
|
||||
def _patch_find_library():
|
||||
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
|
||||
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
|
||||
'libGLESv2'. Patch find_library to return the full ANGLE paths so
|
||||
PyOpenGL loads the same libraries we pre-loaded."""
|
||||
if sys.platform == "linux":
|
||||
return
|
||||
import ctypes.util
|
||||
_orig = ctypes.util.find_library
|
||||
def _patched(name):
|
||||
if name == 'EGL':
|
||||
return comfy_angle.get_egl_path()
|
||||
if name == 'GLESv2':
|
||||
return comfy_angle.get_glesv2_path()
|
||||
return _orig(name)
|
||||
ctypes.util.find_library = _patched
|
||||
|
||||
|
||||
_patch_find_library()
|
||||
|
||||
from OpenGL import EGL
|
||||
from OpenGL import GLES3 as gl
|
||||
|
||||
class SizeModeInput(TypedDict):
|
||||
size_mode: str
|
||||
width: int
|
||||
@ -102,7 +85,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
# (-1,-1)---(3,-1)
|
||||
#
|
||||
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
||||
VERTEX_SHADER = """#version 330 core
|
||||
VERTEX_SHADER = """#version 300 es
|
||||
out vec2 v_texCoord;
|
||||
void main() {
|
||||
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
||||
@ -126,14 +109,99 @@ void main() {
|
||||
"""
|
||||
|
||||
|
||||
def _convert_es_to_desktop(source: str) -> str:
|
||||
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
|
||||
# Remove any existing #version directive
|
||||
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
|
||||
# Remove precision qualifiers (not needed in desktop GLSL)
|
||||
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
|
||||
# Prepend desktop GLSL version
|
||||
return "#version 330 core\n" + source
|
||||
|
||||
def _egl_attribs(*values):
|
||||
"""Build an EGL_NONE-terminated EGLint attribute array."""
|
||||
vals = list(values) + [EGL.EGL_NONE]
|
||||
return (ctypes.c_int32 * len(vals))(*vals)
|
||||
|
||||
|
||||
# EGL platform extension constants
|
||||
EGL_PLATFORM_ANGLE_ANGLE = 0x3202
|
||||
EGL_PLATFORM_ANGLE_TYPE_ANGLE = 0x3203
|
||||
EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE = 0x3450
|
||||
EGL_MESA_PLATFORM_SURFACELESS = 0x31DD
|
||||
|
||||
|
||||
_eglGetPlatformDisplayEXT = None
|
||||
|
||||
def _get_egl_platform_display_ext(platform, native_display, attribs):
|
||||
"""Call eglGetPlatformDisplayEXT via ctypes (extension, not in PyOpenGL)."""
|
||||
global _eglGetPlatformDisplayEXT
|
||||
if _eglGetPlatformDisplayEXT is None:
|
||||
from OpenGL import platform as _plat
|
||||
egl_lib = _plat.PLATFORM.EGL
|
||||
_get_proc = egl_lib.eglGetProcAddress
|
||||
_get_proc.restype = ctypes.c_void_p
|
||||
_get_proc.argtypes = [ctypes.c_char_p]
|
||||
ptr = _get_proc(b"eglGetPlatformDisplayEXT")
|
||||
if not ptr:
|
||||
return None
|
||||
func_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p, ctypes.c_void_p)
|
||||
_eglGetPlatformDisplayEXT = func_type(ptr)
|
||||
|
||||
raw = _eglGetPlatformDisplayEXT(platform, native_display, attribs)
|
||||
if not raw:
|
||||
return None
|
||||
return ctypes.cast(raw, EGL.EGLDisplay)
|
||||
|
||||
|
||||
def _get_egl_display():
|
||||
"""Get an EGL display, trying the default first then ANGLE's Vulkan
|
||||
platform for headless environments without a display server."""
|
||||
failures = []
|
||||
|
||||
# Try the default display first (works when X11/Wayland is available)
|
||||
display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
|
||||
if display:
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
try:
|
||||
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
return display, major.value, minor.value
|
||||
except Exception as e:
|
||||
failures.append(f"default: {e}")
|
||||
|
||||
logger.info("Default EGL display unavailable, trying headless fallbacks")
|
||||
|
||||
# Headless fallback strategies, tried in order:
|
||||
headless_strategies = [
|
||||
("surfaceless", EGL_MESA_PLATFORM_SURFACELESS, None, None),
|
||||
("ANGLE Vulkan", EGL_PLATFORM_ANGLE_ANGLE, None,
|
||||
_egl_attribs(EGL_PLATFORM_ANGLE_TYPE_ANGLE, EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE)),
|
||||
]
|
||||
|
||||
for name, platform, native_display, attribs in headless_strategies:
|
||||
display = _get_egl_platform_display_ext(platform, native_display, attribs)
|
||||
if not display:
|
||||
failures.append(f"{name}: eglGetPlatformDisplayEXT returned no display")
|
||||
continue
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
try:
|
||||
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
logger.info(f"Using EGL {name} platform (headless)")
|
||||
return display, major.value, minor.value
|
||||
failures.append(f"{name}: eglInitialize returned false")
|
||||
except Exception as e:
|
||||
failures.append(f"{name}: {e}")
|
||||
continue
|
||||
|
||||
details = "\n".join(f" - {f}" for f in failures)
|
||||
raise RuntimeError(
|
||||
"Failed to initialize EGL display.\n"
|
||||
"No display server and no headless EGL platform available.\n"
|
||||
f"Tried:\n{details}\n"
|
||||
"Ensure GPU drivers are installed or set DISPLAY for a virtual framebuffer."
|
||||
)
|
||||
|
||||
|
||||
def _gl_str(name):
|
||||
"""Get an OpenGL string parameter."""
|
||||
v = gl.glGetString(name)
|
||||
if not v:
|
||||
return "Unknown"
|
||||
if isinstance(v, bytes):
|
||||
return v.decode(errors="replace")
|
||||
return ctypes.string_at(v).decode(errors="replace")
|
||||
|
||||
|
||||
def _detect_output_count(source: str) -> int:
|
||||
@ -159,163 +227,8 @@ def _detect_pass_count(source: str) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _init_glfw():
|
||||
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_glfw: starting")
|
||||
# On macOS, glfw.init() must be called from main thread or it hangs forever
|
||||
if sys.platform == "darwin":
|
||||
logger.debug("_init_glfw: skipping on macOS")
|
||||
raise RuntimeError("GLFW backend not supported on macOS")
|
||||
|
||||
logger.debug("_init_glfw: importing glfw module")
|
||||
import glfw as _glfw
|
||||
|
||||
logger.debug("_init_glfw: calling glfw.init()")
|
||||
if not _glfw.init():
|
||||
raise RuntimeError("glfw.init() failed")
|
||||
|
||||
try:
|
||||
logger.debug("_init_glfw: setting window hints")
|
||||
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
|
||||
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
|
||||
|
||||
logger.debug("_init_glfw: calling create_window()")
|
||||
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
|
||||
if not window:
|
||||
raise RuntimeError("glfw.create_window() failed")
|
||||
|
||||
logger.debug("_init_glfw: calling make_context_current()")
|
||||
_glfw.make_context_current(window)
|
||||
logger.debug("_init_glfw: completed successfully")
|
||||
return window, _glfw
|
||||
except Exception:
|
||||
logger.debug("_init_glfw: failed, terminating glfw")
|
||||
_glfw.terminate()
|
||||
raise
|
||||
|
||||
|
||||
def _init_egl():
|
||||
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_egl: starting")
|
||||
from OpenGL import EGL as _EGL
|
||||
from OpenGL.EGL import (
|
||||
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
|
||||
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
|
||||
eglTerminate, eglDestroyContext, eglDestroySurface,
|
||||
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
|
||||
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
|
||||
)
|
||||
logger.debug("_init_egl: imports completed")
|
||||
|
||||
display = None
|
||||
context = None
|
||||
surface = None
|
||||
|
||||
try:
|
||||
logger.debug("_init_egl: calling eglGetDisplay()")
|
||||
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
|
||||
if display == _EGL.EGL_NO_DISPLAY:
|
||||
raise RuntimeError("eglGetDisplay() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglInitialize()")
|
||||
major, minor = _EGL.EGLint(), _EGL.EGLint()
|
||||
if not eglInitialize(display, major, minor):
|
||||
display = None # Not initialized, don't terminate
|
||||
raise RuntimeError("eglInitialize() failed")
|
||||
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
|
||||
|
||||
config_attribs = [
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
|
||||
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
|
||||
EGL_DEPTH_SIZE, 0, EGL_NONE
|
||||
]
|
||||
configs = (_EGL.EGLConfig * 1)()
|
||||
num_configs = _EGL.EGLint()
|
||||
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
config = configs[0]
|
||||
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
|
||||
|
||||
if not eglBindAPI(EGL_OPENGL_API):
|
||||
raise RuntimeError("eglBindAPI() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreateContext()")
|
||||
context_attribs = [
|
||||
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
|
||||
EGL_NONE
|
||||
]
|
||||
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
|
||||
if context == EGL_NO_CONTEXT:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
|
||||
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
|
||||
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
|
||||
if surface == _EGL.EGL_NO_SURFACE:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglMakeCurrent()")
|
||||
if not eglMakeCurrent(display, surface, surface, context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_egl: completed successfully")
|
||||
return display, context, surface, _EGL
|
||||
|
||||
except Exception:
|
||||
logger.debug("_init_egl: failed, cleaning up")
|
||||
# Clean up any resources on failure
|
||||
if surface is not None:
|
||||
eglDestroySurface(display, surface)
|
||||
if context is not None:
|
||||
eglDestroyContext(display, context)
|
||||
if display is not None:
|
||||
eglTerminate(display)
|
||||
raise
|
||||
|
||||
|
||||
def _init_osmesa():
|
||||
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
|
||||
import ctypes
|
||||
|
||||
logger.debug("_init_osmesa: starting")
|
||||
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
||||
|
||||
logger.debug("_init_osmesa: importing OpenGL.osmesa")
|
||||
from OpenGL import GL as _gl
|
||||
from OpenGL.osmesa import (
|
||||
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
|
||||
OSMESA_RGBA,
|
||||
)
|
||||
logger.debug("_init_osmesa: imports completed")
|
||||
|
||||
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
|
||||
if not ctx:
|
||||
raise RuntimeError("OSMesaCreateContextExt() failed")
|
||||
|
||||
width, height = 64, 64
|
||||
buffer = (ctypes.c_ubyte * (width * height * 4))()
|
||||
|
||||
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
|
||||
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
|
||||
OSMesaDestroyContext(ctx)
|
||||
raise RuntimeError("OSMesaMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_osmesa: completed successfully")
|
||||
return ctx, buffer
|
||||
|
||||
|
||||
class GLContext:
|
||||
"""Manages OpenGL context and resources for shader execution.
|
||||
|
||||
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
|
||||
"""
|
||||
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
@ -327,131 +240,105 @@ class GLContext:
|
||||
|
||||
def __init__(self):
|
||||
if GLContext._initialized:
|
||||
logger.debug("GLContext.__init__: already initialized, skipping")
|
||||
return
|
||||
|
||||
logger.debug("GLContext.__init__: starting initialization")
|
||||
|
||||
global glfw, EGL
|
||||
|
||||
import time
|
||||
start = time.perf_counter()
|
||||
|
||||
self._backend = None
|
||||
self._window = None
|
||||
self._egl_display = None
|
||||
self._egl_context = None
|
||||
self._egl_surface = None
|
||||
self._osmesa_ctx = None
|
||||
self._osmesa_buffer = None
|
||||
self._display = None
|
||||
self._surface = None
|
||||
self._context = None
|
||||
self._vao = None
|
||||
|
||||
# Try backends in order: GLFW → EGL → OSMesa
|
||||
errors = []
|
||||
|
||||
logger.debug("GLContext.__init__: trying GLFW backend")
|
||||
try:
|
||||
self._window, glfw = _init_glfw()
|
||||
self._backend = "glfw"
|
||||
logger.debug("GLContext.__init__: GLFW backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
|
||||
errors.append(("GLFW", e))
|
||||
self._display, self._egl_major, self._egl_minor = _get_egl_display()
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying EGL backend")
|
||||
try:
|
||||
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
|
||||
self._backend = "egl"
|
||||
logger.debug("GLContext.__init__: EGL backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
|
||||
errors.append(("EGL", e))
|
||||
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
|
||||
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying OSMesa backend")
|
||||
try:
|
||||
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
|
||||
self._backend = "osmesa"
|
||||
logger.debug("GLContext.__init__: OSMesa backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
|
||||
errors.append(("OSMesa", e))
|
||||
config = EGL.EGLConfig()
|
||||
n_configs = ctypes.c_int32(0)
|
||||
if not EGL.eglChooseConfig(
|
||||
self._display,
|
||||
_egl_attribs(
|
||||
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
|
||||
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
|
||||
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
|
||||
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
|
||||
),
|
||||
ctypes.byref(config), 1, ctypes.byref(n_configs),
|
||||
) or n_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
|
||||
if self._backend is None:
|
||||
if sys.platform == "win32":
|
||||
platform_help = (
|
||||
"Windows: Ensure GPU drivers are installed and display is available.\n"
|
||||
" CPU-only/headless mode is not supported on Windows."
|
||||
)
|
||||
elif sys.platform == "darwin":
|
||||
platform_help = (
|
||||
"macOS: GLFW is not supported.\n"
|
||||
" Install OSMesa via Homebrew: brew install mesa\n"
|
||||
" Then: pip install PyOpenGL PyOpenGL-accelerate"
|
||||
)
|
||||
else:
|
||||
platform_help = (
|
||||
"Linux: Install one of these backends:\n"
|
||||
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
|
||||
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
|
||||
" Headless (CPU): sudo apt install libosmesa6"
|
||||
)
|
||||
|
||||
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
|
||||
raise RuntimeError(
|
||||
f"Failed to create OpenGL context.\n\n"
|
||||
f"Backend errors:\n{error_details}\n\n"
|
||||
f"{platform_help}"
|
||||
self._surface = EGL.eglCreatePbufferSurface(
|
||||
self._display, config,
|
||||
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
|
||||
)
|
||||
if not self._surface:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
# Now import OpenGL.GL (after context is current)
|
||||
logger.debug("GLContext.__init__: importing OpenGL.GL")
|
||||
_import_opengl()
|
||||
self._context = EGL.eglCreateContext(
|
||||
self._display, config, EGL.EGL_NO_CONTEXT,
|
||||
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
|
||||
)
|
||||
if not self._context:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
# Create VAO (required for core profile, but OSMesa may use compat profile)
|
||||
logger.debug("GLContext.__init__: creating VAO")
|
||||
try:
|
||||
vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(vao)
|
||||
self._vao = vao # Only store after successful bind
|
||||
logger.debug("GLContext.__init__: VAO created successfully")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
|
||||
# OSMesa with older Mesa may not support VAOs
|
||||
# Clean up if we created but couldn't bind
|
||||
if vao:
|
||||
try:
|
||||
gl.glDeleteVertexArrays(1, [vao])
|
||||
except Exception:
|
||||
pass
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
self._vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
except Exception:
|
||||
self._cleanup()
|
||||
raise
|
||||
|
||||
elapsed = (time.perf_counter() - start) * 1000
|
||||
|
||||
# Log device info
|
||||
renderer = gl.glGetString(gl.GL_RENDERER)
|
||||
vendor = gl.glGetString(gl.GL_VENDOR)
|
||||
version = gl.glGetString(gl.GL_VERSION)
|
||||
renderer = renderer.decode() if renderer else "Unknown"
|
||||
vendor = vendor.decode() if vendor else "Unknown"
|
||||
version = version.decode() if version else "Unknown"
|
||||
renderer = _gl_str(gl.GL_RENDERER)
|
||||
vendor = _gl_str(gl.GL_VENDOR)
|
||||
version = _gl_str(gl.GL_VERSION)
|
||||
|
||||
GLContext._initialized = True
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - EGL {self._egl_major}.{self._egl_minor}, {renderer} ({vendor}), GL {version}")
|
||||
|
||||
def make_current(self):
|
||||
if self._backend == "glfw":
|
||||
glfw.make_context_current(self._window)
|
||||
elif self._backend == "egl":
|
||||
from OpenGL.EGL import eglMakeCurrent
|
||||
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
|
||||
elif self._backend == "osmesa":
|
||||
from OpenGL.osmesa import OSMesaMakeCurrent
|
||||
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
|
||||
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
err = EGL.eglGetError()
|
||||
raise RuntimeError(f"eglMakeCurrent() failed (EGL error: 0x{err:04X})")
|
||||
if self._vao is not None:
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
def _cleanup(self):
|
||||
if not self._display:
|
||||
return
|
||||
try:
|
||||
if self._vao is not None:
|
||||
gl.glDeleteVertexArrays(1, [self._vao])
|
||||
self._vao = None
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._context:
|
||||
EGL.eglDestroyContext(self._display, self._context)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._surface:
|
||||
EGL.eglDestroySurface(self._display, self._surface)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglTerminate(self._display)
|
||||
except Exception:
|
||||
pass
|
||||
self._display = None
|
||||
|
||||
|
||||
def _compile_shader(source: str, shader_type: int) -> int:
|
||||
"""Compile a shader and return its ID."""
|
||||
@ -459,8 +346,10 @@ def _compile_shader(source: str, shader_type: int) -> int:
|
||||
gl.glShaderSource(shader, source)
|
||||
gl.glCompileShader(shader)
|
||||
|
||||
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetShaderInfoLog(shader).decode()
|
||||
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
|
||||
error = gl.glGetShaderInfoLog(shader)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
gl.glDeleteShader(shader)
|
||||
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
||||
|
||||
@ -484,8 +373,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
|
||||
gl.glDeleteShader(vertex_shader)
|
||||
gl.glDeleteShader(fragment_shader)
|
||||
|
||||
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetProgramInfoLog(program).decode()
|
||||
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
|
||||
error = gl.glGetProgramInfoLog(program)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
gl.glDeleteProgram(program)
|
||||
raise RuntimeError(f"Program linking failed:\n{error}")
|
||||
|
||||
@ -530,9 +421,6 @@ def _render_shader_batch(
|
||||
ctx = GLContext()
|
||||
ctx.make_current()
|
||||
|
||||
# Convert from GLSL ES to desktop GLSL 330
|
||||
fragment_source = _convert_es_to_desktop(fragment_code)
|
||||
|
||||
# Detect how many outputs the shader actually uses
|
||||
num_outputs = _detect_output_count(fragment_code)
|
||||
|
||||
@ -558,9 +446,9 @@ def _render_shader_batch(
|
||||
try:
|
||||
# Compile shaders (once for all batches)
|
||||
try:
|
||||
program = _create_program(VERTEX_SHADER, fragment_source)
|
||||
program = _create_program(VERTEX_SHADER, fragment_code)
|
||||
except RuntimeError:
|
||||
logger.error(f"Fragment shader:\n{fragment_source}")
|
||||
logger.error(f"Fragment shader:\n{fragment_code}")
|
||||
raise
|
||||
|
||||
gl.glUseProgram(program)
|
||||
@ -723,13 +611,13 @@ def _render_shader_batch(
|
||||
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
||||
|
||||
# Read back outputs for this batch
|
||||
# (glGetTexImage is synchronous, implicitly waits for rendering)
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
||||
batch_outputs = []
|
||||
for tex in output_textures:
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
|
||||
batch_outputs.append(img[::-1, :, :].copy())
|
||||
for i in range(num_outputs):
|
||||
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
|
||||
buf = np.empty((height, width, 4), dtype=np.float32)
|
||||
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
|
||||
batch_outputs.append(buf[::-1, :, :].copy())
|
||||
|
||||
# Pad with black images for unused outputs
|
||||
black_img = np.zeros((height, width, 4), dtype=np.float32)
|
||||
@ -750,18 +638,18 @@ def _render_shader_batch(
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
||||
gl.glUseProgram(0)
|
||||
|
||||
for tex in input_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in curve_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in output_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in ping_pong_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
if input_textures:
|
||||
gl.glDeleteTextures(len(input_textures), input_textures)
|
||||
if curve_textures:
|
||||
gl.glDeleteTextures(len(curve_textures), curve_textures)
|
||||
if output_textures:
|
||||
gl.glDeleteTextures(len(output_textures), output_textures)
|
||||
if ping_pong_textures:
|
||||
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
|
||||
if fbo is not None:
|
||||
gl.glDeleteFramebuffers(1, [fbo])
|
||||
for pp_fbo in ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(1, [pp_fbo])
|
||||
if ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
|
||||
if program is not None:
|
||||
gl.glDeleteProgram(program)
|
||||
|
||||
|
||||
77
comfy_extras/nodes_json_prompt.py
Normal file
77
comfy_extras/nodes_json_prompt.py
Normal file
@ -0,0 +1,77 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_extras.color_util import normalize_palette
|
||||
|
||||
|
||||
class BuildJsonPromptIdeogram(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
color_palette = io.Colors.Input(
|
||||
"color_palette",
|
||||
socketless=False,
|
||||
tooltip="Hex color codes that steer the image's dominant colors. Up to 16 entries.",
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="BuildJsonPromptIdeogram",
|
||||
display_name="Build JSON Prompt (Ideogram)",
|
||||
category="text",
|
||||
description="Build a JSON prompt for the Ideogram 4 model.",
|
||||
inputs=[
|
||||
io.Array.Input("element", tooltip="Prompt elements from the node Create Bounding Boxes."),
|
||||
io.String.Input("high_level_description", multiline=True, default="",
|
||||
tooltip="Optional description of the image in one or two sentences. Strongly recommended."),
|
||||
io.String.Input("background", multiline=True, default="",
|
||||
tooltip="Mandatory description of the image background or environment."),
|
||||
io.DynamicCombo.Input("style", options=[
|
||||
io.DynamicCombo.Option("none", []),
|
||||
io.DynamicCombo.Option("photo", [io.String.Input("photo", default="", tooltip="Camera or lens details for photographic outputs (e.g. 35mm, f/1.4, bokeh).")]),
|
||||
io.DynamicCombo.Option("art_style", [io.String.Input("art_style", default="", tooltip="Art style description (e.g. flat vector illustration, bold outlines).")]),
|
||||
]),
|
||||
io.String.Input("aesthetics", default="", tooltip="Mandatory aesthetic keywords (e.g. moody, cinematic, desaturated)."),
|
||||
io.String.Input("lighting", default="", tooltip="Mandatory lighting description (e.g. golden hour, rim light, dramatic shadows)."),
|
||||
io.String.Input("medium", default="", tooltip="Mandatory medium type (e.g. photograph, illustration, 3d_render, painting, graphic_design). When style = photo, set to photograph."),
|
||||
color_palette,
|
||||
],
|
||||
outputs=[io.Dict.Output(display_name="prompt")],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, element, style, high_level_description="", background="",
|
||||
aesthetics="", lighting="", medium="", color_palette=None) -> io.NodeOutput:
|
||||
elements = element if isinstance(element, list) else []
|
||||
kind = style.get("style", "none") if isinstance(style, dict) else "none"
|
||||
photo = style.get("photo", "") if isinstance(style, dict) else ""
|
||||
art_style = style.get("art_style", "") if isinstance(style, dict) else ""
|
||||
palette = normalize_palette(color_palette or [])
|
||||
|
||||
caption: dict = {}
|
||||
if high_level_description.strip():
|
||||
caption["high_level_description"] = high_level_description
|
||||
if kind != "none":
|
||||
style_desc: dict = {"aesthetics": aesthetics, "lighting": lighting}
|
||||
if kind == "photo":
|
||||
style_desc["photo"] = photo
|
||||
style_desc["medium"] = medium
|
||||
else:
|
||||
style_desc["medium"] = medium
|
||||
style_desc["art_style"] = art_style
|
||||
if palette:
|
||||
style_desc["color_palette"] = palette
|
||||
caption["style_description"] = style_desc
|
||||
caption["compositional_deconstruction"] = {
|
||||
"background": background,
|
||||
"elements": elements,
|
||||
}
|
||||
return io.NodeOutput(caption)
|
||||
|
||||
|
||||
class JsonPromptExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [BuildJsonPromptIdeogram]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> JsonPromptExtension:
|
||||
return JsonPromptExtension()
|
||||
@ -337,6 +337,36 @@ class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeKrea2(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "model/merging/model specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["first."] = argument
|
||||
arg_dict["tmlp."] = argument
|
||||
arg_dict["txtmlp."] = argument
|
||||
arg_dict["tproj."] = argument
|
||||
|
||||
for i in range(2):
|
||||
arg_dict["txtfusion.layerwise_blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["txtfusion.projector."] = argument
|
||||
|
||||
for i in range(2):
|
||||
arg_dict["txtfusion.refiner_blocks.{}.".format(i)] = argument
|
||||
|
||||
for i in range(28):
|
||||
arg_dict["blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["last."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
@ -353,4 +383,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
|
||||
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
|
||||
"ModelMergeQwenImage": ModelMergeQwenImage,
|
||||
"ModelMergeKrea2": ModelMergeKrea2,
|
||||
}
|
||||
|
||||
@ -123,7 +123,8 @@ class PhotoMakerLoader(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PhotoMakerLoader",
|
||||
category="experimental/photomaker",
|
||||
display_name="Load PhotoMaker Model",
|
||||
category="model/loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
|
||||
],
|
||||
@ -149,7 +150,8 @@ class PhotoMakerEncode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PhotoMakerEncode",
|
||||
category="experimental/photomaker",
|
||||
display_name="PhotoMaker Encode",
|
||||
category="model/conditioning/photomaker",
|
||||
inputs=[
|
||||
io.Photomaker.Input("photomaker"),
|
||||
io.Image.Input("image"),
|
||||
|
||||
33
comfy_extras/nodes_seed.py
Normal file
33
comfy_extras/nodes_seed.py
Normal file
@ -0,0 +1,33 @@
|
||||
import sys
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class SeedNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SeedNode",
|
||||
display_name="Seed",
|
||||
search_aliases=["seed", "random"],
|
||||
category="utilities",
|
||||
inputs=[
|
||||
io.Int.Input("seed", min=0, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed),
|
||||
],
|
||||
outputs=[io.Int.Output(display_name="seed")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, seed: int) -> io.NodeOutput:
|
||||
return io.NodeOutput(seed)
|
||||
|
||||
|
||||
class SeedExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [SeedNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> SeedExtension:
|
||||
return SeedExtension()
|
||||
@ -119,7 +119,7 @@ class StableCascade_SuperResolutionControlnet(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StableCascade_SuperResolutionControlnet",
|
||||
category="experimental/stable_cascade",
|
||||
category="experimental/stable cascade",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
|
||||
@ -440,6 +440,57 @@ class JsonExtractString(io.ComfyNode):
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return io.NodeOutput("")
|
||||
|
||||
|
||||
def _dump_json(value, indent):
|
||||
return json.dumps(value, ensure_ascii=False, indent=indent or None)
|
||||
|
||||
|
||||
class ConvertDictionaryToString(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ConvertDictionaryToString",
|
||||
display_name="Convert Dictionary to String",
|
||||
category="text",
|
||||
search_aliases=["json", "dict to json", "stringify", "serialize", "dict to string"],
|
||||
inputs=[
|
||||
io.Dict.Input("dictionary"),
|
||||
io.Int.Input("indent", default=2, min=0, max=8,
|
||||
tooltip="Spaces per indent level. 0 produces compact single-line string."),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, dictionary, indent=2):
|
||||
return io.NodeOutput(_dump_json(dictionary, indent))
|
||||
|
||||
|
||||
class ConvertArrayToString(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ConvertArrayToString",
|
||||
display_name="Convert Array to String",
|
||||
category="text",
|
||||
search_aliases=["json", "list to json", "stringify", "serialize", "list to string", "array to json"],
|
||||
inputs=[
|
||||
io.Array.Input("array"),
|
||||
io.Int.Input("indent", default=2, min=0, max=8,
|
||||
tooltip="Spaces per indent level. 0 produces compact single-line string."),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, array, indent=2):
|
||||
return io.NodeOutput(_dump_json(array, indent))
|
||||
|
||||
|
||||
class StringExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@ -457,6 +508,8 @@ class StringExtension(ComfyExtension):
|
||||
RegexExtract,
|
||||
RegexReplace,
|
||||
JsonExtractString,
|
||||
ConvertDictionaryToString,
|
||||
ConvertArrayToString,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> StringExtension:
|
||||
|
||||
@ -143,7 +143,7 @@ class VAEDecodeTripoSplat(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeTripoSplat",
|
||||
display_name="TripoSplat Decode",
|
||||
category="3d/latent",
|
||||
category="model/latent/triposplat",
|
||||
description="Decode the sampled TripoSplat latent into a 3D gaussian splat. "
|
||||
"Modify the number of gaussians to vary the density.",
|
||||
inputs=[
|
||||
@ -188,7 +188,7 @@ class TripoSplatSamplingPreview(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoSplatSamplingPreview",
|
||||
display_name="TripoSplat Sampling Preview",
|
||||
category="3d/latent",
|
||||
category="model/latent/triposplat",
|
||||
description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded "
|
||||
"gaussian splat preview at each step.",
|
||||
inputs=[
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.26.0"
|
||||
__version__ = "0.27.0"
|
||||
|
||||
6
main.py
6
main.py
@ -236,7 +236,7 @@ import hook_breaker_ac10a0
|
||||
import comfy.memory_management
|
||||
import comfy.model_patcher
|
||||
|
||||
if args.enable_dynamic_vram or (enables_dynamic_vram() and (comfy.model_management.is_nvidia() or comfy.model_management.is_intel_xpu()) and not comfy.model_management.is_wsl()):
|
||||
if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
|
||||
if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
|
||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
else:
|
||||
@ -403,7 +403,7 @@ def prompt_worker(q, server_instance):
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
if not asset_seeder.is_disabled():
|
||||
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=args.enable_asset_hashing)
|
||||
asset_seeder.resume()
|
||||
|
||||
|
||||
@ -458,7 +458,7 @@ def setup_database():
|
||||
if dependencies_available():
|
||||
init_db()
|
||||
if args.enable_assets:
|
||||
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True):
|
||||
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=args.enable_asset_hashing):
|
||||
logging.info("Background asset scan initiated for models, input, output")
|
||||
except Exception as e:
|
||||
if "database is locked" in str(e):
|
||||
|
||||
39
nodes.py
39
nodes.py
@ -159,6 +159,29 @@ class ConditioningConcat:
|
||||
|
||||
return (out, )
|
||||
|
||||
class ConditioningMultiply:
|
||||
SEARCH_ALIASES = ["scale conditioning", "scale prompt", "multiply conditioning", "multiply prompt"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
"multiplier": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01})
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "multiply"
|
||||
CATEGORY = "model/conditioning/transform"
|
||||
|
||||
def multiply(self, conditioning, multiplier):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
values = {}
|
||||
pooled_output = t[1].get("pooled_output", None)
|
||||
if pooled_output is not None:
|
||||
values["pooled_output"] = pooled_output * multiplier
|
||||
scaled = node_helpers.conditioning_set_values([[t[0] * multiplier, t[1]]], values)[0]
|
||||
c.append(scaled)
|
||||
return (c,)
|
||||
|
||||
class ConditioningSetArea:
|
||||
SEARCH_ALIASES = ["regional prompt", "area prompt", "spatial conditioning", "localized prompt"]
|
||||
|
||||
@ -326,7 +349,7 @@ class VAEDecodeTiled:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "experimental"
|
||||
CATEGORY = "model/latent"
|
||||
|
||||
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
|
||||
if tile_size < overlap * 4:
|
||||
@ -373,7 +396,7 @@ class VAEEncodeTiled:
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "experimental"
|
||||
CATEGORY = "model/latent"
|
||||
|
||||
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
|
||||
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||
@ -491,7 +514,7 @@ class SaveLatent:
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "experimental"
|
||||
CATEGORY = "model/latent"
|
||||
|
||||
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
@ -536,7 +559,7 @@ class LoadLatent:
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
|
||||
return {"required": {"latent": [sorted(files), ]}, }
|
||||
|
||||
CATEGORY = "experimental"
|
||||
CATEGORY = "model/latent"
|
||||
|
||||
RETURN_TYPES = ("LATENT", )
|
||||
FUNCTION = "load"
|
||||
@ -2050,6 +2073,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ConditioningAverage": ConditioningAverage,
|
||||
"ConditioningCombine": ConditioningCombine,
|
||||
"ConditioningConcat": ConditioningConcat,
|
||||
"ConditioningMultiply": ConditioningMultiply,
|
||||
"ConditioningSetArea": ConditioningSetArea,
|
||||
"ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
|
||||
"ConditioningSetAreaStrength": ConditioningSetAreaStrength,
|
||||
@ -2121,6 +2145,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ConditioningAverage ": "Conditioning (Average)",
|
||||
"ConditioningAverage": "Conditioning (Average)",
|
||||
"ConditioningConcat": "Conditioning (Concat)",
|
||||
"ConditioningMultiply": "Conditioning (Multiply)",
|
||||
"ConditioningSetArea": "Conditioning (Set Area)",
|
||||
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
||||
"ConditioningSetAreaStrength": "Conditioning (Set Area Strength)",
|
||||
@ -2130,6 +2155,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"GLIGENTextBoxApply": "Apply GLIGEN Text Box",
|
||||
"ConditioningZeroOut": "Conditioning Zero Out",
|
||||
# Latent
|
||||
"LoadLatent": "Load Latent",
|
||||
"SaveLatent": "Save Latent",
|
||||
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
||||
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
||||
"VAEDecode": "VAE Decode",
|
||||
@ -2164,7 +2191,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageSharpen": "Sharpen Image",
|
||||
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
||||
"GetImageSize": "Get Image Size",
|
||||
# experimental
|
||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||
}
|
||||
@ -2374,6 +2400,8 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_images.py",
|
||||
"nodes_video_model.py",
|
||||
"nodes_ideogram4.py",
|
||||
"nodes_bounding_boxes.py",
|
||||
"nodes_json_prompt.py",
|
||||
"nodes_train.py",
|
||||
"nodes_dataset.py",
|
||||
"nodes_sag.py",
|
||||
@ -2473,6 +2501,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_gaussian_splat.py",
|
||||
"nodes_triposplat.py",
|
||||
"nodes_depth_anything_3.py",
|
||||
"nodes_seed.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
14
openapi.yaml
14
openapi.yaml
@ -1692,6 +1692,12 @@ paths:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Unsupported media type
|
||||
"422":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Validation error (e.g., disallowed model_type tag)
|
||||
"500":
|
||||
content:
|
||||
application/json:
|
||||
@ -2137,6 +2143,12 @@ paths:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Source asset with given hash not found
|
||||
"422":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Validation error (e.g., disallowed model_type tag)
|
||||
"500":
|
||||
content:
|
||||
application/json:
|
||||
@ -2992,7 +3004,7 @@ paths:
|
||||
format: uuid
|
||||
type: string
|
||||
- description: |
|
||||
When present, each output item in the response receives a `short_url` field containing an owner-gated durable link for that asset. Omit this parameter (the default) to receive a response identical to the no-param baseline. The value selects the link's lifetime: use `ephemeral_tool_chain` for short-lived machine-to-machine handoffs (~15 minutes); use `default` for durable human-revisitable links (30 days). Links are minted only for the authenticated request owner and are not resolvable by other users.
|
||||
When present, each output item in the response receives a `short_url` field containing a short link for that asset. Omit this parameter (the default) to receive a response identical to the no-param baseline. The value selects the link's lifetime and auth model: use `ephemeral_tool_chain` for short-lived (≤5 minute) machine-to-machine handoffs — these are public bearer links where the link ID itself is the credential, so anyone holding the link can resolve it (intended for pasting into an agent/MCP tool chain); use `default` for durable (30 day) human-revisitable links, which are owner-gated and resolvable only by the authenticated owner. Links are always minted under the authenticated request owner's identity; the auth model is selected by the server and is never settable by the caller.
|
||||
in: query
|
||||
name: short_link
|
||||
schema:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.26.0"
|
||||
version = "0.27.0"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.45.19
|
||||
comfyui-workflow-templates==0.10.2
|
||||
comfyui-embedded-docs==0.5.5
|
||||
comfyui-frontend-package==1.45.20
|
||||
comfyui-workflow-templates==0.11.1
|
||||
comfyui-embedded-docs==0.5.6
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=16.0.0
|
||||
comfy-kitchen==0.2.10
|
||||
comfy-kitchen==0.2.16
|
||||
comfy-aimdo==0.4.10
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
@ -33,5 +33,5 @@ kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
PyOpenGL
|
||||
glfw
|
||||
PyOpenGL>=3.1.8
|
||||
comfy-angle
|
||||
|
||||
@ -228,6 +228,62 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
with self.assertRaises(KeyError):
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
def test_int8_convrot_metadata_loads_into_params(self):
|
||||
"""ConvRot metadata must reach TensorWiseINT8Layout params."""
|
||||
torch.manual_seed(123)
|
||||
layer_quant_config = {
|
||||
"layer": {
|
||||
"format": "int8_tensorwise",
|
||||
"convrot": True,
|
||||
"convrot_groupsize": 256,
|
||||
}
|
||||
}
|
||||
weight = torch.randn(16, 256, dtype=torch.bfloat16)
|
||||
bias = torch.randn(16, dtype=torch.bfloat16)
|
||||
q_weight = QuantizedTensor.from_float(
|
||||
weight,
|
||||
"TensorWiseINT8Layout",
|
||||
per_channel=True,
|
||||
convrot=True,
|
||||
convrot_groupsize=256,
|
||||
)
|
||||
state_dict = {
|
||||
"layer.weight": q_weight._qdata,
|
||||
"layer.bias": bias,
|
||||
"layer.weight_scale": q_weight._params.scale,
|
||||
}
|
||||
|
||||
state_dict, _ = comfy.utils.convert_old_quants(
|
||||
state_dict,
|
||||
metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})},
|
||||
)
|
||||
model = torch.nn.Module()
|
||||
model.layer = ops.mixed_precision_ops({}).Linear(256, 16, device="cpu", dtype=torch.bfloat16)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
self.assertIsInstance(model.layer.weight, QuantizedTensor)
|
||||
self.assertEqual(model.layer.weight._layout_cls, "TensorWiseINT8Layout")
|
||||
self.assertTrue(model.layer.weight._params.convrot)
|
||||
self.assertEqual(model.layer.weight._params.convrot_groupsize, 256)
|
||||
|
||||
input_tensor = torch.randn(4, 256, dtype=torch.bfloat16)
|
||||
loaded_out = model.layer(input_tensor)
|
||||
ref_out = torch.nn.functional.linear(input_tensor, q_weight, bias)
|
||||
self.assertTrue(torch.equal(loaded_out, ref_out))
|
||||
|
||||
fp16_input = input_tensor.to(torch.float16)
|
||||
loaded_fp16_out = model.layer(fp16_input)
|
||||
ref_fp16_out = torch.nn.functional.linear(
|
||||
fp16_input,
|
||||
q_weight.to(dtype=torch.float16),
|
||||
bias.to(dtype=torch.float16),
|
||||
)
|
||||
self.assertTrue(torch.equal(loaded_fp16_out, ref_fp16_out))
|
||||
|
||||
saved = model.state_dict()
|
||||
saved_conf = json.loads(saved["layer.comfy_quant"].numpy().tobytes())
|
||||
self.assertTrue(saved_conf["convrot"])
|
||||
self.assertEqual(saved_conf["convrot_groupsize"], 256)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user