Compare commits

..

10 Commits

Author SHA1 Message Date
c4f4bd3fd9 Move logic nodes into utils category 2026-05-21 21:38:04 +08:00
4259a0c7c3 Update MoGe nodes display names, search aliases and descriptions (#14030) 2026-05-21 16:50:09 +08:00
af3d9b60af chore: Dataset nodes clean-up (CORE-237) (#14002) 2026-05-21 15:14:16 +08:00
7b7c5fed7c Update MediaPipe nodes to standardize with existing code base (CORE-242) (#14025) 2026-05-21 14:39:30 +08:00
1668aaf037 openapi: remove cloud-only job_ids query param from GET /api/assets (#14016)
The job_ids query parameter on GET /api/assets is tagged x-runtime:
[cloud] and only exists for cloud's variant of this endpoint. Cloud
removed all consumers and the cloud-side handler/codegen/tests in
Comfy-Org/cloud#3778. With cloud no longer accepting this parameter,
the [cloud-only] documentation here is wrong — drop it so the daily
sync to cloud/services/ingest/vendor/openapi.yaml propagates the
removal.
2026-05-20 21:32:08 -07:00
ea174d3f12 fix(openapi): correct POST /api/assets/import to importPublishedAssets (#14027)
The operation at POST /api/assets/import was defined as `importAssets`
with a URL-list body shape, but no runtime actually serves that
operation at this path. The cloud runtime serves a different operation
here — `importPublishedAssets` — which imports published-workflow
assets into the caller's library by ID, not by URL.

Cloud's URL-based asset ingestion lives at separate paths
(POST /assets/download + GET /assets/remote-metadata) tracked
elsewhere; nothing in this PR affects that work.

Changes:

- Replace the operation at POST /api/assets/import with
  `importPublishedAssets`, taking ImportPublishedAssetsRequest
  (published_asset_ids + optional share_id) and returning
  ImportPublishedAssetsResponse (list of AssetInfo).
- Remove the unused AssetImportRequest component schema (no other
  references in the spec).
- Operation and schemas tagged x-runtime: [cloud] with [cloud-only]
  description prefix, matching the existing convention for
  cloud-runtime-only operations elsewhere in the spec.

Spectral lint passes (0 errors); the two hint-level findings on
the spec are pre-existing and unrelated.

No FE consumer references AssetImportRequest today; this is a pure
spec correction to match what the cloud runtime actually serves.
2026-05-20 21:28:16 -07:00
9f9b32ed97 feat: add OAuth 2.1 + RFC 7591 DCR endpoints to openapi.yaml (#14026)
Add the OAuth 2.1 authorization flow and RFC 7591 Dynamic Client
Registration endpoints to the shared spec, alongside the existing
auth-tagged operations (/api/auth/session, /api/auth/token,
/.well-known/jwks.json). All tagged x-runtime: [cloud] with a
[cloud-only] description prefix, following the established
convention for cloud-runtime-only operations.

Endpoints:

- GET  /.well-known/oauth-authorization-server  (RFC 8414 metadata)
- GET  /.well-known/oauth-protected-resource    (RFC 9728 metadata)
- GET  /oauth/authorize                         (consent challenge)
- POST /oauth/authorize                         (consent submission)
- POST /oauth/token                             (RFC 6749 §3.2)
- POST /oauth/register                          (RFC 7591 §3.1 DCR)

Component schemas added:

- OAuthAuthorizationServerMetadata
- OAuthProtectedResourceMetadata
- OAuthConsentChallenge, OAuthConsentChallengeWorkspace
- OAuthAuthorizeRedirectResponse
- OAuthTokenResponse, OAuthTokenError
- OAuthRegisterRequest, OAuthRegisterResponse, OAuthRegisterError

These endpoints are implemented in the cloud runtime today and
are called by browser frontends rendering the consent UI and by
MCP-spec-compliant clients (Claude Desktop, Cursor, etc.) doing
auto-discovery + self-registration. Documenting them in the
shared spec lets the cloud frontend generate types directly from
this spec instead of maintaining a parallel definition.

Spectral lints clean (0 errors). The hint-level findings on
OAuthTokenError / OAuthRegisterError ("standard error schema")
match the same hint on CloudError — these are protocol-specific
RFC-shaped errors, not generic application errors.
2026-05-20 21:22:12 -07:00
95fdc6cf91 Repo security stuff. (#14019) 2026-05-20 17:17:55 -07:00
5aa5ccc9e0 Multi-threaded load of models from disk (big load time speedups & Offload to disk) (CORE-43,CORE-152,CORE-164,CORE-165,CORE-117) (#13802)
* model_management: disable non-dynamic smart memory

Disable smart memory outright for non dynamic models.

This is a minor step towards deprecation of --disable-dynamic-vram
and the legacy ModelPatcher.

This is needed for estimate-free model development, where new models
can opt-out of supplying a memory estimate and not have to worry
about hard VRAM allocations due to legacy non-dynamic model patchers

This is also a general stability increase for a lot of stray use cases
where estimates may still be off and going forward we are not going
to accurately maintain such estimates.

* pinned_memory: implement with aimdo growable buffer

Use a single growable buffer so we can do threaded pre-warming on
pinned memory.

* mm: use aimdo to do transfer from disk to pin

Aimdo implements a faster threaded loader.

* Add stream host pin buffer for AIMDO casts

Introduce per-offload-stream HostBuffer reuse for pinned staging,
include it in cast buffer reset synchronization.

Defer actual casts that go via this pin path to a separate pass
such that the buffer can be allocated monolithically (to avoid
cudaHostRegister thrash).

* remove old pin path

* Implement JIT pinned memory pressure

Replace the predictive pin pressure mechanism with JIT PIN memory
pressure.

* LowVRAMPatch: change to two-phase visit

* lora: re-implement as inplace swiss-army-knife operation

* prepare for multiple pin sets

* implement pinned loras

* requirements: comfy-aimdo 0.4.0

* ops: remove unused arg

This was defeatured in aimdo iteration

* ops: sync the CPU with only the offload stream activity

This was syncing with the offload stream which itself is synced with the
compute stream, so this was syncing CPU with compute transitively. Define
the event to sync it more gently.

* pins: implement freeing intermediate for pinned memory

Pinning is more important than inactive intermediates and the stream
pin buffer is more important than even active intermediates.

* execution: implement pin eviction on RAM presure

Add back proper pin freeing on RAM pressure

* implement pin registration swaps

Uncap the windows pins from 50% by extending the pool and have a pressure
mechanism to move the pin reservations om demand.

This unfortunately implies a GPU sync to do the freeing so significant
hysterisis needs to be added to consolidate these pressure events.

* cli_args/execution: Implement lower background cache-ram threshold

Limit the amount of RAM background intermediates can use, so that
switching workflows doesn't degrade performance too much.

* make default

* bump aimdo

* model-patcher: force-cast tiny weights

Flux 2 gets crazy stalls due to a mix of tiny and giant weights
creating lopsided steam buffer rotations which creates stalls.

* ops: refactor in prep for chunking

* mm: delegate pin-on-the-way to aimdo

Aimdo is able to chunk and slice this on the way for better CPU->GPU
overlap. The main advantage is the ability to shorten the bus contention
window between previous weight transfer and the next weights vbar
fault.

* bump aimdo

* pinning updates

* specify hostbuf max allocation size

There a signs of virtual memory exhaustion on some linux systems when
throwing 128GB for every little piece. Pass the actual to save aimdo
from over-estimates

* tests: update execution tests for caching

The default caching changed to ram-cache so update these tests
accordingly.

Remove the LRU 0 test as this also falls through to RAM cache.
2026-05-20 17:03:58 -07:00
4d6a058bf1 feat: MediaPipe face detection (CORE-235) (#14009)
* Initial mediapipe face detection support

* Update face_geometry.py

* Account for diff sized batch input

* Model folder placeholder
2026-05-20 16:07:48 -07:00
32 changed files with 2522 additions and 341 deletions

View File

@ -1,2 +1,5 @@
# Admins
* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai
/CODEOWNERS @comfyanonymous
/.ci/ @comfyanonymous
/.github/ @comfyanonymous

View File

@ -433,7 +433,7 @@ See also: [https://www.comfy.org/](https://www.comfy.org/)
## Frontend Development
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). The compiled JS files (from TS/Vue) are published to [pypi](https://pypi.org/project/comfyui-frontend-package) and installed as a dependency in ComfyUI.
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
### Reporting Issues and Requesting Features

View File

@ -110,13 +110,11 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
CACHE_RAM_AUTO_GB = -1.0
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metavar="GB", help="Use RAM pressure caching with the specified headroom thresholds. This is the default caching mode. The first value sets the active-cache threshold; the optional second value sets the inactive-cache/pin threshold. Defaults when no values are provided: active 25%% of system RAM (min 4GB, max 32GB), inactive 75%% of system RAM (min 12GB, max 96GB).")
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@ -245,6 +243,9 @@ if comfy.options.args_parsing:
else:
args = parser.parse_args([])
if args.cache_ram is not None and len(args.cache_ram) > 2:
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
if args.windows_standalone_build:
args.auto_launch = True

View File

@ -484,16 +484,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
return weight
def prefetch_prepared_value(value, allocate_buffer, stream):
def prefetch_prepared_value(value, counter, destination, stream, copy):
if isinstance(value, torch.Tensor):
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
size = comfy.memory_management.vram_aligned_size(value)
offset = counter[0]
counter[0] += size
if destination is None:
return value
dest = destination[offset:offset + size]
if copy:
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
elif isinstance(value, weight_adapter.WeightAdapterBase):
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy))
elif isinstance(value, tuple):
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value)
elif isinstance(value, list):
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value]
return value

View File

@ -15,7 +15,7 @@ class TensorFileSlice(NamedTuple):
size: int
def read_tensor_file_slice_into(tensor, destination):
def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None):
if isinstance(tensor, QuantizedTensor):
if not isinstance(destination, QuantizedTensor):
@ -23,12 +23,17 @@ def read_tensor_file_slice_into(tensor, destination):
if tensor._layout_cls != destination._layout_cls:
return False
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream,
destination2=(destination2._qdata if destination2 is not None else None)):
return False
dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
if destination2 is not None:
dst_orig_dtype = destination2._params.orig_dtype
destination2._params.copy_from(destination._params, non_blocking=True)
destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype)
return True
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
@ -48,6 +53,17 @@ def read_tensor_file_slice_into(tensor, destination):
if info.size == 0:
return True
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
if hostbuf is not None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
device_ptr = destination2.data_ptr() if destination2 is not None else 0
hostbuf.read_file_slice(file_obj, info.offset, info.size,
offset=destination.data_ptr() - hostbuf.get_raw_address(),
stream=stream_ptr,
device_ptr=device_ptr,
device=None if destination2 is None else destination2.device.index)
return True
buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))
@ -151,7 +167,7 @@ def set_ram_cache_release_state(callback, headroom):
extra_ram_release_callback = callback
RAM_CACHE_HEADROOM = max(0, int(headroom))
def extra_ram_release(target):
def extra_ram_release(target, free_active=False):
if extra_ram_release_callback is None:
return 0
return extra_ram_release_callback(target)
return extra_ram_release_callback(target, free_active=free_active)

View File

@ -31,6 +31,7 @@ from contextlib import nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.host_buffer
import comfy_aimdo.vram_buffer
class VRAMState(Enum):
@ -495,6 +496,14 @@ except:
current_loaded_models = []
DIRTY_MMAPS = set()
PIN_PRESSURE_HYSTERESIS = 256 * 1024 * 1024
#Freeing registerables on pressure does imply a GPU sync, so go big on
#the hysteresis so each expensive sync gives us back a good chunk.
REGISTERABLE_PIN_HYSTERESIS = 2048 * 1024 * 1024
def module_size(module):
module_mem = 0
sd = module.state_dict()
@ -503,27 +512,46 @@ def module_size(module):
module_mem += t.nbytes
return module_mem
def module_mmap_residency(module, free=False):
mmap_touched_mem = 0
module_mem = 0
bounced_mmaps = set()
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nbytes
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
continue
mmap_touched_mem += t.nbytes
if not free:
continue
storage._comfy_tensor_mmap_touched = False
mmap_obj = storage._comfy_tensor_mmap_refs[0]
if mmap_obj in bounced_mmaps:
continue
mmap_obj.bounce()
bounced_mmaps.add(mmap_obj)
return mmap_touched_mem, module_mem
def mark_mmap_dirty(storage):
mmap_refs = getattr(storage, "_comfy_tensor_mmap_refs", None)
if mmap_refs is not None:
DIRTY_MMAPS.add(mmap_refs[0])
def free_pins(size, evict_active=False):
freed_total = 0
for loaded_model in reversed(current_loaded_models):
if size <= 0:
return freed_total
model = loaded_model.model
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
freed = model.partially_unload_ram(size)
freed_total += freed
size -= freed
return freed_total
def ensure_pin_budget(size, evict_active=False):
shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available
if shortfall <= 0:
return True
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
return free_pins(to_free, evict_active=evict_active) >= shortfall
def ensure_pin_registerable(size, evict_active=False):
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if shortfall <= 0:
return True
shortfall += REGISTERABLE_PIN_HYSTERESIS
for loaded_model in reversed(current_loaded_models):
model = loaded_model.model
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
shortfall -= model.unregister_inactive_pins(shortfall)
if shortfall <= 0:
return True
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
class LoadedModel:
def __init__(self, model):
@ -553,9 +581,6 @@ class LoadedModel:
def model_memory(self):
return self.model.model_size()
def model_mmap_residency(self, free=False):
return self.model.model_mmap_residency(free=free)
def model_loaded_memory(self):
return self.model.loaded_size()
@ -635,15 +660,9 @@ WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS:
import comfy.windows
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
def get_free_ram():
return comfy.windows.get_free_ram()
else:
def get_free_ram():
return psutil.virtual_memory().available
if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
@ -657,7 +676,6 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
cleanup_models_gc()
comfy.memory_management.extra_ram_release(max(pins_required, ram_required))
unloaded_model = []
can_unload = []
unloaded_models = []
@ -673,11 +691,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for x in can_unload_sorted:
i = x[-1]
memory_to_free = 1e32
pins_to_free = 1e32
if not DISABLE_SMART_MEMORY or device is None:
if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None):
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
pins_to_free = pins_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
if for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size()
@ -685,18 +701,6 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i)
if pins_to_free > 0:
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
for x in can_unload_sorted:
i = x[-1]
ram_to_free = ram_required - psutil.virtual_memory().available
if ram_to_free <= 0 and i not in unloaded_model:
continue
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
if resident_memory > 0:
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
@ -762,29 +766,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
model_to_unload.model.detach(unpatch_all=False)
model_to_unload.model_finalizer.detach()
total_memory_required = {}
total_pins_required = {}
total_ram_required = {}
for loaded_model in models_to_load:
device = loaded_model.device
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
resident_memory, model_memory = loaded_model.model_mmap_residency()
pinned_memory = loaded_model.model.pinned_memory_size()
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
#make this JIT to keep as much pinned as possible.
pins_required = model_memory - pinned_memory
ram_required = model_memory - resident_memory
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem,
device,
for_dynamic=free_for_dynamic,
pins_required=total_pins_required[device],
ram_required=total_ram_required[device])
for_dynamic=free_for_dynamic)
for device in total_memory_required:
if device != torch.device("cpu"):
@ -1180,6 +1171,7 @@ STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
STREAM_PIN_BUFFERS = {}
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
@ -1220,21 +1212,66 @@ def get_aimdo_cast_buffer(offload_stream, device):
if cast_buffer is None:
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer
def get_pin_buffer(offload_stream):
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
if pin_buffer is None:
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3))
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
elif offload_stream is not None:
event = getattr(pin_buffer, "_comfy_event", None)
if event is not None:
event.synchronize()
delattr(pin_buffer, "_comfy_event")
return pin_buffer
def resize_pin_buffer(pin_buffer, size):
global TOTAL_PINNED_MEMORY
old_size = pin_buffer.size
if size <= old_size:
return True
growth = size - old_size
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
ensure_pin_budget(growth, evict_active=True)
ensure_pin_registerable(growth, evict_active=True)
try:
pin_buffer.extend(size=size, reallocate=True)
except RuntimeError:
return False
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
return True
def reset_cast_buffers():
global TOTAL_PINNED_MEMORY
global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0)
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS):
if offload_stream is not None:
offload_stream.synchronize()
synchronize()
for mmap_obj in DIRTY_MMAPS:
mmap_obj.bounce()
DIRTY_MMAPS.clear()
for pin_buffer in STREAM_PIN_BUFFERS.values():
TOTAL_PINNED_MEMORY -= pin_buffer.size
TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY)
for loaded_model in current_loaded_models:
model = loaded_model.model
if model is not None and model.is_dynamic():
model.model.dynamic_pins[model.load_device]["active"] = False
model.partially_unload_ram(1e30, subsets=[ "patches" ])
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0])
STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()
STREAM_PIN_BUFFERS.clear()
soft_empty_cache()
def get_offload_stream(device):
@ -1280,7 +1317,7 @@ def sync_stream(device, stream):
current_stream(device).wait_stream(stream)
def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
wf_context = nullcontext()
if stream is not None:
wf_context = stream
@ -1288,17 +1325,20 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
wf_context = wf_context.as_context(stream)
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
with wf_context:
for tensor in tensors:
dest_view = dest_views.pop(0)
dest2_view = dest2_views.pop(0) if dest2_views is not None else None
if tensor is None:
continue
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view, stream=stream, destination2=dest2_view):
continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
if hasattr(storage, "_comfy_tensor_mmap_touched"):
storage._comfy_tensor_mmap_touched = True
mark_mmap_dirty(storage)
dest_view.copy_(tensor, non_blocking=non_blocking)
if dest2_view is not None:
dest2_view.copy_(dest_view, non_blocking=non_blocking)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
@ -1339,14 +1379,18 @@ TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
ram = get_total_memory(torch.device("cpu"))
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.40 # Windows limit is apparently 50%
MAX_PINNED_MEMORY = ram * 0.40 # Windows limit is apparently 50%
else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.90
MAX_PINNED_MEMORY = ram * 0.90
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
def pinned_hostbuf_size(size):
return max(0, int(min(size, MAX_PINNED_MEMORY) * 2))
def discard_cuda_async_error():
try:
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
@ -1378,8 +1422,8 @@ def pin_memory(tensor):
return False
size = tensor.nbytes
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
return False
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
ensure_pin_registerable(size)
ptr = tensor.data_ptr()
if ptr == 0:
@ -1416,7 +1460,8 @@ def unpin_memory(tensor):
return False
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
size = PINNED_MEMORY.pop(ptr)
TOTAL_PINNED_MEMORY -= size
return True
else:
logging.warning("Unpin error.")

View File

@ -35,6 +35,7 @@ import comfy.model_management
import comfy.ops
import comfy.patcher_extension
import comfy.utils
import comfy_aimdo.host_buffer
from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@ -117,6 +118,8 @@ def string_to_seed(data):
return comfy.utils.string_to_seed(data)
class LowVramPatch:
is_lowvram_patch = True
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
self.patches = patches
@ -124,11 +127,21 @@ class LowVramPatch:
self.set_func = set_func
self.prepared_patches = None
def prepare(self, allocate_buffer, stream):
self.prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
def memory_required(self):
counter = [0]
for patch in self.patches[self.key]:
comfy.lora.prefetch_prepared_value(patch[1], counter, None, None, False)
return counter[0]
def prepare(self, destination, stream, copy=True, commit=True):
counter = [0]
prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream, copy), patch[2], patch[3], patch[4])
for patch in self.patches[self.key]
]
if commit:
self.prepared_patches = prepared_patches
return prepared_patches
def clear_prepared(self):
self.prepared_patches = None
@ -341,9 +354,6 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model)
return self.size
def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free)
def loaded_size(self):
return self.model.model_loaded_weight_memory
@ -1118,8 +1128,12 @@ class ModelPatcher:
# Pinned memory pressure tracking is only implemented for DynamicVram loading
return 0
def loaded_ram_size(self):
# Loaded RAM pressure tracking is only implemented for DynamicVram loading
return 0
def partially_unload_ram(self, ram_to_unload):
pass
return 0
def detach(self, unpatch_all=True):
self.eject_model()
@ -1550,6 +1564,16 @@ class ModelPatcherDynamic(ModelPatcher):
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
if not hasattr(self.model, "dynamic_vbars"):
self.model.dynamic_vbars = {}
if not hasattr(self.model, "dynamic_pins"):
self.model.dynamic_pins = {}
if self.load_device not in self.model.dynamic_pins:
self.model.dynamic_pins[self.load_device] = {
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"hostbufs_initialized": False,
"failed": False,
"active": False,
}
self.non_dynamic_delegate_model = None
assert load_device is not None
@ -1611,6 +1635,14 @@ class ModelPatcherDynamic(ModelPatcher):
self.unpatch_hooks()
vbar = self._vbar_get(create=True)
pin_state = self.model.dynamic_pins[self.load_device]
if not pin_state["hostbufs_initialized"]:
hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size())
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0])
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0])
pin_state["hostbufs_initialized"] = True
pin_state["failed"] = False
pin_state["active"] = True
if vbar is not None:
vbar.prioritize()
@ -1636,7 +1668,9 @@ class ModelPatcherDynamic(ModelPatcher):
if key in self.patches:
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
return (True, 0)
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
lowvram_patch = LowVramPatch(key, self.patches)
lowvram_patch._pin_state = pin_state
setattr(m, param_key + "_lowvram_function", lowvram_patch)
num_patches += 1
else:
setattr(m, param_key + "_lowvram_function", None)
@ -1653,6 +1687,9 @@ class ModelPatcherDynamic(ModelPatcher):
def force_load_param(self, param_key, device_to):
key = key_param_name_to_key(n, param_key)
weight, _, _ = get_key_weight(self.model, key)
if weight is None:
return
if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to, force_cast=True)
@ -1662,17 +1699,23 @@ class ModelPatcherDynamic(ModelPatcher):
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
m.pin_failed = False
m.seed_key = n
m._pin_state = pin_state
set_dirty(m, dirty)
force_load, v_weight_size = setup_param(self, m, n, "weight")
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
force_load = force_load or force_load_bias
v_weight_size += v_weight_bias
#Models that mix tiny and giant weights can causing lopsided stream buffer
#rotations and stall. force the tinys over.
if module_mem > 16 * 1024:
force_load, v_weight_size = setup_param(self, m, n, "weight")
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
force_load = force_load or force_load_bias
v_weight_size += v_weight_bias
if force_load:
logging.info(f"Module {n} has resizing Lora - force loading")
else:
force_load=True
if force_load:
logging.info(f"Module {n} has resizing Lora - force loading")
force_load_param(self, "weight", device_to)
force_load_param(self, "bias", device_to)
else:
@ -1740,23 +1783,58 @@ class ModelPatcherDynamic(ModelPatcher):
return freed
def pinned_memory_size(self):
total = 0
loading = self._load_list(for_dynamic=True)
for x in loading:
_, _, _, _, m, _ = x
pin = comfy.pinned_memory.get_pin(m)
if pin is not None:
total += pin.numel() * pin.element_size()
return total
def loaded_ram_size(self):
return (self.model.dynamic_pins[self.load_device]["weights"][0].size +
self.model.dynamic_pins[self.load_device]["patches"][0].size)
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading:
*_, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
if ram_to_unload <= 0:
return
def pinned_memory_size(self):
return (self.model.dynamic_pins[self.load_device]["weights"][3][0] +
self.model.dynamic_pins[self.load_device]["patches"][3][0])
def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets:
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
split = stack_split[0]
while split >= 0:
module, offset = stack[split]
split -= 1
stack_split[0] = split
if not module._pin_registered:
continue
size = module._pin.numel() * module._pin.element_size()
if torch.cuda.cudart().cudaHostUnregister(module._pin.data_ptr()) != 0:
comfy.model_management.discard_cuda_async_error()
continue
module._pin_registered = False
comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size)
pinned_size[0] = max(0, pinned_size[0] - size)
freed += size
ram_to_unload -= size
if ram_to_unload <= 0:
return freed
return freed
def partially_unload_ram(self, ram_to_unload, subsets=[ "weights", "patches" ]):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets:
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
while len(stack) > 0:
module, offset = stack.pop()
size = module._pin.numel() * module._pin.element_size()
del module._pin
hostbuf.truncate(offset, do_unregister=module._pin_registered)
stack_split[0] = min(stack_split[0], len(stack) - 1)
if module._pin_registered:
comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size)
pinned_size[0] = max(0, pinned_size[0] - size)
freed += size
ram_to_unload -= size
if ram_to_unload <= 0:
return freed
return freed
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
#This isn't used by the core at all and can only be to load a model out of

View File

@ -75,6 +75,8 @@ except:
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@ -91,6 +93,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
offload_stream = None
cast_buffer = None
cast_buffer_offset = 0
stream_pin_hostbuf = None
stream_pin_offset = 0
stream_pin_queue = []
def ensure_offload_stream(module, required_size, check_largest):
nonlocal offload_stream
@ -124,6 +129,22 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
cast_buffer_offset += buffer_size
return buffer
def get_stream_pin_buffer_offset(buffer_size):
nonlocal stream_pin_hostbuf
nonlocal stream_pin_offset
if buffer_size == 0 or offload_stream is None:
return None
if stream_pin_hostbuf is None:
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
if stream_pin_hostbuf is None:
return None
offset = stream_pin_offset
stream_pin_offset += buffer_size
return offset
for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
@ -162,23 +183,47 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
if xfer_dest is None:
xfer_dest = get_cast_buffer(dest_size)
if signature is None and pin is None:
comfy.pinned_memory.pin_memory(s)
pin = comfy.pinned_memory.get_pin(s)
else:
pin = None
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
if xfer_source is not None:
if getattr(xfer_source, "is_lowvram_patch", False):
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
else:
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
if pin is not None:
comfy.model_management.cast_to_gathered(xfer_source, pin)
xfer_source = [ pin ]
#send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
def handle_pin(m, pin, source, dest, subset="weights", size=None):
if pin is not None:
cast_maybe_lowvram_patch([pin], dest, offload_stream)
return
if signature is None:
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
pin = comfy.pinned_memory.get_pin(m, subset=subset)
if pin is not None:
if isinstance(source, list):
comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest)
else:
cast_maybe_lowvram_patch(source, pin, None)
cast_maybe_lowvram_patch([ pin ], dest, offload_stream)
return
if pin is None:
pin_offset = get_stream_pin_buffer_offset(size)
if pin_offset is not None:
stream_pin_queue.append((source, pin_offset, size, dest))
return
cast_maybe_lowvram_patch(source, dest, offload_stream)
handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size)
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_source = getattr(s, param_key + "_lowvram_function", None)
if lowvram_source is not None:
ensure_offload_stream(s, cast_buffer_offset, False)
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
lowvram_size = lowvram_source.memory_required()
lowvram_dest = get_cast_buffer(lowvram_size)
lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True)
pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches")
handle_pin(lowvram_source, pin, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size)
prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest
@ -186,6 +231,23 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch
if stream_pin_offset > 0:
if stream_pin_hostbuf.size < stream_pin_offset:
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
for xfer_source, _, _, xfer_dest in stream_pin_queue:
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
return offload_stream
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf)
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
if isinstance(xfer_source, list):
comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest)
else:
cast_maybe_lowvram_patch(xfer_source, pin, None)
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
stream_pin_hostbuf._comfy_event = offload_stream.record_event()
return offload_stream

View File

@ -2,42 +2,62 @@ import comfy.model_management
import comfy.memory_management
import comfy_aimdo.host_buffer
import comfy_aimdo.torch
import torch
from comfy.cli_args import args
def get_pin(module):
return getattr(module, "_pin", None)
def get_pin(module, subset="weights"):
pin = getattr(module, "_pin", None)
if pin is None or module._pin_registered or args.disable_pinned_memory:
return pin
def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
_, _, stack_split, pinned_size = module._pin_state[subset]
size = pin.nbytes
comfy.model_management.ensure_pin_registerable(size)
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
comfy.model_management.discard_cuda_async_error()
return pin
module._pin_registered = True
stack_split[0] = max(stack_split[0], module._pin_stack_index)
comfy.model_management.TOTAL_PINNED_MEMORY += size
pinned_size[0] += size
return pin
def pin_memory(module, subset="weights", size=None):
pin_state = module._pin_state
if args.disable_pinned_memory:
return
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
pin = get_pin(module, subset)
if pin is not None or pin_state["failed"]:
return
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
module.pin_failed = True
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
if size is None:
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
offset = hostbuf.size
registerable_size = size + max(0, hostbuf.size - pinned_size[0])
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
if (not comfy.model_management.ensure_pin_budget(size) or
not comfy.model_management.ensure_pin_registerable(registerable_size)):
pin_state["failed"] = True
return False
try:
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
hostbuf.extend(size=size)
except RuntimeError:
module.pin_failed = True
pin_state["failed"] = True
return False
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
module._pin_hostbuf = hostbuf
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
stack.append((module, offset))
module._pin_registered = True
module._pin_stack_index = len(stack) - 1
stack_split[0] = max(stack_split[0], module._pin_stack_index)
comfy.model_management.TOTAL_PINNED_MEMORY += size
pinned_size[0] += size
return True
def unpin_memory(module):
if get_pin(module) is None:
return 0
size = module._pin.numel() * module._pin.element_size()
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
del module._pin
del module._pin_hostbuf
return size

View File

@ -113,7 +113,6 @@ def load_safetensors(ckpt):
"_comfy_tensor_file_slice",
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
setattr(storage, "_comfy_tensor_mmap_touched", False)
sd[name] = tensor
return sd, header.get("__metadata__", {}),
@ -1451,4 +1450,3 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res
return res

View File

@ -1,52 +0,0 @@
import ctypes
import logging
import psutil
from ctypes import wintypes
import comfy_aimdo.control
psapi = ctypes.WinDLL("psapi")
kernel32 = ctypes.WinDLL("kernel32")
class PERFORMANCE_INFORMATION(ctypes.Structure):
_fields_ = [
("cb", wintypes.DWORD),
("CommitTotal", ctypes.c_size_t),
("CommitLimit", ctypes.c_size_t),
("CommitPeak", ctypes.c_size_t),
("PhysicalTotal", ctypes.c_size_t),
("PhysicalAvailable", ctypes.c_size_t),
("SystemCache", ctypes.c_size_t),
("KernelTotal", ctypes.c_size_t),
("KernelPaged", ctypes.c_size_t),
("KernelNonpaged", ctypes.c_size_t),
("PageSize", ctypes.c_size_t),
("HandleCount", wintypes.DWORD),
("ProcessCount", wintypes.DWORD),
("ThreadCount", wintypes.DWORD),
]
def get_free_ram():
#Windows is way too conservative and chalks recently used uncommitted model RAM
#as "in-use". So, calculate free RAM for the sake of general use as the greater of:
#
#1: What psutil says
#2: Total Memory - (Committed Memory - VRAM in use)
#
#We have to subtract VRAM in use from the comitted memory as WDDM creates a naked
#commit charge for all VRAM used just incase it wants to page it all out. This just
#isn't realistic so "overcommit" on our calculations by just subtracting it off.
pi = PERFORMANCE_INFORMATION()
pi.cb = ctypes.sizeof(pi)
if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb):
logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal")
return psutil.virtual_memory().available
committed = pi.CommitTotal * pi.PageSize
total = pi.PhysicalTotal * pi.PageSize
return max(psutil.virtual_memory().available,
total - (committed - comfy_aimdo.control.get_total_vram_usage()))

View File

@ -0,0 +1,111 @@
"""Pure-numpy port of MediaPipe's face_geometry (FACE_LANDMARK_PIPELINE mode)
+ weighted Procrustes solver. Computes the 4x4 facial transformation matrix.
"""
from __future__ import annotations
import math
import numpy as np
def _solve_weighted_orthogonal_problem(src: np.ndarray, tgt: np.ndarray, weights: np.ndarray) -> np.ndarray:
"""Weighted orthogonal Procrustes (similarity). Returns 4x4 M with
`target ≈ M @ homogeneous(source)` in the weighted LS sense. fp64 for
SVD stability. Port of procrustes_solver.cc."""
sqrt_w = np.sqrt(weights.astype(np.float64))
w_total = float((sqrt_w ** 2).sum())
ws = src.astype(np.float64) * sqrt_w
wt = tgt.astype(np.float64) * sqrt_w
c_w = (ws @ sqrt_w) / w_total
centered = ws - np.outer(c_w, sqrt_w)
U, _S, Vt = np.linalg.svd(wt @ centered.T, full_matrices=True)
# Disallow reflection: flip the least-significant axis when det(U)·det(V)<0.
post, pre = U.copy(), Vt.T.copy()
if np.linalg.det(post) * np.linalg.det(pre) < 0:
post[:, 2] *= -1.0
R = post @ pre.T
denom = float((centered * ws).sum())
if denom < 1e-12:
raise ValueError("Procrustes denominator collapsed (degenerate source).")
scale = float((R @ centered * wt).sum()) / denom
translation = ((wt - scale * (R @ ws)) @ sqrt_w) / w_total
M = np.eye(4, dtype=np.float64)
M[:3, :3] = scale * R
M[:3, 3] = translation
return M
def _estimate_scale(canonical: np.ndarray, runtime: np.ndarray, weights: np.ndarray) -> float:
"""scale = ‖first column of M[:3]‖ per geometry_pipeline.cc::EstimateScale."""
return float(np.linalg.norm(_solve_weighted_orthogonal_problem(canonical, runtime, weights)[:3, 0]))
def solve_facial_transformation_matrix(
landmarks_normalized: np.ndarray,
canonical_vertices: np.ndarray,
procrustes_indices: np.ndarray,
procrustes_weights: np.ndarray,
image_width: int,
image_height: int,
# face_geometry_calculator_options.pbtxt defaults
vertical_fov_degrees: float = 63.0,
near: float = 1.0,
) -> np.ndarray:
"""4x4 facial transformation matrix via two-pass scale recovery
`landmarks_normalized` is (N, 3) in MediaPipe normalized convention: x, y
in [0,1] with TOP-LEFT origin, z in width-scaled units.
"""
h_near = 2.0 * near * math.tan(0.5 * math.radians(vertical_fov_degrees))
w_near = image_width * h_near / image_height
sub = procrustes_indices.astype(np.int64)
screen = landmarks_normalized[sub].T.astype(np.float64).copy()
canon = canonical_vertices[sub].T.astype(np.float64).copy()
weights = procrustes_weights.astype(np.float64)
# ProjectXY (TOP_LEFT y-flip, then scale all 3 axes; z uses x-scale).
screen[1] = 1.0 - screen[1]
screen[0] = screen[0] * w_near - 0.5 * w_near
screen[1] = screen[1] * h_near - 0.5 * h_near
screen[2] = screen[2] * w_near
depth_offset = float(screen[2].mean())
def _unproject(s: np.ndarray, scale: float) -> np.ndarray:
s = s.copy()
s[2] = (s[2] - depth_offset + near) / scale
s[0] *= s[2] / near
s[1] *= s[2] / near
s[2] *= -1.0
return s
first = screen.copy()
first[2] *= -1.0
s1 = _estimate_scale(canon, first, weights) # 1st pass: Procrustes on projected XY
s2 = _estimate_scale(canon, _unproject(screen, s1), weights) # 2nd pass: rescale z by s1, un-project XY
return _solve_weighted_orthogonal_problem(canon, _unproject(screen, s1 * s2), weights).astype(np.float32)
def transformation_matrix_from_detection(face_dict: dict, image_width: int, image_height: int, canonical_data: dict) -> np.ndarray:
"""Adapt a FaceLandmarker face dict to MP's normalized convention and solve.
FaceMesh emits (x, y, z) in 192-canonical units; MP's geometry expects
z_norm = z_canonical * scale_x / image_width"""
lmks_xy, lmks_3d = face_dict["landmarks_xy"], face_dict["landmarks_3d"]
aug = np.concatenate([lmks_3d[:, :2].astype(np.float64), np.ones((lmks_xy.shape[0], 1))], axis=1)
M, *_ = np.linalg.lstsq(aug, lmks_xy.astype(np.float64), rcond=None)
scale_x = float(np.linalg.norm(M[0]))
z_scale = scale_x / image_width if scale_x > 1e-6 else 1.0 / image_width
normalized = np.empty((lmks_xy.shape[0], 3), dtype=np.float32)
normalized[:, 0] = lmks_xy[:, 0] / image_width
normalized[:, 1] = lmks_xy[:, 1] / image_height
normalized[:, 2] = lmks_3d[:, 2] * z_scale
return solve_facial_transformation_matrix(
normalized, canonical_data["canonical_vertices"],
canonical_data["procrustes_indices"], canonical_data["procrustes_weights"],
image_width=image_width, image_height=image_height,
)

View File

@ -0,0 +1,682 @@
"""Pure-PyTorch port of MediaPipe's face_landmarker_v2_with_blendshapes.task:
BlazeFace detector → FaceMesh v2 → ARKit-52 blendshapes."""
from __future__ import annotations
import math
from functools import lru_cache
from typing import List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from scipy.special import expit
from torch import Tensor, nn
# Values below must stay verbatim with the published face_landmarker_v2 graph
# face_blendshapes_graph.cc::kLandmarksSubsetIdxs
_BS_INPUT_INDICES: Tuple[int, ...] = (
0, 1, 4, 5, 6, 7, 8, 10, 13, 14, 17, 21, 33, 37, 39, 40, 46, 52, 53, 54,
55, 58, 61, 63, 65, 66, 67, 70, 78, 80, 81, 82, 84, 87, 88, 91, 93, 95,
103, 105, 107, 109, 127, 132, 133, 136, 144, 145, 146, 148, 149, 150, 152,
153, 154, 155, 157, 158, 159, 160, 161, 162, 163, 168, 172, 173, 176, 178,
181, 185, 191, 195, 197, 234, 246, 249, 251, 263, 267, 269, 270, 276, 282,
283, 284, 285, 288, 291, 293, 295, 296, 297, 300, 308, 310, 311, 312, 314,
317, 318, 321, 323, 324, 332, 334, 336, 338, 356, 361, 362, 365, 373, 374,
375, 377, 378, 379, 380, 381, 382, 384, 385, 386, 387, 388, 389, 390, 397,
398, 400, 402, 405, 409, 415, 454, 466, 468, 469, 470, 471, 472, 473, 474,
475, 476, 477,
)
# face_blendshapes_graph.cc::kCategoryNames
BLENDSHAPE_NAMES: Tuple[str, ...] = (
"_neutral", "browDownLeft", "browDownRight", "browInnerUp", "browOuterUpLeft",
"browOuterUpRight", "cheekPuff", "cheekSquintLeft", "cheekSquintRight",
"eyeBlinkLeft", "eyeBlinkRight", "eyeLookDownLeft", "eyeLookDownRight",
"eyeLookInLeft", "eyeLookInRight", "eyeLookOutLeft", "eyeLookOutRight",
"eyeLookUpLeft", "eyeLookUpRight", "eyeSquintLeft", "eyeSquintRight",
"eyeWideLeft", "eyeWideRight", "jawForward", "jawLeft", "jawOpen",
"jawRight", "mouthClose", "mouthDimpleLeft", "mouthDimpleRight",
"mouthFrownLeft", "mouthFrownRight", "mouthFunnel", "mouthLeft",
"mouthLowerDownLeft", "mouthLowerDownRight", "mouthPressLeft",
"mouthPressRight", "mouthPucker", "mouthRight", "mouthRollLower",
"mouthRollUpper", "mouthShrugLower", "mouthShrugUpper", "mouthSmileLeft",
"mouthSmileRight", "mouthStretchLeft", "mouthStretchRight",
"mouthUpperUpLeft", "mouthUpperUpRight", "noseSneerLeft", "noseSneerRight",
)
# face_detection.pbtxt — short-range BlazeFace.
_BF_NUM_LAYERS = 4
_BF_INPUT_SIZE = 128
_BF_STRIDES = (8, 16, 16, 16)
_BF_ANCHOR_OFFSET_X = 0.5
_BF_ANCHOR_OFFSET_Y = 0.5
_BF_ASPECT_RATIOS = (1.0,)
_BF_INTERP_SCALE_AR = 1.0
_BF_BOX_SCALE = 128.0
_BF_KP_OFFSET = 4
_BF_SCORE_CLIP = 100.0
_BF_MIN_SCORE = 0.5
# face_detection_full_range.pbtxt — 48x48 grid at stride 4, 1 anchor/cell.
_BF_FR_INPUT_SIZE = 192
_BF_FR_GRID = 48
_BF_FR_NUM_ANCHORS = _BF_FR_GRID * _BF_FR_GRID
_BF_FR_BOX_SCALE = 192.0
_BF_FR_SCORE_CLIP = 100.0
_FM_INPUT_SIZE = 192
# Face ROI: 1.5xbbox rect warped anisotropically into 192x192.
_FACE_LEFT_EYE_KP = 0
_FACE_RIGHT_EYE_KP = 1
_FACE_ROI_SCALE_X = 1.5
_FACE_ROI_SCALE_Y = 1.5
_FACE_ROI_TARGET_ANGLE = 0.0
def _tf_same_pad(x: Tensor, kernel: int, stride: int) -> Tensor:
"""TF SAME pad (asymmetric on stride-2; PyTorch's symmetric pad undershoots by 1 px)."""
H, W = x.shape[-2], x.shape[-1]
pad_h = max(((H + stride - 1) // stride - 1) * stride + kernel - H, 0)
pad_w = max(((W + stride - 1) // stride - 1) * stride + kernel - W, 0)
if pad_h == 0 and pad_w == 0:
return x
return F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
# BlazeFace short-range: stem 5x5/s2 → 16 BlazeBlocks → parallel heads at
# 16²x88 (2 anchors/cell) and 8²x96 (6/cell) = 896 anchors. (in, out, stride):
_BLAZEFACE_BLOCKS = [
(24, 24, 1), (24, 28, 1), (28, 32, 2), (32, 36, 1),
(36, 42, 1), (42, 48, 2), (48, 56, 1), (56, 64, 1),
(64, 72, 1), (72, 80, 1), (80, 88, 1), (88, 96, 2),
(96, 96, 1), (96, 96, 1), (96, 96, 1), (96, 96, 1),
]
class BlazeFaceBlock(nn.Module):
"""DW 3x3 + PW + residual. Residual max-pools on stride>1, channel-pads on out_ch>in_ch."""
def __init__(self, in_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
self.in_ch, self.out_ch, self.stride = in_ch, out_ch, stride
self.depthwise = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, device=device, dtype=dtype)
self.pointwise = ops.Conv2d(in_ch, out_ch, 1, padding=0, bias=True, device=device, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
residual = F.max_pool2d(x, 2, 2) if self.stride > 1 else x
if self.out_ch > self.in_ch:
residual = F.pad(residual, (0, 0, 0, 0, 0, self.out_ch - self.in_ch))
x = _tf_same_pad(x, 3, self.stride) if self.stride > 1 else F.pad(x, (1, 1, 1, 1))
return F.relu(self.pointwise(self.depthwise(x)) + residual)
class BlazeFace(nn.Module):
"""Short-range BlazeFace: (B, 3, 128, 128) in [-1, 1] → 896 anchors x 17."""
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
self.stem = ops.Conv2d(3, 24, 5, stride=2, padding=0, bias=True, **kw)
self.blocks = nn.ModuleList(BlazeFaceBlock(i, o, s, device=device, dtype=dtype, operations=operations)
for (i, o, s) in _BLAZEFACE_BLOCKS)
# 16²x2 + 8²x6 = 512 + 384 = 896 anchors.
self.cls_16 = ops.Conv2d(88, 2, 1, padding=0, bias=True, **kw)
self.cls_8 = ops.Conv2d(96, 6, 1, padding=0, bias=True, **kw)
self.reg_16 = ops.Conv2d(88, 32, 1, padding=0, bias=True, **kw)
self.reg_8 = ops.Conv2d(96, 96, 1, padding=0, bias=True, **kw)
def forward(self, image_chw_normalized: Tensor) -> tuple[Tensor, Tensor]:
x = F.relu(self.stem(_tf_same_pad(image_chw_normalized, 5, 2)))
# 16x16 tap is block-10 output (before the 88→96 stride-2 in block 11).
for i in range(11):
x = self.blocks[i](x)
feat_16 = x
for i in range(11, 16):
x = self.blocks[i](x)
feat_8 = x
def flat(t, a, k): # NHWC flatten → (B, H*W*A, K)
B, _, H, W = t.shape
return t.permute(0, 2, 3, 1).reshape(B, H * W * a, k)
cls = torch.cat([flat(self.cls_16(feat_16), 2, 1), flat(self.cls_8(feat_8), 6, 1)], dim=1)
reg = torch.cat([flat(self.reg_16(feat_16), 2, 16), flat(self.reg_8(feat_8), 6, 16)], dim=1)
return reg, cls
# BlazeFace full-range (face_detection_full_range_sparse.tflite): MobileNetV2-ish
# backbone + top-down FPN, 192² input → 2304 anchors at the 48x48 grid.
class FRBlock(nn.Module):
"""Double inverted residual: DW → PW(mid) → DW → PW(out) [+ residual].
Per source tflite: dw* have no fused activation, pw1 is always ReLU, pw2
is ReLU only when no residual (else ReLU fuses into the ADD).
"""
def __init__(self, in_ch: int, mid_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
self.has_residual = (in_ch == out_ch and stride == 1)
self.dw1 = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, **kw)
self.pw1 = ops.Conv2d(in_ch, mid_ch, 1, padding=0, bias=True, **kw)
self.dw2 = ops.Conv2d(mid_ch, mid_ch, 3, stride=1, padding=0, groups=mid_ch, bias=True, **kw)
self.pw2 = ops.Conv2d(mid_ch, out_ch, 1, padding=0, bias=True, **kw)
def forward(self, x: Tensor) -> Tensor:
residual = x if self.has_residual else None
x = F.relu(self.pw1(self.dw1(F.pad(x, (1, 1, 1, 1)))))
x = self.pw2(self.dw2(F.pad(x, (1, 1, 1, 1))))
return F.relu(x + residual) if residual is not None else F.relu(x)
# (in_ch, mid_ch, out_ch, stride). Stages downsample 96²x32 → 48²x64 → 24²x128
# → 12²x192 → 6²x384. Lateral taps at indices 4, 7, 10 (see _FR_LATERAL_*).
_FR_BACKBONE_BLOCKS = [
(32, 8, 32, 1), (32, 8, 32, 1), # 96²x32
(32, 16, 64, 2), (64, 16, 64, 1), (64, 16, 64, 1), # 48²x64 — tap[0]
(64, 32, 128, 2), (128, 32, 128, 1), (128, 32, 128, 1), # 24²x128 — tap[1]
(128, 48, 192, 2), (192, 48, 192, 1), (192, 48, 192, 1), # 12²x192 — tap[2]
(192, 96, 384, 2), (384, 96, 384, 1), (384, 96, 384, 1), (384, 96, 384, 1), # 6²x384
]
_FR_LATERAL_TAP_INDICES = (4, 7, 10)
_FR_LATERAL_CHANNELS = ((64, 48), (128, 64), (192, 96)) # (in, out) per side-conv
# Decoder blocks per FPN level (after upsample-and-merge with the lateral).
_FR_DECODER_BLOCKS = [
[(96, 48, 96, 1), (96, 48, 96, 1)], # 12²x96
[(64, 32, 64, 1), (64, 32, 64, 1)], # 24²x64
[(48, 24, 48, 1)], # 48²x48 — feeds the heads
]
def _dcr_depth_to_space(t: Tensor, r: int, c_out: int) -> Tensor:
"""TF DEPTH_TO_SPACE in DCR layout (input channels = (i, j, c_out)).
pixel_shuffle uses CRD which permutes output channels for c_out > 1."""
B_, _, H_, W_ = t.shape
t = t.reshape(B_, r, r, c_out, H_, W_)
t = t.permute(0, 3, 4, 1, 5, 2).contiguous()
return t.reshape(B_, c_out, H_ * r, W_ * r)
class BlazeFaceFullRange(nn.Module):
"""Full-range face detector: (B, 3, 192, 192) in [-1, 1] → 2304 anchors x 17 values."""
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
mk_block = lambda i, m, o, s: FRBlock(i, m, o, s, device=device, dtype=dtype, operations=operations)
self.stem = ops.Conv2d(3, 32, 3, stride=2, padding=0, bias=True, **kw)
self.backbone = nn.ModuleList(mk_block(i, m, o, s) for (i, m, o, s) in _FR_BACKBONE_BLOCKS)
self.lateral_convs = nn.ModuleList(ops.Conv2d(i, o, 1, padding=0, bias=True, **kw) for (i, o) in _FR_LATERAL_CHANNELS)
self.top_conv = ops.Conv2d(384, 96, 1, padding=0, bias=True, **kw)
self.decoder_levels = nn.ModuleList(
nn.ModuleList(mk_block(i, m, o, s) for (i, m, o, s) in lvl) for lvl in _FR_DECODER_BLOCKS
)
# 96→64 before 12→24, 64→48 before 24→48.
self.decoder_reduce_convs = nn.ModuleList([
ops.Conv2d(96, 64, 1, padding=0, bias=True, **kw),
ops.Conv2d(64, 48, 1, padding=0, bias=True, **kw),
])
# Heads mix 2x2-cell info via DW-stride-2 + depth_to_space block_size=2.
self.cls_conv = ops.Conv2d(48, 4, 1, padding=0, bias=True, **kw)
self.cls_dw = ops.Conv2d(4, 4, 3, stride=2, padding=0, groups=4, bias=True, **kw)
self.reg_conv = ops.Conv2d(48, 64, 1, padding=0, bias=True, **kw)
self.reg_dw = ops.Conv2d(64, 64, 3, stride=2, padding=0, groups=64, bias=True, **kw)
def forward(self, image_chw_normalized: Tensor) -> tuple[Tensor, Tensor]:
# Symmetric pad-1 throughout (full-range tflite uses explicit TF PAD, not SAME).
x = F.relu(self.stem(F.pad(image_chw_normalized, (1, 1, 1, 1))))
tap_set = set(_FR_LATERAL_TAP_INDICES)
laterals: list[Tensor] = []
for i, blk in enumerate(self.backbone):
x = blk(x)
if i in tap_set:
laterals.append(x)
# top_conv / lateral_convs / decoder_reduce_convs all have fused ReLU in the tflite.
p = F.relu(self.top_conv(x))
laterals_rev = list(reversed(laterals))
lateral_convs_rev = list(reversed(self.lateral_convs))
for level in range(len(self.decoder_levels)):
lateral = laterals_rev[level]
p = F.interpolate(p, size=lateral.shape[-2:], mode="bilinear", align_corners=False)
p = p + F.relu(lateral_convs_rev[level](lateral))
for blk in self.decoder_levels[level]:
p = blk(p)
if level < len(self.decoder_reduce_convs):
p = F.relu(self.decoder_reduce_convs[level](p))
c = self.cls_dw(F.pad(self.cls_conv(p), (1, 1, 1, 1)))
c = _dcr_depth_to_space(c, r=2, c_out=1)
r = self.reg_dw(F.pad(self.reg_conv(p), (1, 1, 1, 1)))
r = _dcr_depth_to_space(r, r=2, c_out=16)
B = c.shape[0]
cls_out = c.permute(0, 2, 3, 1).reshape(B, _BF_FR_NUM_ANCHORS, 1)
reg_out = r.permute(0, 2, 3, 1).reshape(B, _BF_FR_NUM_ANCHORS, 16)
return reg_out, cls_out
@lru_cache(maxsize=1)
def _blazeface_full_range_anchors() -> np.ndarray:
"""2304 anchors over 48x48; anchor_w=anchor_h=1 (fixed_anchor_size)."""
feat = _BF_FR_GRID
yy, xx = np.meshgrid(np.arange(feat, dtype=np.float32), np.arange(feat, dtype=np.float32), indexing="ij")
cx, cy, ones = (xx + 0.5) / feat, (yy + 0.5) / feat, np.ones_like(xx)
return np.stack([cx, cy, ones, ones], axis=-1).reshape(_BF_FR_NUM_ANCHORS, 4)
def _decode_blazeface_full_range(regressors: np.ndarray, classificators: np.ndarray,
score_thresh: float = _BF_MIN_SCORE) -> np.ndarray:
"""Same decode as short-range with 2304-anchor grid and box_scale=192."""
scores = expit(np.clip(classificators[:, 0], -_BF_FR_SCORE_CLIP, _BF_FR_SCORE_CLIP))
keep = scores >= score_thresh
if not keep.any():
return np.empty((0, 17), dtype=np.float32)
r = regressors[keep] / _BF_FR_BOX_SCALE
a = _blazeface_full_range_anchors()[keep]
cxs, cys, aws, ahs = a[:, 0:1], a[:, 1:2], a[:, 2:3], a[:, 3:4]
xc, yc = r[:, 0:1] * aws + cxs, r[:, 1:2] * ahs + cys
w, h = r[:, 2:3] * aws, r[:, 3:4] * ahs
out = np.empty((r.shape[0], 17), dtype=np.float32)
out[:, 0:1], out[:, 1:2], out[:, 2:3], out[:, 3:4] = xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2
out[:, 4:16:2] = r[:, _BF_KP_OFFSET::2] * aws + cxs
out[:, 5:16:2] = r[:, _BF_KP_OFFSET + 1::2] * ahs + cys
out[:, 16] = scores[keep]
return out
# FaceMesh (face_landmarks_detector.tflite): PReLU variant of BlazeBlock,
# 17 blocks, heads for 478x3 landmarks + presence.
_FACEMESH_BLOCKS = [ # (in_ch, out_ch, stride)
(16, 16, 1), (16, 16, 1), (16, 32, 2), (32, 32, 1), (32, 32, 1), (32, 64, 2),
(64, 64, 1), (64, 64, 1), (64, 128, 2), (128, 128, 1), (128, 128, 1), (128, 128, 2),
(128, 128, 1), (128, 128, 1), (128, 128, 2), (128, 128, 1), (128, 128, 1),
]
class FaceMeshBlock(nn.Module):
"""PReLU BlazeBlock: PReLU between DW and PW, and after the residual add."""
def __init__(self, in_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
self.in_ch, self.out_ch, self.stride = in_ch, out_ch, stride
self.depthwise = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, **kw)
self.prelu_dwise = nn.PReLU(num_parameters=in_ch, **kw)
self.pointwise = ops.Conv2d(in_ch, out_ch, 1, padding=0, bias=True, **kw)
self.prelu_out = nn.PReLU(num_parameters=out_ch, **kw)
def forward(self, x: Tensor) -> Tensor:
residual = F.max_pool2d(x, 2, 2) if self.stride > 1 else x
if self.out_ch > self.in_ch:
residual = F.pad(residual, (0, 0, 0, 0, 0, self.out_ch - self.in_ch))
x = _tf_same_pad(x, 3, self.stride) if self.stride > 1 else F.pad(x, (1, 1, 1, 1))
return self.prelu_out(self.pointwise(self.prelu_dwise(self.depthwise(x))) + residual)
class FaceMesh(nn.Module):
NUM_LANDMARKS = 478
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
self.stem = ops.Conv2d(3, 16, 3, stride=2, padding=0, bias=True, **kw)
self.prelu_stem = nn.PReLU(num_parameters=16, **kw)
self.blocks = nn.ModuleList(FaceMeshBlock(i, o, s, device=device, dtype=dtype, operations=operations)
for (i, o, s) in _FACEMESH_BLOCKS)
self.head_reduce = ops.Conv2d(128, 8, 1, padding=0, bias=True, **kw)
self.prelu_head_reduce = nn.PReLU(num_parameters=8, **kw)
self.head_block = FaceMeshBlock(8, 8, 1, device=device, dtype=dtype, operations=operations)
self.head_presence = ops.Conv2d(8, 1, 3, padding=0, bias=True, **kw)
self.head_landmarks = ops.Conv2d(8, self.NUM_LANDMARKS * 3, 3, padding=0, bias=True, **kw)
def forward(self, face_chw_normalized: Tensor) -> tuple[Tensor, Tensor]:
"""(B, 3, 192, 192) in [0, 1] → ((B, 478, 3) landmarks in 192-canonical, (B,) presence)."""
x = self.prelu_stem(self.stem(_tf_same_pad(face_chw_normalized, 3, 2)))
for blk in self.blocks:
x = blk(x)
x = self.prelu_head_reduce(self.head_reduce(x))
x = self.head_block(x)
B = x.shape[0]
presence = self.head_presence(x).reshape(B)
lmks = self.head_landmarks(x).reshape(B, self.NUM_LANDMARKS, 3)
return lmks, presence
# FaceBlendshapes (MLP-Mixer "GhumMarkerPoserMlpMixerGeneral"):
# 146x2 → token-reduce 146→96 → embed 2→64 → +cls token → 4x mixer → cls→52.
_BS_NUM_INPUT_LANDMARKS = 146
_BS_NUM_TOKENS_REDUCED = 96
_BS_NUM_TOKENS = 97 # +1 cls
_BS_TOKEN_DIM = 64
_BS_TOKEN_MIX_HIDDEN = 384
_BS_CHANNEL_MIX_HIDDEN = 256
_BS_NUM_BLENDSHAPES = 52
_BS_LN_EPS = 1e-6
class MlpMixerBlock(nn.Module):
"""MLP-Mixer block: token-mixing MLP (over tokens) → channel-mixing MLP (over dim).
Both pre-LN, both residual. LN has no beta (bias=False) to match MP."""
def __init__(self, num_tokens: int, token_dim: int, token_hidden: int, channel_hidden: int,
device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
# bias=False → no LN beta (matches MP).
self.ln1 = ops.LayerNorm(token_dim, eps=_BS_LN_EPS, bias=False, **kw)
self.ln2 = ops.LayerNorm(token_dim, eps=_BS_LN_EPS, bias=False, **kw)
self.token_mlp1 = ops.Linear(num_tokens, token_hidden, bias=True, **kw)
self.token_mlp2 = ops.Linear(token_hidden, num_tokens, bias=True, **kw)
self.channel_mlp1 = ops.Linear(token_dim, channel_hidden, bias=True, **kw)
self.channel_mlp2 = ops.Linear(channel_hidden, token_dim, bias=True, **kw)
def forward(self, x: Tensor) -> Tensor:
y = self.ln1(x).transpose(1, 2)
x = x + self.token_mlp2(F.relu(self.token_mlp1(y))).transpose(1, 2)
return x + self.channel_mlp2(F.relu(self.channel_mlp1(self.ln2(x))))
class FaceBlendshapes(nn.Module):
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
self.token_reduce = ops.Linear(_BS_NUM_INPUT_LANDMARKS, _BS_NUM_TOKENS_REDUCED, bias=True, **kw)
self.token_embed = ops.Linear(2, _BS_TOKEN_DIM, bias=True, **kw)
self.cls_token = nn.Parameter(torch.zeros(1, 1, _BS_TOKEN_DIM, **kw))
self.blocks = nn.ModuleList(
MlpMixerBlock(_BS_NUM_TOKENS, _BS_TOKEN_DIM, _BS_TOKEN_MIX_HIDDEN, _BS_CHANNEL_MIX_HIDDEN,
device=device, dtype=dtype, operations=operations) for _ in range(4)
)
self.head = ops.Linear(_BS_TOKEN_DIM, _BS_NUM_BLENDSHAPES, bias=True, **kw)
@staticmethod
def _input_normalize(landmarks_2d: Tensor) -> Tensor:
# Centroid-subtract → L2 scale → x0.5. The 0.5 is baked into training.
centroid = landmarks_2d.mean(dim=1, keepdim=True)
x = landmarks_2d - centroid
mag = torch.sqrt((x * x).sum(dim=-1, keepdim=True))
scale = mag.mean(dim=1, keepdim=True)
return (x / scale.clamp(min=1e-12)) * 0.5
def forward(self, landmarks_2d: Tensor) -> Tensor:
"""(B, 146, 2) → (B, 52) in [0, 1]. Input units don't matter (centroid + L2 normalize)."""
x = self._input_normalize(landmarks_2d)
x = self.token_reduce(x.transpose(1, 2)).transpose(1, 2)
x = self.token_embed(x)
cls = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat([cls, x], dim=1)
for blk in self.blocks:
x = blk(x)
return torch.sigmoid(self.head(x[:, 0]))
@lru_cache(maxsize=1)
def _blazeface_anchors() -> np.ndarray:
"""896 anchors per SsdAnchorsCalculator (fixed_anchor_size → anchor_w=anchor_h=1)."""
per_ar = len(_BF_ASPECT_RATIOS) + (1 if _BF_INTERP_SCALE_AR > 0 else 0)
layer_anchors: List[np.ndarray] = []
layer = 0
while layer < _BF_NUM_LAYERS:
stride = _BF_STRIDES[layer]
last = layer
while last < _BF_NUM_LAYERS and _BF_STRIDES[last] == stride:
last += 1
per_cell = per_ar * (last - layer)
feat = (_BF_INPUT_SIZE + stride - 1) // stride
yy, xx = np.meshgrid(np.arange(feat, dtype=np.float32), np.arange(feat, dtype=np.float32), indexing="ij")
cx, cy, ones = (xx + _BF_ANCHOR_OFFSET_X) / feat, (yy + _BF_ANCHOR_OFFSET_Y) / feat, np.ones_like(xx)
cell = np.stack([cx, cy, ones, ones], axis=-1).reshape(-1, 4)
layer_anchors.append(np.repeat(cell, per_cell, axis=0))
layer = last
out = np.concatenate(layer_anchors, axis=0)
assert out.shape == (896, 4), out.shape
return out
def _decode_blazeface(regressors: np.ndarray, classificators: np.ndarray,
score_thresh: float = _BF_MIN_SCORE) -> np.ndarray:
"""Decode (regs (896,16), cls (896,1)) → (N, 17) = [xyxy, kp0x..kp5y, score] in [0, 1]."""
scores = expit(np.clip(classificators[:, 0], -_BF_SCORE_CLIP, _BF_SCORE_CLIP))
keep = scores >= score_thresh
if not keep.any():
return np.empty((0, 17), dtype=np.float32)
r = regressors[keep] / _BF_BOX_SCALE
a = _blazeface_anchors()[keep] # (N, 4) cx, cy, 1, 1
cxs, cys, aws, ahs = a[:, 0:1], a[:, 1:2], a[:, 2:3], a[:, 3:4]
xc, yc = r[:, 0:1] * aws + cxs, r[:, 1:2] * ahs + cys
w, h = r[:, 2:3] * aws, r[:, 3:4] * ahs
out = np.empty((r.shape[0], 17), dtype=np.float32)
out[:, 0:1], out[:, 1:2], out[:, 2:3], out[:, 3:4] = xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2
out[:, 4:16:2] = r[:, _BF_KP_OFFSET::2] * aws + cxs
out[:, 5:16:2] = r[:, _BF_KP_OFFSET + 1::2] * ahs + cys
out[:, 16] = scores[keep]
return out
def _weighted_nms(detections: np.ndarray, iou_thresh: float = 0.5) -> np.ndarray:
"""MP weighted NMS — kept boxes are score-weighted averages of overlapping detections."""
if detections.shape[0] == 0:
return detections
dets = detections[np.argsort(-detections[:, 16])]
N = dets.shape[0]
areas = np.clip(dets[:, 2] - dets[:, 0], 0, None) * np.clip(dets[:, 3] - dets[:, 1], 0, None)
kept: List[np.ndarray] = []
used = np.zeros(N, dtype=bool)
for i in range(N):
if used[i]:
continue
ax1, ay1, ax2, ay2 = dets[i, 0:4]
merge_idx = [i]
for j in range(i + 1, N):
if used[j]:
continue
bx1, by1, bx2, by2 = dets[j, 0:4]
iw = max(0.0, min(ax2, bx2) - max(ax1, bx1))
ih = max(0.0, min(ay2, by2) - max(ay1, by1))
inter = iw * ih
union = areas[i] + areas[j] - inter
if union > 0 and inter / union > iou_thresh: # strict > matches MP
merge_idx.append(j)
used[j] = True
used[i] = True
cluster = dets[merge_idx]
ws = cluster[:, 16:17]
ws_sum = ws.sum()
merged = np.copy(cluster[0])
if ws_sum > 0:
merged[:16] = (cluster[:, :16] * ws).sum(axis=0) / ws_sum
kept.append(merged)
return np.stack(kept, axis=0) if kept else np.empty((0, 17), dtype=np.float32)
def _detection_to_face_rect(detection: np.ndarray, image_w: int, image_h: int) -> Tuple[float, float, float, float, float]:
"""Detection (normalized) → rotated 1.5xbbox ROI in image pixels (anisotropic)."""
xmin, ymin, xmax, ymax = detection[0:4]
lx = detection[4 + _FACE_LEFT_EYE_KP * 2 + 0] * image_w
ly = detection[4 + _FACE_LEFT_EYE_KP * 2 + 1] * image_h
rx = detection[4 + _FACE_RIGHT_EYE_KP * 2 + 0] * image_w
ry = detection[4 + _FACE_RIGHT_EYE_KP * 2 + 1] * image_h
# Image-y-down convention: angle = target - atan2(-dy, dx).
angle = _FACE_ROI_TARGET_ANGLE - math.atan2(ly - ry, rx - lx)
return (float((xmin + xmax) * 0.5 * image_w),
float((ymin + ymax) * 0.5 * image_h),
float((xmax - xmin) * image_w * _FACE_ROI_SCALE_X),
float((ymax - ymin) * image_h * _FACE_ROI_SCALE_Y),
float(angle))
def _sample_warp(image_chw: Tensor, src_x: Tensor, src_y: Tensor, padding_mode: str) -> Tensor:
"""Bilinear-sample image_chw at corner-aligned (src_x, src_y)."""
H, W = int(image_chw.shape[-2]), int(image_chw.shape[-1])
grid = torch.stack([(2.0 * src_x + 1.0) / W - 1.0,
(2.0 * src_y + 1.0) / H - 1.0], dim=-1).unsqueeze(0)
return F.grid_sample(image_chw.unsqueeze(0), grid, mode="bilinear",
align_corners=False, padding_mode=padding_mode).squeeze(0)
def _warp_face_crop(image_chw: Tensor, cx: float, cy: float, width: float, height: float,
angle: float, output_size: int = _FM_INPUT_SIZE) -> Tensor:
"""Rotated rect → output_size² with BORDER_REPLICATE. image_chw must be in [0, 1]."""
s_x, s_y = width / output_size, height / output_size
cos_a, sin_a = math.cos(angle), math.sin(angle)
arange = torch.arange(output_size, dtype=image_chw.dtype, device=image_chw.device) - output_size * 0.5
v_grid, u_grid = torch.meshgrid(arange, arange, indexing="ij")
src_x = cx + u_grid * s_x * cos_a - v_grid * s_y * sin_a
src_y = cy + u_grid * s_x * sin_a + v_grid * s_y * cos_a
return _sample_warp(image_chw, src_x, src_y, "border")
def _blazeface_input_warp(image_chw_raw: Tensor, target: int = _BF_INPUT_SIZE) -> Tuple[Tensor, float, float, float]:
"""Centered max(W,H) square → target² with BORDER_ZERO + [-1, 1] norm.
Sub-pixel grid_sample matters; integer-pad-then-resize drifts the bbox ~5%.
Returns (warped, sub_rect_cx, sub_rect_cy, sub_rect_size) — the triplet maps
tensor-normalized [0,1] detections back to image pixels.
"""
H, W = int(image_chw_raw.shape[1]), int(image_chw_raw.shape[2])
sub_rect_size = float(max(W, H))
sub_rect_cx, sub_rect_cy = W * 0.5, H * 0.5
s = sub_rect_size / target
arange = torch.arange(target, dtype=image_chw_raw.dtype, device=image_chw_raw.device) - target * 0.5
v_grid, u_grid = torch.meshgrid(arange, arange, indexing="ij")
out = _sample_warp(image_chw_raw, sub_rect_cx + u_grid * s, sub_rect_cy + v_grid * s, "zeros")
return (out / 127.5) - 1.0, sub_rect_cx, sub_rect_cy, sub_rect_size
class FaceLandmarker(nn.Module):
"""BlazeFace → FaceMesh v2 → blendshapes. `detector_variant` selects 'short'
(128², ≤2m) or 'full' (192² FPN, ≤5m). State dict uses inner-module prefixes
`detector.*` / `mesh.*` / `blendshapes.*`; the outer FaceLandmarkerModel
wrapper rewrites `detector_{variant}.*` keys to `detector.*` before loading.
"""
def __init__(self, device=None, dtype=None, operations=None, detector_variant: str = "short"):
super().__init__()
det_cls = {"short": BlazeFace, "full": BlazeFaceFullRange}.get(detector_variant)
self.detector_variant = detector_variant
self.detector = det_cls(device=device, dtype=dtype, operations=operations)
self.mesh = FaceMesh(device=device, dtype=dtype, operations=operations)
self.blendshapes = FaceBlendshapes(device=device, dtype=dtype, operations=operations)
self.register_buffer("_bs_idx", torch.tensor(_BS_INPUT_INDICES, dtype=torch.long), persistent=False)
def run_detector_batch(self, images_rgb_uint8: List[np.ndarray],
score_thresh: float = _BF_MIN_SCORE,
iou_thresh: float = 0.5):
"""Batched detector pass. Returns (img_raws, sub_rects, sizes, per_frame_decoded)
where per_frame_decoded[b] is (N, 17) in tensor-normalized [0,1] coords."""
if not images_rgb_uint8:
return [], [], [], []
device, dtype = self.detector.stem.weight.device, self.detector.stem.weight.dtype
det_input_size, decode_fn = ((_BF_FR_INPUT_SIZE, _decode_blazeface_full_range)
if self.detector_variant == "full"
else (_BF_INPUT_SIZE, _decode_blazeface))
# Same-size frames: stack once and transfer once. Variable size falls back
# to per-image (only triggers for SAM3DBody's head crops).
sizes = [tuple(img.shape[:2]) for img in images_rgb_uint8]
if len(set(sizes)) == 1:
batch_chw = torch.from_numpy(np.stack(images_rgb_uint8, axis=0)).to(device, dtype).movedim(-1, -3).contiguous()
img_raws = [batch_chw[bi] for bi in range(batch_chw.shape[0])]
else:
img_raws = [torch.from_numpy(img).to(device, dtype).movedim(-1, -3).contiguous() for img in images_rgb_uint8]
warps = [_blazeface_input_warp(img_raw, det_input_size) for img_raw in img_raws]
det_crops = [w[0] for w in warps]
sub_rects = [(w[1], w[2], w[3]) for w in warps]
regs_b, cls_b = self.detector(torch.stack(det_crops, dim=0))
regs_np, cls_np = regs_b.float().cpu().numpy(), cls_b.float().cpu().numpy()
per_frame = []
for b in range(len(images_rgb_uint8)):
decoded = decode_fn(regs_np[b], cls_np[b], score_thresh=score_thresh)
per_frame.append(_weighted_nms(decoded, iou_thresh=iou_thresh) if decoded.shape[0] > 0 else decoded)
return img_raws, sub_rects, sizes, per_frame
def detect_batch(self, images_rgb_uint8: List[np.ndarray], num_faces: int = 1,
score_thresh: float = _BF_MIN_SCORE) -> List[List[dict]]:
"""Full pipeline batched across `images_rgb_uint8`. Returns one face-dict
list per image (empty if nothing detected). Face dict:
bbox_xyxy (4,) image pixels, blendshapes {52} ∈ [0,1],
landmarks_xy (478, 2) image pixels, landmarks_3d (478, 3) in
192-canonical (pre-transformation) units, presence float (raw logit).
"""
img_raws, sub_rects, sizes, per_frame_dets = self.run_detector_batch(
images_rgb_uint8, score_thresh=score_thresh,
)
# tensor-normalized → image-normalized [0,1] for _detection_to_face_rect.
for b, decoded in enumerate(per_frame_dets):
if decoded.shape[0] == 0:
continue
cx, cy, size = sub_rects[b]
H, W = sizes[b]
sx0, sy0 = cx - size * 0.5, cy - size * 0.5
decoded[:, 0:16:2] = (sx0 + size * decoded[:, 0:16:2]) / W
decoded[:, 1:16:2] = (sy0 + size * decoded[:, 1:16:2]) / H
if num_faces > 0:
per_frame_dets[b] = decoded[: int(num_faces)]
# Collect every detected face across all frames into one mesh input.
face_params: List[Tuple[int, float, float, float, float, float, float]] = []
mesh_crops: List[Tensor] = []
for b, dets in enumerate(per_frame_dets):
if dets.shape[0] == 0:
continue
H, W = sizes[b]
img_for_mesh = img_raws[b] / 255.0
for det in dets:
cx, cy, w, h, angle = _detection_to_face_rect(det, W, H)
mesh_crops.append(_warp_face_crop(img_for_mesh, cx, cy, w, h, angle, _FM_INPUT_SIZE))
face_params.append((b, float(det[16]), cx, cy, w, h, angle))
results: List[List[dict]] = [[] for _ in range(len(images_rgb_uint8))]
if not mesh_crops:
return results
lmks_canon_b, presence_b = self.mesh(torch.stack(mesh_crops, dim=0))
bs_out_b = self.blendshapes(lmks_canon_b[:, self._bs_idx, :2])
# Batched canonical→image affine
params_t = torch.tensor(
[(cx, cy, w, h, math.cos(a), math.sin(a)) for (_b, _s, cx, cy, w, h, a) in face_params],
device=lmks_canon_b.device, dtype=lmks_canon_b.dtype,
)
cxs, cys, ws, hs, cos_a, sin_a = params_t.unbind(dim=1)
inv = 1.0 / _FM_INPUT_SIZE
u = lmks_canon_b[..., 0] - _FM_INPUT_SIZE * 0.5
v = lmks_canon_b[..., 1] - _FM_INPUT_SIZE * 0.5
lmks_xy_t = torch.stack([
cxs[:, None] + u * (ws * inv * cos_a)[:, None] - v * (hs * inv * sin_a)[:, None],
cys[:, None] + u * (ws * inv * sin_a)[:, None] + v * (hs * inv * cos_a)[:, None],
], dim=-1)
lmks_xy_np = lmks_xy_t.float().cpu().numpy()
lmks_canon_np = lmks_canon_b.float().cpu().numpy()
presence_np = presence_b.float().cpu().numpy()
bs_np = bs_out_b.float().cpu().numpy()
for i, (b, score, *_) in enumerate(face_params):
lmks_xy = lmks_xy_np[i]
mn, mx = lmks_xy.min(0), lmks_xy.max(0)
results[b].append({
"bbox_xyxy": np.array([mn[0], mn[1], mx[0], mx[1]], dtype=np.float32),
"blendshapes": dict(zip(BLENDSHAPE_NAMES, bs_np[i].tolist())),
"landmarks_xy": lmks_xy,
"landmarks_3d": lmks_canon_np[i],
"presence": float(presence_np[i]),
"score": score,
})
return results

View File

@ -543,7 +543,7 @@ class AudioConcat(IO.ComfyNode):
return IO.Schema(
node_id="AudioConcat",
search_aliases=["join audio", "combine audio", "append audio"],
display_name="Audio Concat",
display_name="Concatenate Audio",
description="Concatenates the audio1 to audio2 in the specified direction.",
category="audio",
inputs=[
@ -597,7 +597,7 @@ class AudioMerge(IO.ComfyNode):
return IO.Schema(
node_id="AudioMerge",
search_aliases=["mix audio", "overlay audio", "layer audio"],
display_name="Audio Merge",
display_name="Merge Audio",
description="Combine two audio tracks by overlaying their waveforms.",
category="audio",
inputs=[
@ -667,8 +667,9 @@ class AudioAdjustVolume(IO.ComfyNode):
return IO.Schema(
node_id="AudioAdjustVolume",
search_aliases=["audio gain", "loudness", "audio level"],
display_name="Audio Adjust Volume",
display_name="Adjust Audio Volume",
category="audio",
description="Adjust the volume of the audio by a specified amount in decibels (dB).",
inputs=[
IO.Audio.Input("audio"),
IO.Int.Input(

View File

@ -47,8 +47,10 @@ class LoadImageDataSetFromFolderNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LoadImageDataSetFromFolder",
display_name="Load Image Dataset from Folder",
category="dataset",
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
display_name="Load Image (from Folder)",
category="image",
description="Load a dataset of images from a specified folder and return a list of images. Supported formats: PNG, JPG, JPEG, WEBP.",
is_experimental=True,
inputs=[
io.Combo.Input(
@ -84,14 +86,16 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LoadImageTextDataSetFromFolder",
display_name="Load Image and Text Dataset from Folder",
category="dataset",
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
display_name="Load Image-Text (from Folder)",
category="image",
description="Load a dataset of pairs of images and text captions from a specified folder and return them as a list. Supported formats: PNG, JPG, JPEG, WEBP.",
is_experimental=True,
inputs=[
io.Combo.Input(
"folder",
options=folder_paths.get_input_subfolders(),
tooltip="The folder to load images from.",
tooltip="The folder to load images and text captions from.",
)
],
outputs=[
@ -206,8 +210,10 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveImageDataSetToFolder",
display_name="Save Image Dataset to Folder",
category="dataset",
search_aliases=["save folder", "save to folder", "save dataset", "save images", "export dataset"],
display_name="Save Image (to Folder) (DEPRECATED)",
category="image",
description="Save a dataset of images to a specified folder. Supported formats: PNG.",
is_experimental=True,
is_output_node=True,
is_input_list=True, # Receive images as list
@ -226,6 +232,7 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
),
],
outputs=[],
is_deprecated=True, # This node is redundant and superseded by existing Save Image nodes where the target folder can be specified in the filename_prefix
)
@classmethod
@ -246,14 +253,20 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveImageTextDataSetToFolder",
display_name="Save Image and Text Dataset to Folder",
category="dataset",
search_aliases=["save folder", "save to folder", "save dataset", "save images", "save text", "export dataset"],
display_name="Save Image-Text (to Folder)",
category="image",
description="Save a dataset of pairs of images and text captions to a specified folder. Images are saved as PNG files and captions are saved as TXT files with the same filename_prefix.",
is_experimental=True,
is_output_node=True,
is_input_list=True, # Receive both images and texts as lists
inputs=[
io.Image.Input("images", tooltip="List of images to save."),
io.String.Input("texts", tooltip="List of text captions to save."),
io.String.Input("texts",
optional=True,
force_input=True,
tooltip="List of text captions to save."
),
io.String.Input(
"folder_name",
default="dataset",
@ -270,7 +283,7 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
)
@classmethod
def execute(cls, images, texts, folder_name, filename_prefix):
def execute(cls, images, folder_name, filename_prefix, texts=None):
# Extract scalar values
folder_name = folder_name[0]
filename_prefix = filename_prefix[0]
@ -279,11 +292,12 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
saved_files = save_images_to_folder(images, output_dir, filename_prefix)
# Save captions
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
caption_filename = filename.replace(".png", ".txt")
caption_path = os.path.join(output_dir, caption_filename)
with open(caption_path, "w", encoding="utf-8") as f:
f.write(caption)
if texts:
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
caption_filename = filename.replace(".png", ".txt")
caption_path = os.path.join(output_dir, caption_filename)
with open(caption_path, "w", encoding="utf-8") as f:
f.write(caption)
logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.")
return io.NodeOutput()
@ -314,11 +328,13 @@ class ImageProcessingNode(io.ComfyNode):
Child classes should set:
node_id: Unique node identifier (required)
search_aliases: List of search aliases (optional)
display_name: Display name (optional, defaults to node_id)
description: Node description (optional)
extra_inputs: List of additional io.Input objects beyond "images" (optional)
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
is_output_list: True (list output) or False (single output) (optional, default True)
is_deprecated: True if the node is deprecated (optional, default False)
Child classes must implement ONE of:
_process(cls, image, **kwargs) -> tensor (for single-item processing)
@ -326,12 +342,13 @@ class ImageProcessingNode(io.ComfyNode):
"""
node_id = None
search_aliases = []
display_name = None
description = None
extra_inputs = []
is_group_process = None # None = auto-detect, True/False = explicit
is_output_list = None # None = auto-detect based on processing mode
is_deprecated = False
@classmethod
def _detect_processing_mode(cls):
"""Detect whether this node uses group or individual processing.
@ -402,8 +419,10 @@ class ImageProcessingNode(io.ComfyNode):
return io.Schema(
node_id=cls.node_id,
search_aliases=cls.search_aliases,
display_name=cls.display_name or cls.node_id,
category="dataset/image",
category=cls.category,
description=cls.description,
is_experimental=True,
is_input_list=is_group, # True for group, False for individual
inputs=inputs,
@ -472,11 +491,13 @@ class TextProcessingNode(io.ComfyNode):
Child classes should set:
node_id: Unique node identifier (required)
search_aliases: List of search aliases (optional)
display_name: Display name (optional, defaults to node_id)
description: Node description (optional)
extra_inputs: List of additional io.Input objects beyond "texts" (optional)
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
is_output_list: True (list output) or False (single output) (optional, default True)
is_deprecated: True if the node is deprecated (optional, default False)
Child classes must implement ONE of:
_process(cls, text, **kwargs) -> str (for single-item processing)
@ -484,12 +505,13 @@ class TextProcessingNode(io.ComfyNode):
"""
node_id = None
search_aliases = []
display_name = None
description = None
extra_inputs = []
is_group_process = None # None = auto-detect, True/False = explicit
is_output_list = None # None = auto-detect based on processing mode
is_deprecated = False
@classmethod
def _detect_processing_mode(cls):
"""Detect whether this node uses group or individual processing.
@ -627,15 +649,17 @@ class TextProcessingNode(io.ComfyNode):
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
node_id = "ResizeImagesByShorterEdge"
display_name = "Resize Images by Shorter Edge"
description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio."
display_name = "Resize Images by Shorter Edge (DEPRECATED)"
category = "image/transform"
description = "Resize images so that the shorter edge matches the specified dimension while preserving aspect ratio."
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale shorter dimension
extra_inputs = [
io.Int.Input(
"shorter_edge",
default=512,
min=1,
max=8192,
tooltip="Target length for the shorter edge.",
tooltip="Target dimension for the shorter edge.",
),
]
@ -655,15 +679,17 @@ class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
node_id = "ResizeImagesByLongerEdge"
display_name = "Resize Images by Longer Edge"
description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio."
display_name = "Resize Images by Longer Edge (DEPRECATED)"
category = "image/transform"
description = "Resize images so that the longer edge matches the specified dimension while preserving aspect ratio."
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale longer dimension
extra_inputs = [
io.Int.Input(
"longer_edge",
default=1024,
min=1,
max=8192,
tooltip="Target length for the longer edge.",
tooltip="Target dimension for the longer edge.",
),
]
@ -686,8 +712,10 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
class CenterCropImagesNode(ImageProcessingNode):
node_id = "CenterCropImages"
display_name = "Center Crop Images"
description = "Center crop all images to the specified dimensions."
search_aliases=["crop", "cut", "trim"]
display_name="Crop Image (Center)"
category="image/transform"
description = "Center crop an image to the specified dimensions."
extra_inputs = [
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
@ -706,10 +734,11 @@ class CenterCropImagesNode(ImageProcessingNode):
class RandomCropImagesNode(ImageProcessingNode):
node_id = "RandomCropImages"
display_name = "Random Crop Images"
description = (
"Randomly crop all images to the specified dimensions (for data augmentation)."
)
search_aliases=["crop", "cut", "trim"]
display_name = "Crop Image (Random)"
category="image/transform"
description = "Randomly crop an image to the specified dimensions."
extra_inputs = [
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
@ -734,7 +763,9 @@ class RandomCropImagesNode(ImageProcessingNode):
class NormalizeImagesNode(ImageProcessingNode):
node_id = "NormalizeImages"
display_name = "Normalize Images"
search_aliases=["normalize", "normalize colors"]
display_name = "Normalize Image Colors"
category = "image/color"
description = "Normalize images using mean and standard deviation."
extra_inputs = [
io.Float.Input(
@ -762,8 +793,10 @@ class NormalizeImagesNode(ImageProcessingNode):
class AdjustBrightnessNode(ImageProcessingNode):
node_id = "AdjustBrightness"
search_aliases=["brightness"]
display_name = "Adjust Brightness"
description = "Adjust brightness of all images."
category="image/adjustments"
description = "Adjust the brightness of an image."
extra_inputs = [
io.Float.Input(
"factor",
@ -781,8 +814,10 @@ class AdjustBrightnessNode(ImageProcessingNode):
class AdjustContrastNode(ImageProcessingNode):
node_id = "AdjustContrast"
search_aliases=["contrast"]
display_name = "Adjust Contrast"
description = "Adjust contrast of all images."
category="image/adjustments"
description = "Adjust the contrast of an image."
extra_inputs = [
io.Float.Input(
"factor",
@ -800,8 +835,10 @@ class AdjustContrastNode(ImageProcessingNode):
class ShuffleDatasetNode(ImageProcessingNode):
node_id = "ShuffleDataset"
display_name = "Shuffle Image Dataset"
description = "Randomly shuffle the order of images in the dataset."
search_aliases=["shuffle", "randomize", "mix"]
display_name = "Shuffle Images List"
category = "image/batch"
description = "Randomly shuffle the order of images in a list."
is_group_process = True # Requires full list to shuffle
extra_inputs = [
io.Int.Input(
@ -823,13 +860,15 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ShuffleImageTextDataset",
display_name="Shuffle Image-Text Dataset",
category="dataset/image",
search_aliases=["shuffle", "randomize", "mix"],
display_name = "Shuffle Pairs of Image-Text",
category = "image/batch",
description = "Randomly shuffle the order of pairs of image-text in a list.",
is_experimental=True,
is_input_list=True,
inputs=[
io.Image.Input("images", tooltip="List of images to shuffle."),
io.String.Input("texts", tooltip="List of texts to shuffle."),
io.String.Input("texts", tooltip="List of texts to shuffle.", force_input=True),
io.Int.Input(
"seed",
default=0,
@ -865,8 +904,11 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
class TextToLowercaseNode(TextProcessingNode):
node_id = "TextToLowercase"
display_name = "Text to Lowercase"
description = "Convert all texts to lowercase."
search_aliases=["lowercase"]
display_name = "Convert Text to Lowercase (DEPRECATED)"
category = "text"
description = "Convert text to lowercase."
is_deprecated = True # This node is superseded by the Convert Text Case node
@classmethod
def _process(cls, text):
@ -875,8 +917,11 @@ class TextToLowercaseNode(TextProcessingNode):
class TextToUppercaseNode(TextProcessingNode):
node_id = "TextToUppercase"
display_name = "Text to Uppercase"
description = "Convert all texts to uppercase."
search_aliases=["uppercase"]
display_name = "Convert Text to Uppercase (DEPRECATED)"
category = "text"
description = "Convert text to uppercase."
is_deprecated = True # This node is superseded by the Convert Text Case node
@classmethod
def _process(cls, text):
@ -885,8 +930,10 @@ class TextToUppercaseNode(TextProcessingNode):
class TruncateTextNode(TextProcessingNode):
node_id = "TruncateText"
search_aliases=["truncate", "cut", "shorten"]
display_name = "Truncate Text"
description = "Truncate all texts to a maximum length."
category = "text"
description = "Truncate text to a maximum length."
extra_inputs = [
io.Int.Input(
"max_length", default=77, min=1, max=10000, tooltip="Maximum text length."
@ -900,8 +947,10 @@ class TruncateTextNode(TextProcessingNode):
class AddTextPrefixNode(TextProcessingNode):
node_id = "AddTextPrefix"
display_name = "Add Text Prefix"
display_name = "Add Text Prefix (DEPRECATED)"
category = "text"
description = "Add a prefix to all texts."
is_deprecated = True # This node is superseded by the Concatenate Text node
extra_inputs = [
io.String.Input("prefix", default="", tooltip="Prefix to add."),
]
@ -913,8 +962,10 @@ class AddTextPrefixNode(TextProcessingNode):
class AddTextSuffixNode(TextProcessingNode):
node_id = "AddTextSuffix"
display_name = "Add Text Suffix"
display_name = "Add Text Suffix (DEPRECATED)"
category = "text"
description = "Add a suffix to all texts."
is_deprecated = True # This node is superseded by the Concatenate Text node
extra_inputs = [
io.String.Input("suffix", default="", tooltip="Suffix to add."),
]
@ -926,8 +977,10 @@ class AddTextSuffixNode(TextProcessingNode):
class ReplaceTextNode(TextProcessingNode):
node_id = "ReplaceText"
display_name = "Replace Text"
display_name = "Replace Text (DEPRECATED)"
category = "text"
description = "Replace text in all texts."
is_deprecated = True # This node is superseded by the other Replace Text node
extra_inputs = [
io.String.Input("find", default="", tooltip="Text to find."),
io.String.Input("replace", default="", tooltip="Text to replace with."),
@ -940,8 +993,10 @@ class ReplaceTextNode(TextProcessingNode):
class StripWhitespaceNode(TextProcessingNode):
node_id = "StripWhitespace"
display_name = "Strip Whitespace"
display_name = "Strip Whitespace (DEPRECATED)"
category = "text"
description = "Strip leading and trailing whitespace from all texts."
is_deprecated = True # This node is superseded by the Trim Text node
@classmethod
def _process(cls, text):
@ -952,11 +1007,13 @@ class StripWhitespaceNode(TextProcessingNode):
class ImageDeduplicationNode(ImageProcessingNode):
"""Remove duplicate or very similar images from the dataset using perceptual hashing."""
"""Remove duplicate or very similar images from a list using perceptual hashing."""
node_id = "ImageDeduplication"
display_name = "Image Deduplication"
description = "Remove duplicate or very similar images from the dataset."
search_aliases=["deduplicate", "remove duplicates", "similarity filter"]
display_name = "Deduplicate Images"
category = "image/batch"
description = "Remove duplicate or very similar images from a list."
is_group_process = True # Requires full list to compare images
extra_inputs = [
io.Float.Input(
@ -1026,7 +1083,9 @@ class ImageGridNode(ImageProcessingNode):
"""Combine multiple images into a single grid/collage."""
node_id = "ImageGrid"
display_name = "Image Grid"
search_aliases=["grid", "collage", "combine"]
display_name = "Make Image Grid"
category="image/batch"
description = "Arrange multiple images into a grid layout."
is_group_process = True # Requires full list to create grid
is_output_list = False # Outputs single grid image
@ -1102,9 +1161,12 @@ class MergeImageListsNode(ImageProcessingNode):
"""Merge multiple image lists into a single list."""
node_id = "MergeImageLists"
display_name = "Merge Image Lists"
search_aliases=["list", "merge list", "make list"]
display_name = "Merge Image Lists (DEPRECATED)"
category = "image/batch"
description = "Concatenate multiple image lists into one."
is_group_process = True # Receives images as list
is_deprecated = True # This node is superseded by the Create List node
@classmethod
def _group_process(cls, images):
@ -1119,9 +1181,11 @@ class MergeTextListsNode(TextProcessingNode):
"""Merge multiple text lists into a single list."""
node_id = "MergeTextLists"
display_name = "Merge Text Lists"
display_name = "Merge Text Lists (DEPRECATED)"
category = "text"
description = "Concatenate multiple text lists into one."
is_group_process = True # Receives texts as list
is_deprecated = True # This node is superseded by the Create List node
@classmethod
def _group_process(cls, texts):
@ -1142,8 +1206,10 @@ class ResolutionBucket(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ResolutionBucket",
search_aliases=["bucket by resolution", "group by resolution", "batch by resolution"],
display_name="Resolution Bucket",
category="dataset",
category="training",
description="Group latents and conditionings into buckets",
is_experimental=True,
is_input_list=True,
inputs=[
@ -1236,7 +1302,8 @@ class MakeTrainingDataset(io.ComfyNode):
node_id="MakeTrainingDataset",
search_aliases=["encode dataset"],
display_name="Make Training Dataset",
category="dataset",
category="training",
description="Encode images with VAE and texts with CLIP to create a training dataset of latents and conditionings.",
is_experimental=True,
is_input_list=True, # images and texts as lists
inputs=[
@ -1251,6 +1318,7 @@ class MakeTrainingDataset(io.ComfyNode):
"texts",
optional=True,
tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).",
force_input=True
),
],
outputs=[
@ -1320,9 +1388,10 @@ class SaveTrainingDataset(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveTrainingDataset",
search_aliases=["export training data"],
search_aliases=["export dataset", "save dataset"],
display_name="Save Training Dataset",
category="dataset",
category="training",
description="Save encoded training dataset (latents + conditioning) to disk for efficient loading during training.",
is_experimental=True,
is_output_node=True,
is_input_list=True, # Receive lists
@ -1424,7 +1493,8 @@ class LoadTrainingDataset(io.ComfyNode):
node_id="LoadTrainingDataset",
search_aliases=["import dataset", "training data"],
display_name="Load Training Dataset",
category="dataset",
category="training",
description="Load encoded training dataset (latents + conditioning) from disk for use in training.",
is_experimental=True,
inputs=[
io.String.Input(

View File

@ -419,15 +419,17 @@ class VoxelToMeshBasic(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="VoxelToMeshBasic",
display_name="Voxel to Mesh (Basic)",
display_name="Voxel to Mesh (Basic) (DEPRECATED)",
category="3d",
description="Converts a voxel grid to a mesh.",
is_deprecated=True, # This node is superseded by the Voxel To Mesh node
inputs=[
IO.Voxel.Input("voxel"),
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
],
outputs=[
IO.Mesh.Output(),
]
],
)
@classmethod
@ -453,9 +455,10 @@ class VoxelToMesh(IO.ComfyNode):
node_id="VoxelToMesh",
display_name="Voxel to Mesh",
category="3d",
description="Converts a voxel grid to a mesh.",
inputs=[
IO.Voxel.Input("voxel"),
IO.Combo.Input("algorithm", options=["surface net", "basic"], advanced=True),
IO.Combo.Input("algorithm", options=["surface net", "basic"]),
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
],
outputs=[

View File

@ -55,9 +55,10 @@ class ImageCropV2(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ImageCropV2",
search_aliases=["trim"],
search_aliases=["crop", "cut", "trim"],
display_name="Crop Image",
category="image/transform",
description = "Crop an image to the specified dimensions.",
essentials_category="Image Tools",
has_intermediate_output=True,
inputs=[

View File

@ -15,7 +15,7 @@ class SwitchNode(io.ComfyNode):
return io.Schema(
node_id="ComfySwitchNode",
display_name="Switch",
category="logic",
category="utils/logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
@ -46,7 +46,7 @@ class SoftSwitchNode(io.ComfyNode):
return io.Schema(
node_id="ComfySoftSwitchNode",
display_name="Soft Switch",
category="logic",
category="utils/logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
@ -136,7 +136,7 @@ class DCTestNode(io.ComfyNode):
return io.Schema(
node_id="DCTestNode",
display_name="DCTest",
category="logic",
category="utils/logic",
is_output_node=True,
inputs=[io.DynamicCombo.Input("combo", options=[
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
@ -174,7 +174,7 @@ class AutogrowNamesTestNode(io.ComfyNode):
return io.Schema(
node_id="AutogrowNamesTestNode",
display_name="AutogrowNamesTest",
category="logic",
category="utils/logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
@ -194,7 +194,7 @@ class AutogrowPrefixTestNode(io.ComfyNode):
return io.Schema(
node_id="AutogrowPrefixTestNode",
display_name="AutogrowPrefixTest",
category="logic",
category="utils/logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
@ -213,7 +213,7 @@ class ComboOutputTestNode(io.ComfyNode):
return io.Schema(
node_id="ComboOptionTestNode",
display_name="ComboOptionTest",
category="logic",
category="utils/logic",
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
outputs=[io.Combo.Output(), io.Combo.Output()],
@ -230,7 +230,7 @@ class ConvertStringToComboNode(io.ComfyNode):
node_id="ConvertStringToComboNode",
search_aliases=["string to dropdown", "text to combo"],
display_name="Convert String to Combo",
category="logic",
category="utils/logic",
inputs=[io.String.Input("string")],
outputs=[io.Combo.Output()],
)
@ -246,7 +246,7 @@ class InvertBooleanNode(io.ComfyNode):
node_id="InvertBooleanNode",
search_aliases=["not", "toggle", "negate", "flip boolean"],
display_name="Invert Boolean",
category="logic",
category="utils/logic",
inputs=[io.Boolean.Input("boolean")],
outputs=[io.Boolean.Output()],
)

View File

@ -11,8 +11,8 @@ class LTXVAudioVAELoader(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVAudioVAELoader",
display_name="LTXV Audio VAE Loader",
category="audio",
display_name="Load LTXV Audio VAE",
category="loaders",
inputs=[
io.Combo.Input(
"ckpt_name",
@ -40,7 +40,7 @@ class LTXVAudioVAEEncode(VAEEncodeAudio):
return io.Schema(
node_id="LTXVAudioVAEEncode",
display_name="LTXV Audio VAE Encode",
category="audio",
category="latent/audio",
inputs=[
io.Audio.Input("audio", tooltip="The audio to be encoded."),
io.Vae.Input(
@ -63,7 +63,7 @@ class LTXVAudioVAEDecode(io.ComfyNode):
return io.Schema(
node_id="LTXVAudioVAEDecode",
display_name="LTXV Audio VAE Decode",
category="audio",
category="latent/audio",
inputs=[
io.Latent.Input("samples", tooltip="The latent to be decoded."),
io.Vae.Input(

View File

@ -70,7 +70,7 @@ class MathExpressionNode(io.ComfyNode):
return io.Schema(
node_id="ComfyMathExpression",
display_name="Math Expression",
category="logic",
category="utils",
search_aliases=[
"expression", "formula", "calculate", "calculator",
"eval", "math",

View File

@ -0,0 +1,509 @@
"""ComfyUI nodes for the pure-PyTorch MediaPipe Face Landmarker port.
Custom IO types:
FACE_LANDMARKER — FaceLandmarkerModel wrapper (ModelPatcher inside)
FACE_LANDMARKS — {"frames": List[List[face_dict]], "image_size": (H, W),
"connection_sets": dict[str, frozenset[(int, int)]]}
face_dict: bbox_xyxy, blendshapes, landmarks_xy,
landmarks_3d, presence, score, transformation_matrix
MediaPipeFaceLandmarker also emits the core BOUNDING_BOX type — pair with DrawBBoxes.
"""
from __future__ import annotations
import numpy as np
import torch
from PIL import Image, ImageColor, ImageDraw
from tqdm.auto import tqdm
from typing_extensions import override
import comfy.model_management
import comfy.model_patcher
import comfy.utils
import folder_paths
from comfy_api.latest import ComfyExtension, io
from comfy_extras.mediapipe.face_landmarker import FaceLandmarker
from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection
FaceDetectionType = io.Custom("FACE_DETECTION_MODEL")
FaceLandmarksType = io.Custom("FACE_LANDMARKS")
_CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights")
_CONTOUR_PARTS = ("face_oval", "left_eye", "right_eye", "left_eyebrow", "right_eyebrow", "lips")
class FaceLandmarkerModel:
"""Loaded FaceLandmarker variants + ModelPatcher per variant.
Safetensors layout: `detector_short.*` / `detector_full.*` plus shared
`mesh.*`, `blendshapes.*`, `canonical_*`, and `topology.*`.
PReLU forces plain-nn / fp32 (manual_cast strands buffers across devices).
"""
def __init__(self, state_dict: dict):
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = torch.float32
# FACEMESH_* connection sets, embedded as int32 (N, 2) under topology.*.
base: dict[str, frozenset] = {}
for k in [k for k in state_dict if k.startswith("topology.")]:
base[k[len("topology."):]] = frozenset(map(tuple, state_dict.pop(k).tolist()))
base["contours"] = frozenset().union(*(base[p] for p in _CONTOUR_PARTS))
base["all"] = base["contours"] | base["irises"] | base["nose"]
self.connection_sets: dict[str, frozenset] = base
self.canonical_data: dict[str, np.ndarray] = {k: state_dict.pop(k).numpy() for k in _CANONICAL_KEYS}
shared = {k: v for k, v in state_dict.items() if k.startswith(("mesh.", "blendshapes."))}
self.models: dict[str, FaceLandmarker] = {}
self.patchers: dict[str, comfy.model_patcher.ModelPatcher] = {}
for variant in ("short", "full"):
prefix = f"detector_{variant}."
sub = dict(shared)
sub.update({f"detector.{k[len(prefix):]}": v for k, v in state_dict.items() if k.startswith(prefix)})
fl = FaceLandmarker(device=offload_device, dtype=self.dtype, operations=None, detector_variant=variant).eval()
fl.load_state_dict(sub, strict=False)
self.models[variant] = fl
self.patchers[variant] = comfy.model_patcher.CoreModelPatcher(
fl, load_device=self.load_device, offload_device=offload_device,
size=comfy.model_management.module_size(fl),
)
def detect_batch(self, images, num_faces: int, score_thresh: float, variant: str):
comfy.model_management.load_model_gpu(self.patchers[variant])
return self.models[variant].detect_batch(images, num_faces=num_faces, score_thresh=score_thresh)
def _image_to_uint8(image: torch.Tensor) -> np.ndarray:
return image[..., :3].mul(255.0).add_(0.5).clamp_(0, 255).to(torch.uint8).cpu().numpy()
def _parse_color(color: str) -> tuple[int, int, int]:
try:
return ImageColor.getrgb(color)[:3]
except ValueError:
return (0, 255, 0)
def _copy_face(face: dict) -> dict:
"""Shallow copy of a face_dict with array-fields cloned so callers can mutate."""
return {
"bbox_xyxy": face["bbox_xyxy"].copy(),
"blendshapes": dict(face["blendshapes"]),
"landmarks_xy": face["landmarks_xy"].copy(),
"landmarks_3d": face["landmarks_3d"].copy(),
"presence": face["presence"],
"score": face["score"],
}
def _lerp_face(a: dict, b: dict, t: float) -> dict:
return {
"bbox_xyxy": (1 - t) * a["bbox_xyxy"] + t * b["bbox_xyxy"],
"blendshapes": {k: (1 - t) * a["blendshapes"][k] + t * b["blendshapes"][k] for k in a["blendshapes"]},
"landmarks_xy": (1 - t) * a["landmarks_xy"] + t * b["landmarks_xy"],
"landmarks_3d": (1 - t) * a["landmarks_3d"] + t * b["landmarks_3d"],
"presence": (1 - t) * a["presence"] + t * b["presence"],
"score": (1 - t) * a["score"] + t * b["score"],
}
def _match_faces(a: list[dict], b: list[dict]) -> list[tuple[int, int]]:
"""Greedy nearest-neighbour pairing of faces between two frames by bbox
centre distance. Unmatched (when counts differ) are dropped."""
if not a or not b:
return []
centers_a = np.array([(0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]),
0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3])) for f in a])
centers_b = np.array([(0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]),
0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3])) for f in b])
dists = np.linalg.norm(centers_a[:, None] - centers_b[None], axis=-1)
pairs: list[tuple[int, int]] = []
used_a: set[int] = set()
used_b: set[int] = set()
candidates = sorted((dists[ia, ib], ia, ib) for ia in range(len(a)) for ib in range(len(b)))
for _, ia, ib in candidates:
if ia in used_a or ib in used_b:
continue
pairs.append((ia, ib))
used_a.add(ia)
used_b.add(ib)
return pairs
def _fill_missing_frames(frames: list[list[dict]], mode: str) -> None:
"""In-place fill empty frame slots from neighbouring detections. Multi-face
aware: pairs faces across bracketing frames by greedy bbox-centre NN.
When counts differ, unmatched faces are dropped from the synthesised frame."""
if mode == "empty":
return
valid = [i for i, fr in enumerate(frames) if fr]
if not valid:
return # nothing to fill from
if mode == "previous":
last: list[dict] = []
for i, fr in enumerate(frames):
if fr:
last = fr
elif last:
frames[i] = [_copy_face(f) for f in last]
return
# interpolate: lerp between bracketing valid frames; clamp at ends.
for i in range(len(frames)):
if frames[i]:
continue
prev_i = max((v for v in valid if v < i), default=None)
next_i = min((v for v in valid if v > i), default=None)
if prev_i is None:
frames[i] = [_copy_face(f) for f in frames[next_i]]
elif next_i is None:
frames[i] = [_copy_face(f) for f in frames[prev_i]]
else:
t = (i - prev_i) / (next_i - prev_i)
pairs = _match_faces(frames[prev_i], frames[next_i])
frames[i] = [_lerp_face(frames[prev_i][a], frames[next_i][b], t) for a, b in pairs]
def _ordered_rings(edges: frozenset[tuple[int, int]]) -> list[list[int]]:
"""Walk an unordered edge set into one or more closed-loop vertex rings
(handles multi-loop sets like FACEMESH_LIPS: outer + inner)."""
adj: dict[int, set[int]] = {}
for a, b in edges:
adj.setdefault(a, set()).add(b)
adj.setdefault(b, set()).add(a)
visited: set[int] = set()
rings: list[list[int]] = []
for start in adj:
if start in visited:
continue
ring = [start]
visited.add(start)
prev, cur = -1, start
while True:
nxt = next((v for v in adj[cur] if v != prev), None)
if nxt is None or nxt == start:
break
ring.append(nxt)
visited.add(nxt)
prev, cur = cur, nxt
rings.append(ring)
return rings
class LoadMediaPipeFaceLandmarker(io.ComfyNode):
"""Load MediaPipe Face Landmarker v2 weights. Contains both detector variants
(short / full), shared mesh, blendshapes, and canonical geometry."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadMediaPipeFaceLandmarker",
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
display_name="Load Face Detection Model (MediaPipe)",
category="loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("detection"),
tooltip="Face detection model from models/detection/."),
],
outputs=[FaceDetectionType.Output()],
)
@classmethod
def execute(cls, model_name) -> io.NodeOutput:
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("detection", model_name), safe_load=True)
wrapper = FaceLandmarkerModel(sd)
return io.NodeOutput(wrapper)
# Per-frame fallback modes for detection failures in a batch.
_FALLBACK_MODES = ("empty", "previous", "interpolate")
class MediaPipeFaceLandmarker(io.ComfyNode):
"""BlazeFace → FaceMesh v2 → ARKit-52 blendshapes, batched across the
input. Also emits a BOUNDING_BOX list (landmark-extent bbox per face) —
pair with DrawBBoxes for detector-only viz or MediaPipeFaceMeshVisualize
for the mesh overlay."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MediaPipeFaceLandmarker",
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
display_name="Detect Face Landmarks (MediaPipe)",
category="image/detection",
description="Detects facial landmarks using MediaPipe model.",
inputs=[
FaceDetectionType.Input("face_detection_model"),
io.Image.Input("image"),
io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short",
tooltip="Face detector range. 'short' is tuned for close-up faces "
"(within ~2 m of the camera); 'full' covers farther / smaller "
"faces (up to ~5 m) but is slower. 'both' runs both detectors and "
"keeps whichever found more faces per frame (~2× detection cost)."),
io.Int.Input("num_faces", default=1, min=0, max=16, step=1,
tooltip="Maximum faces to return per frame. 0 = no cap (return all detected)."),
io.Float.Input("min_confidence", default=0.5, min=0.0, max=1.0, step=0.01, advanced=True,
tooltip="BlazeFace score threshold. Lower to catch small/occluded faces."),
io.Combo.Input("missing_frame_fallback", options=list(_FALLBACK_MODES), default="empty", advanced=True,
tooltip="Per-frame behaviour when detection fails in a batch. "
"'empty' leaves the frame faceless. 'previous' copies the most recent successful "
"detection. 'interpolate' lerps landmarks/bbox/blendshapes between bracketing "
"successful frames. Multi-face: pairs faces across frames by greedy bbox-centre NN."),
],
outputs=[
FaceLandmarksType.Output(display_name="face_landmarks"),
io.BoundingBox.Output("bboxes"),
],
)
@classmethod
def execute(cls, face_detection_model, image, detector_variant, num_faces, min_confidence,
missing_frame_fallback) -> io.NodeOutput:
canonical = face_detection_model.canonical_data
img_np = _image_to_uint8(image)
B, H, W = img_np.shape[:3]
chunk = 16
is_both = detector_variant == "both"
total_work = 2 * B if is_both else B
pbar = comfy.utils.ProgressBar(total_work)
def _run(variant: str) -> list[list[dict]]:
res: list[list[dict]] = []
with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq:
for i in range(0, B, chunk):
end = min(i + chunk, B)
res.extend(face_detection_model.detect_batch(
[img_np[bi] for bi in range(i, end)],
num_faces=int(num_faces),
score_thresh=float(min_confidence),
variant=variant,
))
pbar.update_absolute(min(pbar.current + (end - i), total_work))
tq.update(end - i)
return res
if is_both:
short_res = _run("short")
full_res = _run("full")
# Per-frame keep whichever found more faces (tie → short).
frames: list[list[dict]] = [
short_res[bi] if len(short_res[bi]) >= len(full_res[bi]) else full_res[bi]
for bi in range(B)
]
else:
frames = _run(detector_variant)
_fill_missing_frames(frames, missing_frame_fallback)
bboxes = []
for per_frame in frames:
per_bb = []
for f in per_frame:
f["transformation_matrix"] = transformation_matrix_from_detection(f, W, H, canonical)
x1, y1, x2, y2 = (float(v) for v in f["bbox_xyxy"])
per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])})
bboxes.append(per_bb)
return io.NodeOutput({"frames": frames, "image_size": (H, W),
"connection_sets": face_detection_model.connection_sets}, bboxes)
# Topology keys unioned by the 'all' connections preset (contour parts + irises + nose).
_ALL_CONNECTION_PARTS: tuple[str, ...] = (*_CONTOUR_PARTS, "irises", "nose")
_CUSTOM_FEATURES: tuple[tuple[str, bool], ...] = (
("face_oval", True),
("lips", True),
("left_eye", True),
("right_eye", True),
("left_eyebrow", True),
("right_eyebrow", True),
("irises", True),
("nose", True),
("tesselation", False),
)
class MediaPipeFaceMeshVisualize(io.ComfyNode):
"""Draw a FACEMESH_* subset over an image. Topology travels with the
FACE_LANDMARKS payload (set at detection time)."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MediaPipeFaceMeshVisualize",
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection", "visualize"],
display_name="Visualize Face Landmarks (MediaPipe)",
category="image/detection",
description="Draws face landmarks mesh on the input image.",
inputs=[
FaceLandmarksType.Input("face_landmarks"),
io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."),
io.DynamicCombo.Input(
"connections",
tooltip="'all' = oval+eyes+brows+lips+irises+nose. 'fill' = solid face_oval polygon (silhouette mask). 'custom' = toggle each feature individually (including 'tesselation', the full 2547-edge wireframe).",
options=[
io.DynamicCombo.Option("all", []),
io.DynamicCombo.Option("fill", []),
io.DynamicCombo.Option("custom", [
io.Boolean.Input(feat, default=default,
tooltip=f"Draw the '{feat}' connection set.")
for feat, default in _CUSTOM_FEATURES
]),
],
),
io.Color.Input("color", default="#00ff00"),
io.Int.Input("thickness", default=1, min=0, max=8, step=1,
tooltip="Edge line thickness in pixels. 0 disables edge drawing."),
io.Int.Input("point_size", default=2, min=0, max=16, step=1,
tooltip="Landmark dot radius in pixels. 0 disables point drawing."),
],
outputs=[io.Image.Output()],
)
@classmethod
def execute(cls, face_landmarks, connections, color, thickness, point_size, image=None) -> io.NodeOutput:
sets = face_landmarks["connection_sets"]
sel = connections["connections"]
fill_rings: list[list[int]] | None = None
if sel == "fill":
fill_rings = _ordered_rings(sets["face_oval"])
edges = frozenset()
elif sel == "custom":
parts = [feat for feat, _ in _CUSTOM_FEATURES if connections.get(feat, False)]
edges = frozenset().union(*(sets[p] for p in parts))
else: # "all"
edges = frozenset().union(*(sets[p] for p in _ALL_CONNECTION_PARTS))
rgb, thick, psize = _parse_color(color), int(thickness), int(point_size)
frames = face_landmarks["frames"]
if image is None:
H, W = face_landmarks["image_size"]
img_np = np.zeros((len(frames), H, W, 3), dtype=np.uint8)
else:
img_np = _image_to_uint8(image)
B = img_np.shape[0]
n_frames = len(frames)
pbar = comfy.utils.ProgressBar(B)
out = np.empty_like(img_np)
for bi in range(B):
faces = frames[bi] if bi < n_frames else []
out[bi] = _draw_mesh(img_np[bi], faces, edges, rgb, thick, psize, fill_rings)
pbar.update_absolute(bi + 1)
return io.NodeOutput(torch.from_numpy(out).to(
device=comfy.model_management.intermediate_device(),
dtype=comfy.model_management.intermediate_dtype(),
).div_(255.0))
def _draw_mesh(image_rgb: np.ndarray, faces: list, edges,
rgb: tuple[int, int, int], thickness: int,
point_size: int, fill_rings: list[list[int]] | None = None) -> np.ndarray:
draw_edges = thickness > 0 and edges
if not faces or (fill_rings is None and not draw_edges and point_size <= 0):
return image_rgb.copy()
pil = Image.fromarray(image_rgb)
draw = ImageDraw.Draw(pil)
r = point_size * 0.5
if fill_rings is not None:
for f in faces:
lmks = f["landmarks_xy"]
for ring in fill_rings:
draw.polygon([(float(lmks[i, 0]), float(lmks[i, 1])) for i in ring], fill=rgb)
return np.asarray(pil)
for f in faces:
lmks = f["landmarks_xy"]
n = lmks.shape[0]
if draw_edges:
for a, b in edges:
if a < n and b < n:
draw.line([(float(lmks[a, 0]), float(lmks[a, 1])),
(float(lmks[b, 0]), float(lmks[b, 1]))], fill=rgb, width=thickness)
if point_size == 1:
draw.point(lmks.flatten().tolist(), fill=rgb)
elif point_size > 1:
for x, y in lmks:
draw.ellipse((float(x) - r, float(y) - r, float(x) + r, float(y) + r), fill=rgb)
return np.asarray(pil)
# Mask region presets — closed-loop topologies only.
_MASK_REGIONS: tuple[str, ...] = ("face_oval", "lips", "left_eye", "right_eye", "irises")
_MASK_CUSTOM_FEATURES: tuple[tuple[str, bool], ...] = (
("face_oval", True),
("lips", False),
("left_eye", False),
("right_eye", False),
("irises", False),
)
class MediaPipeFaceMask(io.ComfyNode):
"""Binary mask from face landmarks, filled polygon per face. One mask per
frame in the batch; faces in the same frame composite (union)."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MediaPipeFaceMask",
search_aliases=["face", "facial", "mediapipe", "face mask", "blazeface", "face detection", "visualize"],
display_name="Draw Face Mask (MediaPipe)",
category="image/detection",
description="Draws a mask from face landmarks.",
inputs=[
FaceLandmarksType.Input("face_landmarks"),
io.DynamicCombo.Input(
"regions",
tooltip="'all' = union of face_oval+lips+eyes+irises (which collapses to face_oval since it encloses the rest). 'custom' = toggle each region individually for combos like lips+eyes.",
options=[
io.DynamicCombo.Option("all", []),
io.DynamicCombo.Option("custom", [
io.Boolean.Input(reg, default=default,
tooltip=f"Include the '{reg}' region in the mask.")
for reg, default in _MASK_CUSTOM_FEATURES
]),
],
),
],
outputs=[io.Mask.Output()],
)
@classmethod
def execute(cls, face_landmarks, regions) -> io.NodeOutput:
sets = face_landmarks["connection_sets"]
sel = regions["regions"]
if sel == "custom":
picked = [reg for reg, _ in _MASK_CUSTOM_FEATURES if regions.get(reg, False)]
else:
picked = list(_MASK_REGIONS)
rings = [r for reg in picked for r in _ordered_rings(sets[reg])]
frames = face_landmarks["frames"]
H, W = face_landmarks["image_size"]
masks = np.zeros((len(frames), H, W), dtype=np.uint8)
pbar = comfy.utils.ProgressBar(len(frames))
for bi, per_frame in enumerate(frames):
if per_frame:
pil = Image.new("L", (W, H), 0)
draw = ImageDraw.Draw(pil)
for f in per_frame:
lmks = f["landmarks_xy"]
for ring in rings:
draw.polygon([(float(lmks[i, 0]), float(lmks[i, 1])) for i in ring], fill=255)
masks[bi] = np.asarray(pil)
pbar.update_absolute(bi + 1)
return io.NodeOutput(torch.from_numpy(masks).to(
device=comfy.model_management.intermediate_device(),
dtype=comfy.model_management.intermediate_dtype(),
).div_(255.0))
class MediaPipeFaceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [LoadMediaPipeFaceLandmarker, MediaPipeFaceLandmarker, MediaPipeFaceMeshVisualize, MediaPipeFaceMask]
async def comfy_entrypoint() -> MediaPipeFaceExtension:
return MediaPipeFaceExtension()

View File

@ -103,8 +103,10 @@ class MoGePanoramaInference(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MoGePanoramaInference",
display_name="MoGe Panorama Inference",
search_aliases=["moge", "panorama", "depth", "geometry", "depth estimation", "geometry estimation"],
display_name="Run MoGe Panorama Inference",
category="image/geometry_estimation",
description="Run MoGe on an equirectangular panorama by splitting it into 12 perspective views, running inference on each, and merging the results into a single depth map.",
inputs=[
MoGeModelType.Input("moge_model"),
io.Image.Input("image", tooltip="Equirectangular panorama (any aspect)."),
@ -222,7 +224,9 @@ class MoGeInference(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MoGeInference",
display_name="MoGe Inference",
search_aliases=["moge", "depth", "geometry", "depth estimation", "geometry estimation"],
display_name="Run MoGe Inference",
description="Run MoGe on a single image to estimate depth and geometry.",
category="image/geometry_estimation",
inputs=[
MoGeModelType.Input("moge_model"),
@ -277,7 +281,9 @@ class MoGeRender(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MoGeRender",
display_name="MoGe Render",
search_aliases=["moge", "render", "geometry", "depth", "normal"],
display_name="Render MoGe Geometry",
description="Render a depth map or normal map from geometry data",
category="image/geometry_estimation",
inputs=[
MoGeGeometry.Input("moge_geometry"),
@ -342,7 +348,9 @@ class MoGePointMapToMesh(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MoGePointMapToMesh",
display_name="MoGe Point Map to Mesh",
search_aliases=["moge", "mesh", "geometry", "point map"],
display_name="Convert MoGe Point Map to Mesh",
description="Convert a MoGe point map into a 3D mesh.",
category="image/geometry_estimation",
inputs=[
MoGeGeometry.Input("moge_geometry"),

View File

@ -14,7 +14,7 @@ class CreateList(io.ComfyNode):
return io.Schema(
node_id="CreateList",
display_name="Create List",
category="logic",
category="utils",
is_input_list=True,
search_aliases=["Image Iterator", "Text Iterator", "Iterator"],
inputs=[io.Autogrow.Input("inputs", template=template_autogrow)],

View File

@ -2,6 +2,7 @@ import copy
import heapq
import inspect
import logging
import psutil
import sys
import threading
import time
@ -727,6 +728,7 @@ class PromptExecutor:
self._notify_prompt_lifecycle("start", prompt_id)
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
ram_inactive_headroom = int(self.cache_args["ram_inactive"] * (1024 ** 3))
ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None
comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom)
@ -780,8 +782,14 @@ class PromptExecutor:
execution_list.complete_node_execution()
if self.cache_type == CacheType.RAM_PRESSURE:
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)
ram_release_callback(ram_headroom, free_active=True)
ram_release_callback(ram_inactive_headroom)
ram_shortfall = ram_headroom - psutil.virtual_memory().available
freed = comfy.model_management.free_pins(ram_shortfall + 512 * (1024 ** 2))
if freed < ram_shortfall:
if freed > 64 * (1024 ** 2):
# AIMDO MEM_DECOMMIT can outrun psutil.available catching up.
time.sleep(0.05)
ram_release_callback(ram_headroom, free_active=True)
else:
# Only execute when the while-loop ends without break
# Send cached UI for intermediate output nodes that weren't executed

View File

@ -60,6 +60,8 @@ folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geom
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
folder_names_and_paths["detection"] = ([os.path.join(models_dir, "detection")], supported_pt_extensions)
output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input")

20
main.py
View File

@ -283,19 +283,25 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]:
def prompt_worker(q, server_instance):
current_time: float = 0.0
cache_ram = args.cache_ram
if cache_ram < 0:
cache_ram = 0
cache_ram_inactive = 0
if not args.cache_classic and not args.cache_none and args.cache_lru <= 0:
cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0))
cache_ram_inactive = min(96.0, max(12.0, comfy.model_management.total_ram * 0.75 / 1024.0))
if len(args.cache_ram) > 0:
cache_ram = args.cache_ram[0]
if len(args.cache_ram) > 1:
cache_ram_inactive = args.cache_ram[1]
cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.RAM_PRESSURE
if args.cache_classic:
cache_type = execution.CacheType.CLASSIC
elif args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif cache_ram > 0:
cache_type = execution.CacheType.RAM_PRESSURE
elif args.cache_none:
cache_type = execution.CacheType.NONE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } )
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram, "ram_inactive" : cache_ram_inactive } )
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0

View File

@ -2444,6 +2444,7 @@ async def init_builtin_extra_nodes():
"nodes_hidream_o1.py",
"nodes_save_3d.py",
"nodes_moge.py",
"nodes_mediapipe.py",
]
import_failed = []

View File

@ -1556,12 +1556,6 @@ paths:
type: string
enum: [asc, desc]
description: Sort direction
- name: job_ids
in: query
schema:
type: string
x-runtime: [cloud]
description: "[cloud-only] Comma-separated UUIDs to filter assets by associated job."
- name: include_public
in: query
schema:
@ -2514,37 +2508,25 @@ paths:
/api/assets/import:
post:
operationId: importAssets
operationId: importPublishedAssets
tags: [assets]
summary: Import assets from external URLs
description: "[cloud-only] Imports one or more assets from external URLs into the cloud asset store."
summary: "[cloud-only] Import published assets into the caller's library"
description: |
[cloud-only] Imports the specified published assets into the caller's asset library. New DB records reference the same storage objects; no file copying occurs. Assets the caller already owns (by hash) are deduplicated. The `id` field on each returned `AssetInfo` is the caller's newly-created private asset ID, not the published asset ID supplied in the request.
x-runtime: [cloud]
requestBody:
required: true
content:
application/json:
schema:
type: object
required:
- imports
properties:
imports:
type: array
items:
$ref: "#/components/schemas/AssetImportRequest"
description: Assets to import
$ref: "#/components/schemas/ImportPublishedAssetsRequest"
responses:
"200":
description: Import initiated
description: Successfully imported assets
content:
application/json:
schema:
type: object
properties:
assets:
type: array
items:
$ref: "#/components/schemas/Asset"
$ref: "#/components/schemas/ImportPublishedAssetsResponse"
"400":
description: Bad request
content:
@ -3790,6 +3772,295 @@ paths:
schema:
$ref: "#/components/schemas/JwksResponse"
# ---------------------------------------------------------------------------
# OAuth 2.1 / RFC 7591 Dynamic Client Registration (cloud)
# ---------------------------------------------------------------------------
/.well-known/oauth-authorization-server:
get:
operationId: getOAuthAuthorizationServer
tags: [auth]
summary: "[cloud-only] OAuth 2.1 authorization-server metadata (RFC 8414)"
description: "[cloud-only] Public metadata document for OAuth 2.1 clients. Cached 5 minutes."
x-runtime: [cloud]
security: []
responses:
"200":
description: Authorization-server metadata
content:
application/json:
schema:
$ref: "#/components/schemas/OAuthAuthorizationServerMetadata"
"404":
description: OAuth disabled
content:
application/json:
schema:
$ref: "#/components/schemas/CloudError"
/.well-known/oauth-protected-resource:
get:
operationId: getOAuthProtectedResource
tags: [auth]
summary: "[cloud-only] OAuth 2.1 protected-resource metadata (RFC 9728)"
description: "[cloud-only] Public metadata describing the currently advertised protected resource. Cached 5 minutes."
x-runtime: [cloud]
security: []
responses:
"200":
description: Protected-resource metadata
content:
application/json:
schema:
$ref: "#/components/schemas/OAuthProtectedResourceMetadata"
"404":
description: OAuth disabled or no active resource configured
content:
application/json:
schema:
$ref: "#/components/schemas/CloudError"
/oauth/authorize:
get:
operationId: getOAuthAuthorize
tags: [auth]
summary: "[cloud-only] Begin or resume an OAuth 2.1 authorization request"
description: |
[cloud-only] Two modes:
- **Initial entry** (OAuth params present): validates client/redirect/resource/scopes, persists a server-side authorization-request row, and either redirects (no session / unverified email) to the configured frontend login URL carrying only the opaque `oauth_request_id`, or returns the JSON consent challenge for the frontend to render.
- **Resume** (`oauth_request_id` present): loads the server-side row, fails closed if expired/consumed/unknown, returns the JSON consent challenge. Browser-replayed OAuth params are intentionally ignored.
The frontend renders the consent UI from the JSON payload and POSTs the user's decision back to this endpoint.
x-runtime: [cloud]
security: []
parameters:
- { name: response_type, in: query, required: false, schema: { type: string } }
- { name: client_id, in: query, required: false, schema: { type: string } }
- { name: redirect_uri, in: query, required: false, schema: { type: string } }
- { name: scope, in: query, required: false, schema: { type: string } }
- name: state
in: query
required: false
schema: { type: string }
description: |
RFC 6749 §10.12 marks `state` as RECOMMENDED. Cloud hardening makes it REQUIRED on the initial-entry path (omitted only on the resume path where `oauth_request_id` is supplied instead). This parameter is `required: false` at the spec level only because the operation is dual-mode (initial entry vs. resume); the runtime rejects empty `state` on the initial-entry path with a stable `invalid_request` 400.
- { name: code_challenge, in: query, required: false, schema: { type: string } }
- { name: code_challenge_method, in: query, required: false, schema: { type: string } }
- { name: resource, in: query, required: false, schema: { type: string } }
- { name: oauth_request_id, in: query, required: false, schema: { type: string } }
responses:
"200":
description: Consent challenge payload (session present, email verified). Frontend renders the consent UI from this payload and POSTs back to /oauth/authorize.
content:
application/json:
schema:
$ref: "#/components/schemas/OAuthConsentChallenge"
"302":
description: Redirect to login (no session / unverified email) or to registered redirect_uri (pre-validated client error)
headers:
Location:
schema:
type: string
"400":
description: Invalid authorize request (pre-redirect failure — unknown client, redirect mismatch, malformed params)
content:
application/json:
schema:
$ref: "#/components/schemas/CloudError"
"404":
description: OAuth disabled
content:
application/json:
schema:
$ref: "#/components/schemas/CloudError"
post:
operationId: postOAuthAuthorize
tags: [auth]
summary: "[cloud-only] Submit OAuth consent decision"
description: |
[cloud-only] JSON-only consent submission. The handler verifies the per-row CSRF token, atomically marks the authorization request consumed (single-use covers both allow and deny paths), then returns the redirect URL the browser must navigate to. The URL contains either `code` + original `state` for allow, or the RFC 6749 §5.2 error and `state` for deny.
Workspace membership is re-checked at submission time. Consent is persisted keyed by `(user_id, client_id, resource_id, workspace_id)`; broadening the previously approved scope set requires a fresh consent flow.
x-runtime: [cloud]
security: []
requestBody:
required: true
content:
application/json:
schema:
type: object
required: [oauth_request_id, csrf_token, decision, workspace_id]
properties:
oauth_request_id: { type: string, format: uuid }
csrf_token: { type: string }
decision: { type: string, enum: [allow, deny] }
workspace_id: { type: string }
responses:
"200":
description: Redirect URL for the frontend to navigate to (allow → with code+state; deny → with error+state)
content:
application/json:
schema:
$ref: "#/components/schemas/OAuthAuthorizeRedirectResponse"
"400":
description: Bad request (CSRF mismatch, expired/consumed request, inaccessible workspace)
content:
application/json:
schema:
$ref: "#/components/schemas/CloudError"
"403":
description: Scope broadening on consent re-grant — fresh consent flow required
content:
application/json:
schema:
$ref: "#/components/schemas/CloudError"
"404":
description: OAuth disabled
content:
application/json:
schema:
$ref: "#/components/schemas/CloudError"
/oauth/token:
post:
operationId: postOAuthToken
tags: [auth]
summary: "[cloud-only] Exchange authorization code or refresh token for a resource-bound access token"
description: |
[cloud-only] OAuth 2.1 token endpoint (RFC 6749 §3.2). Public clients only — `client_secret` is rejected.
Two grant types are supported:
- `authorization_code` — exchanges the code minted by `/oauth/authorize` (with PKCE verifier) for an access token + first refresh token. Single-use; reuse fails closed.
- `refresh_token` — rotates the refresh token. Old token immediately invalid; presenting an already-rotated token revokes the entire token family and emits a security metric.
Both grant types re-validate canonical user state, current workspace membership, and the resource's active flag at every mint. A code or refresh token bound to a deactivated resource fails closed.
Errors follow RFC 6749 §5.2. Logs never contain raw codes, refresh tokens, or minted tokens.
Per RFC 6749 §5.1, every 200 and 400 response carries `Cache-Control: no-store` and `Pragma: no-cache` so intermediaries cannot cache token-bearing or state-change-reason responses.
x-runtime: [cloud]
security: []
requestBody:
required: true
content:
application/x-www-form-urlencoded:
schema:
type: object
required: [grant_type, client_id]
properties:
grant_type: { type: string, enum: [authorization_code, refresh_token] }
client_id: { type: string }
code: { type: string }
redirect_uri: { type: string }
code_verifier: { type: string }
refresh_token: { type: string }
scope: { type: string }
client_secret: { type: string }
responses:
"200":
description: New token pair
headers:
Cache-Control:
schema:
type: string
description: 'Always "no-store" per RFC 6749 §5.1'
Pragma:
schema:
type: string
description: 'Always "no-cache" per RFC 6749 §5.1'
content:
application/json:
schema:
$ref: "#/components/schemas/OAuthTokenResponse"
"400":
description: RFC 6749 §5.2 error
headers:
Cache-Control:
schema:
type: string
description: 'Always "no-store" per RFC 6749 §5.1'
Pragma:
schema:
type: string
description: 'Always "no-cache" per RFC 6749 §5.1'
content:
application/json:
schema:
$ref: "#/components/schemas/OAuthTokenError"
"404":
description: OAuth disabled
content:
application/json:
schema:
$ref: "#/components/schemas/CloudError"
/oauth/register:
post:
operationId: postOAuthRegister
tags: [auth]
summary: "[cloud-only] Dynamic Client Registration (RFC 7591)"
description: |
[cloud-only] Public, unauthenticated, insert-only RFC 7591 §3.1 client registration. Used by MCP-spec-compliant clients to self-register a public OAuth client without operator involvement.
Policy:
- Public clients only — `token_endpoint_auth_method` is forced to `none`. Confidential-client registration is out of scope this phase.
- Server-owned `resource_grants`. Caller-supplied `scope` or `resource_grants` is rejected as `invalid_client_metadata` (would be a privilege-escalation surface). Dynamic clients receive the same scopes the active resource publishes.
- Application-type-aware redirect URI policy. `application_type=native` accepts loopback (`127.0.0.1`, `::1`, `localhost`) and reverse-DNS-shaped custom schemes; `application_type=web` accepts HTTPS to hosts in an operator-controlled allowlist only. `application_type` is REQUIRED on the request — missing or empty rejects with `invalid_client_metadata`.
- Anti-impersonation: reserved client names are rejected from third parties via NFKC-folded compare.
- Generated `client_id` carries a stable prefix to distinguish dynamic from seeded clients in audit logs.
- Cache-Control: `no-store` on every 201 and 400 response (the response carries fresh credentials and rejection reasons).
x-runtime: [cloud]
security: []
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/OAuthRegisterRequest"
responses:
"201":
description: Registered. Body echoes the metadata RFC 7591 §3.2.1 requires.
headers:
Cache-Control:
schema:
type: string
description: 'Always "no-store"'
Pragma:
schema:
type: string
description: 'Always "no-cache"'
content:
application/json:
schema:
$ref: "#/components/schemas/OAuthRegisterResponse"
"400":
description: RFC 7591 §3.2.2 invalid client metadata
headers:
Cache-Control:
schema:
type: string
description: 'Always "no-store"'
Pragma:
schema:
type: string
description: 'Always "no-cache"'
content:
application/json:
schema:
$ref: "#/components/schemas/OAuthRegisterError"
"404":
description: OAuth disabled
content:
application/json:
schema:
$ref: "#/components/schemas/CloudError"
"503":
description: No active resource is configured — DCR cannot mint a usable client until an active resource row is seeded.
content:
application/json:
schema:
$ref: "#/components/schemas/CloudError"
# ---------------------------------------------------------------------------
# Billing (cloud)
# ---------------------------------------------------------------------------
@ -7090,24 +7361,35 @@ components:
type: string
description: Target path on the runtime filesystem
AssetImportRequest:
ImportPublishedAssetsRequest:
type: object
x-runtime: [cloud]
description: "[cloud-only] A single asset to import from an external URL."
description: "[cloud-only] Request body for importing published assets into the caller's library."
required:
- url
- published_asset_ids
properties:
url:
type: string
format: uri
description: URL of the asset to import
name:
type: string
description: Display name for the imported asset
tags:
published_asset_ids:
type: array
description: IDs of published assets (inputs and models) to import.
items:
type: string
share_id:
type: string
nullable: true
description: |
Optional. Share ID of the published workflow these assets belong to. When provided (non-null, non-empty): all `published_asset_ids` must belong to this share's workflow version; returns 400 if the share is not found or any asset does not belong to it. When omitted, null, or empty string: no share-scoped validation is performed and the assets are validated only against global rules (preserved for clients that have not yet adopted `share_id`).
ImportPublishedAssetsResponse:
type: object
x-runtime: [cloud]
description: "[cloud-only] Response after importing published assets. Each returned `AssetInfo.id` is the caller's newly-created private asset ID, not the published asset ID supplied in the request."
required:
- assets
properties:
assets:
type: array
items:
$ref: "#/components/schemas/AssetInfo"
RemoteAssetMetadata:
type: object
@ -7424,6 +7706,325 @@ components:
description: RSA exponent (base64url)
additionalProperties: true
OAuthAuthorizationServerMetadata:
type: object
x-runtime: [cloud]
description: "[cloud-only] OAuth 2.1 authorization-server metadata (RFC 8414)."
required:
- issuer
- authorization_endpoint
- token_endpoint
- jwks_uri
- response_types_supported
- grant_types_supported
- code_challenge_methods_supported
- token_endpoint_auth_methods_supported
properties:
issuer:
type: string
format: uri
authorization_endpoint:
type: string
format: uri
token_endpoint:
type: string
format: uri
jwks_uri:
type: string
format: uri
registration_endpoint:
type: string
format: uri
description: "[cloud-only] RFC 7591 §3.1 Dynamic Client Registration endpoint. Advertised so MCP-spec-compliant clients can auto-discover and self-register without operator involvement. Present only when DCR is enabled."
response_types_supported:
type: array
items:
type: string
grant_types_supported:
type: array
items:
type: string
code_challenge_methods_supported:
type: array
items:
type: string
token_endpoint_auth_methods_supported:
type: array
items:
type: string
scopes_supported:
type: array
items:
type: string
OAuthProtectedResourceMetadata:
type: object
x-runtime: [cloud]
description: "[cloud-only] OAuth 2.1 protected-resource metadata (RFC 9728)."
required:
- resource
- authorization_servers
- scopes_supported
properties:
resource:
type: string
format: uri
authorization_servers:
type: array
items:
type: string
format: uri
scopes_supported:
type: array
items:
type: string
bearer_methods_supported:
type: array
items:
type: string
OAuthConsentChallenge:
type: object
x-runtime: [cloud]
description: "[cloud-only] Server-side state describing the OAuth consent decision the user is being asked to make. Returned by GET /oauth/authorize when a valid session exists; the frontend renders the consent UI from this payload and POSTs the decision back. Browser never sees the original OAuth params on resume."
required:
- oauth_request_id
- csrf_token
- client_display_name
- resource_display_name
- scopes
- workspaces
properties:
oauth_request_id:
type: string
format: uuid
description: Opaque server-side identifier for the authorization-request row. Carried back unchanged in the consent submission.
csrf_token:
type: string
description: Per-row CSRF token bound to this authorization request (not to the session). Must be echoed back on POST.
client_display_name:
type: string
description: Human-readable name of the OAuth client requesting authorization.
resource_display_name:
type: string
description: Human-readable name of the protected resource.
scopes:
type: array
description: Scopes the client is requesting for this resource. The frontend should present these for the user to approve.
items:
type: string
workspaces:
type: array
description: Workspaces the user can select from. Membership is re-checked on POST.
items:
$ref: "#/components/schemas/OAuthConsentChallengeWorkspace"
OAuthConsentChallengeWorkspace:
type: object
x-runtime: [cloud]
description: "[cloud-only] One workspace option presented in the OAuth consent challenge."
required: [id, name, type, role]
properties:
id: { type: string }
name: { type: string }
type: { type: string, enum: [personal, team] }
role: { type: string, enum: [owner, member] }
OAuthAuthorizeRedirectResponse:
type: object
x-runtime: [cloud]
description: "[cloud-only] Redirect target produced after a JSON consent submission. The frontend must navigate the browser to this URL so custom-scheme client callbacks work without relying on fetch-visible 302 headers."
required:
- redirect_url
properties:
redirect_url:
type: string
format: uri
description: OAuth client redirect URI with either code+state for allow, or error+state for deny.
OAuthTokenResponse:
type: object
x-runtime: [cloud]
description: "[cloud-only] RFC 6749 §5.1 successful token response."
required: [access_token, token_type, expires_in, refresh_token, scope]
properties:
access_token:
type: string
description: Resource-bound access token (audience matches the protected resource).
token_type:
type: string
enum: [Bearer]
expires_in:
type: integer
description: Access token lifetime in seconds.
refresh_token:
type: string
description: Opaque refresh token. Rotates on every successful refresh; presenting an already-rotated token revokes the entire family.
scope:
type: string
description: Space-delimited scopes granted with this token.
OAuthTokenError:
type: object
x-runtime: [cloud]
description: "[cloud-only] RFC 6749 §5.2 error response."
required: [error]
properties:
error:
type: string
description: 'RFC 6749 §5.2 error code: invalid_request, invalid_client, invalid_grant, unauthorized_client, unsupported_grant_type, invalid_scope.'
error_description:
type: string
description: Human-readable, no leak of internal storage state.
OAuthRegisterRequest:
type: object
x-runtime: [cloud]
additionalProperties: false
description: "[cloud-only] RFC 7591 §2 client metadata document. Only the fields the server honors are listed; presence of `scope` or `resource_grants` in the request is rejected (`invalid_client_metadata`) because those are server-owned for dynamic clients."
required:
- redirect_uris
- application_type
properties:
redirect_uris:
type: array
items:
type: string
minItems: 1
maxItems: 5
description: 15 redirect URIs. Validated against `application_type` policy.
client_name:
type: string
maxLength: 100
description: Human-readable name shown in the consent UI. Reserved-name list rejects impersonation of major clients.
application_type:
type: string
enum: [native, web]
description: |
RFC 7591 §2 application_type. **REQUIRED** — clients MUST declare intent; the server does not default this field. `native` for desktop / CLI / MCP-spec-strict clients (loopback redirects); `web` for hosted clients (HTTPS only, host must be allowlisted). A missing or explicitly empty `application_type` rejects with `invalid_client_metadata`.
token_endpoint_auth_method:
type: string
enum: [none]
description: 'Public clients only this phase — must be `none` if present. The server forces `none` regardless.'
grant_types:
type: array
items:
type: string
enum: [authorization_code, refresh_token]
description: Optional. Defaults to `["authorization_code","refresh_token"]`.
response_types:
type: array
items:
type: string
enum: [code]
description: Optional. Defaults to `["code"]`.
scope:
type: string
nullable: true
description: "**REJECTED IF PRESENT.** Dynamic clients do not pick scopes — the server assigns scopes from the active resource's published list. Sending `scope` in the registration body is treated as a privilege-escalation attempt and returns `invalid_client_metadata`."
resource_grants:
type: object
nullable: true
additionalProperties:
type: array
items:
type: string
description: "**REJECTED IF PRESENT.** Same reason as `scope`. The set of resources and scopes a dynamic client may request is server-policy, not request-driven."
client_uri:
type: string
nullable: true
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
logo_uri:
type: string
nullable: true
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
tos_uri:
type: string
nullable: true
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
policy_uri:
type: string
nullable: true
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
software_id:
type: string
nullable: true
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
software_version:
type: string
nullable: true
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
contacts:
type: array
nullable: true
items:
type: string
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
jwks:
type: object
nullable: true
additionalProperties: true
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
jwks_uri:
type: string
nullable: true
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
OAuthRegisterResponse:
type: object
x-runtime: [cloud]
description: "[cloud-only] RFC 7591 §3.2.1 successful registration response."
required:
- client_id
- client_id_issued_at
- redirect_uris
- grant_types
- response_types
- token_endpoint_auth_method
- application_type
properties:
client_id:
type: string
description: Server-generated client_id.
client_id_issued_at:
type: integer
format: int64
description: Unix timestamp (seconds) when the client was registered.
client_name:
type: string
redirect_uris:
type: array
items:
type: string
grant_types:
type: array
items:
type: string
response_types:
type: array
items:
type: string
token_endpoint_auth_method:
type: string
enum: [none]
application_type:
type: string
enum: [native, web]
OAuthRegisterError:
type: object
x-runtime: [cloud]
description: "[cloud-only] RFC 7591 §3.2.2 error response."
required:
- error
properties:
error:
type: string
enum: [invalid_redirect_uri, invalid_client_metadata]
error_description:
type: string
nullable: true
BillingBalance:
type: object
x-runtime: [cloud]

View File

@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
filelock
av>=14.2.0
comfy-kitchen>=0.2.8
comfy-aimdo==0.3.0
comfy-aimdo==0.4.3
requests
simpleeval>=1.0.0
blake3

View File

@ -14,7 +14,6 @@ from tests.execution.test_execution import ComfyClient, run_warmup
class TestAsyncNodes:
@fixture(scope="class", autouse=True, params=[
(False, 0),
(True, 0),
(True, 100),
])
def _server(self, args_pytest, request):
@ -29,6 +28,8 @@ class TestAsyncNodes:
use_lru, lru_size = request.param
if use_lru:
pargs += ['--cache-lru', str(lru_size)]
else:
pargs += ['--cache-classic']
# Running server with args: pargs
p = subprocess.Popen(pargs)
yield

View File

@ -183,8 +183,7 @@ class TestExecution:
# Initialize server and client
#
@fixture(scope="class", autouse=True, params=[
{ "extra_args" : [], "should_cache_results" : True },
{ "extra_args" : ["--cache-lru", 0], "should_cache_results" : True },
{ "extra_args" : ["--cache-classic"], "should_cache_results" : True },
{ "extra_args" : ["--cache-lru", 100], "should_cache_results" : True },
{ "extra_args" : ["--cache-none"], "should_cache_results" : False },
])