mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-25 02:26:58 +08:00
Compare commits
3 Commits
matt/be-94
...
temp_pr
| Author | SHA1 | Date | |
|---|---|---|---|
| e6e75152e0 | |||
| e715be9105 | |||
| d48a8d417b |
@ -39,7 +39,6 @@ from app.assets.services import (
|
||||
update_asset_metadata,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.services.cursor import InvalidCursorError
|
||||
from app.assets.services.tagging import list_tag_histogram
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
@ -173,7 +172,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
metadata=result.ref.system_metadata,
|
||||
job_id=result.ref.job_id,
|
||||
prompt_id=result.ref.job_id, # deprecated alias of job_id, kept for compatibility
|
||||
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
|
||||
created_at=result.ref.created_at,
|
||||
updated_at=result.ref.updated_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
@ -210,37 +209,24 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
||||
order_candidate = (q.order or "desc").lower()
|
||||
order = order_candidate if order_candidate in {"asc", "desc"} else "desc"
|
||||
|
||||
try:
|
||||
result = list_assets_page(
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
after=q.after,
|
||||
)
|
||||
except InvalidCursorError as e:
|
||||
return _build_error_response(400, "INVALID_CURSOR", str(e))
|
||||
result = list_assets_page(
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
summaries = [_build_asset_response(item) for item in result.items]
|
||||
|
||||
# has_more semantics differ by mode:
|
||||
# - cursor mode: a non-empty next_cursor means there are more results.
|
||||
# - offset mode: derived from total - (offset + page size).
|
||||
if q.after is not None:
|
||||
has_more = result.next_cursor is not None
|
||||
else:
|
||||
has_more = (q.offset + len(summaries)) < result.total
|
||||
|
||||
payload = schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=result.total,
|
||||
has_more=has_more,
|
||||
next_cursor=result.next_cursor,
|
||||
has_more=(q.offset + len(summaries)) < result.total,
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
|
||||
@ -59,11 +59,6 @@ class ListAssetsQuery(BaseModel):
|
||||
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
# Opaque keyset cursor. When supplied, `offset` is ignored. Cursor pagination
|
||||
# is supported for sort values `created_at`, `updated_at`, `name`, `size`.
|
||||
# Supplying `after` together with `sort=last_access_time` returns
|
||||
# 400 INVALID_CURSOR; that sort only supports offset/limit.
|
||||
after: str | None = None
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
|
||||
"created_at"
|
||||
|
||||
@ -40,8 +40,6 @@ class AssetsList(BaseModel):
|
||||
assets: list[Asset]
|
||||
total: int
|
||||
has_more: bool
|
||||
# Opaque cursor for the next page. Omitted when there are no more results.
|
||||
next_cursor: str | None = None
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
|
||||
@ -266,18 +266,9 @@ def list_references_page(
|
||||
metadata_filter: dict | None = None,
|
||||
sort: str | None = None,
|
||||
order: str | None = None,
|
||||
after_cursor_value: object | None = None,
|
||||
after_cursor_id: str | None = None,
|
||||
) -> tuple[list[AssetReference], dict[str, list[str]], int]:
|
||||
"""List references with pagination, filtering, and sorting.
|
||||
|
||||
When ``after_cursor_value``/``after_cursor_id`` are supplied the query uses
|
||||
keyset pagination — ``offset`` is ignored and a WHERE clause selects rows
|
||||
strictly after the given ``(sort_col, id)`` position in the active sort
|
||||
direction. The cursor value must already be typed for the column
|
||||
(datetime for time sorts, int for size, str for name); the caller decodes
|
||||
the opaque cursor string and resolves to the typed value.
|
||||
|
||||
Returns (references, tag_map, total_count).
|
||||
"""
|
||||
base = (
|
||||
@ -306,31 +297,9 @@ def list_references_page(
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetReference.created_at)
|
||||
descending = order == "desc"
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
# Keyset WHERE: (sort_col, id) strictly less-than / greater-than the cursor.
|
||||
# Equivalent to: sort_col <op> v OR (sort_col = v AND id <op> cursor_id).
|
||||
if after_cursor_value is not None and after_cursor_id is not None:
|
||||
if descending:
|
||||
keyset = sa.or_(
|
||||
sort_col < after_cursor_value,
|
||||
sa.and_(sort_col == after_cursor_value, AssetReference.id < after_cursor_id),
|
||||
)
|
||||
else:
|
||||
keyset = sa.or_(
|
||||
sort_col > after_cursor_value,
|
||||
sa.and_(sort_col == after_cursor_value, AssetReference.id > after_cursor_id),
|
||||
)
|
||||
base = base.where(keyset)
|
||||
|
||||
# Secondary ORDER BY id (matching the primary direction) gives the keyset
|
||||
# comparison a deterministic tiebreaker on duplicate sort_col values.
|
||||
id_exp = AssetReference.id.desc() if descending else AssetReference.id.asc()
|
||||
sort_exp = sort_col.desc() if descending else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp, id_exp).limit(limit)
|
||||
if after_cursor_id is None:
|
||||
base = base.offset(offset)
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(sa.func.count())
|
||||
|
||||
@ -1,19 +1,8 @@
|
||||
import contextlib
|
||||
import mimetypes
|
||||
import os
|
||||
from datetime import timezone
|
||||
from typing import Sequence
|
||||
|
||||
from app.assets.services.cursor import (
|
||||
CursorPayload,
|
||||
InvalidCursorError,
|
||||
decode_cursor,
|
||||
decode_cursor_int,
|
||||
decode_cursor_time,
|
||||
encode_cursor,
|
||||
encode_cursor_from_time,
|
||||
)
|
||||
|
||||
|
||||
from app.assets.database.models import Asset
|
||||
from app.assets.database.queries import (
|
||||
@ -253,11 +242,6 @@ def get_asset_by_hash(asset_hash: str) -> AssetData | None:
|
||||
return extract_asset_data(asset)
|
||||
|
||||
|
||||
# Sort fields that support cursor pagination. `last_access_time` is not
|
||||
# in this list — it falls back to offset/limit.
|
||||
_CURSOR_SORT_FIELDS = ("created_at", "updated_at", "name", "size")
|
||||
|
||||
|
||||
def list_assets_page(
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
@ -268,39 +252,7 @@ def list_assets_page(
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
after: str | None = None,
|
||||
) -> ListAssetsResult:
|
||||
"""List assets with optional cursor pagination.
|
||||
|
||||
When ``after`` is supplied it overrides ``offset``. The cursor's sort field
|
||||
must match ``sort`` and be in the cursor-supported allowlist; mismatches
|
||||
raise InvalidCursorError so the handler can map to 400 INVALID_CURSOR.
|
||||
"""
|
||||
cursor_value: object | None = None
|
||||
cursor_id: str | None = None
|
||||
# Mint next_cursor on every page where the sort is cursor-supported, not
|
||||
# only when the request itself arrived with a cursor. Otherwise a first
|
||||
# request (no `after`) returns next_cursor=None and the client can never
|
||||
# enter cursor mode.
|
||||
mint_cursor = sort in _CURSOR_SORT_FIELDS
|
||||
|
||||
if after is not None:
|
||||
if sort not in _CURSOR_SORT_FIELDS:
|
||||
raise InvalidCursorError(
|
||||
f"cursor pagination is not supported for sort={sort!r}"
|
||||
)
|
||||
payload = decode_cursor(after, _CURSOR_SORT_FIELDS, expected_order=order)
|
||||
if payload.sort_field != sort:
|
||||
raise InvalidCursorError(
|
||||
f"cursor sort field {payload.sort_field!r} does not match request sort {sort!r}"
|
||||
)
|
||||
cursor_value, cursor_id = _resolve_cursor_value(payload), payload.id
|
||||
|
||||
# Over-fetch by one row so we can distinguish "exactly `limit` rows total
|
||||
# remaining" from "more rows past this page" without a second query. Drop
|
||||
# the sentinel before returning.
|
||||
fetch_limit = limit + 1 if mint_cursor else limit
|
||||
|
||||
with create_session() as session:
|
||||
refs, tag_map, total = list_references_page(
|
||||
session,
|
||||
@ -309,22 +261,12 @@ def list_assets_page(
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=fetch_limit,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
after_cursor_value=cursor_value,
|
||||
after_cursor_id=cursor_id,
|
||||
)
|
||||
|
||||
next_cursor: str | None = None
|
||||
if mint_cursor and len(refs) > limit:
|
||||
# There's at least one more row past this page — mint a cursor from
|
||||
# the last row of the page (i.e. index `limit - 1`, since we
|
||||
# over-fetched), and drop the sentinel.
|
||||
next_cursor = _encode_next_cursor(refs[limit - 1], sort, order)
|
||||
refs = refs[:limit]
|
||||
|
||||
items: list[AssetSummaryData] = []
|
||||
for ref in refs:
|
||||
items.append(
|
||||
@ -335,39 +277,7 @@ def list_assets_page(
|
||||
)
|
||||
)
|
||||
|
||||
return ListAssetsResult(items=items, total=total, next_cursor=next_cursor)
|
||||
|
||||
|
||||
def _resolve_cursor_value(payload: CursorPayload) -> object:
|
||||
"""Map a decoded cursor payload to a column-typed Python value."""
|
||||
if payload.sort_field in ("created_at", "updated_at"):
|
||||
# DB stores naive UTC; strip tzinfo so the comparison binds against a
|
||||
# `TIMESTAMP WITHOUT TIME ZONE` column without an offset shift.
|
||||
return decode_cursor_time(payload).replace(tzinfo=None)
|
||||
if payload.sort_field == "size":
|
||||
return decode_cursor_int(payload)
|
||||
return payload.value # name, str-typed
|
||||
|
||||
|
||||
def _encode_next_cursor(ref, sort: str, order: str) -> str | None:
|
||||
"""Mint a cursor pointing at *ref* for the given sort dimension.
|
||||
|
||||
Returns None when the boundary row carries a NULL sort value (e.g. an asset
|
||||
record whose size_bytes hasn't been backfilled). Continuing pagination
|
||||
across a NULL boundary is undefined under keyset ordering — better to
|
||||
truncate cleanly here than to mint a cursor that mis-positions.
|
||||
"""
|
||||
if sort == "name":
|
||||
return encode_cursor("name", ref.name, ref.id, order=order)
|
||||
if sort == "size":
|
||||
if ref.asset is None or ref.asset.size_bytes is None:
|
||||
return None
|
||||
return encode_cursor("size", str(ref.asset.size_bytes), ref.id, order=order)
|
||||
# created_at / updated_at — DB datetimes are naive UTC; attach tz before encoding.
|
||||
value = ref.created_at if sort == "created_at" else ref.updated_at
|
||||
if value is None:
|
||||
return None
|
||||
return encode_cursor_from_time(sort, value.replace(tzinfo=timezone.utc), ref.id, order=order)
|
||||
return ListAssetsResult(items=items, total=total)
|
||||
|
||||
|
||||
def resolve_hash_to_path(
|
||||
|
||||
@ -1,225 +0,0 @@
|
||||
"""Opaque keyset-pagination cursor for /api/assets.
|
||||
|
||||
Payload JSON uses short keys to keep the encoded length small:
|
||||
|
||||
{"s": <sort_field>, "v": <value>, "id": <id>, "o": <order>}
|
||||
|
||||
The `o` key binds the cursor to the sort direction it was minted under,
|
||||
so replaying a `desc` cursor against an `asc` request fails with
|
||||
``INVALID_CURSOR`` rather than silently walking the wrong direction.
|
||||
`o` is mandatory on every payload — a cursor without it is rejected as
|
||||
malformed.
|
||||
|
||||
Encoding is base64url with no padding. JSON serialization escapes `<`,
|
||||
`>`, `&`, U+2028, and U+2029 in encoded string values so asset names
|
||||
containing those characters produce a stable, byte-identical wire form
|
||||
across any compatible implementation of the same payload format.
|
||||
|
||||
Time values are serialized as Unix microseconds (UTC) — microsecond
|
||||
precision is sufficient to round-trip the timestamps stored by the
|
||||
database without rounding rows in the same millisecond bucket.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterable, Optional
|
||||
|
||||
|
||||
class InvalidCursorError(ValueError):
|
||||
"""Raised on a malformed, oversized, or unsupported-sort-field cursor.
|
||||
|
||||
Map to a 400 response with code ``INVALID_CURSOR`` at the handler.
|
||||
"""
|
||||
|
||||
|
||||
# Wire-format length caps. Cursors are user-controlled, so caps protect the
|
||||
# decode path from oversized allocations and downstream SQL predicates from
|
||||
# unbounded strings.
|
||||
#
|
||||
# MAX_CURSOR_VALUE_LENGTH is 512 to fit the `AssetReference.name` column max
|
||||
# (`String(512)`) — otherwise a long-named asset would mint a cursor the same
|
||||
# server then refuses on the next request.
|
||||
MAX_ENCODED_CURSOR_LENGTH = 1024
|
||||
MAX_CURSOR_VALUE_LENGTH = 512
|
||||
MAX_CURSOR_ID_LENGTH = 128
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CursorPayload:
|
||||
sort_field: str
|
||||
value: str
|
||||
id: str
|
||||
order: str
|
||||
|
||||
|
||||
_VALID_ORDERS = ("asc", "desc")
|
||||
|
||||
|
||||
def encode_cursor(sort_field: str, value: str, id: str, order: str = "desc") -> str:
|
||||
"""Encode a cursor payload as a base64url (no-padding) string.
|
||||
|
||||
`order` binds the cursor to the sort direction it was minted under so a
|
||||
later request with a flipped `order` query parameter is rejected with
|
||||
``INVALID_CURSOR`` rather than silently walking the wrong direction.
|
||||
"""
|
||||
if order not in _VALID_ORDERS:
|
||||
raise InvalidCursorError(f"order must be one of {_VALID_ORDERS}, got {order!r}")
|
||||
# Symmetric input validation: the encoder must reject anything the
|
||||
# decoder rejects, or the same server will mint cursors it then 400s on
|
||||
# the next request.
|
||||
if not id:
|
||||
raise InvalidCursorError("id must be non-empty")
|
||||
if len(id) > MAX_CURSOR_ID_LENGTH:
|
||||
raise InvalidCursorError("id exceeds maximum length")
|
||||
if len(value) > MAX_CURSOR_VALUE_LENGTH:
|
||||
raise InvalidCursorError("value exceeds maximum length")
|
||||
payload = {"s": sort_field, "v": value, "id": id, "o": order}
|
||||
raw = json.dumps(payload, separators=(",", ":"), ensure_ascii=False)
|
||||
# Match the default JSON escaping of HTML-significant characters and JS
|
||||
# line/paragraph separators (U+2028 / U+2029) so an asset name carrying
|
||||
# any of them encodes to identical bytes across runtimes. None of these
|
||||
# characters appear in JSON structural syntax, so a global replace on the
|
||||
# serialized output can only touch encoded values. Use explicit \uXXXX
|
||||
# escapes for U+2028 / U+2029 so the source survives any editor / git
|
||||
# tooling that normalizes invisible separators.
|
||||
raw = (
|
||||
raw.replace("<", "\\u003c")
|
||||
.replace(">", "\\u003e")
|
||||
.replace("&", "\\u0026")
|
||||
.replace("\u2028", "\\u2028")
|
||||
.replace("\u2029", "\\u2029")
|
||||
)
|
||||
encoded = base64.urlsafe_b64encode(raw.encode("utf-8")).rstrip(b"=").decode("ascii")
|
||||
# Final wire-size guard: the per-field caps above are char-counted, but the
|
||||
# wire cap applies to the base64url of the UTF-8-encoded, escape-expanded
|
||||
# payload. A value full of multibyte or HTML-significant characters (e.g.
|
||||
# 512 \u00d7 "\u00e9" or 512 \u00d7 "<") inflates well past MAX_ENCODED_CURSOR_LENGTH even
|
||||
# though it passes the char-count check. Refuse to mint a cursor the decoder
|
||||
# on the next request would reject.
|
||||
if len(encoded) > MAX_ENCODED_CURSOR_LENGTH:
|
||||
raise InvalidCursorError("encoded cursor exceeds maximum length")
|
||||
return encoded
|
||||
|
||||
|
||||
def encode_cursor_from_time(sort_field: str, t: datetime, id: str, order: str = "desc") -> str:
|
||||
"""Encode a time-typed cursor at Unix microsecond precision.
|
||||
|
||||
Accepts an aware datetime (any timezone) and normalizes to UTC. Naive
|
||||
datetimes are rejected so callers can't accidentally encode the local
|
||||
wall-clock value of a UTC-stored timestamp.
|
||||
"""
|
||||
if t.tzinfo is None:
|
||||
raise ValueError("encode_cursor_from_time requires an aware datetime")
|
||||
micros = _datetime_to_unix_micros(t.astimezone(timezone.utc))
|
||||
return encode_cursor(sort_field, str(micros), id, order=order)
|
||||
|
||||
|
||||
def decode_cursor(
|
||||
cursor: str,
|
||||
allowed_sort_fields: Iterable[str],
|
||||
expected_order: str | None = None,
|
||||
) -> CursorPayload:
|
||||
"""Parse an opaque cursor.
|
||||
|
||||
``allowed_sort_fields`` is the endpoint's accepted sort-field list — a
|
||||
cursor carrying a field outside this set is rejected so a cursor minted
|
||||
for one column can't be replayed against another (e.g. a ``created_at``
|
||||
timestamp string compared against a ``name`` column).
|
||||
|
||||
``expected_order`` (``"asc"``/``"desc"``), when supplied, must match the
|
||||
payload's ``o`` field. ``o`` is required on every payload; a cursor
|
||||
missing it is rejected as malformed.
|
||||
|
||||
Passing no allowed fields rejects every cursor.
|
||||
"""
|
||||
if len(cursor) > MAX_ENCODED_CURSOR_LENGTH:
|
||||
raise InvalidCursorError("cursor exceeds maximum length")
|
||||
|
||||
try:
|
||||
# urlsafe_b64decode requires correct padding; we strip on encode, so
|
||||
# restore the trailing '=' pad here.
|
||||
padding = "=" * (-len(cursor) % 4)
|
||||
raw = base64.urlsafe_b64decode(cursor + padding)
|
||||
except (ValueError, base64.binascii.Error) as e:
|
||||
raise InvalidCursorError(f"encoding: {e}") from e
|
||||
|
||||
try:
|
||||
decoded = json.loads(raw)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
raise InvalidCursorError(f"payload: {e}") from e
|
||||
|
||||
if not isinstance(decoded, dict):
|
||||
raise InvalidCursorError("payload: expected object")
|
||||
|
||||
sort_field = decoded.get("s")
|
||||
value = decoded.get("v")
|
||||
id = decoded.get("id")
|
||||
order = decoded.get("o")
|
||||
|
||||
if not isinstance(sort_field, str) or not isinstance(value, str) or not isinstance(id, str):
|
||||
raise InvalidCursorError("payload: missing or non-string s/v/id")
|
||||
|
||||
if id == "":
|
||||
raise InvalidCursorError("missing id")
|
||||
if len(id) > MAX_CURSOR_ID_LENGTH:
|
||||
raise InvalidCursorError("id exceeds maximum length")
|
||||
if len(value) > MAX_CURSOR_VALUE_LENGTH:
|
||||
raise InvalidCursorError("value exceeds maximum length")
|
||||
|
||||
if sort_field not in allowed_sort_fields:
|
||||
raise InvalidCursorError(f"unsupported sort field {sort_field!r}")
|
||||
|
||||
if not isinstance(order, str):
|
||||
raise InvalidCursorError("missing or non-string o")
|
||||
if order not in _VALID_ORDERS:
|
||||
raise InvalidCursorError(f"unsupported order {order!r}")
|
||||
if expected_order is not None and order != expected_order:
|
||||
raise InvalidCursorError(
|
||||
f"cursor order {order!r} does not match request order {expected_order!r}"
|
||||
)
|
||||
|
||||
return CursorPayload(sort_field=sort_field, value=value, id=id, order=order)
|
||||
|
||||
|
||||
def decode_cursor_time(payload: Optional[CursorPayload]) -> datetime:
|
||||
"""Parse a time-typed cursor value as Unix microseconds, returning UTC."""
|
||||
if payload is None:
|
||||
raise InvalidCursorError("nil cursor payload")
|
||||
try:
|
||||
micros = int(payload.value)
|
||||
except ValueError as e:
|
||||
raise InvalidCursorError(f"value is not a valid timestamp: {e}") from e
|
||||
try:
|
||||
return _unix_micros_to_datetime(micros)
|
||||
except (OverflowError, OSError, ValueError) as e:
|
||||
# Crafted out-of-range microseconds (e.g. > datetime.MAX_YEAR) blow up
|
||||
# in fromtimestamp / datetime construction. Map to 400, not 500.
|
||||
raise InvalidCursorError(f"value is out of representable range: {e}") from e
|
||||
|
||||
|
||||
def decode_cursor_int(payload: Optional[CursorPayload]) -> int:
|
||||
"""Parse a cursor value as a base-10 integer."""
|
||||
if payload is None:
|
||||
raise InvalidCursorError("nil cursor payload")
|
||||
try:
|
||||
return int(payload.value)
|
||||
except ValueError as e:
|
||||
raise InvalidCursorError(f"value is not a valid integer: {e}") from e
|
||||
|
||||
|
||||
_EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _datetime_to_unix_micros(t: datetime) -> int:
|
||||
"""Convert an aware UTC datetime to Unix microseconds (integer math)."""
|
||||
delta = t - _EPOCH
|
||||
return (delta.days * 86_400 + delta.seconds) * 1_000_000 + delta.microseconds
|
||||
|
||||
|
||||
def _unix_micros_to_datetime(micros: int) -> datetime:
|
||||
"""Convert Unix microseconds to a UTC datetime, preserving precision."""
|
||||
seconds, micro_remainder = divmod(micros, 1_000_000)
|
||||
return datetime.fromtimestamp(seconds, tz=timezone.utc).replace(microsecond=micro_remainder)
|
||||
@ -71,7 +71,6 @@ class AssetSummaryData:
|
||||
class ListAssetsResult:
|
||||
items: list[AssetSummaryData]
|
||||
total: int
|
||||
next_cursor: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@ -1613,16 +1613,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
#use all ModelPatcherDynamic this is ignored and its all done dynamically.
|
||||
return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
|
||||
|
||||
def restore_loaded_backups(self):
|
||||
restored = self.model.model_loaded_weight_memory
|
||||
for key in list(self.backup.keys()):
|
||||
bk = self.backup.pop(key)
|
||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||
for key in list(self.backup_buffers.keys()):
|
||||
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
return restored
|
||||
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
|
||||
|
||||
@ -1639,7 +1629,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
num_patches = 0
|
||||
allocated_size = 0
|
||||
self.restore_loaded_backups()
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
@ -1726,9 +1716,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
force_load=True
|
||||
|
||||
if force_load:
|
||||
if hasattr(m, "_v"):
|
||||
comfy_aimdo.model_vbar.vbar_unpin(m._v)
|
||||
delattr(m, "_v")
|
||||
force_load_param(self, "weight", device_to)
|
||||
force_load_param(self, "bias", device_to)
|
||||
else:
|
||||
@ -1786,7 +1773,13 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||
|
||||
if freed < memory_to_free:
|
||||
freed += self.restore_loaded_backups()
|
||||
for key in list(self.backup.keys()):
|
||||
bk = self.backup.pop(key)
|
||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||
for key in list(self.backup_buffers.keys()):
|
||||
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
|
||||
freed += self.model.model_loaded_weight_memory
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
return freed
|
||||
|
||||
|
||||
@ -1019,11 +1019,10 @@ def bislerp(samples, width, height):
|
||||
|
||||
def lanczos(samples, width, height):
|
||||
#the below API is strict and expects grayscale to be squeezed
|
||||
if samples.ndim == 4:
|
||||
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
|
||||
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
|
||||
images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||
images = [torch.from_numpy(t).movedim(-1, 0) if (t := np.array(image).astype(np.float32) / 255.0).ndim == 3 else torch.from_numpy(t) for image in images]
|
||||
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
||||
result = torch.stack(images)
|
||||
return result.to(samples.device, samples.dtype)
|
||||
|
||||
|
||||
@ -35,19 +35,6 @@ class AnthropicMessage(BaseModel):
|
||||
content: list[AnthropicTextContent | AnthropicImageContent] = Field(...)
|
||||
|
||||
|
||||
class AnthropicThinkingConfig(BaseModel):
|
||||
type: Literal["enabled", "disabled", "adaptive"] = Field(...)
|
||||
budget_tokens: int | None = Field(
|
||||
None, ge=1024,
|
||||
description="Reasoning budget in tokens. Used when type is 'enabled'. Must be less than max_tokens.",
|
||||
)
|
||||
|
||||
|
||||
class AnthropicOutputConfig(BaseModel):
|
||||
"""Used with `thinking.type='adaptive'` on models like Opus 4.7."""
|
||||
effort: Literal["low", "medium", "high"] | None = Field(None)
|
||||
|
||||
|
||||
class AnthropicMessagesRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
messages: list[AnthropicMessage] = Field(...)
|
||||
@ -57,8 +44,6 @@ class AnthropicMessagesRequest(BaseModel):
|
||||
top_p: float | None = Field(None, ge=0.0, le=1.0)
|
||||
top_k: int | None = Field(None, ge=0)
|
||||
stop_sequences: list[str] | None = Field(None)
|
||||
thinking: AnthropicThinkingConfig | None = Field(None)
|
||||
output_config: AnthropicOutputConfig | None = Field(None)
|
||||
|
||||
|
||||
class AnthropicResponseTextBlock(BaseModel):
|
||||
@ -66,14 +51,6 @@ class AnthropicResponseTextBlock(BaseModel):
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class AnthropicResponseThinkingBlock(BaseModel):
|
||||
type: Literal["thinking"] = "thinking"
|
||||
thinking: str = Field(...)
|
||||
|
||||
|
||||
AnthropicResponseBlock = AnthropicResponseTextBlock | AnthropicResponseThinkingBlock
|
||||
|
||||
|
||||
class AnthropicCacheCreationUsage(BaseModel):
|
||||
ephemeral_5m_input_tokens: int | None = Field(None)
|
||||
ephemeral_1h_input_tokens: int | None = Field(None)
|
||||
@ -92,7 +69,7 @@ class AnthropicMessagesResponse(BaseModel):
|
||||
type: str | None = Field(None)
|
||||
role: str | None = Field(None)
|
||||
model: str | None = Field(None)
|
||||
content: list[AnthropicResponseBlock] | None = Field(None)
|
||||
content: list[AnthropicResponseTextBlock] | None = Field(None)
|
||||
stop_reason: str | None = Field(None)
|
||||
stop_sequence: str | None = Field(None)
|
||||
usage: AnthropicMessagesUsage | None = Field(None)
|
||||
|
||||
@ -1,93 +0,0 @@
|
||||
"""Pydantic models for the OpenRouter chat completions API.
|
||||
|
||||
See: https://openrouter.ai/docs/api/api-reference/chat/send-chat-completion-request
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OpenRouterTextContent(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class OpenRouterImageUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class OpenRouterImageContent(BaseModel):
|
||||
type: Literal["image_url"] = "image_url"
|
||||
image_url: OpenRouterImageUrl = Field(...)
|
||||
|
||||
|
||||
class OpenRouterVideoUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class OpenRouterVideoContent(BaseModel):
|
||||
type: Literal["video_url"] = "video_url"
|
||||
video_url: OpenRouterVideoUrl = Field(...)
|
||||
|
||||
|
||||
OpenRouterContentBlock = OpenRouterTextContent | OpenRouterImageContent | OpenRouterVideoContent
|
||||
|
||||
|
||||
class OpenRouterMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant"] = Field(...)
|
||||
content: str | list[OpenRouterContentBlock] = Field(...)
|
||||
|
||||
|
||||
class OpenRouterReasoningConfig(BaseModel):
|
||||
effort: str | None = Field(None)
|
||||
exclude: bool | None = Field(None, description="If true, model reasons but reasoning is excluded from response.")
|
||||
|
||||
|
||||
class OpenRouterWebSearchOptions(BaseModel):
|
||||
search_context_size: str | None = Field(None)
|
||||
|
||||
|
||||
class OpenRouterChatRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
messages: list[OpenRouterMessage] = Field(...)
|
||||
seed: int | None = Field(None)
|
||||
reasoning: OpenRouterReasoningConfig | None = Field(None)
|
||||
web_search_options: OpenRouterWebSearchOptions | None = Field(None)
|
||||
stream: bool = Field(False)
|
||||
|
||||
|
||||
class OpenRouterUsage(BaseModel):
|
||||
prompt_tokens: int | None = Field(None)
|
||||
completion_tokens: int | None = Field(None)
|
||||
total_tokens: int | None = Field(None)
|
||||
cost: float | None = Field(None, description="Server-side authoritative USD cost of the call.")
|
||||
|
||||
|
||||
class OpenRouterResponseMessage(BaseModel):
|
||||
role: str | None = Field(None)
|
||||
content: str | None = Field(None)
|
||||
reasoning: str | None = Field(None)
|
||||
refusal: str | None = Field(None)
|
||||
|
||||
|
||||
class OpenRouterChoice(BaseModel):
|
||||
index: int | None = Field(None)
|
||||
message: OpenRouterResponseMessage | None = Field(None)
|
||||
finish_reason: str | None = Field(None)
|
||||
|
||||
|
||||
class OpenRouterError(BaseModel):
|
||||
code: int | str | None = Field(None)
|
||||
message: str | None = Field(None)
|
||||
metadata: dict | None = Field(None)
|
||||
|
||||
|
||||
class OpenRouterChatResponse(BaseModel):
|
||||
id: str | None = Field(None)
|
||||
model: str | None = Field(None)
|
||||
object: str | None = Field(None)
|
||||
provider: str | None = Field(None)
|
||||
choices: list[OpenRouterChoice] | None = Field(None)
|
||||
usage: OpenRouterUsage | None = Field(None)
|
||||
error: OpenRouterError | None = Field(None)
|
||||
@ -9,11 +9,8 @@ from comfy_api_nodes.apis.anthropic import (
|
||||
AnthropicMessage,
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicMessagesResponse,
|
||||
AnthropicOutputConfig,
|
||||
AnthropicResponseTextBlock,
|
||||
AnthropicRole,
|
||||
AnthropicTextContent,
|
||||
AnthropicThinkingConfig,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
@ -35,29 +32,15 @@ CLAUDE_MODELS: dict[str, str] = {
|
||||
"Haiku 4.5": "claude-haiku-4-5-20251001",
|
||||
}
|
||||
|
||||
_THINKING_UNSUPPORTED = {"Haiku 4.5"}
|
||||
# Models that use the newer "adaptive" thinking mode (Opus 4.7 requires it; older models keep the explicit budget API).
|
||||
# Anthropic decides the actual budget when adaptive is used, based on the `output_config.effort` hint.
|
||||
_ADAPTIVE_THINKING_MODELS = {"Opus 4.7", "Opus 4.6", "Sonnet 4.6"}
|
||||
|
||||
# Budget mode (Sonnet 4.5): effort -> reasoning budget in tokens. Must be < max_tokens.
|
||||
# Sized so even the "high" budget fits comfortably under the default max_tokens=32768.
|
||||
_REASONING_BUDGET: dict[str, int] = {
|
||||
"low": 2048,
|
||||
"medium": 8192,
|
||||
"high": 16384,
|
||||
}
|
||||
_REASONING_EFFORTS = ["off", "low", "medium", "high"]
|
||||
|
||||
|
||||
def _claude_model_inputs(model_label: str):
|
||||
inputs: list = [
|
||||
def _claude_model_inputs():
|
||||
return [
|
||||
IO.Int.Input(
|
||||
"max_tokens",
|
||||
default=32768,
|
||||
min=4096,
|
||||
max=64000,
|
||||
tooltip="Maximum number of tokens to generate (includes reasoning tokens when enabled).",
|
||||
default=16000,
|
||||
min=32,
|
||||
max=32000,
|
||||
tooltip="Maximum number of tokens to generate before stopping.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
@ -66,24 +49,10 @@ def _claude_model_inputs(model_label: str):
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip=(
|
||||
"Controls randomness. 0.0 is deterministic, 1.0 is most random. "
|
||||
"Ignored for Opus 4.7 and any model when reasoning_effort is set."
|
||||
),
|
||||
tooltip="Controls randomness. 0.0 is deterministic, 1.0 is most random. Ignored for Opus 4.7.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
if model_label not in _THINKING_UNSUPPORTED:
|
||||
inputs.append(
|
||||
IO.Combo.Input(
|
||||
"reasoning_effort",
|
||||
options=_REASONING_EFFORTS,
|
||||
default="off",
|
||||
tooltip="Extended thinking effort. 'off' disables reasoning.",
|
||||
advanced=True,
|
||||
)
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
def _model_price_per_million(model: str) -> tuple[float, float] | None:
|
||||
@ -126,11 +95,7 @@ def calculate_tokens_price(response: AnthropicMessagesResponse) -> float | None:
|
||||
def _get_text_from_response(response: AnthropicMessagesResponse) -> str:
|
||||
if not response.content:
|
||||
return ""
|
||||
# Thinking blocks are silently dropped — we never want reasoning in the output.
|
||||
return "\n".join(
|
||||
block.text for block in response.content
|
||||
if isinstance(block, AnthropicResponseTextBlock) and block.text
|
||||
)
|
||||
return "\n".join(block.text for block in response.content if block.text)
|
||||
|
||||
|
||||
async def _build_image_content_blocks(
|
||||
@ -168,10 +133,7 @@ class ClaudeNode(IO.ComfyNode):
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(label, _claude_model_inputs(label))
|
||||
for label in CLAUDE_MODELS
|
||||
],
|
||||
options=[IO.DynamicCombo.Option(label, _claude_model_inputs()) for label in CLAUDE_MODELS],
|
||||
tooltip="The Claude model used to generate the response.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
@ -245,29 +207,8 @@ class ClaudeNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_label = model["model"]
|
||||
max_tokens = model.get("max_tokens", 32768)
|
||||
reasoning_effort = model.get("reasoning_effort", "off")
|
||||
thinking_enabled = reasoning_effort not in ("off", None) and model_label not in _THINKING_UNSUPPORTED
|
||||
|
||||
# Anthropic requires temperature to be unset (defaults to 1.0) when thinking is enabled.
|
||||
# Opus 4.7 also rejects user-supplied temperature.
|
||||
if thinking_enabled or model_label == "Opus 4.7":
|
||||
temperature = None
|
||||
else:
|
||||
temperature = model.get("temperature", 1.0)
|
||||
|
||||
thinking_cfg: AnthropicThinkingConfig | None = None
|
||||
output_cfg: AnthropicOutputConfig | None = None
|
||||
if thinking_enabled:
|
||||
if model_label in _ADAPTIVE_THINKING_MODELS:
|
||||
# Adaptive mode - Anthropic chooses the budget based on effort hint
|
||||
thinking_cfg = AnthropicThinkingConfig(type="adaptive")
|
||||
output_cfg = AnthropicOutputConfig(effort=reasoning_effort)
|
||||
else:
|
||||
# Budget mode (Sonnet 4.5). Leave at least 1024 tokens for the actual response
|
||||
budget = _REASONING_BUDGET[reasoning_effort]
|
||||
budget = min(budget, max(1024, max_tokens - 1024))
|
||||
thinking_cfg = AnthropicThinkingConfig(type="enabled", budget_tokens=budget)
|
||||
max_tokens = model["max_tokens"]
|
||||
temperature = None if model_label == "Opus 4.7" else model["temperature"]
|
||||
|
||||
image_tensors: list[Input.Image] = [t for t in (images or {}).values() if t is not None]
|
||||
if sum(get_number_of_images(t) for t in image_tensors) > CLAUDE_MAX_IMAGES:
|
||||
@ -288,8 +229,6 @@ class ClaudeNode(IO.ComfyNode):
|
||||
messages=[AnthropicMessage(role=AnthropicRole.user, content=content)],
|
||||
system=system_prompt or None,
|
||||
temperature=temperature,
|
||||
thinking=thinking_cfg,
|
||||
output_config=output_cfg,
|
||||
),
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
|
||||
@ -43,16 +43,15 @@ from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
downscale_video_to_max_pixels,
|
||||
get_number_of_images,
|
||||
image_tensor_pair_to_batch,
|
||||
poll_op,
|
||||
resize_video_to_pixel_budget,
|
||||
sync_op,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
upscale_video_to_min_pixels,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
@ -111,13 +110,12 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: st
|
||||
max_px = limits.get("max")
|
||||
if min_px and pixels < min_px:
|
||||
raise ValueError(
|
||||
f"Reference video {index} is too small: {w}x{h} = {pixels:,} total pixels. "
|
||||
f"Minimum for this model is {min_px:,} total pixels."
|
||||
f"Reference video {index} is too small: {w}x{h} = {pixels:,}px. " f"Minimum is {min_px:,}px for this model."
|
||||
)
|
||||
if max_px and pixels > max_px:
|
||||
raise ValueError(
|
||||
f"Reference video {index} is too large: {w}x{h} = {pixels:,} total pixels. "
|
||||
f"Maximum for this model is {max_px:,} total pixels. Try downscaling the video."
|
||||
f"Reference video {index} is too large: {w}x{h} = {pixels:,}px. "
|
||||
f"Maximum is {max_px:,}px for this model. Try downscaling the video."
|
||||
)
|
||||
|
||||
|
||||
@ -1678,14 +1676,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
"first_frame_asset_id",
|
||||
default="",
|
||||
tooltip="Seedance asset_id to use as the first frame. "
|
||||
"Mutually exclusive with the first_frame image input.",
|
||||
"Mutually exclusive with the first_frame image input.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"last_frame_asset_id",
|
||||
default="",
|
||||
tooltip="Seedance asset_id to use as the last frame. "
|
||||
"Mutually exclusive with the last_frame image input.",
|
||||
"Mutually exclusive with the last_frame image input.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
@ -1867,20 +1865,11 @@ def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16
|
||||
IO.Boolean.Input(
|
||||
"auto_downscale",
|
||||
default=False,
|
||||
advanced=True,
|
||||
optional=True,
|
||||
tooltip="Automatically downscale reference videos that exceed the model's pixel budget "
|
||||
"for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"auto_upscale",
|
||||
default=False,
|
||||
advanced=True,
|
||||
optional=True,
|
||||
tooltip="Automatically upscale reference videos that are below the model's minimum pixel count "
|
||||
"for the selected resolution. Aspect ratio is preserved; videos already meeting the minimum are "
|
||||
"untouched. Note: upscaling a low-resolution source does not add real detail and may produce "
|
||||
"lower-quality generations.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"reference_assets",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
@ -2041,13 +2030,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
max_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("max")
|
||||
if max_px:
|
||||
for key in reference_videos:
|
||||
reference_videos[key] = downscale_video_to_max_pixels(reference_videos[key], max_px)
|
||||
|
||||
if model.get("auto_upscale") and reference_videos:
|
||||
min_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("min")
|
||||
if min_px:
|
||||
for key in reference_videos:
|
||||
reference_videos[key] = upscale_video_to_min_pixels(reference_videos[key], min_px)
|
||||
reference_videos[key] = resize_video_to_pixel_budget(reference_videos[key], max_px)
|
||||
|
||||
total_video_duration = 0.0
|
||||
for i, key in enumerate(reference_videos, 1):
|
||||
|
||||
@ -1,374 +0,0 @@
|
||||
"""API Nodes for OpenRouter LLM chat completions."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.openrouter import (
|
||||
OpenRouterChatRequest,
|
||||
OpenRouterChatResponse,
|
||||
OpenRouterContentBlock,
|
||||
OpenRouterImageContent,
|
||||
OpenRouterImageUrl,
|
||||
OpenRouterMessage,
|
||||
OpenRouterReasoningConfig,
|
||||
OpenRouterTextContent,
|
||||
OpenRouterVideoContent,
|
||||
OpenRouterVideoUrl,
|
||||
OpenRouterWebSearchOptions,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
get_number_of_images,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
OPENROUTER_CHAT_ENDPOINT = "/proxy/openrouter/api/v1/chat/completions"
|
||||
|
||||
|
||||
Profile = Literal["standard", "reasoning", "frontier_reasoning", "perplexity", "perplexity_reasoning"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ModelSpec:
|
||||
slug: str # exact OpenRouter model id
|
||||
profile: Profile
|
||||
price_in: float # USD per token (prompt)
|
||||
price_out: float # USD per token (completion)
|
||||
max_images: int = 0 # 0 = no image input; otherwise max URL-passed images supported
|
||||
max_videos: int = 0 # 0 = no video input; otherwise max URL-passed videos supported
|
||||
|
||||
|
||||
MODELS: list[_ModelSpec] = [
|
||||
_ModelSpec("anthropic/claude-opus-4.7", "frontier_reasoning", 0.000005, 0.000025, max_images=20),
|
||||
_ModelSpec("openai/gpt-5.5-pro", "frontier_reasoning", 0.00003, 0.00018, max_images=20),
|
||||
_ModelSpec("openai/gpt-5.5", "frontier_reasoning", 0.000005, 0.00003, max_images=20),
|
||||
_ModelSpec("google/gemini-3.5-flash", "reasoning", 0.0000015, 0.000009, max_images=20, max_videos=4),
|
||||
_ModelSpec("x-ai/grok-4.20", "reasoning", 0.00000125, 0.0000025, max_images=20),
|
||||
_ModelSpec("x-ai/grok-4.3", "reasoning", 0.00000125, 0.0000025, max_images=20),
|
||||
_ModelSpec("deepseek/deepseek-v4-pro", "reasoning", 0.000000435, 0.00000087),
|
||||
_ModelSpec("deepseek/deepseek-v4-flash", "reasoning", 0.000000112, 0.000000224),
|
||||
_ModelSpec("deepseek/deepseek-v3.2", "reasoning", 0.000000252, 0.000000378),
|
||||
_ModelSpec("qwen/qwen3.6-max-preview", "reasoning", 0.00000104, 0.00000624),
|
||||
_ModelSpec("qwen/qwen3.6-plus", "reasoning", 0.000000325, 0.00000195, max_images=10, max_videos=4),
|
||||
_ModelSpec("qwen/qwen3.6-flash", "reasoning", 0.0000001875, 0.000001125, max_images=10, max_videos=4),
|
||||
_ModelSpec("mistralai/mistral-large-2512", "standard", 0.0000005, 0.0000015, max_images=8),
|
||||
_ModelSpec("mistralai/mistral-medium-3-5", "reasoning", 0.0000015, 0.0000075, max_images=8),
|
||||
_ModelSpec("z-ai/glm-4.6", "reasoning", 0.00000043, 0.00000174),
|
||||
_ModelSpec("z-ai/glm-5", "reasoning", 0.0000006, 0.00000192),
|
||||
_ModelSpec("moonshotai/kimi-k2.6", "reasoning", 0.00000073, 0.00000349, max_images=10),
|
||||
_ModelSpec("moonshotai/kimi-k2-thinking", "reasoning", 0.0000006, 0.0000025),
|
||||
_ModelSpec("perplexity/sonar-pro", "perplexity", 0.000003, 0.000015),
|
||||
_ModelSpec("perplexity/sonar-reasoning-pro", "perplexity_reasoning", 0.000002, 0.000008),
|
||||
_ModelSpec("perplexity/sonar-deep-research", "perplexity_reasoning", 0.000002, 0.000008),
|
||||
]
|
||||
|
||||
_MODELS_BY_SLUG: dict[str, _ModelSpec] = {m.slug: m for m in MODELS}
|
||||
_REASONING_EFFORTS = ["off", "low", "medium", "high"]
|
||||
_SEARCH_CONTEXT_SIZES = ["low", "medium", "high"]
|
||||
|
||||
|
||||
def _reasoning_extra_inputs() -> list:
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"reasoning_effort",
|
||||
options=_REASONING_EFFORTS,
|
||||
default="off",
|
||||
tooltip="Reasoning effort. 'off' disables reasoning entirely.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _perplexity_extra_inputs() -> list:
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"search_context_size",
|
||||
options=_SEARCH_CONTEXT_SIZES,
|
||||
default="medium",
|
||||
tooltip="How much web search context to retrieve. Larger = more grounded but slower/pricier.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _profile_inputs(profile: Profile) -> list:
|
||||
if profile == "standard":
|
||||
return []
|
||||
if profile in ("reasoning", "frontier_reasoning"):
|
||||
return _reasoning_extra_inputs()
|
||||
if profile == "perplexity":
|
||||
return _perplexity_extra_inputs()
|
||||
if profile == "perplexity_reasoning":
|
||||
return _perplexity_extra_inputs() + _reasoning_extra_inputs()
|
||||
raise ValueError(f"Unknown profile: {profile}")
|
||||
|
||||
|
||||
def _media_inputs(spec: _ModelSpec) -> list:
|
||||
extras: list = []
|
||||
if spec.max_images > 0:
|
||||
extras.append(
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, spec.max_images + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional reference image(s) — up to {spec.max_images}. Sent as URLs.",
|
||||
)
|
||||
)
|
||||
if spec.max_videos > 0:
|
||||
extras.append(
|
||||
IO.Autogrow.Input(
|
||||
"videos",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Video.Input("video"),
|
||||
names=[f"video_{i}" for i in range(1, spec.max_videos + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional reference video(s) — up to {spec.max_videos}. Sent as URLs.",
|
||||
)
|
||||
)
|
||||
return extras
|
||||
|
||||
|
||||
def _inputs_for_model(spec: _ModelSpec) -> list:
|
||||
return _profile_inputs(spec.profile) + _media_inputs(spec)
|
||||
|
||||
|
||||
def _build_model_options() -> list[IO.DynamicCombo.Option]:
|
||||
return [IO.DynamicCombo.Option(spec.slug, _inputs_for_model(spec)) for spec in MODELS]
|
||||
|
||||
|
||||
def _calculate_price(response: OpenRouterChatResponse) -> float | None:
|
||||
if response.usage and response.usage.cost is not None:
|
||||
return float(response.usage.cost)
|
||||
return None
|
||||
|
||||
|
||||
def _price_badge_jsonata() -> str:
|
||||
rates_pairs = []
|
||||
for spec in MODELS:
|
||||
prompt_per_1k = spec.price_in * 1000
|
||||
completion_per_1k = spec.price_out * 1000
|
||||
rates_pairs.append(f' "{spec.slug}": [{prompt_per_1k:.8g}, {completion_per_1k:.8g}]')
|
||||
rates_block = ",\n".join(rates_pairs)
|
||||
return (
|
||||
"(\n"
|
||||
" $rates := {\n"
|
||||
f"{rates_block}\n"
|
||||
" };\n"
|
||||
" $r := $lookup($rates, widgets.model);\n"
|
||||
" $r ? {\n"
|
||||
' "type": "list_usd",\n'
|
||||
' "usd": $r,\n'
|
||||
' "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }\n'
|
||||
' } : {"type": "text", "text": "Token-based"}\n'
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
async def _build_image_blocks(
|
||||
cls: type[IO.ComfyNode], spec: _ModelSpec, images: list[Input.Image]
|
||||
) -> list[OpenRouterImageContent]:
|
||||
urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
images,
|
||||
max_images=spec.max_images,
|
||||
total_pixels=2048 * 2048,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading reference images",
|
||||
)
|
||||
return [OpenRouterImageContent(image_url=OpenRouterImageUrl(url=url)) for url in urls]
|
||||
|
||||
|
||||
async def _build_video_blocks(cls: type[IO.ComfyNode], videos: list[Input.Video]) -> list[OpenRouterVideoContent]:
|
||||
blocks: list[OpenRouterVideoContent] = []
|
||||
total = len(videos)
|
||||
for idx, video in enumerate(videos):
|
||||
label = "Uploading reference video"
|
||||
if total > 1:
|
||||
label = f"{label} ({idx + 1}/{total})"
|
||||
url = await upload_video_to_comfyapi(cls, video, wait_label=label)
|
||||
blocks.append(OpenRouterVideoContent(video_url=OpenRouterVideoUrl(url=url)))
|
||||
return blocks
|
||||
|
||||
|
||||
def _user_message(prompt: str, media_blocks: list[OpenRouterContentBlock]) -> OpenRouterMessage:
|
||||
if not media_blocks:
|
||||
return OpenRouterMessage(role="user", content=prompt)
|
||||
blocks: list[OpenRouterContentBlock] = list(media_blocks)
|
||||
blocks.append(OpenRouterTextContent(text=prompt))
|
||||
return OpenRouterMessage(role="user", content=blocks)
|
||||
|
||||
|
||||
def _build_messages(
|
||||
system_prompt: str, prompt: str, media_blocks: list[OpenRouterContentBlock]
|
||||
) -> list[OpenRouterMessage]:
|
||||
messages: list[OpenRouterMessage] = []
|
||||
if system_prompt:
|
||||
messages.append(OpenRouterMessage(role="system", content=system_prompt))
|
||||
messages.append(_user_message(prompt, media_blocks))
|
||||
return messages
|
||||
|
||||
|
||||
def _build_request(
|
||||
slug: str,
|
||||
system_prompt: str,
|
||||
prompt: str,
|
||||
media_blocks: list[OpenRouterContentBlock],
|
||||
*,
|
||||
seed: int,
|
||||
reasoning_effort: str | None,
|
||||
search_context_size: str | None,
|
||||
) -> OpenRouterChatRequest:
|
||||
reasoning_cfg: OpenRouterReasoningConfig | None = None
|
||||
if reasoning_effort and reasoning_effort != "off":
|
||||
# exclude=True asks providers to reason internally but not return the trace
|
||||
reasoning_cfg = OpenRouterReasoningConfig(effort=reasoning_effort, exclude=True)
|
||||
web_search_cfg: OpenRouterWebSearchOptions | None = None
|
||||
if search_context_size:
|
||||
web_search_cfg = OpenRouterWebSearchOptions(search_context_size=search_context_size)
|
||||
return OpenRouterChatRequest(
|
||||
model=slug,
|
||||
messages=_build_messages(system_prompt, prompt, media_blocks),
|
||||
seed=seed if seed > 0 else None,
|
||||
reasoning=reasoning_cfg,
|
||||
web_search_options=web_search_cfg,
|
||||
)
|
||||
|
||||
|
||||
def _extract_text(response: OpenRouterChatResponse) -> str:
|
||||
if response.error:
|
||||
code = response.error.code if response.error.code is not None else "unknown"
|
||||
raise ValueError(f"OpenRouter error ({code}): {response.error.message or 'no message'}")
|
||||
if not response.choices:
|
||||
raise ValueError("Empty response from OpenRouter (no choices).")
|
||||
message = response.choices[0].message
|
||||
if not message:
|
||||
raise ValueError("Empty response from OpenRouter (no message).")
|
||||
if message.refusal:
|
||||
raise ValueError(f"Model refused to respond: {message.refusal}")
|
||||
return message.content or ""
|
||||
|
||||
|
||||
class OpenRouterLLMNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="OpenRouterLLMNode",
|
||||
display_name="OpenRouter LLM",
|
||||
category="api node/text/OpenRouter",
|
||||
essentials_category="Text Generation",
|
||||
description=(
|
||||
"Generate text responses through OpenRouter. Routes to a curated set of popular "
|
||||
"models from xAI, DeepSeek, Qwen, Mistral, Z.AI (GLM), Moonshot (Kimi), and "
|
||||
"Perplexity Sonar."
|
||||
),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text input to the model.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=_build_model_options(),
|
||||
tooltip="The OpenRouter model used to generate the response.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for sampling. Set to 0 to omit. Most models treat this as a hint only.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Foundational instructions that dictate the model's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.String.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr=_price_badge_jsonata(),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
slug: str = model["model"]
|
||||
spec = _MODELS_BY_SLUG.get(slug)
|
||||
if spec is None:
|
||||
raise ValueError(f"Unknown OpenRouter model: {slug}")
|
||||
|
||||
reasoning_effort: str | None = model.get("reasoning_effort")
|
||||
search_context_size: str | None = model.get("search_context_size")
|
||||
|
||||
image_tensors: list[Input.Image] = [t for t in (model.get("images") or {}).values() if t is not None]
|
||||
if image_tensors and sum(get_number_of_images(t) for t in image_tensors) > spec.max_images:
|
||||
raise ValueError(f"Up to {spec.max_images} images are supported for {slug}.")
|
||||
video_inputs: list[Input.Video] = [v for v in (model.get("videos") or {}).values() if v is not None]
|
||||
if video_inputs and len(video_inputs) > spec.max_videos:
|
||||
raise ValueError(f"Up to {spec.max_videos} videos are supported for {slug}.")
|
||||
|
||||
media_blocks: list[OpenRouterContentBlock] = []
|
||||
if image_tensors:
|
||||
media_blocks.extend(await _build_image_blocks(cls, spec, image_tensors))
|
||||
if video_inputs:
|
||||
media_blocks.extend(await _build_video_blocks(cls, video_inputs))
|
||||
|
||||
request = _build_request(
|
||||
slug,
|
||||
system_prompt,
|
||||
prompt,
|
||||
media_blocks,
|
||||
seed=seed,
|
||||
reasoning_effort=reasoning_effort,
|
||||
search_context_size=search_context_size,
|
||||
)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=OPENROUTER_CHAT_ENDPOINT, method="POST"),
|
||||
response_model=OpenRouterChatResponse,
|
||||
data=request,
|
||||
price_extractor=_calculate_price,
|
||||
)
|
||||
return IO.NodeOutput(_extract_text(response))
|
||||
|
||||
|
||||
class OpenRouterExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [OpenRouterLLMNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> OpenRouterExtension:
|
||||
return OpenRouterExtension()
|
||||
@ -16,17 +16,16 @@ from .conversions import (
|
||||
convert_mask_to_image,
|
||||
downscale_image_tensor,
|
||||
downscale_image_tensor_by_max_side,
|
||||
downscale_video_to_max_pixels,
|
||||
image_tensor_pair_to_batch,
|
||||
pil_to_bytesio,
|
||||
resize_mask_to_image,
|
||||
resize_video_to_pixel_budget,
|
||||
tensor_to_base64_string,
|
||||
tensor_to_bytesio,
|
||||
tensor_to_pil,
|
||||
text_filepath_to_base64_string,
|
||||
text_filepath_to_data_uri,
|
||||
trim_video,
|
||||
upscale_video_to_min_pixels,
|
||||
video_to_base64_string,
|
||||
)
|
||||
from .download_helpers import (
|
||||
@ -89,17 +88,16 @@ __all__ = [
|
||||
"convert_mask_to_image",
|
||||
"downscale_image_tensor",
|
||||
"downscale_image_tensor_by_max_side",
|
||||
"downscale_video_to_max_pixels",
|
||||
"image_tensor_pair_to_batch",
|
||||
"pil_to_bytesio",
|
||||
"resize_mask_to_image",
|
||||
"resize_video_to_pixel_budget",
|
||||
"tensor_to_base64_string",
|
||||
"tensor_to_bytesio",
|
||||
"tensor_to_pil",
|
||||
"text_filepath_to_base64_string",
|
||||
"text_filepath_to_data_uri",
|
||||
"trim_video",
|
||||
"upscale_video_to_min_pixels",
|
||||
"video_to_base64_string",
|
||||
# Validation utilities
|
||||
"get_image_dimensions",
|
||||
|
||||
@ -415,48 +415,14 @@ def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
|
||||
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
||||
|
||||
|
||||
def downscale_video_to_max_pixels(video: Input.Video, max_pixels: int) -> Input.Video:
|
||||
"""Downscale a video to fit within ``max_pixels`` (w * h), preserving aspect ratio.
|
||||
def resize_video_to_pixel_budget(video: Input.Video, total_pixels: int) -> Input.Video:
|
||||
"""Downscale a video to fit within ``total_pixels`` (w * h), preserving aspect ratio.
|
||||
|
||||
Returns the original video object untouched when it already fits. Preserves frame rate, duration, and audio.
|
||||
Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
|
||||
"""
|
||||
src_w, src_h = video.get_dimensions()
|
||||
scale_dims = _compute_downscale_dims(src_w, src_h, max_pixels)
|
||||
if scale_dims is None:
|
||||
return video
|
||||
return _apply_video_scale(video, scale_dims)
|
||||
|
||||
|
||||
def _compute_upscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[int, int] | None:
|
||||
"""Return upscaled (w, h) with even dims meeting at least ``total_pixels``, or None if already large enough.
|
||||
|
||||
Source aspect ratio is preserved; output may drift by a fraction of a percent because both dimensions
|
||||
are rounded up to even values (many codecs require divisible-by-2). The result is guaranteed to be at
|
||||
least ``total_pixels``.
|
||||
"""
|
||||
pixels = src_w * src_h
|
||||
if pixels >= total_pixels:
|
||||
return None
|
||||
scale = math.sqrt(total_pixels / pixels)
|
||||
new_w = math.ceil(src_w * scale)
|
||||
new_h = math.ceil(src_h * scale)
|
||||
if new_w % 2:
|
||||
new_w += 1
|
||||
if new_h % 2:
|
||||
new_h += 1
|
||||
return new_w, new_h
|
||||
|
||||
|
||||
def upscale_video_to_min_pixels(video: Input.Video, min_pixels: int) -> Input.Video:
|
||||
"""Upscale a video to meet at least ``min_pixels`` (w * h), preserving aspect ratio.
|
||||
|
||||
Returns the original video object untouched when it already meets the minimum. Preserves frame rate,
|
||||
duration, and audio. Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
|
||||
Note: upscaling a low-resolution source does not add real detail; downstream model quality may suffer.
|
||||
"""
|
||||
src_w, src_h = video.get_dimensions()
|
||||
scale_dims = _compute_upscale_dims(src_w, src_h, min_pixels)
|
||||
scale_dims = _compute_downscale_dims(src_w, src_h, total_pixels)
|
||||
if scale_dims is None:
|
||||
return video
|
||||
return _apply_video_scale(video, scale_dims)
|
||||
|
||||
@ -543,7 +543,7 @@ class AudioConcat(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="AudioConcat",
|
||||
search_aliases=["join audio", "combine audio", "append audio"],
|
||||
display_name="Concatenate Audio",
|
||||
display_name="Audio Concat",
|
||||
description="Concatenates the audio1 to audio2 in the specified direction.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -597,7 +597,7 @@ class AudioMerge(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="AudioMerge",
|
||||
search_aliases=["mix audio", "overlay audio", "layer audio"],
|
||||
display_name="Merge Audio",
|
||||
display_name="Audio Merge",
|
||||
description="Combine two audio tracks by overlaying their waveforms.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -667,9 +667,8 @@ class AudioAdjustVolume(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="AudioAdjustVolume",
|
||||
search_aliases=["audio gain", "loudness", "audio level"],
|
||||
display_name="Adjust Audio Volume",
|
||||
display_name="Audio Adjust Volume",
|
||||
category="audio",
|
||||
description="Adjust the volume of the audio by a specified amount in decibels (dB).",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.Int.Input(
|
||||
|
||||
@ -47,10 +47,8 @@ class LoadImageDataSetFromFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadImageDataSetFromFolder",
|
||||
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
|
||||
display_name="Load Image (from Folder)",
|
||||
category="image",
|
||||
description="Load a dataset of images from a specified folder and return a list of images. Supported formats: PNG, JPG, JPEG, WEBP.",
|
||||
display_name="Load Image Dataset from Folder",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
@ -86,16 +84,14 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadImageTextDataSetFromFolder",
|
||||
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
|
||||
display_name="Load Image-Text (from Folder)",
|
||||
category="image",
|
||||
description="Load a dataset of pairs of images and text captions from a specified folder and return them as a list. Supported formats: PNG, JPG, JPEG, WEBP.",
|
||||
display_name="Load Image and Text Dataset from Folder",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"folder",
|
||||
options=folder_paths.get_input_subfolders(),
|
||||
tooltip="The folder to load images and text captions from.",
|
||||
tooltip="The folder to load images from.",
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
@ -210,10 +206,8 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveImageDataSetToFolder",
|
||||
search_aliases=["save folder", "save to folder", "save dataset", "save images", "export dataset"],
|
||||
display_name="Save Image (to Folder) (DEPRECATED)",
|
||||
category="image",
|
||||
description="Save a dataset of images to a specified folder. Supported formats: PNG.",
|
||||
display_name="Save Image Dataset to Folder",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
is_input_list=True, # Receive images as list
|
||||
@ -232,7 +226,6 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[],
|
||||
is_deprecated=True, # This node is redundant and superseded by existing Save Image nodes where the target folder can be specified in the filename_prefix
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -253,20 +246,14 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveImageTextDataSetToFolder",
|
||||
search_aliases=["save folder", "save to folder", "save dataset", "save images", "save text", "export dataset"],
|
||||
display_name="Save Image-Text (to Folder)",
|
||||
category="image",
|
||||
description="Save a dataset of pairs of images and text captions to a specified folder. Images are saved as PNG files and captions are saved as TXT files with the same filename_prefix.",
|
||||
display_name="Save Image and Text Dataset to Folder",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
is_input_list=True, # Receive both images and texts as lists
|
||||
inputs=[
|
||||
io.Image.Input("images", tooltip="List of images to save."),
|
||||
io.String.Input("texts",
|
||||
optional=True,
|
||||
force_input=True,
|
||||
tooltip="List of text captions to save."
|
||||
),
|
||||
io.String.Input("texts", tooltip="List of text captions to save."),
|
||||
io.String.Input(
|
||||
"folder_name",
|
||||
default="dataset",
|
||||
@ -283,7 +270,7 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, folder_name, filename_prefix, texts=None):
|
||||
def execute(cls, images, texts, folder_name, filename_prefix):
|
||||
# Extract scalar values
|
||||
folder_name = folder_name[0]
|
||||
filename_prefix = filename_prefix[0]
|
||||
@ -292,12 +279,11 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
||||
saved_files = save_images_to_folder(images, output_dir, filename_prefix)
|
||||
|
||||
# Save captions
|
||||
if texts:
|
||||
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
|
||||
caption_filename = filename.replace(".png", ".txt")
|
||||
caption_path = os.path.join(output_dir, caption_filename)
|
||||
with open(caption_path, "w", encoding="utf-8") as f:
|
||||
f.write(caption)
|
||||
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
|
||||
caption_filename = filename.replace(".png", ".txt")
|
||||
caption_path = os.path.join(output_dir, caption_filename)
|
||||
with open(caption_path, "w", encoding="utf-8") as f:
|
||||
f.write(caption)
|
||||
|
||||
logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.")
|
||||
return io.NodeOutput()
|
||||
@ -328,13 +314,11 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
|
||||
Child classes should set:
|
||||
node_id: Unique node identifier (required)
|
||||
search_aliases: List of search aliases (optional)
|
||||
display_name: Display name (optional, defaults to node_id)
|
||||
description: Node description (optional)
|
||||
extra_inputs: List of additional io.Input objects beyond "images" (optional)
|
||||
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
||||
is_output_list: True (list output) or False (single output) (optional, default True)
|
||||
is_deprecated: True if the node is deprecated (optional, default False)
|
||||
|
||||
Child classes must implement ONE of:
|
||||
_process(cls, image, **kwargs) -> tensor (for single-item processing)
|
||||
@ -342,13 +326,12 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
"""
|
||||
|
||||
node_id = None
|
||||
search_aliases = []
|
||||
display_name = None
|
||||
description = None
|
||||
extra_inputs = []
|
||||
is_group_process = None # None = auto-detect, True/False = explicit
|
||||
is_output_list = None # None = auto-detect based on processing mode
|
||||
is_deprecated = False
|
||||
|
||||
@classmethod
|
||||
def _detect_processing_mode(cls):
|
||||
"""Detect whether this node uses group or individual processing.
|
||||
@ -419,10 +402,8 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
|
||||
return io.Schema(
|
||||
node_id=cls.node_id,
|
||||
search_aliases=cls.search_aliases,
|
||||
display_name=cls.display_name or cls.node_id,
|
||||
category=cls.category,
|
||||
description=cls.description,
|
||||
category="dataset/image",
|
||||
is_experimental=True,
|
||||
is_input_list=is_group, # True for group, False for individual
|
||||
inputs=inputs,
|
||||
@ -491,13 +472,11 @@ class TextProcessingNode(io.ComfyNode):
|
||||
|
||||
Child classes should set:
|
||||
node_id: Unique node identifier (required)
|
||||
search_aliases: List of search aliases (optional)
|
||||
display_name: Display name (optional, defaults to node_id)
|
||||
description: Node description (optional)
|
||||
extra_inputs: List of additional io.Input objects beyond "texts" (optional)
|
||||
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
||||
is_output_list: True (list output) or False (single output) (optional, default True)
|
||||
is_deprecated: True if the node is deprecated (optional, default False)
|
||||
|
||||
Child classes must implement ONE of:
|
||||
_process(cls, text, **kwargs) -> str (for single-item processing)
|
||||
@ -505,13 +484,12 @@ class TextProcessingNode(io.ComfyNode):
|
||||
"""
|
||||
|
||||
node_id = None
|
||||
search_aliases = []
|
||||
display_name = None
|
||||
description = None
|
||||
extra_inputs = []
|
||||
is_group_process = None # None = auto-detect, True/False = explicit
|
||||
is_output_list = None # None = auto-detect based on processing mode
|
||||
is_deprecated = False
|
||||
|
||||
@classmethod
|
||||
def _detect_processing_mode(cls):
|
||||
"""Detect whether this node uses group or individual processing.
|
||||
@ -649,17 +627,15 @@ class TextProcessingNode(io.ComfyNode):
|
||||
|
||||
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesByShorterEdge"
|
||||
display_name = "Resize Images by Shorter Edge (DEPRECATED)"
|
||||
category = "image/transform"
|
||||
description = "Resize images so that the shorter edge matches the specified dimension while preserving aspect ratio."
|
||||
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale shorter dimension
|
||||
display_name = "Resize Images by Shorter Edge"
|
||||
description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio."
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"shorter_edge",
|
||||
default=512,
|
||||
min=1,
|
||||
max=8192,
|
||||
tooltip="Target dimension for the shorter edge.",
|
||||
tooltip="Target length for the shorter edge.",
|
||||
),
|
||||
]
|
||||
|
||||
@ -679,17 +655,15 @@ class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||
|
||||
class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesByLongerEdge"
|
||||
display_name = "Resize Images by Longer Edge (DEPRECATED)"
|
||||
category = "image/transform"
|
||||
description = "Resize images so that the longer edge matches the specified dimension while preserving aspect ratio."
|
||||
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale longer dimension
|
||||
display_name = "Resize Images by Longer Edge"
|
||||
description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio."
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"longer_edge",
|
||||
default=1024,
|
||||
min=1,
|
||||
max=8192,
|
||||
tooltip="Target dimension for the longer edge.",
|
||||
tooltip="Target length for the longer edge.",
|
||||
),
|
||||
]
|
||||
|
||||
@ -712,10 +686,8 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
||||
|
||||
class CenterCropImagesNode(ImageProcessingNode):
|
||||
node_id = "CenterCropImages"
|
||||
search_aliases=["crop", "cut", "trim"]
|
||||
display_name="Crop Image (Center)"
|
||||
category="image/transform"
|
||||
description = "Center crop an image to the specified dimensions."
|
||||
display_name = "Center Crop Images"
|
||||
description = "Center crop all images to the specified dimensions."
|
||||
extra_inputs = [
|
||||
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
|
||||
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
|
||||
@ -734,11 +706,10 @@ class CenterCropImagesNode(ImageProcessingNode):
|
||||
|
||||
class RandomCropImagesNode(ImageProcessingNode):
|
||||
node_id = "RandomCropImages"
|
||||
search_aliases=["crop", "cut", "trim"]
|
||||
display_name = "Crop Image (Random)"
|
||||
category="image/transform"
|
||||
description = "Randomly crop an image to the specified dimensions."
|
||||
|
||||
display_name = "Random Crop Images"
|
||||
description = (
|
||||
"Randomly crop all images to the specified dimensions (for data augmentation)."
|
||||
)
|
||||
extra_inputs = [
|
||||
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
|
||||
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
|
||||
@ -763,9 +734,7 @@ class RandomCropImagesNode(ImageProcessingNode):
|
||||
|
||||
class NormalizeImagesNode(ImageProcessingNode):
|
||||
node_id = "NormalizeImages"
|
||||
search_aliases=["normalize", "normalize colors"]
|
||||
display_name = "Normalize Image Colors"
|
||||
category = "image/color"
|
||||
display_name = "Normalize Images"
|
||||
description = "Normalize images using mean and standard deviation."
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
@ -793,10 +762,8 @@ class NormalizeImagesNode(ImageProcessingNode):
|
||||
|
||||
class AdjustBrightnessNode(ImageProcessingNode):
|
||||
node_id = "AdjustBrightness"
|
||||
search_aliases=["brightness"]
|
||||
display_name = "Adjust Brightness"
|
||||
category="image/adjustments"
|
||||
description = "Adjust the brightness of an image."
|
||||
description = "Adjust brightness of all images."
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
"factor",
|
||||
@ -814,10 +781,8 @@ class AdjustBrightnessNode(ImageProcessingNode):
|
||||
|
||||
class AdjustContrastNode(ImageProcessingNode):
|
||||
node_id = "AdjustContrast"
|
||||
search_aliases=["contrast"]
|
||||
display_name = "Adjust Contrast"
|
||||
category="image/adjustments"
|
||||
description = "Adjust the contrast of an image."
|
||||
description = "Adjust contrast of all images."
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
"factor",
|
||||
@ -835,10 +800,8 @@ class AdjustContrastNode(ImageProcessingNode):
|
||||
|
||||
class ShuffleDatasetNode(ImageProcessingNode):
|
||||
node_id = "ShuffleDataset"
|
||||
search_aliases=["shuffle", "randomize", "mix"]
|
||||
display_name = "Shuffle Images List"
|
||||
category = "image/batch"
|
||||
description = "Randomly shuffle the order of images in a list."
|
||||
display_name = "Shuffle Image Dataset"
|
||||
description = "Randomly shuffle the order of images in the dataset."
|
||||
is_group_process = True # Requires full list to shuffle
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
@ -860,15 +823,13 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ShuffleImageTextDataset",
|
||||
search_aliases=["shuffle", "randomize", "mix"],
|
||||
display_name = "Shuffle Pairs of Image-Text",
|
||||
category = "image/batch",
|
||||
description = "Randomly shuffle the order of pairs of image-text in a list.",
|
||||
display_name="Shuffle Image-Text Dataset",
|
||||
category="dataset/image",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
io.Image.Input("images", tooltip="List of images to shuffle."),
|
||||
io.String.Input("texts", tooltip="List of texts to shuffle.", force_input=True),
|
||||
io.String.Input("texts", tooltip="List of texts to shuffle."),
|
||||
io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
@ -904,11 +865,8 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
|
||||
|
||||
class TextToLowercaseNode(TextProcessingNode):
|
||||
node_id = "TextToLowercase"
|
||||
search_aliases=["lowercase"]
|
||||
display_name = "Convert Text to Lowercase (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Convert text to lowercase."
|
||||
is_deprecated = True # This node is superseded by the Convert Text Case node
|
||||
display_name = "Text to Lowercase"
|
||||
description = "Convert all texts to lowercase."
|
||||
|
||||
@classmethod
|
||||
def _process(cls, text):
|
||||
@ -917,11 +875,8 @@ class TextToLowercaseNode(TextProcessingNode):
|
||||
|
||||
class TextToUppercaseNode(TextProcessingNode):
|
||||
node_id = "TextToUppercase"
|
||||
search_aliases=["uppercase"]
|
||||
display_name = "Convert Text to Uppercase (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Convert text to uppercase."
|
||||
is_deprecated = True # This node is superseded by the Convert Text Case node
|
||||
display_name = "Text to Uppercase"
|
||||
description = "Convert all texts to uppercase."
|
||||
|
||||
@classmethod
|
||||
def _process(cls, text):
|
||||
@ -930,10 +885,8 @@ class TextToUppercaseNode(TextProcessingNode):
|
||||
|
||||
class TruncateTextNode(TextProcessingNode):
|
||||
node_id = "TruncateText"
|
||||
search_aliases=["truncate", "cut", "shorten"]
|
||||
display_name = "Truncate Text"
|
||||
category = "text"
|
||||
description = "Truncate text to a maximum length."
|
||||
description = "Truncate all texts to a maximum length."
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"max_length", default=77, min=1, max=10000, tooltip="Maximum text length."
|
||||
@ -947,10 +900,8 @@ class TruncateTextNode(TextProcessingNode):
|
||||
|
||||
class AddTextPrefixNode(TextProcessingNode):
|
||||
node_id = "AddTextPrefix"
|
||||
display_name = "Add Text Prefix (DEPRECATED)"
|
||||
category = "text"
|
||||
display_name = "Add Text Prefix"
|
||||
description = "Add a prefix to all texts."
|
||||
is_deprecated = True # This node is superseded by the Concatenate Text node
|
||||
extra_inputs = [
|
||||
io.String.Input("prefix", default="", tooltip="Prefix to add."),
|
||||
]
|
||||
@ -962,10 +913,8 @@ class AddTextPrefixNode(TextProcessingNode):
|
||||
|
||||
class AddTextSuffixNode(TextProcessingNode):
|
||||
node_id = "AddTextSuffix"
|
||||
display_name = "Add Text Suffix (DEPRECATED)"
|
||||
category = "text"
|
||||
display_name = "Add Text Suffix"
|
||||
description = "Add a suffix to all texts."
|
||||
is_deprecated = True # This node is superseded by the Concatenate Text node
|
||||
extra_inputs = [
|
||||
io.String.Input("suffix", default="", tooltip="Suffix to add."),
|
||||
]
|
||||
@ -977,10 +926,8 @@ class AddTextSuffixNode(TextProcessingNode):
|
||||
|
||||
class ReplaceTextNode(TextProcessingNode):
|
||||
node_id = "ReplaceText"
|
||||
display_name = "Replace Text (DEPRECATED)"
|
||||
category = "text"
|
||||
display_name = "Replace Text"
|
||||
description = "Replace text in all texts."
|
||||
is_deprecated = True # This node is superseded by the other Replace Text node
|
||||
extra_inputs = [
|
||||
io.String.Input("find", default="", tooltip="Text to find."),
|
||||
io.String.Input("replace", default="", tooltip="Text to replace with."),
|
||||
@ -993,10 +940,8 @@ class ReplaceTextNode(TextProcessingNode):
|
||||
|
||||
class StripWhitespaceNode(TextProcessingNode):
|
||||
node_id = "StripWhitespace"
|
||||
display_name = "Strip Whitespace (DEPRECATED)"
|
||||
category = "text"
|
||||
display_name = "Strip Whitespace"
|
||||
description = "Strip leading and trailing whitespace from all texts."
|
||||
is_deprecated = True # This node is superseded by the Trim Text node
|
||||
|
||||
@classmethod
|
||||
def _process(cls, text):
|
||||
@ -1007,13 +952,11 @@ class StripWhitespaceNode(TextProcessingNode):
|
||||
|
||||
|
||||
class ImageDeduplicationNode(ImageProcessingNode):
|
||||
"""Remove duplicate or very similar images from a list using perceptual hashing."""
|
||||
"""Remove duplicate or very similar images from the dataset using perceptual hashing."""
|
||||
|
||||
node_id = "ImageDeduplication"
|
||||
search_aliases=["deduplicate", "remove duplicates", "similarity filter"]
|
||||
display_name = "Deduplicate Images"
|
||||
category = "image/batch"
|
||||
description = "Remove duplicate or very similar images from a list."
|
||||
display_name = "Image Deduplication"
|
||||
description = "Remove duplicate or very similar images from the dataset."
|
||||
is_group_process = True # Requires full list to compare images
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
@ -1083,9 +1026,7 @@ class ImageGridNode(ImageProcessingNode):
|
||||
"""Combine multiple images into a single grid/collage."""
|
||||
|
||||
node_id = "ImageGrid"
|
||||
search_aliases=["grid", "collage", "combine"]
|
||||
display_name = "Make Image Grid"
|
||||
category="image/batch"
|
||||
display_name = "Image Grid"
|
||||
description = "Arrange multiple images into a grid layout."
|
||||
is_group_process = True # Requires full list to create grid
|
||||
is_output_list = False # Outputs single grid image
|
||||
@ -1161,12 +1102,9 @@ class MergeImageListsNode(ImageProcessingNode):
|
||||
"""Merge multiple image lists into a single list."""
|
||||
|
||||
node_id = "MergeImageLists"
|
||||
search_aliases=["list", "merge list", "make list"]
|
||||
display_name = "Merge Image Lists (DEPRECATED)"
|
||||
category = "image/batch"
|
||||
display_name = "Merge Image Lists"
|
||||
description = "Concatenate multiple image lists into one."
|
||||
is_group_process = True # Receives images as list
|
||||
is_deprecated = True # This node is superseded by the Create List node
|
||||
|
||||
@classmethod
|
||||
def _group_process(cls, images):
|
||||
@ -1181,11 +1119,9 @@ class MergeTextListsNode(TextProcessingNode):
|
||||
"""Merge multiple text lists into a single list."""
|
||||
|
||||
node_id = "MergeTextLists"
|
||||
display_name = "Merge Text Lists (DEPRECATED)"
|
||||
category = "text"
|
||||
display_name = "Merge Text Lists"
|
||||
description = "Concatenate multiple text lists into one."
|
||||
is_group_process = True # Receives texts as list
|
||||
is_deprecated = True # This node is superseded by the Create List node
|
||||
|
||||
@classmethod
|
||||
def _group_process(cls, texts):
|
||||
@ -1206,10 +1142,8 @@ class ResolutionBucket(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ResolutionBucket",
|
||||
search_aliases=["bucket by resolution", "group by resolution", "batch by resolution"],
|
||||
display_name="Resolution Bucket",
|
||||
category="training",
|
||||
description="Group latents and conditionings into buckets",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
@ -1302,8 +1236,7 @@ class MakeTrainingDataset(io.ComfyNode):
|
||||
node_id="MakeTrainingDataset",
|
||||
search_aliases=["encode dataset"],
|
||||
display_name="Make Training Dataset",
|
||||
category="training",
|
||||
description="Encode images with VAE and texts with CLIP to create a training dataset of latents and conditionings.",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_input_list=True, # images and texts as lists
|
||||
inputs=[
|
||||
@ -1318,7 +1251,6 @@ class MakeTrainingDataset(io.ComfyNode):
|
||||
"texts",
|
||||
optional=True,
|
||||
tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).",
|
||||
force_input=True
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
@ -1388,10 +1320,9 @@ class SaveTrainingDataset(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveTrainingDataset",
|
||||
search_aliases=["export dataset", "save dataset"],
|
||||
search_aliases=["export training data"],
|
||||
display_name="Save Training Dataset",
|
||||
category="training",
|
||||
description="Save encoded training dataset (latents + conditioning) to disk for efficient loading during training.",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
is_input_list=True, # Receive lists
|
||||
@ -1493,8 +1424,7 @@ class LoadTrainingDataset(io.ComfyNode):
|
||||
node_id="LoadTrainingDataset",
|
||||
search_aliases=["import dataset", "training data"],
|
||||
display_name="Load Training Dataset",
|
||||
category="training",
|
||||
description="Load encoded training dataset (latents + conditioning) from disk for use in training.",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
|
||||
@ -419,17 +419,15 @@ class VoxelToMeshBasic(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VoxelToMeshBasic",
|
||||
display_name="Voxel to Mesh (Basic) (DEPRECATED)",
|
||||
display_name="Voxel to Mesh (Basic)",
|
||||
category="3d",
|
||||
description="Converts a voxel grid to a mesh.",
|
||||
is_deprecated=True, # This node is superseded by the Voxel To Mesh node
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output(),
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -455,10 +453,9 @@ class VoxelToMesh(IO.ComfyNode):
|
||||
node_id="VoxelToMesh",
|
||||
display_name="Voxel to Mesh",
|
||||
category="3d",
|
||||
description="Converts a voxel grid to a mesh.",
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
IO.Combo.Input("algorithm", options=["surface net", "basic"]),
|
||||
IO.Combo.Input("algorithm", options=["surface net", "basic"], advanced=True),
|
||||
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
|
||||
@ -3,15 +3,23 @@ from __future__ import annotations
|
||||
import nodes
|
||||
import folder_paths
|
||||
|
||||
import av
|
||||
import json
|
||||
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
import numpy as np
|
||||
import struct
|
||||
import torch
|
||||
|
||||
import zlib
|
||||
import comfy.utils
|
||||
from fractions import Fraction
|
||||
|
||||
from server import PromptServer
|
||||
from comfy_api.latest import ComfyExtension, IO, UI
|
||||
from comfy.cli_args import args
|
||||
from typing_extensions import override
|
||||
|
||||
SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later.
|
||||
@ -55,10 +63,9 @@ class ImageCropV2(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCropV2",
|
||||
search_aliases=["crop", "cut", "trim"],
|
||||
search_aliases=["trim"],
|
||||
display_name="Crop Image",
|
||||
category="image/transform",
|
||||
description = "Crop an image to the specified dimensions.",
|
||||
essentials_category="Image Tools",
|
||||
has_intermediate_output=True,
|
||||
inputs=[
|
||||
@ -835,6 +842,405 @@ class ImageMergeTileList(IO.ComfyNode):
|
||||
return IO.NodeOutput(merged_image)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Format specifications
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Maps (file_format, bit_depth, has_alpha) -> (numpy dtype scale, av pixel format,
|
||||
# stream pix_fmt). Keeps the encode path declarative instead of branchy.
|
||||
_FORMAT_SPECS = {
|
||||
("png", "8-bit", False): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "rgb24", "stream_fmt": "rgb24"},
|
||||
("png", "8-bit", True): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "rgba", "stream_fmt": "rgba"},
|
||||
("png", "16-bit", False): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "rgb48le", "stream_fmt": "rgb48be"},
|
||||
("png", "16-bit", True): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "rgba64le", "stream_fmt": "rgba64be"},
|
||||
("exr", "32-bit float", False): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "gbrpf32le", "stream_fmt": "gbrpf32le"},
|
||||
("exr", "32-bit float", True): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "gbrapf32le", "stream_fmt": "gbrapf32le"},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Color transforms
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def srgb_to_linear(t: torch.Tensor) -> torch.Tensor:
|
||||
"""Inverse sRGB EOTF (IEC 61966-2-1). Operates on RGB channels only;
|
||||
alpha (if present as the 4th channel) is passed through unchanged."""
|
||||
if t.shape[-1] == 4:
|
||||
rgb, alpha = t[..., :3], t[..., 3:]
|
||||
return torch.cat([srgb_to_linear(rgb), alpha], dim=-1)
|
||||
|
||||
# Piecewise: linear toe below 0.04045, gamma curve above.
|
||||
low = t / 12.92
|
||||
high = ((t.clamp(min=0.0) + 0.055) / 1.055) ** 2.4
|
||||
return torch.where(t <= 0.04045, low, high)
|
||||
|
||||
|
||||
# HLG OETF constants from BT.2100 Table 5.
|
||||
_HLG_A = 0.17883277
|
||||
_HLG_B = 0.28466892
|
||||
_HLG_C = 0.55991072928 # = 0.5 - a*ln(4*a)
|
||||
|
||||
|
||||
def hlg_to_linear(t: torch.Tensor) -> torch.Tensor:
|
||||
"""Inverse HLG OETF (BT.2100). Maps a non-linear HLG signal in [0, 1] to
|
||||
*scene*-linear light in [0, 1]. Per BT.2100 Note 5a, this is the correct
|
||||
transform when converting HLG to a linear scene-light representation
|
||||
(rather than display-light, which would also involve the HLG OOTF).
|
||||
|
||||
Operates on RGB channels only; alpha is passed through unchanged."""
|
||||
if t.shape[-1] == 4:
|
||||
rgb, alpha = t[..., :3], t[..., 3:]
|
||||
return torch.cat([hlg_to_linear(rgb), alpha], dim=-1)
|
||||
|
||||
# Piecewise: sqrt branch below 0.5, log branch above.
|
||||
# Clamp inside the log branch so negative / out-of-range values don't blow up;
|
||||
# values above 1.0 are allowed and extrapolate naturally.
|
||||
low = (t ** 2) / 3.0
|
||||
high = (torch.exp((t.clamp(min=_HLG_C) - _HLG_C) / _HLG_A) + _HLG_B) / 12.0
|
||||
return torch.where(t <= 0.5, low, high)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metadata injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
|
||||
|
||||
|
||||
def _png_chunk(chunk_type: bytes, data: bytes) -> bytes:
|
||||
"""Build a single PNG chunk: length | type | data | CRC32(type+data)."""
|
||||
crc = zlib.crc32(chunk_type + data) & 0xFFFFFFFF
|
||||
return struct.pack(">I", len(data)) + chunk_type + data + struct.pack(">I", crc)
|
||||
|
||||
|
||||
def _png_text_chunk(keyword: str, text: str) -> bytes:
|
||||
"""tEXt chunk: latin-1 keyword + NUL + latin-1 text."""
|
||||
payload = keyword.encode("latin-1") + b"\x00" + text.encode("latin-1", errors="replace")
|
||||
return _png_chunk(b"tEXt", payload)
|
||||
|
||||
|
||||
def inject_png_metadata(png_bytes: bytes, prompt: dict | None, extra_pnginfo: dict | None) -> bytes:
|
||||
"""Insert ComfyUI prompt/workflow as tEXt chunks right after IHDR."""
|
||||
if not png_bytes.startswith(_PNG_SIGNATURE):
|
||||
return png_bytes
|
||||
|
||||
chunks: list[bytes] = []
|
||||
if prompt is not None:
|
||||
chunks.append(_png_text_chunk("prompt", json.dumps(prompt)))
|
||||
if extra_pnginfo:
|
||||
for key, value in extra_pnginfo.items():
|
||||
chunks.append(_png_text_chunk(key, json.dumps(value)))
|
||||
if not chunks:
|
||||
return png_bytes
|
||||
|
||||
# IHDR is always the first chunk; insert ours immediately after it.
|
||||
ihdr_length = struct.unpack(">I", png_bytes[8:12])[0]
|
||||
ihdr_end = 8 + 8 + ihdr_length + 4 # signature + (len+type) + data + crc
|
||||
return png_bytes[:ihdr_end] + b"".join(chunks) + png_bytes[ihdr_end:]
|
||||
|
||||
|
||||
# Standard chromaticities (CIE 1931 xy) for the colorspaces this node writes.
|
||||
# Each tuple is (Rx, Ry, Gx, Gy, Bx, By, Wx, Wy). All share D65 white point.
|
||||
_CHROMATICITIES = {
|
||||
# ITU-R BT.709 / sRGB primaries
|
||||
"Rec.709": (0.6400, 0.3300, 0.3000, 0.6000, 0.1500, 0.0600, 0.3127, 0.3290),
|
||||
# ITU-R BT.2020 (UHDTV / wide-gamut HDR) primaries
|
||||
"Rec.2020": (0.7080, 0.2920, 0.1700, 0.7970, 0.1310, 0.0460, 0.3127, 0.3290),
|
||||
}
|
||||
|
||||
|
||||
def _pack_chromaticities(primaries: tuple) -> bytes:
|
||||
"""Serialize 8 chromaticity floats into the EXR `chromaticities` payload."""
|
||||
return struct.pack("<8f", *primaries)
|
||||
|
||||
|
||||
def _exr_attribute(name: str, attr_type: str, value: bytes) -> bytes:
|
||||
"""Serialize one EXR header attribute: name\\0 type\\0 size:int32 value."""
|
||||
return (
|
||||
name.encode("utf-8") + b"\x00"
|
||||
+ attr_type.encode("utf-8") + b"\x00"
|
||||
+ struct.pack("<i", len(value))
|
||||
+ value
|
||||
)
|
||||
|
||||
|
||||
def inject_exr_metadata(
|
||||
exr_bytes: bytes,
|
||||
prompt: dict | None,
|
||||
extra_pnginfo: dict | None,
|
||||
colorspace: str | None = None,
|
||||
) -> bytes:
|
||||
"""Insert ComfyUI metadata and color-space info into an EXR header.
|
||||
|
||||
Color: EXR pixels are linear by convention. The standard way to describe
|
||||
their RGB→XYZ relationship is the `chromaticities` attribute. We pick the
|
||||
primaries that match what the user told us their input was:
|
||||
|
||||
colorspace="sRGB" → Rec. 709 / sRGB primaries (D65)
|
||||
colorspace="HDR" → Rec. 2020 / BT.2100 primaries (D65)
|
||||
|
||||
Pixels are always converted to linear scene light upstream (sRGB EOTF
|
||||
inverse for sRGB; HLG OETF inverse for HDR), so the file content is
|
||||
scene-linear in the indicated gamut. OpenEXR has no standard transfer-
|
||||
function attribute (the OpenEXR TSC has discussed adding one but it
|
||||
doesn't exist), so we don't invent one — `chromaticities` plus the EXR
|
||||
linear-by-convention rule fully specifies the color.
|
||||
|
||||
Prompt/workflow: written as plain `string` attributes using the same keys
|
||||
(`prompt`, `workflow`, ...) that Comfy uses for PNG tEXt chunks, so the
|
||||
same readers can pull them out symmetrically.
|
||||
|
||||
Implementation note: the chunk-offset table that follows the header stores
|
||||
*absolute* byte offsets into the file. Inserting N bytes into the header
|
||||
means every offset must be incremented by N or the file becomes unreadable.
|
||||
"""
|
||||
if len(exr_bytes) < 8 or exr_bytes[:4] != b"\x76\x2f\x31\x01":
|
||||
return exr_bytes
|
||||
|
||||
new_blob = b""
|
||||
if prompt is not None:
|
||||
new_blob += _exr_attribute("prompt", "string", json.dumps(prompt).encode("utf-8"))
|
||||
if extra_pnginfo:
|
||||
for key, value in extra_pnginfo.items():
|
||||
new_blob += _exr_attribute(key, "string", json.dumps(value).encode("utf-8"))
|
||||
if colorspace is not None:
|
||||
# Map each colorspace option to the RGB primaries the linear pixels
|
||||
# are now in. "sRGB" and "linear" both produce Rec. 709 linear; "HDR"
|
||||
# (HLG-encoded Rec. 2020 input) produces Rec. 2020 linear.
|
||||
primaries_name = {
|
||||
"sRGB": "Rec.709",
|
||||
"linear": "Rec.709",
|
||||
"HDR": "Rec.2020",
|
||||
}.get(colorspace, "Rec.709")
|
||||
new_blob += _exr_attribute(
|
||||
"chromaticities",
|
||||
"chromaticities",
|
||||
_pack_chromaticities(_CHROMATICITIES[primaries_name]),
|
||||
)
|
||||
if not new_blob:
|
||||
return exr_bytes
|
||||
|
||||
# Walk header attributes to find the terminating null byte, and pick up
|
||||
# dataWindow + compression so we know how many chunks the offset table has.
|
||||
pos = 8 # past magic (4) + version (4)
|
||||
data_window = None
|
||||
compression = 0
|
||||
while pos < len(exr_bytes) and exr_bytes[pos] != 0:
|
||||
name_end = exr_bytes.index(b"\x00", pos)
|
||||
attr_name = exr_bytes[pos:name_end].decode("latin-1", errors="replace")
|
||||
type_end = exr_bytes.index(b"\x00", name_end + 1)
|
||||
attr_type = exr_bytes[name_end + 1:type_end].decode("latin-1", errors="replace")
|
||||
size = struct.unpack("<i", exr_bytes[type_end + 1:type_end + 5])[0]
|
||||
value_start = type_end + 5
|
||||
value = exr_bytes[value_start:value_start + size]
|
||||
|
||||
if attr_name == "dataWindow" and attr_type == "box2i":
|
||||
data_window = struct.unpack("<iiii", value) # xMin, yMin, xMax, yMax
|
||||
elif attr_name == "compression" and attr_type == "compression":
|
||||
compression = value[0]
|
||||
|
||||
pos = value_start + size
|
||||
|
||||
if data_window is None:
|
||||
return exr_bytes # required attribute missing — don't risk corrupting
|
||||
|
||||
# Scanlines per chunk by compression, from the OpenEXR spec.
|
||||
scanlines_per_block = {
|
||||
0: 1, # NO_COMPRESSION
|
||||
1: 1, # RLE
|
||||
2: 1, # ZIPS
|
||||
3: 16, # ZIP
|
||||
4: 32, # PIZ
|
||||
5: 16, # PXR24
|
||||
6: 32, # B44
|
||||
7: 32, # B44A
|
||||
8: 256, # DWAA
|
||||
9: 256, # DWAB
|
||||
}.get(compression, 1)
|
||||
|
||||
_, y_min, _, y_max = data_window
|
||||
height = y_max - y_min + 1
|
||||
num_chunks = (height + scanlines_per_block - 1) // scanlines_per_block
|
||||
|
||||
header_end = pos # position of the terminating null byte
|
||||
table_start = header_end + 1
|
||||
pixel_start = table_start + num_chunks * 8
|
||||
delta = len(new_blob)
|
||||
|
||||
old_offsets = struct.unpack(f"<{num_chunks}Q", exr_bytes[table_start:pixel_start])
|
||||
new_table = struct.pack(f"<{num_chunks}Q", *(o + delta for o in old_offsets))
|
||||
|
||||
return (
|
||||
exr_bytes[:header_end] # header attributes
|
||||
+ new_blob # our new attributes
|
||||
+ exr_bytes[header_end:table_start] # terminating null byte
|
||||
+ new_table # shifted offset table
|
||||
+ exr_bytes[pixel_start:] # pixel data, untouched
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _encode_image(
|
||||
img_tensor: torch.Tensor,
|
||||
file_format: str,
|
||||
bit_depth: str,
|
||||
colorspace: str,
|
||||
) -> bytes:
|
||||
"""Encode a single HxWxC tensor to PNG or EXR bytes in memory.
|
||||
|
||||
For EXR the input is interpreted according to `colorspace` and converted
|
||||
to scene-linear (EXR's convention) before writing:
|
||||
|
||||
"sRGB" → input is sRGB-encoded Rec. 709; apply inverse sRGB EOTF.
|
||||
"HDR" → input is HLG-encoded Rec. 2020 (BT.2100); apply inverse HLG
|
||||
OETF to get scene-linear, per BT.2100 Note 5a.
|
||||
"linear" → input is already scene-linear (Rec. 709 primaries); write
|
||||
through unchanged. Use this for renderer/compositor output.
|
||||
|
||||
For PNG, colorspace selection does not modify pixels — PNG is delivered
|
||||
sRGB-encoded and there is no PNG path for wide-gamut HDR in this node.
|
||||
"""
|
||||
height, width, num_channels = img_tensor.shape
|
||||
has_alpha = num_channels == 4
|
||||
|
||||
spec = _FORMAT_SPECS[(file_format, bit_depth, has_alpha)]
|
||||
|
||||
if spec["dtype"] == np.float32:
|
||||
# EXR path: preserve full range, no clamp.
|
||||
if colorspace == "sRGB":
|
||||
img_tensor = srgb_to_linear(img_tensor)
|
||||
elif colorspace == "HDR":
|
||||
img_tensor = hlg_to_linear(img_tensor)
|
||||
img_np = img_tensor.cpu().numpy().astype(np.float32)
|
||||
else:
|
||||
# PNG path: quantize to integer range.
|
||||
scaled = (img_tensor * spec["scale"]).clamp(0, spec["scale"])
|
||||
img_np = scaled.to(torch.int32).cpu().numpy().astype(spec["dtype"])
|
||||
|
||||
# Encode directly via CodecContext. PyAV's `image2` muxer does NOT write to
|
||||
# BytesIO (it expects a real file path), so we bypass the container entirely.
|
||||
# For single-frame PNG/EXR the raw codec output IS the file.
|
||||
codec = av.CodecContext.create(file_format, "w")
|
||||
codec.width = width
|
||||
codec.height = height
|
||||
codec.pix_fmt = spec["stream_fmt"]
|
||||
codec.time_base = Fraction(1, 1)
|
||||
|
||||
frame = av.VideoFrame.from_ndarray(img_np, format=spec["frame_fmt"])
|
||||
if spec["frame_fmt"] != spec["stream_fmt"]:
|
||||
frame = frame.reformat(format=spec["stream_fmt"])
|
||||
frame.pts = 0
|
||||
frame.time_base = codec.time_base
|
||||
|
||||
packets = list(codec.encode(frame)) + list(codec.encode(None)) # flush with None
|
||||
return b"".join(bytes(p) for p in packets)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SaveImageAdvanced(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveImageAdvanced",
|
||||
search_aliases=["save", "save image", "export image", "output image", "write image"],
|
||||
display_name="Save Image (Advanced)",
|
||||
description="Saves the input images to your ComfyUI output directory.",
|
||||
category="image",
|
||||
essentials_category="Basics",
|
||||
inputs=[
|
||||
IO.Image.Input("images", tooltip="The images to save."),
|
||||
IO.String.Input(
|
||||
"filename_prefix",
|
||||
default="ComfyUI",
|
||||
tooltip=(
|
||||
"The prefix for the file to save. May include formatting tokens "
|
||||
"such as %date:yyyy-MM-dd% or %Empty Latent Image.width%."
|
||||
),
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"format",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("png", [
|
||||
IO.Combo.Input("bit_depth", options=["8-bit", "16-bit"],
|
||||
default="8-bit", advanced=True),
|
||||
IO.Combo.Input("input_color_space", options=["sRGB"],
|
||||
default="sRGB", advanced=True),
|
||||
]),
|
||||
IO.DynamicCombo.Option("exr", [
|
||||
IO.Combo.Input("bit_depth", options=["32-bit float"],
|
||||
default="32-bit float", advanced=True),
|
||||
IO.Combo.Input(
|
||||
"input_color_space",
|
||||
options=["sRGB", "HDR", "linear"],
|
||||
default="sRGB",
|
||||
advanced=True,
|
||||
tooltip=(
|
||||
"Colorspace of the input tensor. The EXR is "
|
||||
"always written as scene-linear in the matching "
|
||||
"gamut.\n"
|
||||
" 'sRGB' — input is sRGB-encoded Rec.709; "
|
||||
"the inverse sRGB EOTF is applied.\n"
|
||||
" 'HDR' — input is HLG-encoded Rec.2020 "
|
||||
"(BT.2100); the inverse HLG OETF is applied "
|
||||
"to get scene-linear light.\n"
|
||||
" 'linear' — input is already scene-linear "
|
||||
"(Rec.709 primaries); written through unchanged. "
|
||||
"Use this for renderer/compositor output."
|
||||
),
|
||||
),
|
||||
]),
|
||||
],
|
||||
tooltip="The file format in which to save the image.",
|
||||
),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, filename_prefix: str, format: dict) -> IO.NodeOutput:
|
||||
file_format = format["format"]
|
||||
bit_depth = format["bit_depth"]
|
||||
colorspace = format.get("input_color_space", "sRGB")
|
||||
|
||||
output_dir = folder_paths.get_output_directory()
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = (
|
||||
folder_paths.get_save_image_path(
|
||||
filename_prefix, output_dir, images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
)
|
||||
|
||||
prompt = cls.hidden.prompt
|
||||
extra_pnginfo = cls.hidden.extra_pnginfo
|
||||
write_metadata = not args.disable_metadata
|
||||
|
||||
results = []
|
||||
for batch_number, image in enumerate(images):
|
||||
encoded = _encode_image(image, file_format, bit_depth, colorspace)
|
||||
|
||||
if write_metadata:
|
||||
if file_format == "png":
|
||||
encoded = inject_png_metadata(encoded, prompt, extra_pnginfo)
|
||||
elif file_format == "exr":
|
||||
encoded = inject_exr_metadata(encoded, prompt, extra_pnginfo, colorspace)
|
||||
|
||||
name = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{name}_{counter:05}.{file_format}"
|
||||
with open(os.path.join(full_output_folder, file), "wb") as f:
|
||||
f.write(encoded)
|
||||
|
||||
results.append({"filename": file, "subfolder": subfolder, "type": "output"})
|
||||
counter += 1
|
||||
|
||||
return IO.NodeOutput(ui={"images": results})
|
||||
|
||||
|
||||
class ImagesExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -847,6 +1253,7 @@ class ImagesExtension(ComfyExtension):
|
||||
ImageAddNoise,
|
||||
SaveAnimatedWEBP,
|
||||
SaveAnimatedPNG,
|
||||
SaveImageAdvanced,
|
||||
SaveSVGNode,
|
||||
ImageStitch,
|
||||
ResizeAndPadImage,
|
||||
|
||||
@ -11,8 +11,8 @@ class LTXVAudioVAELoader(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAELoader",
|
||||
display_name="Load LTXV Audio VAE",
|
||||
category="loaders",
|
||||
display_name="LTXV Audio VAE Loader",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"ckpt_name",
|
||||
@ -40,7 +40,7 @@ class LTXVAudioVAEEncode(VAEEncodeAudio):
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEEncode",
|
||||
display_name="LTXV Audio VAE Encode",
|
||||
category="latent/audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Audio.Input("audio", tooltip="The audio to be encoded."),
|
||||
io.Vae.Input(
|
||||
@ -63,7 +63,7 @@ class LTXVAudioVAEDecode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEDecode",
|
||||
display_name="LTXV Audio VAE Decode",
|
||||
category="latent/audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Latent.Input("samples", tooltip="The latent to be decoded."),
|
||||
io.Vae.Input(
|
||||
|
||||
@ -28,7 +28,7 @@ from comfy_extras.mediapipe.face_landmarker import FaceLandmarker
|
||||
from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection
|
||||
|
||||
|
||||
FaceDetectionType = io.Custom("FACE_DETECTION_MODEL")
|
||||
FaceLandmarkerType = io.Custom("FACE_LANDMARKER")
|
||||
FaceLandmarksType = io.Custom("FACE_LANDMARKS")
|
||||
|
||||
_CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights")
|
||||
@ -204,19 +204,18 @@ class LoadMediaPipeFaceLandmarker(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadMediaPipeFaceLandmarker",
|
||||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
|
||||
display_name="Load Face Detection Model (MediaPipe)",
|
||||
display_name="Load MediaPipe Face Landmarker",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("detection"),
|
||||
tooltip="Face detection model from models/detection/."),
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("mediapipe"),
|
||||
tooltip="Face Landmarker safetensors from models/mediapipe/."),
|
||||
],
|
||||
outputs=[FaceDetectionType.Output()],
|
||||
outputs=[FaceLandmarkerType.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name) -> io.NodeOutput:
|
||||
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("detection", model_name), safe_load=True)
|
||||
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("mediapipe", model_name), safe_load=True)
|
||||
wrapper = FaceLandmarkerModel(sd)
|
||||
return io.NodeOutput(wrapper)
|
||||
|
||||
@ -235,12 +234,10 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MediaPipeFaceLandmarker",
|
||||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
|
||||
display_name="Detect Face Landmarks (MediaPipe)",
|
||||
display_name="MediaPipe Face Landmarker",
|
||||
category="image/detection",
|
||||
description="Detects facial landmarks using MediaPipe model.",
|
||||
inputs=[
|
||||
FaceDetectionType.Input("face_detection_model"),
|
||||
FaceLandmarkerType.Input("face_landmarker"),
|
||||
io.Image.Input("image"),
|
||||
io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short",
|
||||
tooltip="Face detector range. 'short' is tuned for close-up faces "
|
||||
@ -264,9 +261,9 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, face_detection_model, image, detector_variant, num_faces, min_confidence,
|
||||
def execute(cls, face_landmarker, image, detector_variant, num_faces, min_confidence,
|
||||
missing_frame_fallback) -> io.NodeOutput:
|
||||
canonical = face_detection_model.canonical_data
|
||||
canonical = face_landmarker.canonical_data
|
||||
img_np = _image_to_uint8(image)
|
||||
B, H, W = img_np.shape[:3]
|
||||
chunk = 16
|
||||
@ -279,7 +276,7 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq:
|
||||
for i in range(0, B, chunk):
|
||||
end = min(i + chunk, B)
|
||||
res.extend(face_detection_model.detect_batch(
|
||||
res.extend(face_landmarker.detect_batch(
|
||||
[img_np[bi] for bi in range(i, end)],
|
||||
num_faces=int(num_faces),
|
||||
score_thresh=float(min_confidence),
|
||||
@ -309,7 +306,7 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])})
|
||||
bboxes.append(per_bb)
|
||||
return io.NodeOutput({"frames": frames, "image_size": (H, W),
|
||||
"connection_sets": face_detection_model.connection_sets}, bboxes)
|
||||
"connection_sets": face_landmarker.connection_sets}, bboxes)
|
||||
|
||||
|
||||
# Topology keys unioned by the 'all' connections preset (contour parts + irises + nose).
|
||||
@ -335,10 +332,8 @@ class MediaPipeFaceMeshVisualize(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MediaPipeFaceMeshVisualize",
|
||||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection", "visualize"],
|
||||
display_name="Visualize Face Landmarks (MediaPipe)",
|
||||
display_name="MediaPipe Face Mesh Visualize",
|
||||
category="image/detection",
|
||||
description="Draws face landmarks mesh on the input image.",
|
||||
inputs=[
|
||||
FaceLandmarksType.Input("face_landmarks"),
|
||||
io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."),
|
||||
@ -448,10 +443,8 @@ class MediaPipeFaceMask(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MediaPipeFaceMask",
|
||||
search_aliases=["face", "facial", "mediapipe", "face mask", "blazeface", "face detection", "visualize"],
|
||||
display_name="Draw Face Mask (MediaPipe)",
|
||||
display_name="MediaPipe Face Mask",
|
||||
category="image/detection",
|
||||
description="Draws a mask from face landmarks.",
|
||||
inputs=[
|
||||
FaceLandmarksType.Input("face_landmarks"),
|
||||
io.DynamicCombo.Input(
|
||||
|
||||
@ -103,10 +103,8 @@ class MoGePanoramaInference(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGePanoramaInference",
|
||||
search_aliases=["moge", "panorama", "depth", "geometry", "depth estimation", "geometry estimation"],
|
||||
display_name="Run MoGe Panorama Inference",
|
||||
display_name="MoGe Panorama Inference",
|
||||
category="image/geometry_estimation",
|
||||
description="Run MoGe on an equirectangular panorama by splitting it into 12 perspective views, running inference on each, and merging the results into a single depth map.",
|
||||
inputs=[
|
||||
MoGeModelType.Input("moge_model"),
|
||||
io.Image.Input("image", tooltip="Equirectangular panorama (any aspect)."),
|
||||
@ -224,9 +222,7 @@ class MoGeInference(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGeInference",
|
||||
search_aliases=["moge", "depth", "geometry", "depth estimation", "geometry estimation"],
|
||||
display_name="Run MoGe Inference",
|
||||
description="Run MoGe on a single image to estimate depth and geometry.",
|
||||
display_name="MoGe Inference",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeModelType.Input("moge_model"),
|
||||
@ -281,9 +277,7 @@ class MoGeRender(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGeRender",
|
||||
search_aliases=["moge", "render", "geometry", "depth", "normal"],
|
||||
display_name="Render MoGe Geometry",
|
||||
description="Render a depth map or normal map from geometry data",
|
||||
display_name="MoGe Render",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeGeometry.Input("moge_geometry"),
|
||||
@ -348,9 +342,7 @@ class MoGePointMapToMesh(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGePointMapToMesh",
|
||||
search_aliases=["moge", "mesh", "geometry", "point map"],
|
||||
display_name="Convert MoGe Point Map to Mesh",
|
||||
description="Convert a MoGe point map into a 3D mesh.",
|
||||
display_name="MoGe Point Map to Mesh",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeGeometry.Input("moge_geometry"),
|
||||
|
||||
@ -60,7 +60,7 @@ folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geom
|
||||
|
||||
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["detection"] = ([os.path.join(models_dir, "detection")], supported_pt_extensions)
|
||||
folder_names_and_paths["mediapipe"] = ([os.path.join(models_dir, "mediapipe")], supported_pt_extensions)
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
|
||||
58
openapi.yaml
58
openapi.yaml
@ -1517,22 +1517,6 @@ paths:
|
||||
schema:
|
||||
type: integer
|
||||
default: 0
|
||||
description: |
|
||||
Offset-based pagination. Cursor pagination via `after` is preferred
|
||||
for sequential walks (stable across concurrent inserts/deletes) but
|
||||
`offset` remains fully supported for random access (jump-to-page
|
||||
UIs, "showing items X–Y of N" displays). When both are supplied,
|
||||
`after` wins and `offset` is ignored.
|
||||
- name: after
|
||||
in: query
|
||||
schema:
|
||||
type: string
|
||||
description: |
|
||||
Opaque cursor for keyset pagination. Pass the `next_cursor` value
|
||||
from a previous response to fetch the next page. Stable across
|
||||
inserts/deletes between pages. Supported with `sort` values
|
||||
`created_at`, `updated_at`, `name`, and `size`. Malformed or
|
||||
unsupported cursors return 400 with `INVALID_CURSOR`.
|
||||
- name: include_tags
|
||||
in: query
|
||||
schema:
|
||||
@ -1591,12 +1575,6 @@ paths:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ListAssetsResponse"
|
||||
"400":
|
||||
description: Malformed query or cursor (e.g. `INVALID_CURSOR`)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/AssetsApiError"
|
||||
post:
|
||||
operationId: createAsset
|
||||
tags: [assets]
|
||||
@ -6772,42 +6750,6 @@ components:
|
||||
type: integer
|
||||
has_more:
|
||||
type: boolean
|
||||
next_cursor:
|
||||
type: string
|
||||
description: |
|
||||
Opaque cursor to fetch the next page. Pass back as the `after`
|
||||
query parameter. Omitted when there are no more results.
|
||||
|
||||
AssetsApiError:
|
||||
type: object
|
||||
description: Error envelope returned by the assets API on 400 responses.
|
||||
required:
|
||||
- error
|
||||
properties:
|
||||
error:
|
||||
type: object
|
||||
required:
|
||||
- code
|
||||
- message
|
||||
- details
|
||||
properties:
|
||||
code:
|
||||
type: string
|
||||
description: |
|
||||
Machine-readable error code. `INVALID_CURSOR` is returned when the
|
||||
`after` cursor is malformed, oversized, or its sort field does
|
||||
not match the request's `sort`. `INVALID_QUERY` covers other
|
||||
Pydantic validation failures.
|
||||
enum: [INVALID_CURSOR, INVALID_QUERY]
|
||||
message:
|
||||
type: string
|
||||
details:
|
||||
type: object
|
||||
description: |
|
||||
Free-form, code-specific context. `INVALID_QUERY` populates this
|
||||
with an `errors` array of Pydantic validation entries;
|
||||
`INVALID_CURSOR` returns an empty object.
|
||||
additionalProperties: true
|
||||
|
||||
TagInfo:
|
||||
type: object
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.43.18
|
||||
comfyui-workflow-templates==0.9.82
|
||||
comfyui-workflow-templates==0.9.79
|
||||
comfyui-embedded-docs==0.5.0
|
||||
torch
|
||||
torchsde
|
||||
|
||||
@ -1,112 +0,0 @@
|
||||
"""Keyset-pagination tiebreaker tests for list_references_page.
|
||||
|
||||
When multiple rows share the same primary sort value (e.g. four assets
|
||||
created in the same microsecond), the secondary `ORDER BY id` is what keeps
|
||||
keyset pagination from losing or repeating rows. This file exercises that
|
||||
branch directly against an in-memory SQLite session — engineering identical
|
||||
timestamps via HTTP is unreliable enough that we work at the query layer.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.database.queries.asset_reference import list_references_page
|
||||
|
||||
|
||||
def _make_ref(session: Session, created_at: datetime, name: str, owner: str = "") -> AssetReference:
|
||||
asset = Asset(hash=f"blake3:{uuid.uuid4().hex}", size_bytes=1024)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
ref = AssetReference(
|
||||
id=str(uuid.uuid4()),
|
||||
asset_id=asset.id,
|
||||
owner_id=owner,
|
||||
name=name,
|
||||
file_path=f"/tmp/{name}",
|
||||
created_at=created_at,
|
||||
updated_at=created_at,
|
||||
last_access_time=created_at,
|
||||
is_missing=False,
|
||||
)
|
||||
session.add(ref)
|
||||
return ref
|
||||
|
||||
|
||||
@pytest.mark.parametrize("order", ["desc", "asc"])
|
||||
def test_tiebreaker_walks_duplicate_sort_values(session: Session, order: str):
|
||||
"""Four rows with the SAME created_at must paginate cleanly under cursor
|
||||
mode — no row dropped, no row repeated, despite the primary sort column
|
||||
being non-discriminating.
|
||||
"""
|
||||
shared_ts = datetime(2024, 5, 20, 12, 0, 0) # naive UTC, like the DB stores
|
||||
refs = [_make_ref(session, shared_ts, f"tie_{i}.png") for i in range(4)]
|
||||
session.commit()
|
||||
|
||||
expected_ids = sorted([r.id for r in refs], reverse=(order == "desc"))
|
||||
|
||||
# Walk the cursor by hand: page size 2, take 3 pages (2 + 2 + 0).
|
||||
seen: list[str] = []
|
||||
after_value = None
|
||||
after_id = None
|
||||
for _ in range(4): # generous loop bound; ought to be 2 iterations
|
||||
page, _tag_map, _total = list_references_page(
|
||||
session,
|
||||
limit=2,
|
||||
sort="created_at",
|
||||
order=order,
|
||||
after_cursor_value=after_value,
|
||||
after_cursor_id=after_id,
|
||||
)
|
||||
if not page:
|
||||
break
|
||||
seen.extend(p.id for p in page)
|
||||
# Use the last row's (created_at, id) as the next cursor input.
|
||||
last = page[-1]
|
||||
after_value, after_id = last.created_at, last.id
|
||||
if len(page) < 2:
|
||||
break
|
||||
|
||||
assert seen == expected_ids, (
|
||||
f"keyset tiebreaker failed for order={order}: expected {expected_ids}, got {seen}"
|
||||
)
|
||||
|
||||
|
||||
def test_tiebreaker_no_duplicates_under_mixed_collisions(session: Session):
|
||||
"""Some rows share a timestamp, some don't. The cursor must still walk
|
||||
every row exactly once regardless of where ties sit relative to a
|
||||
page boundary."""
|
||||
t1 = datetime(2024, 5, 20, 12, 0, 0)
|
||||
t2 = datetime(2024, 5, 20, 12, 0, 1)
|
||||
layout = [t1, t1, t1, t2, t2] # three rows at t1, two at t2
|
||||
refs = [_make_ref(session, ts, f"mix_{i}.png") for i, ts in enumerate(layout)]
|
||||
session.commit()
|
||||
|
||||
all_ids = {r.id for r in refs}
|
||||
seen_set: set[str] = set()
|
||||
seen_list: list[str] = []
|
||||
after_value = None
|
||||
after_id = None
|
||||
for _ in range(6):
|
||||
page, _, _ = list_references_page(
|
||||
session,
|
||||
limit=2,
|
||||
sort="created_at",
|
||||
order="desc",
|
||||
after_cursor_value=after_value,
|
||||
after_cursor_id=after_id,
|
||||
)
|
||||
if not page:
|
||||
break
|
||||
for p in page:
|
||||
assert p.id not in seen_set, f"duplicate row {p.id} appeared in cursor walk"
|
||||
seen_set.add(p.id)
|
||||
seen_list.append(p.id)
|
||||
last = page[-1]
|
||||
after_value, after_id = last.created_at, last.id
|
||||
if len(page) < 2:
|
||||
break
|
||||
|
||||
assert seen_set == all_ids, f"missing rows: expected {all_ids}, got {seen_set}"
|
||||
@ -1,354 +0,0 @@
|
||||
"""Tests for app.assets.services.cursor.
|
||||
|
||||
The byte-identity fixtures below pin the wire format so a parallel
|
||||
implementation in another runtime can mint exchange-compatible cursors
|
||||
for the same payload. Drift here would break frontend pagination against
|
||||
any compatible backend.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.services.cursor import (
|
||||
MAX_CURSOR_ID_LENGTH,
|
||||
MAX_CURSOR_VALUE_LENGTH,
|
||||
MAX_ENCODED_CURSOR_LENGTH,
|
||||
CursorPayload,
|
||||
InvalidCursorError,
|
||||
decode_cursor,
|
||||
decode_cursor_int,
|
||||
decode_cursor_time,
|
||||
encode_cursor,
|
||||
encode_cursor_from_time,
|
||||
)
|
||||
|
||||
|
||||
ALLOWED = ("created_at", "updated_at", "name", "size")
|
||||
|
||||
|
||||
class TestRoundTrip:
|
||||
@pytest.mark.parametrize(
|
||||
"sort_field, value, id",
|
||||
[
|
||||
("created_at", "1716200000000000", "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7"),
|
||||
("size", "1024", "asset-123"),
|
||||
("name", "my-asset.png", "asset-abc"),
|
||||
("name", "résumé.txt", "asset-uni"),
|
||||
],
|
||||
)
|
||||
def test_encode_decode(self, sort_field, value, id):
|
||||
encoded = encode_cursor(sort_field, value, id)
|
||||
assert encoded != ""
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.sort_field == sort_field
|
||||
assert payload.value == value
|
||||
assert payload.id == id
|
||||
|
||||
|
||||
class TestTimeCursor:
|
||||
def test_microsecond_precision_preserved(self):
|
||||
# Pick a time with non-zero microseconds — encoding at ms would lose the µs.
|
||||
ts = datetime(2024, 5, 20, 12, 53, 20, 123456, tzinfo=timezone.utc)
|
||||
encoded = encode_cursor_from_time("created_at", ts, "id-1")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
# Value must be a microsecond integer string, not a millisecond one.
|
||||
assert payload.value == "1716209600123456"
|
||||
decoded = decode_cursor_time(payload)
|
||||
assert decoded == ts
|
||||
|
||||
def test_decode_returns_utc(self):
|
||||
payload = CursorPayload(sort_field="created_at", value="1716200000123456", id="id-1", order="desc")
|
||||
decoded = decode_cursor_time(payload)
|
||||
assert decoded.tzinfo == timezone.utc
|
||||
|
||||
def test_naive_datetime_rejected_on_encode(self):
|
||||
naive = datetime(2024, 5, 20, 12, 0, 0)
|
||||
with pytest.raises(ValueError):
|
||||
encode_cursor_from_time("created_at", naive, "id-1")
|
||||
|
||||
def test_non_integer_value_rejected_on_decode(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_time(CursorPayload("created_at", "not-a-number", "id-1", "desc"))
|
||||
|
||||
def test_none_payload_rejected(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_time(None)
|
||||
|
||||
def test_non_utc_aware_normalized(self):
|
||||
# Same instant, different timezone — must encode to the same micros.
|
||||
utc_ts = datetime(2024, 5, 20, 12, 0, 0, tzinfo=timezone.utc)
|
||||
offset_ts = utc_ts.astimezone(timezone(timedelta(hours=-5)))
|
||||
assert encode_cursor_from_time("created_at", utc_ts, "x") == encode_cursor_from_time(
|
||||
"created_at", offset_ts, "x"
|
||||
)
|
||||
|
||||
|
||||
class TestIntCursor:
|
||||
def test_decode_int(self):
|
||||
assert decode_cursor_int(CursorPayload("size", "1024", "id-1", "desc")) == 1024
|
||||
|
||||
def test_decode_int_rejects_non_int(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_int(CursorPayload("size", "abc", "id-1", "desc"))
|
||||
|
||||
def test_decode_int_rejects_none(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_int(None)
|
||||
|
||||
|
||||
class TestInvalidInputs:
|
||||
def test_oversized_cursor(self):
|
||||
oversized = "a" * (MAX_ENCODED_CURSOR_LENGTH + 1)
|
||||
with pytest.raises(InvalidCursorError, match="maximum length"):
|
||||
decode_cursor(oversized, ALLOWED)
|
||||
|
||||
def test_not_base64(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor("not base64!!!", ALLOWED)
|
||||
|
||||
def test_not_json(self):
|
||||
encoded = base64.urlsafe_b64encode(b"definitely not json").rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_empty_id(self):
|
||||
# Encoder rejects empty id symmetrically with the decoder, so build the
|
||||
# payload manually to exercise the decoder's missing-id branch.
|
||||
raw = b'{"s":"created_at","v":"1","id":"","o":"desc"}'
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="missing id"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_oversized_id(self):
|
||||
# Encoder enforces the cap symmetrically; hand-build to exercise decode.
|
||||
big_id = "a" * (MAX_CURSOR_ID_LENGTH + 1)
|
||||
raw = ('{"s":"created_at","v":"1","id":"' + big_id + '","o":"desc"}').encode("ascii")
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="id exceeds maximum length"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_oversized_value(self):
|
||||
# Encoder enforces the cap symmetrically; hand-build to exercise decode.
|
||||
big_v = "v" * (MAX_CURSOR_VALUE_LENGTH + 1)
|
||||
raw = ('{"s":"created_at","v":"' + big_v + '","id":"id-1","o":"desc"}').encode("ascii")
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="value exceeds maximum length"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_unsupported_sort_field(self):
|
||||
encoded = encode_cursor("execution_time", "1", "id-1")
|
||||
with pytest.raises(InvalidCursorError, match="unsupported sort field"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_no_allowed_fields_rejects_everything(self):
|
||||
encoded = encode_cursor("created_at", "1", "id-1")
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor(encoded, ())
|
||||
|
||||
def test_non_dict_payload_rejected(self):
|
||||
encoded = base64.urlsafe_b64encode(b'["array","not","dict"]').rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="expected object"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
|
||||
class TestEncodeAtCapsFits:
|
||||
def test_max_field_lengths_fit_wire_cap(self):
|
||||
# Worst-case payload: value and id at their per-field caps, with a long
|
||||
# sort field name. The encoded cursor must fit within MAX_ENCODED_CURSOR_LENGTH
|
||||
# so the wire cap cannot reject a cursor the encoder mints at the per-field caps.
|
||||
value = "v" * MAX_CURSOR_VALUE_LENGTH
|
||||
id = "i" * MAX_CURSOR_ID_LENGTH
|
||||
sort_field = "very_long_sort_field_name"
|
||||
|
||||
encoded = encode_cursor(sort_field, value, id)
|
||||
assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH
|
||||
payload = decode_cursor(encoded, (sort_field,))
|
||||
assert payload.value == value
|
||||
assert payload.id == id
|
||||
|
||||
|
||||
class TestDatetimeOverflow:
|
||||
"""Crafted cursors with extreme micros must map to InvalidCursorError,
|
||||
not OverflowError/OSError leaking as 500.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"micros_str",
|
||||
[
|
||||
"999999999999999999999", # 10^21 µs — past datetime.MAX_YEAR by ~14 orders
|
||||
"-999999999999999999999", # symmetric negative — pre-epoch overflow
|
||||
],
|
||||
)
|
||||
def test_out_of_range_micros_rejected(self, micros_str):
|
||||
encoded = encode_cursor("created_at", micros_str, "asset-x")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_time(payload)
|
||||
|
||||
|
||||
class TestEncoderDecoderSymmetry:
|
||||
"""The encoder must reject inputs the decoder rejects, or the same server
|
||||
will mint a cursor it then 400s on the next request.
|
||||
"""
|
||||
|
||||
def test_long_name_within_cap_round_trips(self):
|
||||
"""Assets allow names up to 512 chars (`String(512)`); the cursor
|
||||
encoder must round-trip a value at that cap so a freshly minted
|
||||
cursor never fails decode on the next request."""
|
||||
long_name = "n" * MAX_CURSOR_VALUE_LENGTH
|
||||
encoded = encode_cursor("name", long_name, "asset-x")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.value == long_name
|
||||
|
||||
def test_encoder_rejects_empty_id(self):
|
||||
with pytest.raises(InvalidCursorError, match="id must be non-empty"):
|
||||
encode_cursor("created_at", "1", "")
|
||||
|
||||
def test_encoder_rejects_oversized_id(self):
|
||||
with pytest.raises(InvalidCursorError, match="id exceeds maximum length"):
|
||||
encode_cursor("created_at", "1", "a" * (MAX_CURSOR_ID_LENGTH + 1))
|
||||
|
||||
def test_encoder_rejects_oversized_value(self):
|
||||
with pytest.raises(InvalidCursorError, match="value exceeds maximum length"):
|
||||
encode_cursor("name", "v" * (MAX_CURSOR_VALUE_LENGTH + 1), "id-1")
|
||||
|
||||
def test_encoder_rejects_multibyte_value_over_wire_cap(self):
|
||||
"""A value that passes the char-count cap can still inflate past the
|
||||
wire cap once UTF-8-encoded. Asset name made of 512 × multibyte
|
||||
characters (e.g. 'é' = 2 bytes) must be rejected at encode time, not
|
||||
minted into a cursor the next request will 400."""
|
||||
with pytest.raises(InvalidCursorError, match="encoded cursor exceeds maximum length"):
|
||||
encode_cursor("name", "é" * MAX_CURSOR_VALUE_LENGTH, "asset-multibyte")
|
||||
|
||||
def test_encoder_rejects_escape_heavy_value_over_wire_cap(self):
|
||||
"""Same wire-cap concern via escape expansion: each `<` serializes to
|
||||
the six-byte sequence `\\u003c`, so 512 of them blow past the encoded
|
||||
cap even though the raw char count is within the per-field limit."""
|
||||
with pytest.raises(InvalidCursorError, match="encoded cursor exceeds maximum length"):
|
||||
encode_cursor("name", "<" * MAX_CURSOR_VALUE_LENGTH, "asset-escape")
|
||||
|
||||
|
||||
class TestOrderBinding:
|
||||
def test_order_baked_into_payload(self):
|
||||
encoded = encode_cursor("created_at", "1", "id-1", order="asc")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.order == "asc"
|
||||
|
||||
def test_mismatched_order_rejected(self):
|
||||
encoded = encode_cursor("created_at", "1", "id-1", order="desc")
|
||||
with pytest.raises(InvalidCursorError, match="does not match request order"):
|
||||
decode_cursor(encoded, ALLOWED, expected_order="asc")
|
||||
|
||||
def test_matching_order_accepted(self):
|
||||
encoded = encode_cursor("created_at", "1", "id-1", order="desc")
|
||||
payload = decode_cursor(encoded, ALLOWED, expected_order="desc")
|
||||
assert payload.order == "desc"
|
||||
|
||||
def test_invalid_order_token_rejected_on_encode(self):
|
||||
with pytest.raises(ValueError):
|
||||
encode_cursor("created_at", "1", "id-1", order="sideways")
|
||||
|
||||
def test_invalid_order_token_rejected_on_decode(self):
|
||||
# Hand-craft a payload with an illegal `o` value.
|
||||
raw = b'{"s":"name","v":"x","id":"id-1","o":"sideways"}'
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="unsupported order"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_cursor_without_order_rejected(self):
|
||||
"""`o` is mandatory. A cursor minted without it is rejected as
|
||||
malformed rather than silently walking the keyset in whatever
|
||||
direction the request happens to ask for."""
|
||||
raw = b'{"s":"name","v":"x","id":"id-1"}'
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="missing or non-string o"):
|
||||
decode_cursor(encoded, ALLOWED, expected_order="desc")
|
||||
|
||||
|
||||
class TestHtmlSignificantCharEscaping:
|
||||
"""An asset name containing `<`, `>`, `&`, U+2028, or U+2029 must encode
|
||||
to the same escaped wire bytes as any compatible implementation of the
|
||||
same payload format. Drift here breaks cross-runtime byte-identity for
|
||||
those characters.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, escaped_substring",
|
||||
[
|
||||
("foo<bar>.png", "\\u003c"), # `<` escaped
|
||||
("foo<bar>.png", "\\u003e"), # `>` escaped
|
||||
("foo&bar.png", "\\u0026"),
|
||||
("foo
bar.png", "\\u2028"), # JS line separator
|
||||
("foo
bar.png", "\\u2029"), # JS paragraph separator
|
||||
],
|
||||
)
|
||||
def test_html_significant_chars_escaped(self, value, escaped_substring):
|
||||
encoded = encode_cursor("name", value, "id-1")
|
||||
decoded_bytes = base64.urlsafe_b64decode(encoded + "=" * (-len(encoded) % 4))
|
||||
assert escaped_substring in decoded_bytes.decode("ascii"), (
|
||||
f"Expected {escaped_substring!r} in serialized payload, got: {decoded_bytes!r}"
|
||||
)
|
||||
|
||||
def test_value_round_trips_through_escape(self):
|
||||
"""Encoding then decoding a value with `<>&` should yield the original
|
||||
string — the escape only affects the wire form, not the decoded value."""
|
||||
original = "foo<&>bar.png"
|
||||
encoded = encode_cursor("name", original, "id-1")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.value == original
|
||||
|
||||
|
||||
class TestByteIdentityFixtures:
|
||||
"""Pin the wire format so it doesn't drift silently.
|
||||
|
||||
These fixtures assert exact byte equality of the encoded JSON payload —
|
||||
a change in key order, escape choice, separator whitespace, or anything
|
||||
else that shifts a byte fails the test loudly rather than diverging
|
||||
silently from any external consumer of the same payload format.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sort_field, value, id, order, expected_payload",
|
||||
[
|
||||
(
|
||||
"created_at",
|
||||
"1716200000000000",
|
||||
"a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7",
|
||||
"desc",
|
||||
'{"s":"created_at","v":"1716200000000000","id":"a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7","o":"desc"}',
|
||||
),
|
||||
(
|
||||
"size",
|
||||
"1024",
|
||||
"asset-123",
|
||||
"asc",
|
||||
'{"s":"size","v":"1024","id":"asset-123","o":"asc"}',
|
||||
),
|
||||
(
|
||||
"name",
|
||||
"my-asset.png",
|
||||
"asset-abc",
|
||||
"desc",
|
||||
'{"s":"name","v":"my-asset.png","id":"asset-abc","o":"desc"}',
|
||||
),
|
||||
(
|
||||
"name",
|
||||
"foo<bar>&baz.png",
|
||||
"asset-html",
|
||||
"desc",
|
||||
# `<`, `>`, `&` escape to <, >, & in the value.
|
||||
'{"s":"name","v":"foo\\u003cbar\\u003e\\u0026baz.png","id":"asset-html","o":"desc"}',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_encoded_payload_shape_pinned(self, sort_field, value, id, order, expected_payload):
|
||||
encoded = encode_cursor(sort_field, value, id, order=order)
|
||||
decoded_bytes = base64.urlsafe_b64decode(encoded + "=" * (-len(encoded) % 4))
|
||||
assert decoded_bytes.decode("utf-8") == expected_payload, (
|
||||
f"wire format drifted for sort={sort_field!r}, value={value!r}:\n"
|
||||
f" expected: {expected_payload!r}\n"
|
||||
f" actual: {decoded_bytes.decode('utf-8')!r}"
|
||||
)
|
||||
@ -1,349 +0,0 @@
|
||||
"""Integration tests for cursor-based pagination on GET /api/assets.
|
||||
|
||||
These tests exercise the handler/service/query path end-to-end;
|
||||
cursor-encoding-level tests live in
|
||||
tests-unit/assets_test/services/test_cursor.py.
|
||||
"""
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def _seed(asset_factory, make_asset_bytes, count: int, tag: str) -> list[str]:
|
||||
names = [f"cursor_{i:02d}.safetensors" for i in range(count)]
|
||||
for n in names:
|
||||
asset_factory(
|
||||
n,
|
||||
["models", "checkpoints", "unit-tests", tag],
|
||||
{},
|
||||
make_asset_bytes(n, size=2048),
|
||||
)
|
||||
return sorted(names)
|
||||
|
||||
|
||||
def test_cursor_pages_all_items_in_order(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
names = _seed(asset_factory, make_asset_bytes, count=5, tag="cursor-walk")
|
||||
|
||||
params = {
|
||||
"include_tags": "unit-tests,cursor-walk",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "2",
|
||||
}
|
||||
|
||||
seen: list[str] = []
|
||||
after: str | None = None
|
||||
pages = 0
|
||||
while True:
|
||||
page_params = dict(params)
|
||||
if after is not None:
|
||||
page_params["after"] = after
|
||||
r = http.get(api_base + "/api/assets", params=page_params, timeout=120)
|
||||
assert r.status_code == 200, r.text
|
||||
body = r.json()
|
||||
seen.extend(a["name"] for a in body["assets"])
|
||||
pages += 1
|
||||
after = body.get("next_cursor")
|
||||
if after is None:
|
||||
break
|
||||
assert body["has_more"] is True
|
||||
assert pages < 10, "guard against runaway cursor loop"
|
||||
|
||||
assert seen == names, f"expected {names}, got {seen}"
|
||||
# Last page should have has_more False
|
||||
assert body["has_more"] is False
|
||||
assert "next_cursor" not in body
|
||||
|
||||
|
||||
def test_cursor_invalid_returns_400(http: requests.Session, api_base: str):
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"after": "not-a-real-cursor", "sort": "created_at"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 400, r.text
|
||||
body = r.json()
|
||||
assert body["error"]["code"] == "INVALID_CURSOR"
|
||||
|
||||
|
||||
def test_cursor_sort_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
_seed(asset_factory, make_asset_bytes, count=2, tag="cursor-mismatch")
|
||||
|
||||
# Take a real cursor minted for sort=name.
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-mismatch",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "1",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
cursor = r.json()["next_cursor"]
|
||||
assert cursor is not None
|
||||
|
||||
# Replay against sort=created_at — should fail with INVALID_CURSOR.
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"after": cursor, "sort": "created_at"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 400, r2.text
|
||||
assert r2.json()["error"]["code"] == "INVALID_CURSOR"
|
||||
|
||||
|
||||
def test_cursor_wins_over_offset(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-vs-offset")
|
||||
|
||||
# Take a cursor that points past the first item.
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-vs-offset",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "1",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
cursor = r.json()["next_cursor"]
|
||||
assert cursor is not None
|
||||
|
||||
# Pass both 'after' and a large offset. Cursor must win; offset is ignored.
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-vs-offset",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "1",
|
||||
"after": cursor,
|
||||
"offset": "999",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 200
|
||||
body = r2.json()
|
||||
# Should land on the second name in sorted order — not skip ahead by 999.
|
||||
assert [a["name"] for a in body["assets"]] == [names[1]]
|
||||
|
||||
|
||||
def test_next_cursor_absent_when_no_more_results(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
_seed(asset_factory, make_asset_bytes, count=2, tag="cursor-exhaust")
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-exhaust",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "50",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
body = r.json()
|
||||
assert body["has_more"] is False
|
||||
assert "next_cursor" not in body
|
||||
|
||||
|
||||
def test_cursor_pagination_first_page_mints_cursor(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
"""First-page request (no `after`) must still return `next_cursor` when
|
||||
more rows exist, or pagination is unreachable from a cold start.
|
||||
"""
|
||||
_seed(asset_factory, make_asset_bytes, count=3, tag="cursor-first-page")
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,cursor-first-page", "sort": "name", "order": "asc", "limit": "2"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
body = r.json()
|
||||
assert body["has_more"] is True
|
||||
assert body.get("next_cursor"), "first page must mint a cursor when more rows exist"
|
||||
|
||||
|
||||
def test_cursor_no_spurious_cursor_when_page_size_equals_remainder(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
"""When `total` is an exact multiple of `limit`, the final page must
|
||||
NOT carry a next_cursor — there is nothing past it.
|
||||
"""
|
||||
_seed(asset_factory, make_asset_bytes, count=4, tag="cursor-exact-multiple")
|
||||
# Page 1
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
cursor = r.json()["next_cursor"]
|
||||
assert cursor is not None
|
||||
# Page 2 — should exhaust the set with no cursor for a phantom page 3
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2", "after": cursor},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 200, r2.text
|
||||
body = r2.json()
|
||||
assert len(body["assets"]) == 2
|
||||
assert body["has_more"] is False
|
||||
assert "next_cursor" not in body
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sort_field", ["created_at", "updated_at", "size"])
|
||||
def test_cursor_walks_for_non_name_sorts(sort_field, http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
"""Cursor pagination must work for every sort field the contract claims.
|
||||
|
||||
Without this, the `created_at` / `updated_at` (time-encoded micros) and
|
||||
`size` (int-encoded) cursor paths go entirely unexercised end-to-end.
|
||||
"""
|
||||
# Sizes increase strictly by index, so `size desc` has a deterministic
|
||||
# expected order. Time-based sorts (created_at / updated_at) can tie when
|
||||
# rows are inserted faster than the DB's timestamp resolution; for those
|
||||
# we check coverage and no-duplicates and let the keyset tiebreaker do
|
||||
# the rest, instead of sleeping between inserts and asserting an order
|
||||
# that depends on clock granularity.
|
||||
names = []
|
||||
for i in range(4):
|
||||
n = f"cursor_{sort_field}_{i:02d}.safetensors"
|
||||
asset_factory(n, ["models", "checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i))
|
||||
names.append(n)
|
||||
|
||||
params = {
|
||||
"include_tags": f"unit-tests,cursor-{sort_field}",
|
||||
"sort": sort_field,
|
||||
"order": "desc",
|
||||
"limit": "2",
|
||||
}
|
||||
seen: list[str] = []
|
||||
after: str | None = None
|
||||
pages = 0
|
||||
while True:
|
||||
page_params = dict(params)
|
||||
if after is not None:
|
||||
page_params["after"] = after
|
||||
r = http.get(api_base + "/api/assets", params=page_params, timeout=120)
|
||||
assert r.status_code == 200, r.text
|
||||
body = r.json()
|
||||
seen.extend(a["name"] for a in body["assets"])
|
||||
after = body.get("next_cursor")
|
||||
pages += 1
|
||||
if after is None:
|
||||
break
|
||||
assert pages < 10, "guard against runaway cursor loop"
|
||||
|
||||
# No duplicates: a faulty keyset boundary that returns the same row across
|
||||
# two pages must fail this check.
|
||||
assert len(seen) == len(set(seen)), (
|
||||
f"cursor walk repeated rows for sort={sort_field}: {seen}"
|
||||
)
|
||||
# Full coverage: every seeded asset reached exactly once.
|
||||
assert set(seen) == set(names), (
|
||||
f"missing items for sort={sort_field}: expected {set(names)}, got {set(seen)}"
|
||||
)
|
||||
# Strict order check for the only field with a clock-independent ordering.
|
||||
if sort_field == "size":
|
||||
assert seen == list(reversed(names)), (
|
||||
f"size cursor walked out of order: got {seen}, expected {list(reversed(names))}"
|
||||
)
|
||||
|
||||
|
||||
def test_cursor_order_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
"""A cursor minted under desc order replayed against asc must 400, not
|
||||
silently walk the wrong direction."""
|
||||
_seed(asset_factory, make_asset_bytes, count=3, tag="cursor-order-flip")
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-order-flip",
|
||||
"sort": "name",
|
||||
"order": "desc",
|
||||
"limit": "1",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
cursor = r.json()["next_cursor"]
|
||||
assert cursor is not None
|
||||
|
||||
# Replay with order flipped to asc — server must reject the cursor.
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-order-flip",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "1",
|
||||
"after": cursor,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 400, r2.text
|
||||
assert r2.json()["error"]["code"] == "INVALID_CURSOR"
|
||||
|
||||
|
||||
def test_cursor_invalid_cursor_at_microsecond_boundary(http: requests.Session, api_base: str):
|
||||
"""A cursor carrying an out-of-range microsecond timestamp must map to
|
||||
400 INVALID_CURSOR, not 500."""
|
||||
import base64
|
||||
import json
|
||||
# 10^18 microseconds ≈ year 33658, well past datetime.MAX_YEAR.
|
||||
# `o` and `order=` must be set; otherwise decode fails earlier on the
|
||||
# missing-order branch and the µs-overflow path is never exercised.
|
||||
payload = {"s": "created_at", "o": "desc", "v": "999999999999999999999", "id": "asset-x"}
|
||||
raw = json.dumps(payload, separators=(",", ":")).encode("utf-8")
|
||||
cursor = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"after": cursor, "sort": "created_at", "order": "desc"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 400, r.text
|
||||
assert r.json()["error"]["code"] == "INVALID_CURSOR"
|
||||
|
||||
|
||||
def test_cursor_pagination_stable_after_delete(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-delete")
|
||||
|
||||
# Page 1.
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-delete",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "2",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
page1_names = [a["name"] for a in body["assets"]]
|
||||
cursor = body["next_cursor"]
|
||||
assert cursor is not None
|
||||
assert page1_names == names[:2]
|
||||
|
||||
# Delete an item from page 1 (already returned) — cursor should still
|
||||
# locate the next page from where it was minted, not re-index.
|
||||
target_id = body["assets"][0]["id"]
|
||||
d = http.delete(api_base + f"/api/assets/{target_id}", timeout=120)
|
||||
assert d.status_code in (200, 204), d.text
|
||||
|
||||
# Page 2 via cursor.
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-delete",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "2",
|
||||
"after": cursor,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 200, r2.text
|
||||
body2 = r2.json()
|
||||
assert [a["name"] for a in body2["assets"]] == names[2:]
|
||||
Reference in New Issue
Block a user