Compare commits

..

3 Commits

Author SHA1 Message Date
95fdc6cf91 Repo security stuff. (#14019) 2026-05-20 17:17:55 -07:00
5aa5ccc9e0 Multi-threaded load of models from disk (big load time speedups & Offload to disk) (CORE-43,CORE-152,CORE-164,CORE-165,CORE-117) (#13802)
* model_management: disable non-dynamic smart memory

Disable smart memory outright for non dynamic models.

This is a minor step towards deprecation of --disable-dynamic-vram
and the legacy ModelPatcher.

This is needed for estimate-free model development, where new models
can opt-out of supplying a memory estimate and not have to worry
about hard VRAM allocations due to legacy non-dynamic model patchers

This is also a general stability increase for a lot of stray use cases
where estimates may still be off and going forward we are not going
to accurately maintain such estimates.

* pinned_memory: implement with aimdo growable buffer

Use a single growable buffer so we can do threaded pre-warming on
pinned memory.

* mm: use aimdo to do transfer from disk to pin

Aimdo implements a faster threaded loader.

* Add stream host pin buffer for AIMDO casts

Introduce per-offload-stream HostBuffer reuse for pinned staging,
include it in cast buffer reset synchronization.

Defer actual casts that go via this pin path to a separate pass
such that the buffer can be allocated monolithically (to avoid
cudaHostRegister thrash).

* remove old pin path

* Implement JIT pinned memory pressure

Replace the predictive pin pressure mechanism with JIT PIN memory
pressure.

* LowVRAMPatch: change to two-phase visit

* lora: re-implement as inplace swiss-army-knife operation

* prepare for multiple pin sets

* implement pinned loras

* requirements: comfy-aimdo 0.4.0

* ops: remove unused arg

This was defeatured in aimdo iteration

* ops: sync the CPU with only the offload stream activity

This was syncing with the offload stream which itself is synced with the
compute stream, so this was syncing CPU with compute transitively. Define
the event to sync it more gently.

* pins: implement freeing intermediate for pinned memory

Pinning is more important than inactive intermediates and the stream
pin buffer is more important than even active intermediates.

* execution: implement pin eviction on RAM presure

Add back proper pin freeing on RAM pressure

* implement pin registration swaps

Uncap the windows pins from 50% by extending the pool and have a pressure
mechanism to move the pin reservations om demand.

This unfortunately implies a GPU sync to do the freeing so significant
hysterisis needs to be added to consolidate these pressure events.

* cli_args/execution: Implement lower background cache-ram threshold

Limit the amount of RAM background intermediates can use, so that
switching workflows doesn't degrade performance too much.

* make default

* bump aimdo

* model-patcher: force-cast tiny weights

Flux 2 gets crazy stalls due to a mix of tiny and giant weights
creating lopsided steam buffer rotations which creates stalls.

* ops: refactor in prep for chunking

* mm: delegate pin-on-the-way to aimdo

Aimdo is able to chunk and slice this on the way for better CPU->GPU
overlap. The main advantage is the ability to shorten the bus contention
window between previous weight transfer and the next weights vbar
fault.

* bump aimdo

* pinning updates

* specify hostbuf max allocation size

There a signs of virtual memory exhaustion on some linux systems when
throwing 128GB for every little piece. Pass the actual to save aimdo
from over-estimates

* tests: update execution tests for caching

The default caching changed to ram-cache so update these tests
accordingly.

Remove the LRU 0 test as this also falls through to RAM cache.
2026-05-20 17:03:58 -07:00
4d6a058bf1 feat: MediaPipe face detection (CORE-235) (#14009)
* Initial mediapipe face detection support

* Update face_geometry.py

* Account for diff sized batch input

* Model folder placeholder
2026-05-20 16:07:48 -07:00
32 changed files with 1726 additions and 1391 deletions

View File

@ -1,2 +1,5 @@
# Admins
* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai
/CODEOWNERS @comfyanonymous
/.ci/ @comfyanonymous
/.github/ @comfyanonymous

View File

@ -39,7 +39,6 @@ from app.assets.services import (
update_asset_metadata,
upload_from_temp_path,
)
from app.assets.services.cursor import InvalidCursorError
from app.assets.services.tagging import list_tag_histogram
ROUTES = web.RouteTableDef()
@ -210,40 +209,24 @@ async def list_assets_route(request: web.Request) -> web.Response:
order_candidate = (q.order or "desc").lower()
order = order_candidate if order_candidate in {"asc", "desc"} else "desc"
try:
result = list_assets_page(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=sort,
order=order,
after=q.after,
)
except InvalidCursorError as e:
return web.json_response(
{"error": {"code": "INVALID_CURSOR", "message": str(e)}},
status=400,
)
result = list_assets_page(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=sort,
order=order,
)
summaries = [_build_asset_response(item) for item in result.items]
# has_more semantics differ by mode:
# - cursor mode: a non-empty next_cursor means there are more results.
# - offset mode: derived from total - (offset + page size).
if q.after is not None:
has_more = result.next_cursor is not None
else:
has_more = (q.offset + len(summaries)) < result.total
payload = schemas_out.AssetsList(
assets=summaries,
total=result.total,
has_more=has_more,
next_cursor=result.next_cursor,
has_more=(q.offset + len(summaries)) < result.total,
)
return web.json_response(payload.model_dump(mode="json", exclude_none=True))

View File

@ -59,11 +59,6 @@ class ListAssetsQuery(BaseModel):
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
# Opaque keyset cursor. When supplied, `offset` is ignored. Cursor pagination
# is supported for sort values `created_at`, `updated_at`, `name`, `size`.
# Supplying `after` together with `sort=last_access_time` returns
# 400 INVALID_CURSOR; that sort only supports offset/limit.
after: str | None = None
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
"created_at"

View File

@ -40,8 +40,6 @@ class AssetsList(BaseModel):
assets: list[Asset]
total: int
has_more: bool
# Opaque cursor for the next page. Omitted when there are no more results.
next_cursor: str | None = None
class TagUsage(BaseModel):

View File

@ -266,18 +266,9 @@ def list_references_page(
metadata_filter: dict | None = None,
sort: str | None = None,
order: str | None = None,
after_cursor_value: object | None = None,
after_cursor_id: str | None = None,
) -> tuple[list[AssetReference], dict[str, list[str]], int]:
"""List references with pagination, filtering, and sorting.
When ``after_cursor_value``/``after_cursor_id`` are supplied the query uses
keyset pagination — ``offset`` is ignored and a WHERE clause selects rows
strictly after the given ``(sort_col, id)`` position in the active sort
direction. The cursor value must already be typed for the column
(datetime for time sorts, int for size, str for name); the caller decodes
the opaque cursor string and resolves to the typed value.
Returns (references, tag_map, total_count).
"""
base = (
@ -306,31 +297,9 @@ def list_references_page(
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetReference.created_at)
descending = order == "desc"
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
# Keyset WHERE: (sort_col, id) strictly less-than / greater-than the cursor.
# Equivalent to: sort_col <op> v OR (sort_col = v AND id <op> cursor_id).
if after_cursor_value is not None and after_cursor_id is not None:
if descending:
keyset = sa.or_(
sort_col < after_cursor_value,
sa.and_(sort_col == after_cursor_value, AssetReference.id < after_cursor_id),
)
else:
keyset = sa.or_(
sort_col > after_cursor_value,
sa.and_(sort_col == after_cursor_value, AssetReference.id > after_cursor_id),
)
base = base.where(keyset)
# Secondary ORDER BY id (matching the primary direction) gives the keyset
# comparison a deterministic tiebreaker on duplicate sort_col values.
id_exp = AssetReference.id.desc() if descending else AssetReference.id.asc()
sort_exp = sort_col.desc() if descending else sort_col.asc()
base = base.order_by(sort_exp, id_exp).limit(limit)
if after_cursor_id is None:
base = base.offset(offset)
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())

View File

@ -1,19 +1,8 @@
import contextlib
import mimetypes
import os
from datetime import timezone
from typing import Sequence
from app.assets.services.cursor import (
CursorPayload,
InvalidCursorError,
decode_cursor,
decode_cursor_int,
decode_cursor_time,
encode_cursor,
encode_cursor_from_time,
)
from app.assets.database.models import Asset
from app.assets.database.queries import (
@ -253,12 +242,6 @@ def get_asset_by_hash(asset_hash: str) -> AssetData | None:
return extract_asset_data(asset)
# Sort fields that support cursor pagination. Mirrors cloud's allowlist
# (created_at, updated_at, name, size). `last_access_time` is OSS-only and
# falls back to offset/limit — no cloud contract to match.
_CURSOR_SORT_FIELDS = ("created_at", "updated_at", "name", "size")
def list_assets_page(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
@ -269,39 +252,7 @@ def list_assets_page(
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
after: str | None = None,
) -> ListAssetsResult:
"""List assets with optional cursor pagination.
When ``after`` is supplied it overrides ``offset``. The cursor's sort field
must match ``sort`` and be in the cursor-supported allowlist; mismatches
raise InvalidCursorError so the handler can map to 400 INVALID_CURSOR.
"""
cursor_value: object | None = None
cursor_id: str | None = None
# Mint next_cursor on every page where the sort is cursor-supported, not
# only when the request itself arrived with a cursor. Otherwise a first
# request (no `after`) returns next_cursor=None and the client can never
# enter cursor mode.
mint_cursor = sort in _CURSOR_SORT_FIELDS
if after is not None:
if sort not in _CURSOR_SORT_FIELDS:
raise InvalidCursorError(
f"cursor pagination is not supported for sort={sort!r}"
)
payload = decode_cursor(after, _CURSOR_SORT_FIELDS, expected_order=order)
if payload.sort_field != sort:
raise InvalidCursorError(
f"cursor sort field {payload.sort_field!r} does not match request sort {sort!r}"
)
cursor_value, cursor_id = _resolve_cursor_value(payload), payload.id
# Over-fetch by one row so we can distinguish "exactly `limit` rows total
# remaining" from "more rows past this page" without a second query. Drop
# the sentinel before returning.
fetch_limit = limit + 1 if mint_cursor else limit
with create_session() as session:
refs, tag_map, total = list_references_page(
session,
@ -310,22 +261,12 @@ def list_assets_page(
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=fetch_limit,
limit=limit,
offset=offset,
sort=sort,
order=order,
after_cursor_value=cursor_value,
after_cursor_id=cursor_id,
)
next_cursor: str | None = None
if mint_cursor and len(refs) > limit:
# There's at least one more row past this page — mint a cursor from
# the last row of the page (i.e. index `limit - 1`, since we
# over-fetched), and drop the sentinel.
next_cursor = _encode_next_cursor(refs[limit - 1], sort, order)
refs = refs[:limit]
items: list[AssetSummaryData] = []
for ref in refs:
items.append(
@ -336,39 +277,7 @@ def list_assets_page(
)
)
return ListAssetsResult(items=items, total=total, next_cursor=next_cursor)
def _resolve_cursor_value(payload: CursorPayload) -> object:
"""Map a decoded cursor payload to a column-typed Python value."""
if payload.sort_field in ("created_at", "updated_at"):
# DB stores naive UTC; strip tzinfo so the comparison binds against a
# `TIMESTAMP WITHOUT TIME ZONE` column without an offset shift.
return decode_cursor_time(payload).replace(tzinfo=None)
if payload.sort_field == "size":
return decode_cursor_int(payload)
return payload.value # name, str-typed
def _encode_next_cursor(ref, sort: str, order: str) -> str | None:
"""Mint a cursor pointing at *ref* for the given sort dimension.
Returns None when the boundary row carries a NULL sort value (e.g. an asset
record whose size_bytes hasn't been backfilled). Continuing pagination
across a NULL boundary is undefined under keyset ordering — better to
truncate cleanly here than to mint a cursor that mis-positions.
"""
if sort == "name":
return encode_cursor("name", ref.name, ref.id, order=order)
if sort == "size":
if ref.asset is None or ref.asset.size_bytes is None:
return None
return encode_cursor("size", str(ref.asset.size_bytes), ref.id, order=order)
# created_at / updated_at — DB datetimes are naive UTC; attach tz before encoding.
value = ref.created_at if sort == "created_at" else ref.updated_at
if value is None:
return None
return encode_cursor_from_time(sort, value.replace(tzinfo=timezone.utc), ref.id, order=order)
return ListAssetsResult(items=items, total=total)
def resolve_hash_to_path(

View File

@ -1,219 +0,0 @@
"""Opaque keyset-pagination cursor for /api/assets.
Wire format aligns with the cloud Go implementation in
`common/pagination/cursor.go` so the frontend sees one contract across
runtimes. Payload JSON uses short keys to keep the encoded length small:
{"s": <sort_field>, "v": <value>, "id": <id>, "o": <order>}
The `o` key binds the cursor to the sort direction it was minted under,
so replaying a `desc` cursor against an `asc` request fails with
``INVALID_CURSOR`` rather than silently walking the wrong direction.
Decoders accept payloads without `o` for backward compatibility with
cursors minted before the binding was introduced (these skip the order
check); new cursors always include it. Cloud has a follow-up to mirror
the field — until then, Python and cloud cursors differ by exactly the
`o` key.
Encoding is base64url with no padding. JSON serialization escapes `<`,
`>`, `&`, U+2028, and U+2029 to match Go's default `json.Marshal`
behavior so asset names containing those characters produce
byte-identical cursors across runtimes.
Time values are serialized as Unix microseconds (UTC) — microsecond
precision matches PostgreSQL's `timestamp` type, so a cursor minted from
a stored timestamp compares back exactly without rounding rows in the
same millisecond bucket.
"""
from __future__ import annotations
import base64
import json
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Iterable, Optional
class InvalidCursorError(ValueError):
"""Raised on a malformed, oversized, or unsupported-sort-field cursor.
Map to a 400 response with code ``INVALID_CURSOR`` at the handler.
"""
# Wire-format length caps. Cursors are user-controlled, so caps protect the
# decode path from oversized allocations and downstream SQL predicates from
# unbounded strings.
#
# MAX_CURSOR_VALUE_LENGTH is 512 (vs cloud's 256) to fit OSS's
# `AssetReference.name` column max (String(512)) — otherwise a long-named
# asset would mint a cursor the same server then refuses on the next request.
# Cloud's data model has shorter names so its lower cap is fine there;
# cross-runtime byte-identity is unaffected because no real cloud cursor ever
# carries a value > 256.
MAX_ENCODED_CURSOR_LENGTH = 1024
MAX_CURSOR_VALUE_LENGTH = 512
MAX_CURSOR_ID_LENGTH = 128
@dataclass(frozen=True)
class CursorPayload:
sort_field: str
value: str
id: str
# None means "minted by a producer that did not bind order" (e.g. a cloud
# cursor from before BE-944's follow-up lands). New cursors always set it.
order: str | None = None
# Order direction tokens. Mirrored on the cloud follow-up so cursors carrying
# `o` are interchangeable between runtimes once both sides ship the field.
_VALID_ORDERS = ("asc", "desc")
def encode_cursor(sort_field: str, value: str, id: str, order: str = "desc") -> str:
"""Encode a cursor payload as a base64url (no-padding) string.
`order` binds the cursor to the sort direction it was minted under so a
later request with a flipped `order` query parameter is rejected with
``INVALID_CURSOR`` rather than silently walking the wrong direction.
"""
if order not in _VALID_ORDERS:
raise ValueError(f"order must be one of {_VALID_ORDERS}, got {order!r}")
payload = {"s": sort_field, "v": value, "id": id, "o": order}
raw = json.dumps(payload, separators=(",", ":"), ensure_ascii=False)
# Go's default `json.Marshal` escapes these characters in string values; we
# do the same so an asset name containing one of them produces byte-
# identical cursors across runtimes. None of these characters appear in
# JSON structural syntax, so a global replace on the serialized output is
# safe — it can only touch characters from the encoded values.
raw = (
raw.replace("<", "\\u003c")
.replace(">", "\\u003e")
.replace("&", "\\u0026")
.replace("", "\\u2028")
.replace("", "\\u2029")
)
return base64.urlsafe_b64encode(raw.encode("utf-8")).rstrip(b"=").decode("ascii")
def encode_cursor_from_time(sort_field: str, t: datetime, id: str, order: str = "desc") -> str:
"""Encode a time-typed cursor at Unix microsecond precision.
Accepts an aware datetime (any timezone) and normalizes to UTC. Naive
datetimes are rejected so callers can't accidentally encode the local
wall-clock value of a UTC-stored timestamp.
"""
if t.tzinfo is None:
raise ValueError("encode_cursor_from_time requires an aware datetime")
micros = _datetime_to_unix_micros(t.astimezone(timezone.utc))
return encode_cursor(sort_field, str(micros), id, order=order)
def decode_cursor(
cursor: str,
allowed_sort_fields: Iterable[str],
expected_order: str | None = None,
) -> CursorPayload:
"""Parse an opaque cursor.
``allowed_sort_fields`` is the endpoint's accepted sort-field list — a
cursor carrying a field outside this set is rejected so a cursor minted
for one column can't be replayed against another (e.g. a ``created_at``
timestamp string compared against a ``name`` column).
``expected_order`` (``"asc"``/``"desc"``), when supplied, must match the
payload's ``o`` field. Cursors minted without ``o`` (e.g. by an older
cloud build) pass the check unconditionally — the binding is best-effort
until both runtimes ship the field.
Passing no allowed fields rejects every cursor.
"""
if len(cursor) > MAX_ENCODED_CURSOR_LENGTH:
raise InvalidCursorError("cursor exceeds maximum length")
try:
# urlsafe_b64decode requires correct padding; we strip on encode, so
# restore the trailing '=' pad here.
padding = "=" * (-len(cursor) % 4)
raw = base64.urlsafe_b64decode(cursor + padding)
except (ValueError, base64.binascii.Error) as e:
raise InvalidCursorError(f"encoding: {e}") from e
try:
decoded = json.loads(raw)
except (json.JSONDecodeError, UnicodeDecodeError) as e:
raise InvalidCursorError(f"payload: {e}") from e
if not isinstance(decoded, dict):
raise InvalidCursorError("payload: expected object")
sort_field = decoded.get("s")
value = decoded.get("v")
id = decoded.get("id")
order = decoded.get("o") # may be absent on legacy cursors
if not isinstance(sort_field, str) or not isinstance(value, str) or not isinstance(id, str):
raise InvalidCursorError("payload: missing or non-string s/v/id")
if id == "":
raise InvalidCursorError("missing id")
if len(id) > MAX_CURSOR_ID_LENGTH:
raise InvalidCursorError("id exceeds maximum length")
if len(value) > MAX_CURSOR_VALUE_LENGTH:
raise InvalidCursorError("value exceeds maximum length")
if sort_field not in allowed_sort_fields:
raise InvalidCursorError(f"unsupported sort field {sort_field!r}")
if order is not None and not isinstance(order, str):
raise InvalidCursorError("payload: non-string o")
if order is not None and order not in _VALID_ORDERS:
raise InvalidCursorError(f"unsupported order {order!r}")
if expected_order is not None and order is not None and order != expected_order:
raise InvalidCursorError(
f"cursor order {order!r} does not match request order {expected_order!r}"
)
return CursorPayload(sort_field=sort_field, value=value, id=id, order=order)
def decode_cursor_time(payload: Optional[CursorPayload]) -> datetime:
"""Parse a time-typed cursor value as Unix microseconds, returning UTC."""
if payload is None:
raise InvalidCursorError("nil cursor payload")
try:
micros = int(payload.value)
except ValueError as e:
raise InvalidCursorError(f"value is not a valid timestamp: {e}") from e
try:
return _unix_micros_to_datetime(micros)
except (OverflowError, OSError, ValueError) as e:
# Crafted out-of-range microseconds (e.g. > datetime.MAX_YEAR) blow up
# in fromtimestamp / datetime construction. Map to 400, not 500.
raise InvalidCursorError(f"value is out of representable range: {e}") from e
def decode_cursor_int(payload: Optional[CursorPayload]) -> int:
"""Parse a cursor value as a base-10 integer."""
if payload is None:
raise InvalidCursorError("nil cursor payload")
try:
return int(payload.value)
except ValueError as e:
raise InvalidCursorError(f"value is not a valid integer: {e}") from e
_EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc)
def _datetime_to_unix_micros(t: datetime) -> int:
"""Convert an aware UTC datetime to Unix microseconds (integer math)."""
delta = t - _EPOCH
return (delta.days * 86_400 + delta.seconds) * 1_000_000 + delta.microseconds
def _unix_micros_to_datetime(micros: int) -> datetime:
"""Convert Unix microseconds to a UTC datetime, preserving precision."""
seconds, micro_remainder = divmod(micros, 1_000_000)
return datetime.fromtimestamp(seconds, tz=timezone.utc).replace(microsecond=micro_remainder)

View File

@ -71,7 +71,6 @@ class AssetSummaryData:
class ListAssetsResult:
items: list[AssetSummaryData]
total: int
next_cursor: str | None = None
@dataclass(frozen=True)

View File

@ -110,13 +110,11 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
CACHE_RAM_AUTO_GB = -1.0
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metavar="GB", help="Use RAM pressure caching with the specified headroom thresholds. This is the default caching mode. The first value sets the active-cache threshold; the optional second value sets the inactive-cache/pin threshold. Defaults when no values are provided: active 25%% of system RAM (min 4GB, max 32GB), inactive 75%% of system RAM (min 12GB, max 96GB).")
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@ -245,6 +243,9 @@ if comfy.options.args_parsing:
else:
args = parser.parse_args([])
if args.cache_ram is not None and len(args.cache_ram) > 2:
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
if args.windows_standalone_build:
args.auto_launch = True

View File

@ -484,16 +484,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
return weight
def prefetch_prepared_value(value, allocate_buffer, stream):
def prefetch_prepared_value(value, counter, destination, stream, copy):
if isinstance(value, torch.Tensor):
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
size = comfy.memory_management.vram_aligned_size(value)
offset = counter[0]
counter[0] += size
if destination is None:
return value
dest = destination[offset:offset + size]
if copy:
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
elif isinstance(value, weight_adapter.WeightAdapterBase):
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy))
elif isinstance(value, tuple):
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value)
elif isinstance(value, list):
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value]
return value

View File

@ -15,7 +15,7 @@ class TensorFileSlice(NamedTuple):
size: int
def read_tensor_file_slice_into(tensor, destination):
def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None):
if isinstance(tensor, QuantizedTensor):
if not isinstance(destination, QuantizedTensor):
@ -23,12 +23,17 @@ def read_tensor_file_slice_into(tensor, destination):
if tensor._layout_cls != destination._layout_cls:
return False
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream,
destination2=(destination2._qdata if destination2 is not None else None)):
return False
dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
if destination2 is not None:
dst_orig_dtype = destination2._params.orig_dtype
destination2._params.copy_from(destination._params, non_blocking=True)
destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype)
return True
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
@ -48,6 +53,17 @@ def read_tensor_file_slice_into(tensor, destination):
if info.size == 0:
return True
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
if hostbuf is not None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
device_ptr = destination2.data_ptr() if destination2 is not None else 0
hostbuf.read_file_slice(file_obj, info.offset, info.size,
offset=destination.data_ptr() - hostbuf.get_raw_address(),
stream=stream_ptr,
device_ptr=device_ptr,
device=None if destination2 is None else destination2.device.index)
return True
buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))
@ -151,7 +167,7 @@ def set_ram_cache_release_state(callback, headroom):
extra_ram_release_callback = callback
RAM_CACHE_HEADROOM = max(0, int(headroom))
def extra_ram_release(target):
def extra_ram_release(target, free_active=False):
if extra_ram_release_callback is None:
return 0
return extra_ram_release_callback(target)
return extra_ram_release_callback(target, free_active=free_active)

View File

@ -31,6 +31,7 @@ from contextlib import nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.host_buffer
import comfy_aimdo.vram_buffer
class VRAMState(Enum):
@ -495,6 +496,14 @@ except:
current_loaded_models = []
DIRTY_MMAPS = set()
PIN_PRESSURE_HYSTERESIS = 256 * 1024 * 1024
#Freeing registerables on pressure does imply a GPU sync, so go big on
#the hysteresis so each expensive sync gives us back a good chunk.
REGISTERABLE_PIN_HYSTERESIS = 2048 * 1024 * 1024
def module_size(module):
module_mem = 0
sd = module.state_dict()
@ -503,27 +512,46 @@ def module_size(module):
module_mem += t.nbytes
return module_mem
def module_mmap_residency(module, free=False):
mmap_touched_mem = 0
module_mem = 0
bounced_mmaps = set()
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nbytes
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
continue
mmap_touched_mem += t.nbytes
if not free:
continue
storage._comfy_tensor_mmap_touched = False
mmap_obj = storage._comfy_tensor_mmap_refs[0]
if mmap_obj in bounced_mmaps:
continue
mmap_obj.bounce()
bounced_mmaps.add(mmap_obj)
return mmap_touched_mem, module_mem
def mark_mmap_dirty(storage):
mmap_refs = getattr(storage, "_comfy_tensor_mmap_refs", None)
if mmap_refs is not None:
DIRTY_MMAPS.add(mmap_refs[0])
def free_pins(size, evict_active=False):
freed_total = 0
for loaded_model in reversed(current_loaded_models):
if size <= 0:
return freed_total
model = loaded_model.model
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
freed = model.partially_unload_ram(size)
freed_total += freed
size -= freed
return freed_total
def ensure_pin_budget(size, evict_active=False):
shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available
if shortfall <= 0:
return True
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
return free_pins(to_free, evict_active=evict_active) >= shortfall
def ensure_pin_registerable(size, evict_active=False):
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if shortfall <= 0:
return True
shortfall += REGISTERABLE_PIN_HYSTERESIS
for loaded_model in reversed(current_loaded_models):
model = loaded_model.model
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
shortfall -= model.unregister_inactive_pins(shortfall)
if shortfall <= 0:
return True
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
class LoadedModel:
def __init__(self, model):
@ -553,9 +581,6 @@ class LoadedModel:
def model_memory(self):
return self.model.model_size()
def model_mmap_residency(self, free=False):
return self.model.model_mmap_residency(free=free)
def model_loaded_memory(self):
return self.model.loaded_size()
@ -635,15 +660,9 @@ WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS:
import comfy.windows
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
def get_free_ram():
return comfy.windows.get_free_ram()
else:
def get_free_ram():
return psutil.virtual_memory().available
if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
@ -657,7 +676,6 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
cleanup_models_gc()
comfy.memory_management.extra_ram_release(max(pins_required, ram_required))
unloaded_model = []
can_unload = []
unloaded_models = []
@ -673,11 +691,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for x in can_unload_sorted:
i = x[-1]
memory_to_free = 1e32
pins_to_free = 1e32
if not DISABLE_SMART_MEMORY or device is None:
if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None):
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
pins_to_free = pins_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
if for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size()
@ -685,18 +701,6 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i)
if pins_to_free > 0:
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
for x in can_unload_sorted:
i = x[-1]
ram_to_free = ram_required - psutil.virtual_memory().available
if ram_to_free <= 0 and i not in unloaded_model:
continue
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
if resident_memory > 0:
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
@ -762,29 +766,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
model_to_unload.model.detach(unpatch_all=False)
model_to_unload.model_finalizer.detach()
total_memory_required = {}
total_pins_required = {}
total_ram_required = {}
for loaded_model in models_to_load:
device = loaded_model.device
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
resident_memory, model_memory = loaded_model.model_mmap_residency()
pinned_memory = loaded_model.model.pinned_memory_size()
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
#make this JIT to keep as much pinned as possible.
pins_required = model_memory - pinned_memory
ram_required = model_memory - resident_memory
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem,
device,
for_dynamic=free_for_dynamic,
pins_required=total_pins_required[device],
ram_required=total_ram_required[device])
for_dynamic=free_for_dynamic)
for device in total_memory_required:
if device != torch.device("cpu"):
@ -1180,6 +1171,7 @@ STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
STREAM_PIN_BUFFERS = {}
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
@ -1220,21 +1212,66 @@ def get_aimdo_cast_buffer(offload_stream, device):
if cast_buffer is None:
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer
def get_pin_buffer(offload_stream):
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
if pin_buffer is None:
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3))
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
elif offload_stream is not None:
event = getattr(pin_buffer, "_comfy_event", None)
if event is not None:
event.synchronize()
delattr(pin_buffer, "_comfy_event")
return pin_buffer
def resize_pin_buffer(pin_buffer, size):
global TOTAL_PINNED_MEMORY
old_size = pin_buffer.size
if size <= old_size:
return True
growth = size - old_size
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
ensure_pin_budget(growth, evict_active=True)
ensure_pin_registerable(growth, evict_active=True)
try:
pin_buffer.extend(size=size, reallocate=True)
except RuntimeError:
return False
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
return True
def reset_cast_buffers():
global TOTAL_PINNED_MEMORY
global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0)
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS):
if offload_stream is not None:
offload_stream.synchronize()
synchronize()
for mmap_obj in DIRTY_MMAPS:
mmap_obj.bounce()
DIRTY_MMAPS.clear()
for pin_buffer in STREAM_PIN_BUFFERS.values():
TOTAL_PINNED_MEMORY -= pin_buffer.size
TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY)
for loaded_model in current_loaded_models:
model = loaded_model.model
if model is not None and model.is_dynamic():
model.model.dynamic_pins[model.load_device]["active"] = False
model.partially_unload_ram(1e30, subsets=[ "patches" ])
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0])
STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()
STREAM_PIN_BUFFERS.clear()
soft_empty_cache()
def get_offload_stream(device):
@ -1280,7 +1317,7 @@ def sync_stream(device, stream):
current_stream(device).wait_stream(stream)
def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
wf_context = nullcontext()
if stream is not None:
wf_context = stream
@ -1288,17 +1325,20 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
wf_context = wf_context.as_context(stream)
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
with wf_context:
for tensor in tensors:
dest_view = dest_views.pop(0)
dest2_view = dest2_views.pop(0) if dest2_views is not None else None
if tensor is None:
continue
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view, stream=stream, destination2=dest2_view):
continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
if hasattr(storage, "_comfy_tensor_mmap_touched"):
storage._comfy_tensor_mmap_touched = True
mark_mmap_dirty(storage)
dest_view.copy_(tensor, non_blocking=non_blocking)
if dest2_view is not None:
dest2_view.copy_(dest_view, non_blocking=non_blocking)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
@ -1339,14 +1379,18 @@ TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
ram = get_total_memory(torch.device("cpu"))
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.40 # Windows limit is apparently 50%
MAX_PINNED_MEMORY = ram * 0.40 # Windows limit is apparently 50%
else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.90
MAX_PINNED_MEMORY = ram * 0.90
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
def pinned_hostbuf_size(size):
return max(0, int(min(size, MAX_PINNED_MEMORY) * 2))
def discard_cuda_async_error():
try:
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
@ -1378,8 +1422,8 @@ def pin_memory(tensor):
return False
size = tensor.nbytes
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
return False
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
ensure_pin_registerable(size)
ptr = tensor.data_ptr()
if ptr == 0:
@ -1416,7 +1460,8 @@ def unpin_memory(tensor):
return False
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
size = PINNED_MEMORY.pop(ptr)
TOTAL_PINNED_MEMORY -= size
return True
else:
logging.warning("Unpin error.")

View File

@ -35,6 +35,7 @@ import comfy.model_management
import comfy.ops
import comfy.patcher_extension
import comfy.utils
import comfy_aimdo.host_buffer
from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@ -117,6 +118,8 @@ def string_to_seed(data):
return comfy.utils.string_to_seed(data)
class LowVramPatch:
is_lowvram_patch = True
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
self.patches = patches
@ -124,11 +127,21 @@ class LowVramPatch:
self.set_func = set_func
self.prepared_patches = None
def prepare(self, allocate_buffer, stream):
self.prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
def memory_required(self):
counter = [0]
for patch in self.patches[self.key]:
comfy.lora.prefetch_prepared_value(patch[1], counter, None, None, False)
return counter[0]
def prepare(self, destination, stream, copy=True, commit=True):
counter = [0]
prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream, copy), patch[2], patch[3], patch[4])
for patch in self.patches[self.key]
]
if commit:
self.prepared_patches = prepared_patches
return prepared_patches
def clear_prepared(self):
self.prepared_patches = None
@ -341,9 +354,6 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model)
return self.size
def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free)
def loaded_size(self):
return self.model.model_loaded_weight_memory
@ -1118,8 +1128,12 @@ class ModelPatcher:
# Pinned memory pressure tracking is only implemented for DynamicVram loading
return 0
def loaded_ram_size(self):
# Loaded RAM pressure tracking is only implemented for DynamicVram loading
return 0
def partially_unload_ram(self, ram_to_unload):
pass
return 0
def detach(self, unpatch_all=True):
self.eject_model()
@ -1550,6 +1564,16 @@ class ModelPatcherDynamic(ModelPatcher):
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
if not hasattr(self.model, "dynamic_vbars"):
self.model.dynamic_vbars = {}
if not hasattr(self.model, "dynamic_pins"):
self.model.dynamic_pins = {}
if self.load_device not in self.model.dynamic_pins:
self.model.dynamic_pins[self.load_device] = {
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"hostbufs_initialized": False,
"failed": False,
"active": False,
}
self.non_dynamic_delegate_model = None
assert load_device is not None
@ -1611,6 +1635,14 @@ class ModelPatcherDynamic(ModelPatcher):
self.unpatch_hooks()
vbar = self._vbar_get(create=True)
pin_state = self.model.dynamic_pins[self.load_device]
if not pin_state["hostbufs_initialized"]:
hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size())
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0])
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0])
pin_state["hostbufs_initialized"] = True
pin_state["failed"] = False
pin_state["active"] = True
if vbar is not None:
vbar.prioritize()
@ -1636,7 +1668,9 @@ class ModelPatcherDynamic(ModelPatcher):
if key in self.patches:
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
return (True, 0)
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
lowvram_patch = LowVramPatch(key, self.patches)
lowvram_patch._pin_state = pin_state
setattr(m, param_key + "_lowvram_function", lowvram_patch)
num_patches += 1
else:
setattr(m, param_key + "_lowvram_function", None)
@ -1653,6 +1687,9 @@ class ModelPatcherDynamic(ModelPatcher):
def force_load_param(self, param_key, device_to):
key = key_param_name_to_key(n, param_key)
weight, _, _ = get_key_weight(self.model, key)
if weight is None:
return
if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to, force_cast=True)
@ -1662,17 +1699,23 @@ class ModelPatcherDynamic(ModelPatcher):
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
m.pin_failed = False
m.seed_key = n
m._pin_state = pin_state
set_dirty(m, dirty)
force_load, v_weight_size = setup_param(self, m, n, "weight")
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
force_load = force_load or force_load_bias
v_weight_size += v_weight_bias
#Models that mix tiny and giant weights can causing lopsided stream buffer
#rotations and stall. force the tinys over.
if module_mem > 16 * 1024:
force_load, v_weight_size = setup_param(self, m, n, "weight")
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
force_load = force_load or force_load_bias
v_weight_size += v_weight_bias
if force_load:
logging.info(f"Module {n} has resizing Lora - force loading")
else:
force_load=True
if force_load:
logging.info(f"Module {n} has resizing Lora - force loading")
force_load_param(self, "weight", device_to)
force_load_param(self, "bias", device_to)
else:
@ -1740,23 +1783,58 @@ class ModelPatcherDynamic(ModelPatcher):
return freed
def pinned_memory_size(self):
total = 0
loading = self._load_list(for_dynamic=True)
for x in loading:
_, _, _, _, m, _ = x
pin = comfy.pinned_memory.get_pin(m)
if pin is not None:
total += pin.numel() * pin.element_size()
return total
def loaded_ram_size(self):
return (self.model.dynamic_pins[self.load_device]["weights"][0].size +
self.model.dynamic_pins[self.load_device]["patches"][0].size)
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading:
*_, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
if ram_to_unload <= 0:
return
def pinned_memory_size(self):
return (self.model.dynamic_pins[self.load_device]["weights"][3][0] +
self.model.dynamic_pins[self.load_device]["patches"][3][0])
def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets:
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
split = stack_split[0]
while split >= 0:
module, offset = stack[split]
split -= 1
stack_split[0] = split
if not module._pin_registered:
continue
size = module._pin.numel() * module._pin.element_size()
if torch.cuda.cudart().cudaHostUnregister(module._pin.data_ptr()) != 0:
comfy.model_management.discard_cuda_async_error()
continue
module._pin_registered = False
comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size)
pinned_size[0] = max(0, pinned_size[0] - size)
freed += size
ram_to_unload -= size
if ram_to_unload <= 0:
return freed
return freed
def partially_unload_ram(self, ram_to_unload, subsets=[ "weights", "patches" ]):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets:
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
while len(stack) > 0:
module, offset = stack.pop()
size = module._pin.numel() * module._pin.element_size()
del module._pin
hostbuf.truncate(offset, do_unregister=module._pin_registered)
stack_split[0] = min(stack_split[0], len(stack) - 1)
if module._pin_registered:
comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size)
pinned_size[0] = max(0, pinned_size[0] - size)
freed += size
ram_to_unload -= size
if ram_to_unload <= 0:
return freed
return freed
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
#This isn't used by the core at all and can only be to load a model out of

View File

@ -75,6 +75,8 @@ except:
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@ -91,6 +93,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
offload_stream = None
cast_buffer = None
cast_buffer_offset = 0
stream_pin_hostbuf = None
stream_pin_offset = 0
stream_pin_queue = []
def ensure_offload_stream(module, required_size, check_largest):
nonlocal offload_stream
@ -124,6 +129,22 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
cast_buffer_offset += buffer_size
return buffer
def get_stream_pin_buffer_offset(buffer_size):
nonlocal stream_pin_hostbuf
nonlocal stream_pin_offset
if buffer_size == 0 or offload_stream is None:
return None
if stream_pin_hostbuf is None:
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
if stream_pin_hostbuf is None:
return None
offset = stream_pin_offset
stream_pin_offset += buffer_size
return offset
for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
@ -162,23 +183,47 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
if xfer_dest is None:
xfer_dest = get_cast_buffer(dest_size)
if signature is None and pin is None:
comfy.pinned_memory.pin_memory(s)
pin = comfy.pinned_memory.get_pin(s)
else:
pin = None
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
if xfer_source is not None:
if getattr(xfer_source, "is_lowvram_patch", False):
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
else:
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
if pin is not None:
comfy.model_management.cast_to_gathered(xfer_source, pin)
xfer_source = [ pin ]
#send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
def handle_pin(m, pin, source, dest, subset="weights", size=None):
if pin is not None:
cast_maybe_lowvram_patch([pin], dest, offload_stream)
return
if signature is None:
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
pin = comfy.pinned_memory.get_pin(m, subset=subset)
if pin is not None:
if isinstance(source, list):
comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest)
else:
cast_maybe_lowvram_patch(source, pin, None)
cast_maybe_lowvram_patch([ pin ], dest, offload_stream)
return
if pin is None:
pin_offset = get_stream_pin_buffer_offset(size)
if pin_offset is not None:
stream_pin_queue.append((source, pin_offset, size, dest))
return
cast_maybe_lowvram_patch(source, dest, offload_stream)
handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size)
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_source = getattr(s, param_key + "_lowvram_function", None)
if lowvram_source is not None:
ensure_offload_stream(s, cast_buffer_offset, False)
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
lowvram_size = lowvram_source.memory_required()
lowvram_dest = get_cast_buffer(lowvram_size)
lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True)
pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches")
handle_pin(lowvram_source, pin, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size)
prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest
@ -186,6 +231,23 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch
if stream_pin_offset > 0:
if stream_pin_hostbuf.size < stream_pin_offset:
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
for xfer_source, _, _, xfer_dest in stream_pin_queue:
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
return offload_stream
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf)
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
if isinstance(xfer_source, list):
comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest)
else:
cast_maybe_lowvram_patch(xfer_source, pin, None)
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
stream_pin_hostbuf._comfy_event = offload_stream.record_event()
return offload_stream

View File

@ -2,42 +2,62 @@ import comfy.model_management
import comfy.memory_management
import comfy_aimdo.host_buffer
import comfy_aimdo.torch
import torch
from comfy.cli_args import args
def get_pin(module):
return getattr(module, "_pin", None)
def get_pin(module, subset="weights"):
pin = getattr(module, "_pin", None)
if pin is None or module._pin_registered or args.disable_pinned_memory:
return pin
def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
_, _, stack_split, pinned_size = module._pin_state[subset]
size = pin.nbytes
comfy.model_management.ensure_pin_registerable(size)
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
comfy.model_management.discard_cuda_async_error()
return pin
module._pin_registered = True
stack_split[0] = max(stack_split[0], module._pin_stack_index)
comfy.model_management.TOTAL_PINNED_MEMORY += size
pinned_size[0] += size
return pin
def pin_memory(module, subset="weights", size=None):
pin_state = module._pin_state
if args.disable_pinned_memory:
return
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
pin = get_pin(module, subset)
if pin is not None or pin_state["failed"]:
return
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
module.pin_failed = True
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
if size is None:
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
offset = hostbuf.size
registerable_size = size + max(0, hostbuf.size - pinned_size[0])
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
if (not comfy.model_management.ensure_pin_budget(size) or
not comfy.model_management.ensure_pin_registerable(registerable_size)):
pin_state["failed"] = True
return False
try:
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
hostbuf.extend(size=size)
except RuntimeError:
module.pin_failed = True
pin_state["failed"] = True
return False
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
module._pin_hostbuf = hostbuf
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
stack.append((module, offset))
module._pin_registered = True
module._pin_stack_index = len(stack) - 1
stack_split[0] = max(stack_split[0], module._pin_stack_index)
comfy.model_management.TOTAL_PINNED_MEMORY += size
pinned_size[0] += size
return True
def unpin_memory(module):
if get_pin(module) is None:
return 0
size = module._pin.numel() * module._pin.element_size()
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
del module._pin
del module._pin_hostbuf
return size

View File

@ -113,7 +113,6 @@ def load_safetensors(ckpt):
"_comfy_tensor_file_slice",
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
setattr(storage, "_comfy_tensor_mmap_touched", False)
sd[name] = tensor
return sd, header.get("__metadata__", {}),
@ -1451,4 +1450,3 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res
return res

View File

@ -1,52 +0,0 @@
import ctypes
import logging
import psutil
from ctypes import wintypes
import comfy_aimdo.control
psapi = ctypes.WinDLL("psapi")
kernel32 = ctypes.WinDLL("kernel32")
class PERFORMANCE_INFORMATION(ctypes.Structure):
_fields_ = [
("cb", wintypes.DWORD),
("CommitTotal", ctypes.c_size_t),
("CommitLimit", ctypes.c_size_t),
("CommitPeak", ctypes.c_size_t),
("PhysicalTotal", ctypes.c_size_t),
("PhysicalAvailable", ctypes.c_size_t),
("SystemCache", ctypes.c_size_t),
("KernelTotal", ctypes.c_size_t),
("KernelPaged", ctypes.c_size_t),
("KernelNonpaged", ctypes.c_size_t),
("PageSize", ctypes.c_size_t),
("HandleCount", wintypes.DWORD),
("ProcessCount", wintypes.DWORD),
("ThreadCount", wintypes.DWORD),
]
def get_free_ram():
#Windows is way too conservative and chalks recently used uncommitted model RAM
#as "in-use". So, calculate free RAM for the sake of general use as the greater of:
#
#1: What psutil says
#2: Total Memory - (Committed Memory - VRAM in use)
#
#We have to subtract VRAM in use from the comitted memory as WDDM creates a naked
#commit charge for all VRAM used just incase it wants to page it all out. This just
#isn't realistic so "overcommit" on our calculations by just subtracting it off.
pi = PERFORMANCE_INFORMATION()
pi.cb = ctypes.sizeof(pi)
if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb):
logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal")
return psutil.virtual_memory().available
committed = pi.CommitTotal * pi.PageSize
total = pi.PhysicalTotal * pi.PageSize
return max(psutil.virtual_memory().available,
total - (committed - comfy_aimdo.control.get_total_vram_usage()))

View File

@ -0,0 +1,111 @@
"""Pure-numpy port of MediaPipe's face_geometry (FACE_LANDMARK_PIPELINE mode)
+ weighted Procrustes solver. Computes the 4x4 facial transformation matrix.
"""
from __future__ import annotations
import math
import numpy as np
def _solve_weighted_orthogonal_problem(src: np.ndarray, tgt: np.ndarray, weights: np.ndarray) -> np.ndarray:
"""Weighted orthogonal Procrustes (similarity). Returns 4x4 M with
`target ≈ M @ homogeneous(source)` in the weighted LS sense. fp64 for
SVD stability. Port of procrustes_solver.cc."""
sqrt_w = np.sqrt(weights.astype(np.float64))
w_total = float((sqrt_w ** 2).sum())
ws = src.astype(np.float64) * sqrt_w
wt = tgt.astype(np.float64) * sqrt_w
c_w = (ws @ sqrt_w) / w_total
centered = ws - np.outer(c_w, sqrt_w)
U, _S, Vt = np.linalg.svd(wt @ centered.T, full_matrices=True)
# Disallow reflection: flip the least-significant axis when det(U)·det(V)<0.
post, pre = U.copy(), Vt.T.copy()
if np.linalg.det(post) * np.linalg.det(pre) < 0:
post[:, 2] *= -1.0
R = post @ pre.T
denom = float((centered * ws).sum())
if denom < 1e-12:
raise ValueError("Procrustes denominator collapsed (degenerate source).")
scale = float((R @ centered * wt).sum()) / denom
translation = ((wt - scale * (R @ ws)) @ sqrt_w) / w_total
M = np.eye(4, dtype=np.float64)
M[:3, :3] = scale * R
M[:3, 3] = translation
return M
def _estimate_scale(canonical: np.ndarray, runtime: np.ndarray, weights: np.ndarray) -> float:
"""scale = ‖first column of M[:3]‖ per geometry_pipeline.cc::EstimateScale."""
return float(np.linalg.norm(_solve_weighted_orthogonal_problem(canonical, runtime, weights)[:3, 0]))
def solve_facial_transformation_matrix(
landmarks_normalized: np.ndarray,
canonical_vertices: np.ndarray,
procrustes_indices: np.ndarray,
procrustes_weights: np.ndarray,
image_width: int,
image_height: int,
# face_geometry_calculator_options.pbtxt defaults
vertical_fov_degrees: float = 63.0,
near: float = 1.0,
) -> np.ndarray:
"""4x4 facial transformation matrix via two-pass scale recovery
`landmarks_normalized` is (N, 3) in MediaPipe normalized convention: x, y
in [0,1] with TOP-LEFT origin, z in width-scaled units.
"""
h_near = 2.0 * near * math.tan(0.5 * math.radians(vertical_fov_degrees))
w_near = image_width * h_near / image_height
sub = procrustes_indices.astype(np.int64)
screen = landmarks_normalized[sub].T.astype(np.float64).copy()
canon = canonical_vertices[sub].T.astype(np.float64).copy()
weights = procrustes_weights.astype(np.float64)
# ProjectXY (TOP_LEFT y-flip, then scale all 3 axes; z uses x-scale).
screen[1] = 1.0 - screen[1]
screen[0] = screen[0] * w_near - 0.5 * w_near
screen[1] = screen[1] * h_near - 0.5 * h_near
screen[2] = screen[2] * w_near
depth_offset = float(screen[2].mean())
def _unproject(s: np.ndarray, scale: float) -> np.ndarray:
s = s.copy()
s[2] = (s[2] - depth_offset + near) / scale
s[0] *= s[2] / near
s[1] *= s[2] / near
s[2] *= -1.0
return s
first = screen.copy()
first[2] *= -1.0
s1 = _estimate_scale(canon, first, weights) # 1st pass: Procrustes on projected XY
s2 = _estimate_scale(canon, _unproject(screen, s1), weights) # 2nd pass: rescale z by s1, un-project XY
return _solve_weighted_orthogonal_problem(canon, _unproject(screen, s1 * s2), weights).astype(np.float32)
def transformation_matrix_from_detection(face_dict: dict, image_width: int, image_height: int, canonical_data: dict) -> np.ndarray:
"""Adapt a FaceLandmarker face dict to MP's normalized convention and solve.
FaceMesh emits (x, y, z) in 192-canonical units; MP's geometry expects
z_norm = z_canonical * scale_x / image_width"""
lmks_xy, lmks_3d = face_dict["landmarks_xy"], face_dict["landmarks_3d"]
aug = np.concatenate([lmks_3d[:, :2].astype(np.float64), np.ones((lmks_xy.shape[0], 1))], axis=1)
M, *_ = np.linalg.lstsq(aug, lmks_xy.astype(np.float64), rcond=None)
scale_x = float(np.linalg.norm(M[0]))
z_scale = scale_x / image_width if scale_x > 1e-6 else 1.0 / image_width
normalized = np.empty((lmks_xy.shape[0], 3), dtype=np.float32)
normalized[:, 0] = lmks_xy[:, 0] / image_width
normalized[:, 1] = lmks_xy[:, 1] / image_height
normalized[:, 2] = lmks_3d[:, 2] * z_scale
return solve_facial_transformation_matrix(
normalized, canonical_data["canonical_vertices"],
canonical_data["procrustes_indices"], canonical_data["procrustes_weights"],
image_width=image_width, image_height=image_height,
)

View File

@ -0,0 +1,682 @@
"""Pure-PyTorch port of MediaPipe's face_landmarker_v2_with_blendshapes.task:
BlazeFace detector → FaceMesh v2 → ARKit-52 blendshapes."""
from __future__ import annotations
import math
from functools import lru_cache
from typing import List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from scipy.special import expit
from torch import Tensor, nn
# Values below must stay verbatim with the published face_landmarker_v2 graph
# face_blendshapes_graph.cc::kLandmarksSubsetIdxs
_BS_INPUT_INDICES: Tuple[int, ...] = (
0, 1, 4, 5, 6, 7, 8, 10, 13, 14, 17, 21, 33, 37, 39, 40, 46, 52, 53, 54,
55, 58, 61, 63, 65, 66, 67, 70, 78, 80, 81, 82, 84, 87, 88, 91, 93, 95,
103, 105, 107, 109, 127, 132, 133, 136, 144, 145, 146, 148, 149, 150, 152,
153, 154, 155, 157, 158, 159, 160, 161, 162, 163, 168, 172, 173, 176, 178,
181, 185, 191, 195, 197, 234, 246, 249, 251, 263, 267, 269, 270, 276, 282,
283, 284, 285, 288, 291, 293, 295, 296, 297, 300, 308, 310, 311, 312, 314,
317, 318, 321, 323, 324, 332, 334, 336, 338, 356, 361, 362, 365, 373, 374,
375, 377, 378, 379, 380, 381, 382, 384, 385, 386, 387, 388, 389, 390, 397,
398, 400, 402, 405, 409, 415, 454, 466, 468, 469, 470, 471, 472, 473, 474,
475, 476, 477,
)
# face_blendshapes_graph.cc::kCategoryNames
BLENDSHAPE_NAMES: Tuple[str, ...] = (
"_neutral", "browDownLeft", "browDownRight", "browInnerUp", "browOuterUpLeft",
"browOuterUpRight", "cheekPuff", "cheekSquintLeft", "cheekSquintRight",
"eyeBlinkLeft", "eyeBlinkRight", "eyeLookDownLeft", "eyeLookDownRight",
"eyeLookInLeft", "eyeLookInRight", "eyeLookOutLeft", "eyeLookOutRight",
"eyeLookUpLeft", "eyeLookUpRight", "eyeSquintLeft", "eyeSquintRight",
"eyeWideLeft", "eyeWideRight", "jawForward", "jawLeft", "jawOpen",
"jawRight", "mouthClose", "mouthDimpleLeft", "mouthDimpleRight",
"mouthFrownLeft", "mouthFrownRight", "mouthFunnel", "mouthLeft",
"mouthLowerDownLeft", "mouthLowerDownRight", "mouthPressLeft",
"mouthPressRight", "mouthPucker", "mouthRight", "mouthRollLower",
"mouthRollUpper", "mouthShrugLower", "mouthShrugUpper", "mouthSmileLeft",
"mouthSmileRight", "mouthStretchLeft", "mouthStretchRight",
"mouthUpperUpLeft", "mouthUpperUpRight", "noseSneerLeft", "noseSneerRight",
)
# face_detection.pbtxt — short-range BlazeFace.
_BF_NUM_LAYERS = 4
_BF_INPUT_SIZE = 128
_BF_STRIDES = (8, 16, 16, 16)
_BF_ANCHOR_OFFSET_X = 0.5
_BF_ANCHOR_OFFSET_Y = 0.5
_BF_ASPECT_RATIOS = (1.0,)
_BF_INTERP_SCALE_AR = 1.0
_BF_BOX_SCALE = 128.0
_BF_KP_OFFSET = 4
_BF_SCORE_CLIP = 100.0
_BF_MIN_SCORE = 0.5
# face_detection_full_range.pbtxt — 48x48 grid at stride 4, 1 anchor/cell.
_BF_FR_INPUT_SIZE = 192
_BF_FR_GRID = 48
_BF_FR_NUM_ANCHORS = _BF_FR_GRID * _BF_FR_GRID
_BF_FR_BOX_SCALE = 192.0
_BF_FR_SCORE_CLIP = 100.0
_FM_INPUT_SIZE = 192
# Face ROI: 1.5xbbox rect warped anisotropically into 192x192.
_FACE_LEFT_EYE_KP = 0
_FACE_RIGHT_EYE_KP = 1
_FACE_ROI_SCALE_X = 1.5
_FACE_ROI_SCALE_Y = 1.5
_FACE_ROI_TARGET_ANGLE = 0.0
def _tf_same_pad(x: Tensor, kernel: int, stride: int) -> Tensor:
"""TF SAME pad (asymmetric on stride-2; PyTorch's symmetric pad undershoots by 1 px)."""
H, W = x.shape[-2], x.shape[-1]
pad_h = max(((H + stride - 1) // stride - 1) * stride + kernel - H, 0)
pad_w = max(((W + stride - 1) // stride - 1) * stride + kernel - W, 0)
if pad_h == 0 and pad_w == 0:
return x
return F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
# BlazeFace short-range: stem 5x5/s2 → 16 BlazeBlocks → parallel heads at
# 16²x88 (2 anchors/cell) and 8²x96 (6/cell) = 896 anchors. (in, out, stride):
_BLAZEFACE_BLOCKS = [
(24, 24, 1), (24, 28, 1), (28, 32, 2), (32, 36, 1),
(36, 42, 1), (42, 48, 2), (48, 56, 1), (56, 64, 1),
(64, 72, 1), (72, 80, 1), (80, 88, 1), (88, 96, 2),
(96, 96, 1), (96, 96, 1), (96, 96, 1), (96, 96, 1),
]
class BlazeFaceBlock(nn.Module):
"""DW 3x3 + PW + residual. Residual max-pools on stride>1, channel-pads on out_ch>in_ch."""
def __init__(self, in_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
self.in_ch, self.out_ch, self.stride = in_ch, out_ch, stride
self.depthwise = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, device=device, dtype=dtype)
self.pointwise = ops.Conv2d(in_ch, out_ch, 1, padding=0, bias=True, device=device, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
residual = F.max_pool2d(x, 2, 2) if self.stride > 1 else x
if self.out_ch > self.in_ch:
residual = F.pad(residual, (0, 0, 0, 0, 0, self.out_ch - self.in_ch))
x = _tf_same_pad(x, 3, self.stride) if self.stride > 1 else F.pad(x, (1, 1, 1, 1))
return F.relu(self.pointwise(self.depthwise(x)) + residual)
class BlazeFace(nn.Module):
"""Short-range BlazeFace: (B, 3, 128, 128) in [-1, 1] → 896 anchors x 17."""
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
self.stem = ops.Conv2d(3, 24, 5, stride=2, padding=0, bias=True, **kw)
self.blocks = nn.ModuleList(BlazeFaceBlock(i, o, s, device=device, dtype=dtype, operations=operations)
for (i, o, s) in _BLAZEFACE_BLOCKS)
# 16²x2 + 8²x6 = 512 + 384 = 896 anchors.
self.cls_16 = ops.Conv2d(88, 2, 1, padding=0, bias=True, **kw)
self.cls_8 = ops.Conv2d(96, 6, 1, padding=0, bias=True, **kw)
self.reg_16 = ops.Conv2d(88, 32, 1, padding=0, bias=True, **kw)
self.reg_8 = ops.Conv2d(96, 96, 1, padding=0, bias=True, **kw)
def forward(self, image_chw_normalized: Tensor) -> tuple[Tensor, Tensor]:
x = F.relu(self.stem(_tf_same_pad(image_chw_normalized, 5, 2)))
# 16x16 tap is block-10 output (before the 88→96 stride-2 in block 11).
for i in range(11):
x = self.blocks[i](x)
feat_16 = x
for i in range(11, 16):
x = self.blocks[i](x)
feat_8 = x
def flat(t, a, k): # NHWC flatten → (B, H*W*A, K)
B, _, H, W = t.shape
return t.permute(0, 2, 3, 1).reshape(B, H * W * a, k)
cls = torch.cat([flat(self.cls_16(feat_16), 2, 1), flat(self.cls_8(feat_8), 6, 1)], dim=1)
reg = torch.cat([flat(self.reg_16(feat_16), 2, 16), flat(self.reg_8(feat_8), 6, 16)], dim=1)
return reg, cls
# BlazeFace full-range (face_detection_full_range_sparse.tflite): MobileNetV2-ish
# backbone + top-down FPN, 192² input → 2304 anchors at the 48x48 grid.
class FRBlock(nn.Module):
"""Double inverted residual: DW → PW(mid) → DW → PW(out) [+ residual].
Per source tflite: dw* have no fused activation, pw1 is always ReLU, pw2
is ReLU only when no residual (else ReLU fuses into the ADD).
"""
def __init__(self, in_ch: int, mid_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
self.has_residual = (in_ch == out_ch and stride == 1)
self.dw1 = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, **kw)
self.pw1 = ops.Conv2d(in_ch, mid_ch, 1, padding=0, bias=True, **kw)
self.dw2 = ops.Conv2d(mid_ch, mid_ch, 3, stride=1, padding=0, groups=mid_ch, bias=True, **kw)
self.pw2 = ops.Conv2d(mid_ch, out_ch, 1, padding=0, bias=True, **kw)
def forward(self, x: Tensor) -> Tensor:
residual = x if self.has_residual else None
x = F.relu(self.pw1(self.dw1(F.pad(x, (1, 1, 1, 1)))))
x = self.pw2(self.dw2(F.pad(x, (1, 1, 1, 1))))
return F.relu(x + residual) if residual is not None else F.relu(x)
# (in_ch, mid_ch, out_ch, stride). Stages downsample 96²x32 → 48²x64 → 24²x128
# → 12²x192 → 6²x384. Lateral taps at indices 4, 7, 10 (see _FR_LATERAL_*).
_FR_BACKBONE_BLOCKS = [
(32, 8, 32, 1), (32, 8, 32, 1), # 96²x32
(32, 16, 64, 2), (64, 16, 64, 1), (64, 16, 64, 1), # 48²x64 — tap[0]
(64, 32, 128, 2), (128, 32, 128, 1), (128, 32, 128, 1), # 24²x128 — tap[1]
(128, 48, 192, 2), (192, 48, 192, 1), (192, 48, 192, 1), # 12²x192 — tap[2]
(192, 96, 384, 2), (384, 96, 384, 1), (384, 96, 384, 1), (384, 96, 384, 1), # 6²x384
]
_FR_LATERAL_TAP_INDICES = (4, 7, 10)
_FR_LATERAL_CHANNELS = ((64, 48), (128, 64), (192, 96)) # (in, out) per side-conv
# Decoder blocks per FPN level (after upsample-and-merge with the lateral).
_FR_DECODER_BLOCKS = [
[(96, 48, 96, 1), (96, 48, 96, 1)], # 12²x96
[(64, 32, 64, 1), (64, 32, 64, 1)], # 24²x64
[(48, 24, 48, 1)], # 48²x48 — feeds the heads
]
def _dcr_depth_to_space(t: Tensor, r: int, c_out: int) -> Tensor:
"""TF DEPTH_TO_SPACE in DCR layout (input channels = (i, j, c_out)).
pixel_shuffle uses CRD which permutes output channels for c_out > 1."""
B_, _, H_, W_ = t.shape
t = t.reshape(B_, r, r, c_out, H_, W_)
t = t.permute(0, 3, 4, 1, 5, 2).contiguous()
return t.reshape(B_, c_out, H_ * r, W_ * r)
class BlazeFaceFullRange(nn.Module):
"""Full-range face detector: (B, 3, 192, 192) in [-1, 1] → 2304 anchors x 17 values."""
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
mk_block = lambda i, m, o, s: FRBlock(i, m, o, s, device=device, dtype=dtype, operations=operations)
self.stem = ops.Conv2d(3, 32, 3, stride=2, padding=0, bias=True, **kw)
self.backbone = nn.ModuleList(mk_block(i, m, o, s) for (i, m, o, s) in _FR_BACKBONE_BLOCKS)
self.lateral_convs = nn.ModuleList(ops.Conv2d(i, o, 1, padding=0, bias=True, **kw) for (i, o) in _FR_LATERAL_CHANNELS)
self.top_conv = ops.Conv2d(384, 96, 1, padding=0, bias=True, **kw)
self.decoder_levels = nn.ModuleList(
nn.ModuleList(mk_block(i, m, o, s) for (i, m, o, s) in lvl) for lvl in _FR_DECODER_BLOCKS
)
# 96→64 before 12→24, 64→48 before 24→48.
self.decoder_reduce_convs = nn.ModuleList([
ops.Conv2d(96, 64, 1, padding=0, bias=True, **kw),
ops.Conv2d(64, 48, 1, padding=0, bias=True, **kw),
])
# Heads mix 2x2-cell info via DW-stride-2 + depth_to_space block_size=2.
self.cls_conv = ops.Conv2d(48, 4, 1, padding=0, bias=True, **kw)
self.cls_dw = ops.Conv2d(4, 4, 3, stride=2, padding=0, groups=4, bias=True, **kw)
self.reg_conv = ops.Conv2d(48, 64, 1, padding=0, bias=True, **kw)
self.reg_dw = ops.Conv2d(64, 64, 3, stride=2, padding=0, groups=64, bias=True, **kw)
def forward(self, image_chw_normalized: Tensor) -> tuple[Tensor, Tensor]:
# Symmetric pad-1 throughout (full-range tflite uses explicit TF PAD, not SAME).
x = F.relu(self.stem(F.pad(image_chw_normalized, (1, 1, 1, 1))))
tap_set = set(_FR_LATERAL_TAP_INDICES)
laterals: list[Tensor] = []
for i, blk in enumerate(self.backbone):
x = blk(x)
if i in tap_set:
laterals.append(x)
# top_conv / lateral_convs / decoder_reduce_convs all have fused ReLU in the tflite.
p = F.relu(self.top_conv(x))
laterals_rev = list(reversed(laterals))
lateral_convs_rev = list(reversed(self.lateral_convs))
for level in range(len(self.decoder_levels)):
lateral = laterals_rev[level]
p = F.interpolate(p, size=lateral.shape[-2:], mode="bilinear", align_corners=False)
p = p + F.relu(lateral_convs_rev[level](lateral))
for blk in self.decoder_levels[level]:
p = blk(p)
if level < len(self.decoder_reduce_convs):
p = F.relu(self.decoder_reduce_convs[level](p))
c = self.cls_dw(F.pad(self.cls_conv(p), (1, 1, 1, 1)))
c = _dcr_depth_to_space(c, r=2, c_out=1)
r = self.reg_dw(F.pad(self.reg_conv(p), (1, 1, 1, 1)))
r = _dcr_depth_to_space(r, r=2, c_out=16)
B = c.shape[0]
cls_out = c.permute(0, 2, 3, 1).reshape(B, _BF_FR_NUM_ANCHORS, 1)
reg_out = r.permute(0, 2, 3, 1).reshape(B, _BF_FR_NUM_ANCHORS, 16)
return reg_out, cls_out
@lru_cache(maxsize=1)
def _blazeface_full_range_anchors() -> np.ndarray:
"""2304 anchors over 48x48; anchor_w=anchor_h=1 (fixed_anchor_size)."""
feat = _BF_FR_GRID
yy, xx = np.meshgrid(np.arange(feat, dtype=np.float32), np.arange(feat, dtype=np.float32), indexing="ij")
cx, cy, ones = (xx + 0.5) / feat, (yy + 0.5) / feat, np.ones_like(xx)
return np.stack([cx, cy, ones, ones], axis=-1).reshape(_BF_FR_NUM_ANCHORS, 4)
def _decode_blazeface_full_range(regressors: np.ndarray, classificators: np.ndarray,
score_thresh: float = _BF_MIN_SCORE) -> np.ndarray:
"""Same decode as short-range with 2304-anchor grid and box_scale=192."""
scores = expit(np.clip(classificators[:, 0], -_BF_FR_SCORE_CLIP, _BF_FR_SCORE_CLIP))
keep = scores >= score_thresh
if not keep.any():
return np.empty((0, 17), dtype=np.float32)
r = regressors[keep] / _BF_FR_BOX_SCALE
a = _blazeface_full_range_anchors()[keep]
cxs, cys, aws, ahs = a[:, 0:1], a[:, 1:2], a[:, 2:3], a[:, 3:4]
xc, yc = r[:, 0:1] * aws + cxs, r[:, 1:2] * ahs + cys
w, h = r[:, 2:3] * aws, r[:, 3:4] * ahs
out = np.empty((r.shape[0], 17), dtype=np.float32)
out[:, 0:1], out[:, 1:2], out[:, 2:3], out[:, 3:4] = xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2
out[:, 4:16:2] = r[:, _BF_KP_OFFSET::2] * aws + cxs
out[:, 5:16:2] = r[:, _BF_KP_OFFSET + 1::2] * ahs + cys
out[:, 16] = scores[keep]
return out
# FaceMesh (face_landmarks_detector.tflite): PReLU variant of BlazeBlock,
# 17 blocks, heads for 478x3 landmarks + presence.
_FACEMESH_BLOCKS = [ # (in_ch, out_ch, stride)
(16, 16, 1), (16, 16, 1), (16, 32, 2), (32, 32, 1), (32, 32, 1), (32, 64, 2),
(64, 64, 1), (64, 64, 1), (64, 128, 2), (128, 128, 1), (128, 128, 1), (128, 128, 2),
(128, 128, 1), (128, 128, 1), (128, 128, 2), (128, 128, 1), (128, 128, 1),
]
class FaceMeshBlock(nn.Module):
"""PReLU BlazeBlock: PReLU between DW and PW, and after the residual add."""
def __init__(self, in_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
self.in_ch, self.out_ch, self.stride = in_ch, out_ch, stride
self.depthwise = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, **kw)
self.prelu_dwise = nn.PReLU(num_parameters=in_ch, **kw)
self.pointwise = ops.Conv2d(in_ch, out_ch, 1, padding=0, bias=True, **kw)
self.prelu_out = nn.PReLU(num_parameters=out_ch, **kw)
def forward(self, x: Tensor) -> Tensor:
residual = F.max_pool2d(x, 2, 2) if self.stride > 1 else x
if self.out_ch > self.in_ch:
residual = F.pad(residual, (0, 0, 0, 0, 0, self.out_ch - self.in_ch))
x = _tf_same_pad(x, 3, self.stride) if self.stride > 1 else F.pad(x, (1, 1, 1, 1))
return self.prelu_out(self.pointwise(self.prelu_dwise(self.depthwise(x))) + residual)
class FaceMesh(nn.Module):
NUM_LANDMARKS = 478
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
self.stem = ops.Conv2d(3, 16, 3, stride=2, padding=0, bias=True, **kw)
self.prelu_stem = nn.PReLU(num_parameters=16, **kw)
self.blocks = nn.ModuleList(FaceMeshBlock(i, o, s, device=device, dtype=dtype, operations=operations)
for (i, o, s) in _FACEMESH_BLOCKS)
self.head_reduce = ops.Conv2d(128, 8, 1, padding=0, bias=True, **kw)
self.prelu_head_reduce = nn.PReLU(num_parameters=8, **kw)
self.head_block = FaceMeshBlock(8, 8, 1, device=device, dtype=dtype, operations=operations)
self.head_presence = ops.Conv2d(8, 1, 3, padding=0, bias=True, **kw)
self.head_landmarks = ops.Conv2d(8, self.NUM_LANDMARKS * 3, 3, padding=0, bias=True, **kw)
def forward(self, face_chw_normalized: Tensor) -> tuple[Tensor, Tensor]:
"""(B, 3, 192, 192) in [0, 1] → ((B, 478, 3) landmarks in 192-canonical, (B,) presence)."""
x = self.prelu_stem(self.stem(_tf_same_pad(face_chw_normalized, 3, 2)))
for blk in self.blocks:
x = blk(x)
x = self.prelu_head_reduce(self.head_reduce(x))
x = self.head_block(x)
B = x.shape[0]
presence = self.head_presence(x).reshape(B)
lmks = self.head_landmarks(x).reshape(B, self.NUM_LANDMARKS, 3)
return lmks, presence
# FaceBlendshapes (MLP-Mixer "GhumMarkerPoserMlpMixerGeneral"):
# 146x2 → token-reduce 146→96 → embed 2→64 → +cls token → 4x mixer → cls→52.
_BS_NUM_INPUT_LANDMARKS = 146
_BS_NUM_TOKENS_REDUCED = 96
_BS_NUM_TOKENS = 97 # +1 cls
_BS_TOKEN_DIM = 64
_BS_TOKEN_MIX_HIDDEN = 384
_BS_CHANNEL_MIX_HIDDEN = 256
_BS_NUM_BLENDSHAPES = 52
_BS_LN_EPS = 1e-6
class MlpMixerBlock(nn.Module):
"""MLP-Mixer block: token-mixing MLP (over tokens) → channel-mixing MLP (over dim).
Both pre-LN, both residual. LN has no beta (bias=False) to match MP."""
def __init__(self, num_tokens: int, token_dim: int, token_hidden: int, channel_hidden: int,
device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
# bias=False → no LN beta (matches MP).
self.ln1 = ops.LayerNorm(token_dim, eps=_BS_LN_EPS, bias=False, **kw)
self.ln2 = ops.LayerNorm(token_dim, eps=_BS_LN_EPS, bias=False, **kw)
self.token_mlp1 = ops.Linear(num_tokens, token_hidden, bias=True, **kw)
self.token_mlp2 = ops.Linear(token_hidden, num_tokens, bias=True, **kw)
self.channel_mlp1 = ops.Linear(token_dim, channel_hidden, bias=True, **kw)
self.channel_mlp2 = ops.Linear(channel_hidden, token_dim, bias=True, **kw)
def forward(self, x: Tensor) -> Tensor:
y = self.ln1(x).transpose(1, 2)
x = x + self.token_mlp2(F.relu(self.token_mlp1(y))).transpose(1, 2)
return x + self.channel_mlp2(F.relu(self.channel_mlp1(self.ln2(x))))
class FaceBlendshapes(nn.Module):
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
self.token_reduce = ops.Linear(_BS_NUM_INPUT_LANDMARKS, _BS_NUM_TOKENS_REDUCED, bias=True, **kw)
self.token_embed = ops.Linear(2, _BS_TOKEN_DIM, bias=True, **kw)
self.cls_token = nn.Parameter(torch.zeros(1, 1, _BS_TOKEN_DIM, **kw))
self.blocks = nn.ModuleList(
MlpMixerBlock(_BS_NUM_TOKENS, _BS_TOKEN_DIM, _BS_TOKEN_MIX_HIDDEN, _BS_CHANNEL_MIX_HIDDEN,
device=device, dtype=dtype, operations=operations) for _ in range(4)
)
self.head = ops.Linear(_BS_TOKEN_DIM, _BS_NUM_BLENDSHAPES, bias=True, **kw)
@staticmethod
def _input_normalize(landmarks_2d: Tensor) -> Tensor:
# Centroid-subtract → L2 scale → x0.5. The 0.5 is baked into training.
centroid = landmarks_2d.mean(dim=1, keepdim=True)
x = landmarks_2d - centroid
mag = torch.sqrt((x * x).sum(dim=-1, keepdim=True))
scale = mag.mean(dim=1, keepdim=True)
return (x / scale.clamp(min=1e-12)) * 0.5
def forward(self, landmarks_2d: Tensor) -> Tensor:
"""(B, 146, 2) → (B, 52) in [0, 1]. Input units don't matter (centroid + L2 normalize)."""
x = self._input_normalize(landmarks_2d)
x = self.token_reduce(x.transpose(1, 2)).transpose(1, 2)
x = self.token_embed(x)
cls = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat([cls, x], dim=1)
for blk in self.blocks:
x = blk(x)
return torch.sigmoid(self.head(x[:, 0]))
@lru_cache(maxsize=1)
def _blazeface_anchors() -> np.ndarray:
"""896 anchors per SsdAnchorsCalculator (fixed_anchor_size → anchor_w=anchor_h=1)."""
per_ar = len(_BF_ASPECT_RATIOS) + (1 if _BF_INTERP_SCALE_AR > 0 else 0)
layer_anchors: List[np.ndarray] = []
layer = 0
while layer < _BF_NUM_LAYERS:
stride = _BF_STRIDES[layer]
last = layer
while last < _BF_NUM_LAYERS and _BF_STRIDES[last] == stride:
last += 1
per_cell = per_ar * (last - layer)
feat = (_BF_INPUT_SIZE + stride - 1) // stride
yy, xx = np.meshgrid(np.arange(feat, dtype=np.float32), np.arange(feat, dtype=np.float32), indexing="ij")
cx, cy, ones = (xx + _BF_ANCHOR_OFFSET_X) / feat, (yy + _BF_ANCHOR_OFFSET_Y) / feat, np.ones_like(xx)
cell = np.stack([cx, cy, ones, ones], axis=-1).reshape(-1, 4)
layer_anchors.append(np.repeat(cell, per_cell, axis=0))
layer = last
out = np.concatenate(layer_anchors, axis=0)
assert out.shape == (896, 4), out.shape
return out
def _decode_blazeface(regressors: np.ndarray, classificators: np.ndarray,
score_thresh: float = _BF_MIN_SCORE) -> np.ndarray:
"""Decode (regs (896,16), cls (896,1)) → (N, 17) = [xyxy, kp0x..kp5y, score] in [0, 1]."""
scores = expit(np.clip(classificators[:, 0], -_BF_SCORE_CLIP, _BF_SCORE_CLIP))
keep = scores >= score_thresh
if not keep.any():
return np.empty((0, 17), dtype=np.float32)
r = regressors[keep] / _BF_BOX_SCALE
a = _blazeface_anchors()[keep] # (N, 4) cx, cy, 1, 1
cxs, cys, aws, ahs = a[:, 0:1], a[:, 1:2], a[:, 2:3], a[:, 3:4]
xc, yc = r[:, 0:1] * aws + cxs, r[:, 1:2] * ahs + cys
w, h = r[:, 2:3] * aws, r[:, 3:4] * ahs
out = np.empty((r.shape[0], 17), dtype=np.float32)
out[:, 0:1], out[:, 1:2], out[:, 2:3], out[:, 3:4] = xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2
out[:, 4:16:2] = r[:, _BF_KP_OFFSET::2] * aws + cxs
out[:, 5:16:2] = r[:, _BF_KP_OFFSET + 1::2] * ahs + cys
out[:, 16] = scores[keep]
return out
def _weighted_nms(detections: np.ndarray, iou_thresh: float = 0.5) -> np.ndarray:
"""MP weighted NMS — kept boxes are score-weighted averages of overlapping detections."""
if detections.shape[0] == 0:
return detections
dets = detections[np.argsort(-detections[:, 16])]
N = dets.shape[0]
areas = np.clip(dets[:, 2] - dets[:, 0], 0, None) * np.clip(dets[:, 3] - dets[:, 1], 0, None)
kept: List[np.ndarray] = []
used = np.zeros(N, dtype=bool)
for i in range(N):
if used[i]:
continue
ax1, ay1, ax2, ay2 = dets[i, 0:4]
merge_idx = [i]
for j in range(i + 1, N):
if used[j]:
continue
bx1, by1, bx2, by2 = dets[j, 0:4]
iw = max(0.0, min(ax2, bx2) - max(ax1, bx1))
ih = max(0.0, min(ay2, by2) - max(ay1, by1))
inter = iw * ih
union = areas[i] + areas[j] - inter
if union > 0 and inter / union > iou_thresh: # strict > matches MP
merge_idx.append(j)
used[j] = True
used[i] = True
cluster = dets[merge_idx]
ws = cluster[:, 16:17]
ws_sum = ws.sum()
merged = np.copy(cluster[0])
if ws_sum > 0:
merged[:16] = (cluster[:, :16] * ws).sum(axis=0) / ws_sum
kept.append(merged)
return np.stack(kept, axis=0) if kept else np.empty((0, 17), dtype=np.float32)
def _detection_to_face_rect(detection: np.ndarray, image_w: int, image_h: int) -> Tuple[float, float, float, float, float]:
"""Detection (normalized) → rotated 1.5xbbox ROI in image pixels (anisotropic)."""
xmin, ymin, xmax, ymax = detection[0:4]
lx = detection[4 + _FACE_LEFT_EYE_KP * 2 + 0] * image_w
ly = detection[4 + _FACE_LEFT_EYE_KP * 2 + 1] * image_h
rx = detection[4 + _FACE_RIGHT_EYE_KP * 2 + 0] * image_w
ry = detection[4 + _FACE_RIGHT_EYE_KP * 2 + 1] * image_h
# Image-y-down convention: angle = target - atan2(-dy, dx).
angle = _FACE_ROI_TARGET_ANGLE - math.atan2(ly - ry, rx - lx)
return (float((xmin + xmax) * 0.5 * image_w),
float((ymin + ymax) * 0.5 * image_h),
float((xmax - xmin) * image_w * _FACE_ROI_SCALE_X),
float((ymax - ymin) * image_h * _FACE_ROI_SCALE_Y),
float(angle))
def _sample_warp(image_chw: Tensor, src_x: Tensor, src_y: Tensor, padding_mode: str) -> Tensor:
"""Bilinear-sample image_chw at corner-aligned (src_x, src_y)."""
H, W = int(image_chw.shape[-2]), int(image_chw.shape[-1])
grid = torch.stack([(2.0 * src_x + 1.0) / W - 1.0,
(2.0 * src_y + 1.0) / H - 1.0], dim=-1).unsqueeze(0)
return F.grid_sample(image_chw.unsqueeze(0), grid, mode="bilinear",
align_corners=False, padding_mode=padding_mode).squeeze(0)
def _warp_face_crop(image_chw: Tensor, cx: float, cy: float, width: float, height: float,
angle: float, output_size: int = _FM_INPUT_SIZE) -> Tensor:
"""Rotated rect → output_size² with BORDER_REPLICATE. image_chw must be in [0, 1]."""
s_x, s_y = width / output_size, height / output_size
cos_a, sin_a = math.cos(angle), math.sin(angle)
arange = torch.arange(output_size, dtype=image_chw.dtype, device=image_chw.device) - output_size * 0.5
v_grid, u_grid = torch.meshgrid(arange, arange, indexing="ij")
src_x = cx + u_grid * s_x * cos_a - v_grid * s_y * sin_a
src_y = cy + u_grid * s_x * sin_a + v_grid * s_y * cos_a
return _sample_warp(image_chw, src_x, src_y, "border")
def _blazeface_input_warp(image_chw_raw: Tensor, target: int = _BF_INPUT_SIZE) -> Tuple[Tensor, float, float, float]:
"""Centered max(W,H) square → target² with BORDER_ZERO + [-1, 1] norm.
Sub-pixel grid_sample matters; integer-pad-then-resize drifts the bbox ~5%.
Returns (warped, sub_rect_cx, sub_rect_cy, sub_rect_size) — the triplet maps
tensor-normalized [0,1] detections back to image pixels.
"""
H, W = int(image_chw_raw.shape[1]), int(image_chw_raw.shape[2])
sub_rect_size = float(max(W, H))
sub_rect_cx, sub_rect_cy = W * 0.5, H * 0.5
s = sub_rect_size / target
arange = torch.arange(target, dtype=image_chw_raw.dtype, device=image_chw_raw.device) - target * 0.5
v_grid, u_grid = torch.meshgrid(arange, arange, indexing="ij")
out = _sample_warp(image_chw_raw, sub_rect_cx + u_grid * s, sub_rect_cy + v_grid * s, "zeros")
return (out / 127.5) - 1.0, sub_rect_cx, sub_rect_cy, sub_rect_size
class FaceLandmarker(nn.Module):
"""BlazeFace → FaceMesh v2 → blendshapes. `detector_variant` selects 'short'
(128², ≤2m) or 'full' (192² FPN, ≤5m). State dict uses inner-module prefixes
`detector.*` / `mesh.*` / `blendshapes.*`; the outer FaceLandmarkerModel
wrapper rewrites `detector_{variant}.*` keys to `detector.*` before loading.
"""
def __init__(self, device=None, dtype=None, operations=None, detector_variant: str = "short"):
super().__init__()
det_cls = {"short": BlazeFace, "full": BlazeFaceFullRange}.get(detector_variant)
self.detector_variant = detector_variant
self.detector = det_cls(device=device, dtype=dtype, operations=operations)
self.mesh = FaceMesh(device=device, dtype=dtype, operations=operations)
self.blendshapes = FaceBlendshapes(device=device, dtype=dtype, operations=operations)
self.register_buffer("_bs_idx", torch.tensor(_BS_INPUT_INDICES, dtype=torch.long), persistent=False)
def run_detector_batch(self, images_rgb_uint8: List[np.ndarray],
score_thresh: float = _BF_MIN_SCORE,
iou_thresh: float = 0.5):
"""Batched detector pass. Returns (img_raws, sub_rects, sizes, per_frame_decoded)
where per_frame_decoded[b] is (N, 17) in tensor-normalized [0,1] coords."""
if not images_rgb_uint8:
return [], [], [], []
device, dtype = self.detector.stem.weight.device, self.detector.stem.weight.dtype
det_input_size, decode_fn = ((_BF_FR_INPUT_SIZE, _decode_blazeface_full_range)
if self.detector_variant == "full"
else (_BF_INPUT_SIZE, _decode_blazeface))
# Same-size frames: stack once and transfer once. Variable size falls back
# to per-image (only triggers for SAM3DBody's head crops).
sizes = [tuple(img.shape[:2]) for img in images_rgb_uint8]
if len(set(sizes)) == 1:
batch_chw = torch.from_numpy(np.stack(images_rgb_uint8, axis=0)).to(device, dtype).movedim(-1, -3).contiguous()
img_raws = [batch_chw[bi] for bi in range(batch_chw.shape[0])]
else:
img_raws = [torch.from_numpy(img).to(device, dtype).movedim(-1, -3).contiguous() for img in images_rgb_uint8]
warps = [_blazeface_input_warp(img_raw, det_input_size) for img_raw in img_raws]
det_crops = [w[0] for w in warps]
sub_rects = [(w[1], w[2], w[3]) for w in warps]
regs_b, cls_b = self.detector(torch.stack(det_crops, dim=0))
regs_np, cls_np = regs_b.float().cpu().numpy(), cls_b.float().cpu().numpy()
per_frame = []
for b in range(len(images_rgb_uint8)):
decoded = decode_fn(regs_np[b], cls_np[b], score_thresh=score_thresh)
per_frame.append(_weighted_nms(decoded, iou_thresh=iou_thresh) if decoded.shape[0] > 0 else decoded)
return img_raws, sub_rects, sizes, per_frame
def detect_batch(self, images_rgb_uint8: List[np.ndarray], num_faces: int = 1,
score_thresh: float = _BF_MIN_SCORE) -> List[List[dict]]:
"""Full pipeline batched across `images_rgb_uint8`. Returns one face-dict
list per image (empty if nothing detected). Face dict:
bbox_xyxy (4,) image pixels, blendshapes {52} ∈ [0,1],
landmarks_xy (478, 2) image pixels, landmarks_3d (478, 3) in
192-canonical (pre-transformation) units, presence float (raw logit).
"""
img_raws, sub_rects, sizes, per_frame_dets = self.run_detector_batch(
images_rgb_uint8, score_thresh=score_thresh,
)
# tensor-normalized → image-normalized [0,1] for _detection_to_face_rect.
for b, decoded in enumerate(per_frame_dets):
if decoded.shape[0] == 0:
continue
cx, cy, size = sub_rects[b]
H, W = sizes[b]
sx0, sy0 = cx - size * 0.5, cy - size * 0.5
decoded[:, 0:16:2] = (sx0 + size * decoded[:, 0:16:2]) / W
decoded[:, 1:16:2] = (sy0 + size * decoded[:, 1:16:2]) / H
if num_faces > 0:
per_frame_dets[b] = decoded[: int(num_faces)]
# Collect every detected face across all frames into one mesh input.
face_params: List[Tuple[int, float, float, float, float, float, float]] = []
mesh_crops: List[Tensor] = []
for b, dets in enumerate(per_frame_dets):
if dets.shape[0] == 0:
continue
H, W = sizes[b]
img_for_mesh = img_raws[b] / 255.0
for det in dets:
cx, cy, w, h, angle = _detection_to_face_rect(det, W, H)
mesh_crops.append(_warp_face_crop(img_for_mesh, cx, cy, w, h, angle, _FM_INPUT_SIZE))
face_params.append((b, float(det[16]), cx, cy, w, h, angle))
results: List[List[dict]] = [[] for _ in range(len(images_rgb_uint8))]
if not mesh_crops:
return results
lmks_canon_b, presence_b = self.mesh(torch.stack(mesh_crops, dim=0))
bs_out_b = self.blendshapes(lmks_canon_b[:, self._bs_idx, :2])
# Batched canonical→image affine
params_t = torch.tensor(
[(cx, cy, w, h, math.cos(a), math.sin(a)) for (_b, _s, cx, cy, w, h, a) in face_params],
device=lmks_canon_b.device, dtype=lmks_canon_b.dtype,
)
cxs, cys, ws, hs, cos_a, sin_a = params_t.unbind(dim=1)
inv = 1.0 / _FM_INPUT_SIZE
u = lmks_canon_b[..., 0] - _FM_INPUT_SIZE * 0.5
v = lmks_canon_b[..., 1] - _FM_INPUT_SIZE * 0.5
lmks_xy_t = torch.stack([
cxs[:, None] + u * (ws * inv * cos_a)[:, None] - v * (hs * inv * sin_a)[:, None],
cys[:, None] + u * (ws * inv * sin_a)[:, None] + v * (hs * inv * cos_a)[:, None],
], dim=-1)
lmks_xy_np = lmks_xy_t.float().cpu().numpy()
lmks_canon_np = lmks_canon_b.float().cpu().numpy()
presence_np = presence_b.float().cpu().numpy()
bs_np = bs_out_b.float().cpu().numpy()
for i, (b, score, *_) in enumerate(face_params):
lmks_xy = lmks_xy_np[i]
mn, mx = lmks_xy.min(0), lmks_xy.max(0)
results[b].append({
"bbox_xyxy": np.array([mn[0], mn[1], mx[0], mx[1]], dtype=np.float32),
"blendshapes": dict(zip(BLENDSHAPE_NAMES, bs_np[i].tolist())),
"landmarks_xy": lmks_xy,
"landmarks_3d": lmks_canon_np[i],
"presence": float(presence_np[i]),
"score": score,
})
return results

View File

@ -0,0 +1,502 @@
"""ComfyUI nodes for the pure-PyTorch MediaPipe Face Landmarker port.
Custom IO types:
FACE_LANDMARKER — FaceLandmarkerModel wrapper (ModelPatcher inside)
FACE_LANDMARKS — {"frames": List[List[face_dict]], "image_size": (H, W),
"connection_sets": dict[str, frozenset[(int, int)]]}
face_dict: bbox_xyxy, blendshapes, landmarks_xy,
landmarks_3d, presence, score, transformation_matrix
MediaPipeFaceLandmarker also emits the core BOUNDING_BOX type — pair with DrawBBoxes.
"""
from __future__ import annotations
import numpy as np
import torch
from PIL import Image, ImageColor, ImageDraw
from tqdm.auto import tqdm
from typing_extensions import override
import comfy.model_management
import comfy.model_patcher
import comfy.utils
import folder_paths
from comfy_api.latest import ComfyExtension, io
from comfy_extras.mediapipe.face_landmarker import FaceLandmarker
from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection
FaceLandmarkerType = io.Custom("FACE_LANDMARKER")
FaceLandmarksType = io.Custom("FACE_LANDMARKS")
_CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights")
_CONTOUR_PARTS = ("face_oval", "left_eye", "right_eye", "left_eyebrow", "right_eyebrow", "lips")
class FaceLandmarkerModel:
"""Loaded FaceLandmarker variants + ModelPatcher per variant.
Safetensors layout: `detector_short.*` / `detector_full.*` plus shared
`mesh.*`, `blendshapes.*`, `canonical_*`, and `topology.*`.
PReLU forces plain-nn / fp32 (manual_cast strands buffers across devices).
"""
def __init__(self, state_dict: dict):
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = torch.float32
# FACEMESH_* connection sets, embedded as int32 (N, 2) under topology.*.
base: dict[str, frozenset] = {}
for k in [k for k in state_dict if k.startswith("topology.")]:
base[k[len("topology."):]] = frozenset(map(tuple, state_dict.pop(k).tolist()))
base["contours"] = frozenset().union(*(base[p] for p in _CONTOUR_PARTS))
base["all"] = base["contours"] | base["irises"] | base["nose"]
self.connection_sets: dict[str, frozenset] = base
self.canonical_data: dict[str, np.ndarray] = {k: state_dict.pop(k).numpy() for k in _CANONICAL_KEYS}
shared = {k: v for k, v in state_dict.items() if k.startswith(("mesh.", "blendshapes."))}
self.models: dict[str, FaceLandmarker] = {}
self.patchers: dict[str, comfy.model_patcher.ModelPatcher] = {}
for variant in ("short", "full"):
prefix = f"detector_{variant}."
sub = dict(shared)
sub.update({f"detector.{k[len(prefix):]}": v for k, v in state_dict.items() if k.startswith(prefix)})
fl = FaceLandmarker(device=offload_device, dtype=self.dtype, operations=None, detector_variant=variant).eval()
fl.load_state_dict(sub, strict=False)
self.models[variant] = fl
self.patchers[variant] = comfy.model_patcher.CoreModelPatcher(
fl, load_device=self.load_device, offload_device=offload_device,
size=comfy.model_management.module_size(fl),
)
def detect_batch(self, images, num_faces: int, score_thresh: float, variant: str):
comfy.model_management.load_model_gpu(self.patchers[variant])
return self.models[variant].detect_batch(images, num_faces=num_faces, score_thresh=score_thresh)
def _image_to_uint8(image: torch.Tensor) -> np.ndarray:
return image[..., :3].mul(255.0).add_(0.5).clamp_(0, 255).to(torch.uint8).cpu().numpy()
def _parse_color(color: str) -> tuple[int, int, int]:
try:
return ImageColor.getrgb(color)[:3]
except ValueError:
return (0, 255, 0)
def _copy_face(face: dict) -> dict:
"""Shallow copy of a face_dict with array-fields cloned so callers can mutate."""
return {
"bbox_xyxy": face["bbox_xyxy"].copy(),
"blendshapes": dict(face["blendshapes"]),
"landmarks_xy": face["landmarks_xy"].copy(),
"landmarks_3d": face["landmarks_3d"].copy(),
"presence": face["presence"],
"score": face["score"],
}
def _lerp_face(a: dict, b: dict, t: float) -> dict:
return {
"bbox_xyxy": (1 - t) * a["bbox_xyxy"] + t * b["bbox_xyxy"],
"blendshapes": {k: (1 - t) * a["blendshapes"][k] + t * b["blendshapes"][k] for k in a["blendshapes"]},
"landmarks_xy": (1 - t) * a["landmarks_xy"] + t * b["landmarks_xy"],
"landmarks_3d": (1 - t) * a["landmarks_3d"] + t * b["landmarks_3d"],
"presence": (1 - t) * a["presence"] + t * b["presence"],
"score": (1 - t) * a["score"] + t * b["score"],
}
def _match_faces(a: list[dict], b: list[dict]) -> list[tuple[int, int]]:
"""Greedy nearest-neighbour pairing of faces between two frames by bbox
centre distance. Unmatched (when counts differ) are dropped."""
if not a or not b:
return []
centers_a = np.array([(0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]),
0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3])) for f in a])
centers_b = np.array([(0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]),
0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3])) for f in b])
dists = np.linalg.norm(centers_a[:, None] - centers_b[None], axis=-1)
pairs: list[tuple[int, int]] = []
used_a: set[int] = set()
used_b: set[int] = set()
candidates = sorted((dists[ia, ib], ia, ib) for ia in range(len(a)) for ib in range(len(b)))
for _, ia, ib in candidates:
if ia in used_a or ib in used_b:
continue
pairs.append((ia, ib))
used_a.add(ia)
used_b.add(ib)
return pairs
def _fill_missing_frames(frames: list[list[dict]], mode: str) -> None:
"""In-place fill empty frame slots from neighbouring detections. Multi-face
aware: pairs faces across bracketing frames by greedy bbox-centre NN.
When counts differ, unmatched faces are dropped from the synthesised frame."""
if mode == "empty":
return
valid = [i for i, fr in enumerate(frames) if fr]
if not valid:
return # nothing to fill from
if mode == "previous":
last: list[dict] = []
for i, fr in enumerate(frames):
if fr:
last = fr
elif last:
frames[i] = [_copy_face(f) for f in last]
return
# interpolate: lerp between bracketing valid frames; clamp at ends.
for i in range(len(frames)):
if frames[i]:
continue
prev_i = max((v for v in valid if v < i), default=None)
next_i = min((v for v in valid if v > i), default=None)
if prev_i is None:
frames[i] = [_copy_face(f) for f in frames[next_i]]
elif next_i is None:
frames[i] = [_copy_face(f) for f in frames[prev_i]]
else:
t = (i - prev_i) / (next_i - prev_i)
pairs = _match_faces(frames[prev_i], frames[next_i])
frames[i] = [_lerp_face(frames[prev_i][a], frames[next_i][b], t) for a, b in pairs]
def _ordered_rings(edges: frozenset[tuple[int, int]]) -> list[list[int]]:
"""Walk an unordered edge set into one or more closed-loop vertex rings
(handles multi-loop sets like FACEMESH_LIPS: outer + inner)."""
adj: dict[int, set[int]] = {}
for a, b in edges:
adj.setdefault(a, set()).add(b)
adj.setdefault(b, set()).add(a)
visited: set[int] = set()
rings: list[list[int]] = []
for start in adj:
if start in visited:
continue
ring = [start]
visited.add(start)
prev, cur = -1, start
while True:
nxt = next((v for v in adj[cur] if v != prev), None)
if nxt is None or nxt == start:
break
ring.append(nxt)
visited.add(nxt)
prev, cur = cur, nxt
rings.append(ring)
return rings
class LoadMediaPipeFaceLandmarker(io.ComfyNode):
"""Load MediaPipe Face Landmarker v2 weights. Contains both detector variants
(short / full), shared mesh, blendshapes, and canonical geometry."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadMediaPipeFaceLandmarker",
display_name="Load MediaPipe Face Landmarker",
category="loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("mediapipe"),
tooltip="Face Landmarker safetensors from models/mediapipe/."),
],
outputs=[FaceLandmarkerType.Output()],
)
@classmethod
def execute(cls, model_name) -> io.NodeOutput:
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("mediapipe", model_name), safe_load=True)
wrapper = FaceLandmarkerModel(sd)
return io.NodeOutput(wrapper)
# Per-frame fallback modes for detection failures in a batch.
_FALLBACK_MODES = ("empty", "previous", "interpolate")
class MediaPipeFaceLandmarker(io.ComfyNode):
"""BlazeFace → FaceMesh v2 → ARKit-52 blendshapes, batched across the
input. Also emits a BOUNDING_BOX list (landmark-extent bbox per face) —
pair with DrawBBoxes for detector-only viz or MediaPipeFaceMeshVisualize
for the mesh overlay."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MediaPipeFaceLandmarker",
display_name="MediaPipe Face Landmarker",
category="image/detection",
inputs=[
FaceLandmarkerType.Input("face_landmarker"),
io.Image.Input("image"),
io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short",
tooltip="Face detector range. 'short' is tuned for close-up faces "
"(within ~2 m of the camera); 'full' covers farther / smaller "
"faces (up to ~5 m) but is slower. 'both' runs both detectors and "
"keeps whichever found more faces per frame (~2× detection cost)."),
io.Int.Input("num_faces", default=1, min=0, max=16, step=1,
tooltip="Maximum faces to return per frame. 0 = no cap (return all detected)."),
io.Float.Input("min_confidence", default=0.5, min=0.0, max=1.0, step=0.01, advanced=True,
tooltip="BlazeFace score threshold. Lower to catch small/occluded faces."),
io.Combo.Input("missing_frame_fallback", options=list(_FALLBACK_MODES), default="empty", advanced=True,
tooltip="Per-frame behaviour when detection fails in a batch. "
"'empty' leaves the frame faceless. 'previous' copies the most recent successful "
"detection. 'interpolate' lerps landmarks/bbox/blendshapes between bracketing "
"successful frames. Multi-face: pairs faces across frames by greedy bbox-centre NN."),
],
outputs=[
FaceLandmarksType.Output(display_name="face_landmarks"),
io.BoundingBox.Output("bboxes"),
],
)
@classmethod
def execute(cls, face_landmarker, image, detector_variant, num_faces, min_confidence,
missing_frame_fallback) -> io.NodeOutput:
canonical = face_landmarker.canonical_data
img_np = _image_to_uint8(image)
B, H, W = img_np.shape[:3]
chunk = 16
is_both = detector_variant == "both"
total_work = 2 * B if is_both else B
pbar = comfy.utils.ProgressBar(total_work)
def _run(variant: str) -> list[list[dict]]:
res: list[list[dict]] = []
with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq:
for i in range(0, B, chunk):
end = min(i + chunk, B)
res.extend(face_landmarker.detect_batch(
[img_np[bi] for bi in range(i, end)],
num_faces=int(num_faces),
score_thresh=float(min_confidence),
variant=variant,
))
pbar.update_absolute(min(pbar.current + (end - i), total_work))
tq.update(end - i)
return res
if is_both:
short_res = _run("short")
full_res = _run("full")
# Per-frame keep whichever found more faces (tie → short).
frames: list[list[dict]] = [
short_res[bi] if len(short_res[bi]) >= len(full_res[bi]) else full_res[bi]
for bi in range(B)
]
else:
frames = _run(detector_variant)
_fill_missing_frames(frames, missing_frame_fallback)
bboxes = []
for per_frame in frames:
per_bb = []
for f in per_frame:
f["transformation_matrix"] = transformation_matrix_from_detection(f, W, H, canonical)
x1, y1, x2, y2 = (float(v) for v in f["bbox_xyxy"])
per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])})
bboxes.append(per_bb)
return io.NodeOutput({"frames": frames, "image_size": (H, W),
"connection_sets": face_landmarker.connection_sets}, bboxes)
# Topology keys unioned by the 'all' connections preset (contour parts + irises + nose).
_ALL_CONNECTION_PARTS: tuple[str, ...] = (*_CONTOUR_PARTS, "irises", "nose")
_CUSTOM_FEATURES: tuple[tuple[str, bool], ...] = (
("face_oval", True),
("lips", True),
("left_eye", True),
("right_eye", True),
("left_eyebrow", True),
("right_eyebrow", True),
("irises", True),
("nose", True),
("tesselation", False),
)
class MediaPipeFaceMeshVisualize(io.ComfyNode):
"""Draw a FACEMESH_* subset over an image. Topology travels with the
FACE_LANDMARKS payload (set at detection time)."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MediaPipeFaceMeshVisualize",
display_name="MediaPipe Face Mesh Visualize",
category="image/detection",
inputs=[
FaceLandmarksType.Input("face_landmarks"),
io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."),
io.DynamicCombo.Input(
"connections",
tooltip="'all' = oval+eyes+brows+lips+irises+nose. 'fill' = solid face_oval polygon (silhouette mask). 'custom' = toggle each feature individually (including 'tesselation', the full 2547-edge wireframe).",
options=[
io.DynamicCombo.Option("all", []),
io.DynamicCombo.Option("fill", []),
io.DynamicCombo.Option("custom", [
io.Boolean.Input(feat, default=default,
tooltip=f"Draw the '{feat}' connection set.")
for feat, default in _CUSTOM_FEATURES
]),
],
),
io.Color.Input("color", default="#00ff00"),
io.Int.Input("thickness", default=1, min=0, max=8, step=1,
tooltip="Edge line thickness in pixels. 0 disables edge drawing."),
io.Int.Input("point_size", default=2, min=0, max=16, step=1,
tooltip="Landmark dot radius in pixels. 0 disables point drawing."),
],
outputs=[io.Image.Output()],
)
@classmethod
def execute(cls, face_landmarks, connections, color, thickness, point_size, image=None) -> io.NodeOutput:
sets = face_landmarks["connection_sets"]
sel = connections["connections"]
fill_rings: list[list[int]] | None = None
if sel == "fill":
fill_rings = _ordered_rings(sets["face_oval"])
edges = frozenset()
elif sel == "custom":
parts = [feat for feat, _ in _CUSTOM_FEATURES if connections.get(feat, False)]
edges = frozenset().union(*(sets[p] for p in parts))
else: # "all"
edges = frozenset().union(*(sets[p] for p in _ALL_CONNECTION_PARTS))
rgb, thick, psize = _parse_color(color), int(thickness), int(point_size)
frames = face_landmarks["frames"]
if image is None:
H, W = face_landmarks["image_size"]
img_np = np.zeros((len(frames), H, W, 3), dtype=np.uint8)
else:
img_np = _image_to_uint8(image)
B = img_np.shape[0]
n_frames = len(frames)
pbar = comfy.utils.ProgressBar(B)
out = np.empty_like(img_np)
for bi in range(B):
faces = frames[bi] if bi < n_frames else []
out[bi] = _draw_mesh(img_np[bi], faces, edges, rgb, thick, psize, fill_rings)
pbar.update_absolute(bi + 1)
return io.NodeOutput(torch.from_numpy(out).to(
device=comfy.model_management.intermediate_device(),
dtype=comfy.model_management.intermediate_dtype(),
).div_(255.0))
def _draw_mesh(image_rgb: np.ndarray, faces: list, edges,
rgb: tuple[int, int, int], thickness: int,
point_size: int, fill_rings: list[list[int]] | None = None) -> np.ndarray:
draw_edges = thickness > 0 and edges
if not faces or (fill_rings is None and not draw_edges and point_size <= 0):
return image_rgb.copy()
pil = Image.fromarray(image_rgb)
draw = ImageDraw.Draw(pil)
r = point_size * 0.5
if fill_rings is not None:
for f in faces:
lmks = f["landmarks_xy"]
for ring in fill_rings:
draw.polygon([(float(lmks[i, 0]), float(lmks[i, 1])) for i in ring], fill=rgb)
return np.asarray(pil)
for f in faces:
lmks = f["landmarks_xy"]
n = lmks.shape[0]
if draw_edges:
for a, b in edges:
if a < n and b < n:
draw.line([(float(lmks[a, 0]), float(lmks[a, 1])),
(float(lmks[b, 0]), float(lmks[b, 1]))], fill=rgb, width=thickness)
if point_size == 1:
draw.point(lmks.flatten().tolist(), fill=rgb)
elif point_size > 1:
for x, y in lmks:
draw.ellipse((float(x) - r, float(y) - r, float(x) + r, float(y) + r), fill=rgb)
return np.asarray(pil)
# Mask region presets — closed-loop topologies only.
_MASK_REGIONS: tuple[str, ...] = ("face_oval", "lips", "left_eye", "right_eye", "irises")
_MASK_CUSTOM_FEATURES: tuple[tuple[str, bool], ...] = (
("face_oval", True),
("lips", False),
("left_eye", False),
("right_eye", False),
("irises", False),
)
class MediaPipeFaceMask(io.ComfyNode):
"""Binary mask from face landmarks, filled polygon per face. One mask per
frame in the batch; faces in the same frame composite (union)."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MediaPipeFaceMask",
display_name="MediaPipe Face Mask",
category="image/detection",
inputs=[
FaceLandmarksType.Input("face_landmarks"),
io.DynamicCombo.Input(
"regions",
tooltip="'all' = union of face_oval+lips+eyes+irises (which collapses to face_oval since it encloses the rest). 'custom' = toggle each region individually for combos like lips+eyes.",
options=[
io.DynamicCombo.Option("all", []),
io.DynamicCombo.Option("custom", [
io.Boolean.Input(reg, default=default,
tooltip=f"Include the '{reg}' region in the mask.")
for reg, default in _MASK_CUSTOM_FEATURES
]),
],
),
],
outputs=[io.Mask.Output()],
)
@classmethod
def execute(cls, face_landmarks, regions) -> io.NodeOutput:
sets = face_landmarks["connection_sets"]
sel = regions["regions"]
if sel == "custom":
picked = [reg for reg, _ in _MASK_CUSTOM_FEATURES if regions.get(reg, False)]
else:
picked = list(_MASK_REGIONS)
rings = [r for reg in picked for r in _ordered_rings(sets[reg])]
frames = face_landmarks["frames"]
H, W = face_landmarks["image_size"]
masks = np.zeros((len(frames), H, W), dtype=np.uint8)
pbar = comfy.utils.ProgressBar(len(frames))
for bi, per_frame in enumerate(frames):
if per_frame:
pil = Image.new("L", (W, H), 0)
draw = ImageDraw.Draw(pil)
for f in per_frame:
lmks = f["landmarks_xy"]
for ring in rings:
draw.polygon([(float(lmks[i, 0]), float(lmks[i, 1])) for i in ring], fill=255)
masks[bi] = np.asarray(pil)
pbar.update_absolute(bi + 1)
return io.NodeOutput(torch.from_numpy(masks).to(
device=comfy.model_management.intermediate_device(),
dtype=comfy.model_management.intermediate_dtype(),
).div_(255.0))
class MediaPipeFaceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [LoadMediaPipeFaceLandmarker, MediaPipeFaceLandmarker, MediaPipeFaceMeshVisualize, MediaPipeFaceMask]
async def comfy_entrypoint() -> MediaPipeFaceExtension:
return MediaPipeFaceExtension()

View File

@ -2,6 +2,7 @@ import copy
import heapq
import inspect
import logging
import psutil
import sys
import threading
import time
@ -727,6 +728,7 @@ class PromptExecutor:
self._notify_prompt_lifecycle("start", prompt_id)
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
ram_inactive_headroom = int(self.cache_args["ram_inactive"] * (1024 ** 3))
ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None
comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom)
@ -780,8 +782,14 @@ class PromptExecutor:
execution_list.complete_node_execution()
if self.cache_type == CacheType.RAM_PRESSURE:
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)
ram_release_callback(ram_headroom, free_active=True)
ram_release_callback(ram_inactive_headroom)
ram_shortfall = ram_headroom - psutil.virtual_memory().available
freed = comfy.model_management.free_pins(ram_shortfall + 512 * (1024 ** 2))
if freed < ram_shortfall:
if freed > 64 * (1024 ** 2):
# AIMDO MEM_DECOMMIT can outrun psutil.available catching up.
time.sleep(0.05)
ram_release_callback(ram_headroom, free_active=True)
else:
# Only execute when the while-loop ends without break
# Send cached UI for intermediate output nodes that weren't executed

View File

@ -60,6 +60,8 @@ folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geom
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
folder_names_and_paths["mediapipe"] = ([os.path.join(models_dir, "mediapipe")], supported_pt_extensions)
output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input")

20
main.py
View File

@ -283,19 +283,25 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]:
def prompt_worker(q, server_instance):
current_time: float = 0.0
cache_ram = args.cache_ram
if cache_ram < 0:
cache_ram = 0
cache_ram_inactive = 0
if not args.cache_classic and not args.cache_none and args.cache_lru <= 0:
cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0))
cache_ram_inactive = min(96.0, max(12.0, comfy.model_management.total_ram * 0.75 / 1024.0))
if len(args.cache_ram) > 0:
cache_ram = args.cache_ram[0]
if len(args.cache_ram) > 1:
cache_ram_inactive = args.cache_ram[1]
cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.RAM_PRESSURE
if args.cache_classic:
cache_type = execution.CacheType.CLASSIC
elif args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif cache_ram > 0:
cache_type = execution.CacheType.RAM_PRESSURE
elif args.cache_none:
cache_type = execution.CacheType.NONE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } )
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram, "ram_inactive" : cache_ram_inactive } )
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0

View File

@ -2444,6 +2444,7 @@ async def init_builtin_extra_nodes():
"nodes_hidream_o1.py",
"nodes_save_3d.py",
"nodes_moge.py",
"nodes_mediapipe.py",
]
import_failed = []

View File

@ -1517,22 +1517,6 @@ paths:
schema:
type: integer
default: 0
description: |
Offset-based pagination. Cursor pagination via `after` is preferred
for sequential walks (stable across concurrent inserts/deletes) but
`offset` remains fully supported for random access (jump-to-page
UIs, "showing items XY of N" displays). When both are supplied,
`after` wins and `offset` is ignored.
- name: after
in: query
schema:
type: string
description: |
Opaque cursor for keyset pagination. Pass the `next_cursor` value
from a previous response to fetch the next page. Stable across
inserts/deletes between pages. Supported with `sort` values
`created_at`, `updated_at`, `name`, and `size`. Malformed or
unsupported cursors return 400 with `INVALID_CURSOR`.
- name: include_tags
in: query
schema:
@ -1597,12 +1581,6 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/ListAssetsResponse"
"400":
description: Malformed query or cursor (e.g. `INVALID_CURSOR`)
content:
application/json:
schema:
$ref: "#/components/schemas/AssetsApiError"
post:
operationId: createAsset
tags: [assets]
@ -6501,34 +6479,6 @@ components:
type: integer
has_more:
type: boolean
next_cursor:
type: string
description: |
Opaque cursor to fetch the next page. Pass back as the `after`
query parameter. Omitted when there are no more results.
AssetsApiError:
type: object
description: Error envelope returned by the assets API on 400 responses.
required:
- error
properties:
error:
type: object
required:
- code
- message
properties:
code:
type: string
description: |
Machine-readable error code. `INVALID_CURSOR` is returned when the
`after` cursor is malformed, oversized, or its sort field does
not match the request's `sort`. `INVALID_QUERY` covers other
Pydantic validation failures.
enum: [INVALID_CURSOR, INVALID_QUERY]
message:
type: string
TagInfo:
type: object

View File

@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
filelock
av>=14.2.0
comfy-kitchen>=0.2.8
comfy-aimdo==0.3.0
comfy-aimdo==0.4.3
requests
simpleeval>=1.0.0
blake3

View File

@ -1,112 +0,0 @@
"""Keyset-pagination tiebreaker tests for list_references_page.
When multiple rows share the same primary sort value (e.g. four assets
created in the same microsecond), the secondary `ORDER BY id` is what keeps
keyset pagination from losing or repeating rows. This file exercises that
branch directly against an in-memory SQLite session — engineering identical
timestamps via HTTP is unreliable enough that we work at the query layer.
"""
import uuid
from datetime import datetime
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries.asset_reference import list_references_page
def _make_ref(session: Session, created_at: datetime, name: str, owner: str = "") -> AssetReference:
asset = Asset(hash=f"blake3:{uuid.uuid4().hex}", size_bytes=1024)
session.add(asset)
session.flush()
ref = AssetReference(
id=str(uuid.uuid4()),
asset_id=asset.id,
owner_id=owner,
name=name,
file_path=f"/tmp/{name}",
created_at=created_at,
updated_at=created_at,
last_access_time=created_at,
is_missing=False,
)
session.add(ref)
return ref
@pytest.mark.parametrize("order", ["desc", "asc"])
def test_tiebreaker_walks_duplicate_sort_values(session: Session, order: str):
"""Four rows with the SAME created_at must paginate cleanly under cursor
mode — no row dropped, no row repeated, despite the primary sort column
being non-discriminating.
"""
shared_ts = datetime(2024, 5, 20, 12, 0, 0) # naive UTC, like the DB stores
refs = [_make_ref(session, shared_ts, f"tie_{i}.png") for i in range(4)]
session.commit()
expected_ids = sorted([r.id for r in refs], reverse=(order == "desc"))
# Walk the cursor by hand: page size 2, take 3 pages (2 + 2 + 0).
seen: list[str] = []
after_value = None
after_id = None
for _ in range(4): # generous loop bound; ought to be 2 iterations
page, _tag_map, _total = list_references_page(
session,
limit=2,
sort="created_at",
order=order,
after_cursor_value=after_value,
after_cursor_id=after_id,
)
if not page:
break
seen.extend(p.id for p in page)
# Use the last row's (created_at, id) as the next cursor input.
last = page[-1]
after_value, after_id = last.created_at, last.id
if len(page) < 2:
break
assert seen == expected_ids, (
f"keyset tiebreaker failed for order={order}: expected {expected_ids}, got {seen}"
)
def test_tiebreaker_no_duplicates_under_mixed_collisions(session: Session):
"""Some rows share a timestamp, some don't. The cursor must still walk
every row exactly once regardless of where ties sit relative to a
page boundary."""
t1 = datetime(2024, 5, 20, 12, 0, 0)
t2 = datetime(2024, 5, 20, 12, 0, 1)
layout = [t1, t1, t1, t2, t2] # three rows at t1, two at t2
refs = [_make_ref(session, ts, f"mix_{i}.png") for i, ts in enumerate(layout)]
session.commit()
all_ids = {r.id for r in refs}
seen_set: set[str] = set()
seen_list: list[str] = []
after_value = None
after_id = None
for _ in range(6):
page, _, _ = list_references_page(
session,
limit=2,
sort="created_at",
order="desc",
after_cursor_value=after_value,
after_cursor_id=after_id,
)
if not page:
break
for p in page:
assert p.id not in seen_set, f"duplicate row {p.id} appeared in cursor walk"
seen_set.add(p.id)
seen_list.append(p.id)
last = page[-1]
after_value, after_id = last.created_at, last.id
if len(page) < 2:
break
assert seen_set == all_ids, f"missing rows: expected {all_ids}, got {seen_set}"

View File

@ -1,295 +0,0 @@
"""Tests for app.assets.services.cursor.
The wire format must stay byte-identical with the cloud Go implementation
(common/pagination/cursor.go in Comfy-Org/cloud) so the frontend sees one
contract across runtimes. The byte-identity fixture below mirrors the Go
test cases — any drift here means cloud and OSS minted different cursors
for the same triple, which would break FE pagination across backends.
"""
from __future__ import annotations
import base64
from datetime import datetime, timedelta, timezone
import pytest
from app.assets.services.cursor import (
MAX_CURSOR_ID_LENGTH,
MAX_CURSOR_VALUE_LENGTH,
MAX_ENCODED_CURSOR_LENGTH,
CursorPayload,
InvalidCursorError,
decode_cursor,
decode_cursor_int,
decode_cursor_time,
encode_cursor,
encode_cursor_from_time,
)
ALLOWED = ("created_at", "updated_at", "name", "size")
class TestRoundTrip:
@pytest.mark.parametrize(
"sort_field, value, id",
[
("created_at", "1716200000000000", "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7"),
("size", "1024", "asset-123"),
("name", "my-asset.png", "asset-abc"),
("name", "résumé.txt", "asset-uni"),
],
)
def test_encode_decode(self, sort_field, value, id):
encoded = encode_cursor(sort_field, value, id)
assert encoded != ""
payload = decode_cursor(encoded, ALLOWED)
assert payload.sort_field == sort_field
assert payload.value == value
assert payload.id == id
class TestTimeCursor:
def test_microsecond_precision_preserved(self):
# Pick a time with non-zero microseconds — encoding at ms would lose the µs.
ts = datetime(2024, 5, 20, 12, 53, 20, 123456, tzinfo=timezone.utc)
encoded = encode_cursor_from_time("created_at", ts, "id-1")
payload = decode_cursor(encoded, ALLOWED)
# Value must be a microsecond integer string, not a millisecond one.
assert payload.value == "1716209600123456"
decoded = decode_cursor_time(payload)
assert decoded == ts
def test_decode_returns_utc(self):
payload = CursorPayload(sort_field="created_at", value="1716200000123456", id="id-1")
decoded = decode_cursor_time(payload)
assert decoded.tzinfo == timezone.utc
def test_naive_datetime_rejected_on_encode(self):
naive = datetime(2024, 5, 20, 12, 0, 0)
with pytest.raises(ValueError):
encode_cursor_from_time("created_at", naive, "id-1")
def test_non_integer_value_rejected_on_decode(self):
with pytest.raises(InvalidCursorError):
decode_cursor_time(CursorPayload("created_at", "not-a-number", "id-1"))
def test_none_payload_rejected(self):
with pytest.raises(InvalidCursorError):
decode_cursor_time(None)
def test_non_utc_aware_normalized(self):
# Same instant, different timezone — must encode to the same micros.
utc_ts = datetime(2024, 5, 20, 12, 0, 0, tzinfo=timezone.utc)
offset_ts = utc_ts.astimezone(timezone(timedelta(hours=-5)))
assert encode_cursor_from_time("created_at", utc_ts, "x") == encode_cursor_from_time(
"created_at", offset_ts, "x"
)
class TestIntCursor:
def test_decode_int(self):
assert decode_cursor_int(CursorPayload("size", "1024", "id-1")) == 1024
def test_decode_int_rejects_non_int(self):
with pytest.raises(InvalidCursorError):
decode_cursor_int(CursorPayload("size", "abc", "id-1"))
def test_decode_int_rejects_none(self):
with pytest.raises(InvalidCursorError):
decode_cursor_int(None)
class TestInvalidInputs:
def test_oversized_cursor(self):
oversized = "a" * (MAX_ENCODED_CURSOR_LENGTH + 1)
with pytest.raises(InvalidCursorError, match="maximum length"):
decode_cursor(oversized, ALLOWED)
def test_not_base64(self):
with pytest.raises(InvalidCursorError):
decode_cursor("not base64!!!", ALLOWED)
def test_not_json(self):
encoded = base64.urlsafe_b64encode(b"definitely not json").rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError):
decode_cursor(encoded, ALLOWED)
def test_empty_id(self):
encoded = encode_cursor("created_at", "1", "")
with pytest.raises(InvalidCursorError, match="missing id"):
decode_cursor(encoded, ALLOWED)
def test_oversized_id(self):
encoded = encode_cursor("created_at", "1", "a" * (MAX_CURSOR_ID_LENGTH + 1))
with pytest.raises(InvalidCursorError, match="id exceeds maximum length"):
decode_cursor(encoded, ALLOWED)
def test_oversized_value(self):
encoded = encode_cursor("created_at", "v" * (MAX_CURSOR_VALUE_LENGTH + 1), "id-1")
with pytest.raises(InvalidCursorError, match="value exceeds maximum length"):
decode_cursor(encoded, ALLOWED)
def test_unsupported_sort_field(self):
encoded = encode_cursor("execution_time", "1", "id-1")
with pytest.raises(InvalidCursorError, match="unsupported sort field"):
decode_cursor(encoded, ALLOWED)
def test_no_allowed_fields_rejects_everything(self):
encoded = encode_cursor("created_at", "1", "id-1")
with pytest.raises(InvalidCursorError):
decode_cursor(encoded, ())
def test_non_dict_payload_rejected(self):
encoded = base64.urlsafe_b64encode(b'["array","not","dict"]').rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError, match="expected object"):
decode_cursor(encoded, ALLOWED)
class TestEncodeAtCapsFits:
def test_max_field_lengths_fit_wire_cap(self):
# Worst-case payload: value and id at their per-field caps, with a long
# sort field name. The encoded cursor must fit within MAX_ENCODED_CURSOR_LENGTH
# so the wire cap cannot reject a cursor the encoder mints at the per-field caps.
value = "v" * MAX_CURSOR_VALUE_LENGTH
id = "i" * MAX_CURSOR_ID_LENGTH
sort_field = "very_long_sort_field_name"
encoded = encode_cursor(sort_field, value, id)
assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH
payload = decode_cursor(encoded, (sort_field,))
assert payload.value == value
assert payload.id == id
class TestDatetimeOverflow:
"""Crafted cursors with extreme micros must map to InvalidCursorError,
not OverflowError/OSError leaking as 500.
"""
@pytest.mark.parametrize(
"micros_str",
[
"999999999999999999999", # 10^21 µs — past datetime.MAX_YEAR by ~14 orders
"-999999999999999999999", # symmetric negative — pre-epoch overflow
],
)
def test_out_of_range_micros_rejected(self, micros_str):
encoded = encode_cursor("created_at", micros_str, "asset-x")
payload = decode_cursor(encoded, ALLOWED)
with pytest.raises(InvalidCursorError):
decode_cursor_time(payload)
class TestEncoderDecoderSymmetry:
"""The encoder must reject inputs the decoder rejects, or the same server
will mint a cursor it then 400s on the next request.
"""
def test_long_name_within_cap_round_trips(self):
"""OSS assets allow names up to 512 chars (`String(512)`); cursor must
handle that. Cloud's lower cap is acceptable on its side because the
cloud schema doesn't permit names that long."""
long_name = "n" * MAX_CURSOR_VALUE_LENGTH
encoded = encode_cursor("name", long_name, "asset-x")
payload = decode_cursor(encoded, ALLOWED)
assert payload.value == long_name
class TestOrderBinding:
def test_order_baked_into_payload(self):
encoded = encode_cursor("created_at", "1", "id-1", order="asc")
payload = decode_cursor(encoded, ALLOWED)
assert payload.order == "asc"
def test_mismatched_order_rejected(self):
encoded = encode_cursor("created_at", "1", "id-1", order="desc")
with pytest.raises(InvalidCursorError, match="does not match request order"):
decode_cursor(encoded, ALLOWED, expected_order="asc")
def test_matching_order_accepted(self):
encoded = encode_cursor("created_at", "1", "id-1", order="desc")
payload = decode_cursor(encoded, ALLOWED, expected_order="desc")
assert payload.order == "desc"
def test_invalid_order_token_rejected_on_encode(self):
with pytest.raises(ValueError):
encode_cursor("created_at", "1", "id-1", order="sideways")
def test_invalid_order_token_rejected_on_decode(self):
# Hand-craft a payload with an illegal `o` value.
raw = b'{"s":"name","v":"x","id":"id-1","o":"sideways"}'
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError, match="unsupported order"):
decode_cursor(encoded, ALLOWED)
def test_legacy_cursor_without_order_accepted(self):
"""Cursors minted by a producer that didn't include `o` (e.g. an older
cloud build) must still decode — the order binding is best-effort
until cloud mirrors the field."""
raw = b'{"s":"name","v":"x","id":"id-1"}'
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
payload = decode_cursor(encoded, ALLOWED, expected_order="desc")
assert payload.order is None # binding skipped, decode succeeds
class TestGoCompatJsonEscaping:
"""An asset name containing `<`, `>`, `&`, U+2028, or U+2029 must encode
to the same bytes Go's default `json.Marshal` would produce. The fixtures
below are generated by Go's encoder; any drift here means cross-runtime
byte-identity for HTML-significant characters is broken.
"""
@pytest.mark.parametrize(
"value, escaped_substring",
[
("foo<bar>.png", "\\u003c"), # `<` escaped
("foo<bar>.png", "\\u003e"), # `>` escaped
("foo&bar.png", "\\u0026"),
("foobar.png", "\\u2028"), # JS line separator
("foobar.png", "\\u2029"), # JS paragraph separator
],
)
def test_html_significant_chars_escaped(self, value, escaped_substring):
encoded = encode_cursor("name", value, "id-1")
decoded_bytes = base64.urlsafe_b64decode(encoded + "=" * (-len(encoded) % 4))
assert escaped_substring in decoded_bytes.decode("ascii"), (
f"Expected {escaped_substring!r} in serialized payload, got: {decoded_bytes!r}"
)
def test_value_round_trips_through_escape(self):
"""Encoding then decoding a value with `<>&` should yield the original
string — the escape only affects the wire form, not the decoded value."""
original = "foo<&>bar.png"
encoded = encode_cursor("name", original, "id-1")
payload = decode_cursor(encoded, ALLOWED)
assert payload.value == original
class TestByteIdentityFixtures:
"""Pin the wire format so it doesn't drift silently.
NOTE — these fixtures will need updates on the cloud side once cloud
mirrors the `o` (order binding) field. Until then, Python cursors and
cloud cursors differ by exactly that key. The structural format (`s`,
`v`, `id` plus base64url + Go-compat escaping) remains aligned.
"""
@pytest.mark.parametrize(
"sort_field, value, id, order",
[
("created_at", "1716200000000000", "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7", "desc"),
("size", "1024", "asset-123", "asc"),
("name", "my-asset.png", "asset-abc", "desc"),
("name", "foo<bar>&baz.png", "asset-html", "desc"), # exercises Go-compat escaping
],
)
def test_encoded_payload_shape_pinned(self, sort_field, value, id, order):
encoded = encode_cursor(sort_field, value, id, order=order)
decoded_bytes = base64.urlsafe_b64decode(encoded + "=" * (-len(encoded) % 4))
decoded_str = decoded_bytes.decode("ascii")
# Keys appear in the documented order; binding `o` is present.
for needle in (f'"s":"{sort_field}"', f'"id":"{id}"', f'"o":"{order}"'):
assert needle in decoded_str, (
f"Expected {needle!r} in payload {decoded_str!r}"
)

View File

@ -1,332 +0,0 @@
"""Integration tests for cursor-based pagination on GET /api/assets.
Wire contract is shared with cloud's Go implementation (BE-893). These tests
exercise the handler/service/query path end-to-end; cursor-encoding-level
tests live in tests-unit/assets_test/services/test_cursor.py.
"""
import time
import pytest
import requests
def _seed(asset_factory, make_asset_bytes, count: int, tag: str) -> list[str]:
names = [f"cursor_{i:02d}.safetensors" for i in range(count)]
for n in names:
asset_factory(
n,
["models", "checkpoints", "unit-tests", tag],
{},
make_asset_bytes(n, size=2048),
)
return sorted(names)
def test_cursor_pages_all_items_in_order(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
names = _seed(asset_factory, make_asset_bytes, count=5, tag="cursor-walk")
params = {
"include_tags": "unit-tests,cursor-walk",
"sort": "name",
"order": "asc",
"limit": "2",
}
seen: list[str] = []
after: str | None = None
pages = 0
while True:
page_params = dict(params)
if after is not None:
page_params["after"] = after
r = http.get(api_base + "/api/assets", params=page_params, timeout=120)
assert r.status_code == 200, r.text
body = r.json()
seen.extend(a["name"] for a in body["assets"])
pages += 1
after = body.get("next_cursor")
if after is None:
break
assert body["has_more"] is True
assert pages < 10, "guard against runaway cursor loop"
assert seen == names, f"expected {names}, got {seen}"
# Last page should have has_more False
assert body["has_more"] is False
assert "next_cursor" not in body
def test_cursor_invalid_returns_400(http: requests.Session, api_base: str):
r = http.get(
api_base + "/api/assets",
params={"after": "not-a-real-cursor", "sort": "created_at"},
timeout=120,
)
assert r.status_code == 400, r.text
body = r.json()
assert body["error"]["code"] == "INVALID_CURSOR"
def test_cursor_sort_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
_seed(asset_factory, make_asset_bytes, count=2, tag="cursor-mismatch")
# Take a real cursor minted for sort=name.
r = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-mismatch",
"sort": "name",
"order": "asc",
"limit": "1",
},
timeout=120,
)
assert r.status_code == 200
cursor = r.json()["next_cursor"]
assert cursor is not None
# Replay against sort=created_at — should fail with INVALID_CURSOR.
r2 = http.get(
api_base + "/api/assets",
params={"after": cursor, "sort": "created_at"},
timeout=120,
)
assert r2.status_code == 400, r2.text
assert r2.json()["error"]["code"] == "INVALID_CURSOR"
def test_cursor_wins_over_offset(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-vs-offset")
# Take a cursor that points past the first item.
r = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-vs-offset",
"sort": "name",
"order": "asc",
"limit": "1",
},
timeout=120,
)
assert r.status_code == 200, r.text
cursor = r.json()["next_cursor"]
assert cursor is not None
# Pass both 'after' and a large offset. Cursor must win; offset is ignored.
r2 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-vs-offset",
"sort": "name",
"order": "asc",
"limit": "1",
"after": cursor,
"offset": "999",
},
timeout=120,
)
assert r2.status_code == 200
body = r2.json()
# Should land on the second name in sorted order — not skip ahead by 999.
assert [a["name"] for a in body["assets"]] == [names[1]]
def test_next_cursor_absent_when_no_more_results(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
_seed(asset_factory, make_asset_bytes, count=2, tag="cursor-exhaust")
r = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-exhaust",
"sort": "name",
"order": "asc",
"limit": "50",
},
timeout=120,
)
assert r.status_code == 200, r.text
body = r.json()
assert body["has_more"] is False
assert "next_cursor" not in body
def test_cursor_pagination_first_page_mints_cursor(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
"""First-page request (no `after`) must still return `next_cursor` when
more rows exist, or pagination is unreachable from a cold start.
"""
_seed(asset_factory, make_asset_bytes, count=3, tag="cursor-first-page")
r = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,cursor-first-page", "sort": "name", "order": "asc", "limit": "2"},
timeout=120,
)
assert r.status_code == 200, r.text
body = r.json()
assert body["has_more"] is True
assert body.get("next_cursor"), "first page must mint a cursor when more rows exist"
def test_cursor_no_spurious_cursor_when_page_size_equals_remainder(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
"""When `total` is an exact multiple of `limit`, the final page must
NOT carry a next_cursor — there is nothing past it.
"""
_seed(asset_factory, make_asset_bytes, count=4, tag="cursor-exact-multiple")
# Page 1
r = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2"},
timeout=120,
)
assert r.status_code == 200, r.text
cursor = r.json()["next_cursor"]
assert cursor is not None
# Page 2 — should exhaust the set with no cursor for a phantom page 3
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2", "after": cursor},
timeout=120,
)
assert r2.status_code == 200, r2.text
body = r2.json()
assert len(body["assets"]) == 2
assert body["has_more"] is False
assert "next_cursor" not in body
@pytest.mark.parametrize("sort_field", ["created_at", "updated_at", "size"])
def test_cursor_walks_for_non_name_sorts(sort_field, http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
"""Cursor pagination must work for every sort field the contract claims.
Without this, the `created_at` / `updated_at` (time-encoded micros) and
`size` (int-encoded) cursor paths go entirely unexercised end-to-end.
"""
# Stagger create_time slightly so created_at / updated_at sort is well-defined.
names = []
for i in range(4):
n = f"cursor_{sort_field}_{i:02d}.safetensors"
asset_factory(n, ["models", "checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i))
time.sleep(0.05) # ensure distinct timestamps
names.append(n)
params = {
"include_tags": f"unit-tests,cursor-{sort_field}",
"sort": sort_field,
"order": "desc",
"limit": "2",
}
seen: list[str] = []
after: str | None = None
pages = 0
while True:
page_params = dict(params)
if after is not None:
page_params["after"] = after
r = http.get(api_base + "/api/assets", params=page_params, timeout=120)
assert r.status_code == 200, r.text
body = r.json()
seen.extend(a["name"] for a in body["assets"])
after = body.get("next_cursor")
pages += 1
if after is None:
break
assert pages < 10, "guard against runaway cursor loop"
assert set(seen) == set(names), f"missing items for sort={sort_field}: expected {set(names)}, got {set(seen)}"
def test_cursor_order_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
"""A cursor minted under desc order replayed against asc must 400, not
silently walk the wrong direction."""
_seed(asset_factory, make_asset_bytes, count=3, tag="cursor-order-flip")
r = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-order-flip",
"sort": "name",
"order": "desc",
"limit": "1",
},
timeout=120,
)
assert r.status_code == 200, r.text
cursor = r.json()["next_cursor"]
assert cursor is not None
# Replay with order flipped to asc — server must reject the cursor.
r2 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-order-flip",
"sort": "name",
"order": "asc",
"limit": "1",
"after": cursor,
},
timeout=120,
)
assert r2.status_code == 400, r2.text
assert r2.json()["error"]["code"] == "INVALID_CURSOR"
def test_cursor_invalid_cursor_at_microsecond_boundary(http: requests.Session, api_base: str):
"""A cursor carrying an out-of-range microsecond timestamp must map to
400 INVALID_CURSOR, not 500."""
import base64
import json
# 10^18 microseconds ≈ year 33658, well past datetime.MAX_YEAR.
payload = {"s": "created_at", "v": "999999999999999999999", "id": "asset-x"}
raw = json.dumps(payload, separators=(",", ":")).encode("utf-8")
cursor = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
r = http.get(
api_base + "/api/assets",
params={"after": cursor, "sort": "created_at"},
timeout=120,
)
assert r.status_code == 400, r.text
assert r.json()["error"]["code"] == "INVALID_CURSOR"
def test_cursor_pagination_stable_after_delete(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-delete")
# Page 1.
r = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-delete",
"sort": "name",
"order": "asc",
"limit": "2",
},
timeout=120,
)
assert r.status_code == 200
body = r.json()
page1_names = [a["name"] for a in body["assets"]]
cursor = body["next_cursor"]
assert cursor is not None
assert page1_names == names[:2]
# Delete an item from page 1 (already returned) — cursor should still
# locate the next page from where it was minted, not re-index.
target_id = body["assets"][0]["id"]
d = http.delete(api_base + f"/api/assets/{target_id}", timeout=120)
assert d.status_code in (200, 204), d.text
# Page 2 via cursor.
r2 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-delete",
"sort": "name",
"order": "asc",
"limit": "2",
"after": cursor,
},
timeout=120,
)
assert r2.status_code == 200, r2.text
body2 = r2.json()
assert [a["name"] for a in body2["assets"]] == names[2:]

View File

@ -14,7 +14,6 @@ from tests.execution.test_execution import ComfyClient, run_warmup
class TestAsyncNodes:
@fixture(scope="class", autouse=True, params=[
(False, 0),
(True, 0),
(True, 100),
])
def _server(self, args_pytest, request):
@ -29,6 +28,8 @@ class TestAsyncNodes:
use_lru, lru_size = request.param
if use_lru:
pargs += ['--cache-lru', str(lru_size)]
else:
pargs += ['--cache-classic']
# Running server with args: pargs
p = subprocess.Popen(pargs)
yield

View File

@ -183,8 +183,7 @@ class TestExecution:
# Initialize server and client
#
@fixture(scope="class", autouse=True, params=[
{ "extra_args" : [], "should_cache_results" : True },
{ "extra_args" : ["--cache-lru", 0], "should_cache_results" : True },
{ "extra_args" : ["--cache-classic"], "should_cache_results" : True },
{ "extra_args" : ["--cache-lru", 100], "should_cache_results" : True },
{ "extra_args" : ["--cache-none"], "should_cache_results" : False },
])