mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-22 01:00:07 +08:00
Compare commits
7 Commits
matt/asset
...
alexis/uti
| Author | SHA1 | Date | |
|---|---|---|---|
| c4f4bd3fd9 | |||
| 4259a0c7c3 | |||
| af3d9b60af | |||
| 7b7c5fed7c | |||
| 1668aaf037 | |||
| ea174d3f12 | |||
| 9f9b32ed97 |
@ -401,16 +401,12 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
)
|
||||
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
# tag[1] may be the standalone category ("checkpoints") or the
|
||||
# slash-joined shape ("checkpoints/flux/...") that
|
||||
# `get_name_and_tags_from_asset_path` and cloud both emit. Match
|
||||
# `resolve_destination_from_tags` by extracting the first segment.
|
||||
category = spec.tags[1].split("/", 1)[0] if len(spec.tags) >= 2 else ""
|
||||
if (
|
||||
len(spec.tags) < 2
|
||||
or category not in folder_paths.folder_names_and_paths
|
||||
or spec.tags[1] not in folder_paths.folder_names_and_paths
|
||||
):
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
category = spec.tags[1] if len(spec.tags) >= 2 else ""
|
||||
return _build_error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{category}'"
|
||||
)
|
||||
|
||||
@ -327,12 +327,7 @@ def list_references_page(
|
||||
select(AssetReferenceTag.asset_reference_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
|
||||
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
|
||||
# Preserve insertion order so the structural first tag (the root
|
||||
# category like "models") stays in position 0 and the path-derived
|
||||
# sub-path tag stays in position 1, matching cloud's behavior.
|
||||
# tag_name is a deterministic tiebreaker when multiple tags share
|
||||
# an added_at (same-batch insert via set_reference_tags).
|
||||
.order_by(AssetReferenceTag.added_at.asc(), AssetReferenceTag.tag_name.asc())
|
||||
.order_by(AssetReferenceTag.tag_name.asc())
|
||||
)
|
||||
for ref_id, tag_name in rows.all():
|
||||
tag_map[ref_id].append(tag_name)
|
||||
@ -360,8 +355,7 @@ def fetch_reference_asset_and_tags(
|
||||
build_visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetReference.tags))
|
||||
# See list_references_page for the rationale behind ordering by added_at.
|
||||
.order_by(AssetReferenceTag.added_at.asc(), Tag.name.asc())
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -21,12 +20,7 @@ from app.assets.database.queries.common import (
|
||||
build_visible_owner_clause,
|
||||
iter_row_chunks,
|
||||
)
|
||||
from app.assets.helpers import (
|
||||
escape_sql_like_string,
|
||||
expand_bucket_prefixes,
|
||||
get_utc_now,
|
||||
normalize_tags,
|
||||
)
|
||||
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -50,26 +44,6 @@ class SetTagsResult:
|
||||
total: list[str]
|
||||
|
||||
|
||||
def _next_added_at_base(session: Session, reference_id: str) -> datetime:
|
||||
"""Return a timestamp strictly greater than any existing
|
||||
`added_at` for this reference. On platforms where the wall clock
|
||||
has insufficient resolution between back-to-back commits (notably
|
||||
Windows), two write batches on the same reference can otherwise
|
||||
share a microsecond — the `ORDER BY added_at, tag_name` retrieval
|
||||
then falls back to the alphabetic tiebreaker and user-tier tags
|
||||
sort ahead of path-tier tags they were meant to follow.
|
||||
"""
|
||||
existing_max = session.execute(
|
||||
sa.select(sa.func.max(AssetReferenceTag.added_at)).where(
|
||||
AssetReferenceTag.asset_reference_id == reference_id
|
||||
)
|
||||
).scalar()
|
||||
now = get_utc_now()
|
||||
if existing_max is None:
|
||||
return now
|
||||
return max(existing_max + timedelta(microseconds=1), now)
|
||||
|
||||
|
||||
def validate_tags_exist(session: Session, tags: list[str]) -> None:
|
||||
"""Raise ValueError if any of the given tag names do not exist."""
|
||||
existing_tag_names = set(
|
||||
@ -103,13 +77,7 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
|
||||
session.execute(
|
||||
select(AssetReferenceTag.tag_name)
|
||||
.where(AssetReferenceTag.asset_reference_id == reference_id)
|
||||
# Match the response-path ordering used by
|
||||
# list_references_page / fetch_reference_asset_and_tags so
|
||||
# upload responses and subsequent GETs agree on tag order.
|
||||
.order_by(
|
||||
AssetReferenceTag.added_at.asc(),
|
||||
AssetReferenceTag.tag_name.asc(),
|
||||
)
|
||||
.order_by(AssetReferenceTag.tag_name.asc())
|
||||
)
|
||||
).all()
|
||||
]
|
||||
@ -121,7 +89,7 @@ def set_reference_tags(
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> SetTagsResult:
|
||||
desired = expand_bucket_prefixes(normalize_tags(tags))
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(get_reference_tags(session, reference_id))
|
||||
|
||||
@ -130,22 +98,15 @@ def set_reference_tags(
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
# Stagger added_at by microsecond per tag so the retrieval ORDER BY
|
||||
# added_at preserves input order. Per-tag get_utc_now() calls can
|
||||
# collide at microsecond resolution on fast machines, dropping the
|
||||
# query to the tag_name alphabetical tiebreaker — same fix as in
|
||||
# batch_insert_seed_assets. Read max(existing) so this batch sorts
|
||||
# strictly after any prior batch on the same reference.
|
||||
base_ts = _next_added_at_base(session, reference_id)
|
||||
session.add_all(
|
||||
[
|
||||
AssetReferenceTag(
|
||||
asset_reference_id=reference_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=base_ts + timedelta(microseconds=i),
|
||||
added_at=get_utc_now(),
|
||||
)
|
||||
for i, t in enumerate(to_add)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
@ -175,7 +136,7 @@ def add_tags_to_reference(
|
||||
if not ref:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
norm = expand_bucket_prefixes(normalize_tags(tags))
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_reference_tags(session, reference_id=reference_id)
|
||||
return AddTagsResult(added=[], already_present=[], total_tags=total)
|
||||
@ -185,17 +146,10 @@ def add_tags_to_reference(
|
||||
|
||||
current = set(get_reference_tags(session, reference_id))
|
||||
|
||||
# Preserve the caller's insertion order rather than alphabetizing —
|
||||
# the retrieval ORDER BY added_at + microsecond stagger only meaningfully
|
||||
# preserves insertion order if "the order we insert in" actually matches
|
||||
# the caller's intent.
|
||||
want = set(norm)
|
||||
to_add = [t for t in norm if t not in current]
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
# See set_reference_tags for the rationale behind the per-tag stagger
|
||||
# and the max(existing) seed.
|
||||
base_ts = _next_added_at_base(session, reference_id)
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
@ -204,9 +158,9 @@ def add_tags_to_reference(
|
||||
asset_reference_id=reference_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=base_ts + timedelta(microseconds=i),
|
||||
added_at=get_utc_now(),
|
||||
)
|
||||
for i, t in enumerate(to_add)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
@ -47,50 +47,6 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
|
||||
|
||||
|
||||
def _known_bucket_prefixes() -> set[str]:
|
||||
"""Lowercased model-category names eligible for standalone-prefix
|
||||
expansion. Tags whose first slash segment matches one of these get
|
||||
the bucket inserted as a separate token, so FE filters like
|
||||
``include_tags=models,checkpoints`` keep matching even when the
|
||||
asset lives in a nested subfolder (`models/checkpoints/flux/foo`).
|
||||
|
||||
Bare user labels with slashes whose first segment is not a registered
|
||||
bucket (e.g. ``my-org/team-a``) pass through unchanged.
|
||||
"""
|
||||
try:
|
||||
import folder_paths
|
||||
|
||||
return {
|
||||
name.lower()
|
||||
for name in folder_paths.folder_names_and_paths.keys()
|
||||
if name != "custom_nodes"
|
||||
}
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
|
||||
def expand_bucket_prefixes(tags: list[str]) -> list[str]:
|
||||
"""Insert standalone bucket tokens after any slash-joined tag whose
|
||||
first segment is a registered model category. Preserves caller order
|
||||
and is idempotent (existing bucket tokens are not duplicated).
|
||||
"""
|
||||
if not tags:
|
||||
return list(tags)
|
||||
buckets = _known_bucket_prefixes()
|
||||
if not buckets:
|
||||
return list(tags)
|
||||
seen = set(tags)
|
||||
result: list[str] = []
|
||||
for t in tags:
|
||||
result.append(t)
|
||||
if "/" in t:
|
||||
prefix = t.split("/", 1)[0]
|
||||
if prefix.lower() in buckets and prefix not in seen:
|
||||
result.append(prefix)
|
||||
seen.add(prefix)
|
||||
return result
|
||||
|
||||
|
||||
def validate_blake3_hash(s: str) -> str:
|
||||
"""Validate and normalize a blake3 hash string.
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@ -13,14 +13,13 @@ from app.assets.database.queries import (
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
bulk_insert_tags_and_meta,
|
||||
delete_assets_by_ids,
|
||||
ensure_tags_exist,
|
||||
get_existing_asset_ids,
|
||||
get_reference_ids_by_ids,
|
||||
get_references_by_paths_and_asset_ids,
|
||||
get_unreferenced_unhashed_asset_ids,
|
||||
restore_references_by_paths,
|
||||
)
|
||||
from app.assets.helpers import expand_bucket_prefixes, get_utc_now
|
||||
from app.assets.helpers import get_utc_now
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.assets.services.metadata_extract import ExtractedMetadata
|
||||
@ -234,20 +233,13 @@ def batch_insert_seed_assets(
|
||||
if ref_id not in inserted_ref_ids:
|
||||
continue
|
||||
|
||||
# Stagger added_at by microsecond per tag within a reference so
|
||||
# the retrieval ORDER BY added_at preserves the input list order
|
||||
# (the path-derived root category stays at position 0). Without
|
||||
# this, every tag in a bulk-insert batch shares current_time and
|
||||
# the tag_name tiebreaker sorts them alphabetically — putting the
|
||||
# subpath tag ahead of "models" since "c"/"d"/"l" < "m".
|
||||
ref_tags = expand_bucket_prefixes(ref_data["tags"])
|
||||
for tag_idx, tag in enumerate(ref_tags):
|
||||
for tag in ref_data["tags"]:
|
||||
tag_rows.append(
|
||||
{
|
||||
"asset_reference_id": ref_id,
|
||||
"tag_name": tag,
|
||||
"origin": "automatic",
|
||||
"added_at": current_time + timedelta(microseconds=tag_idx),
|
||||
"added_at": current_time,
|
||||
}
|
||||
)
|
||||
|
||||
@ -269,16 +261,6 @@ def batch_insert_seed_assets(
|
||||
}
|
||||
)
|
||||
|
||||
if tag_rows:
|
||||
# Bucket-prefix expansion may have introduced tags the caller did
|
||||
# not register via the upstream tag_pool (e.g. `checkpoints` for a
|
||||
# nested `checkpoints/flux/foo` path). Pre-register the full set so
|
||||
# the AssetReferenceTag.tag_name FK is satisfied; the underlying
|
||||
# insert is ON CONFLICT DO NOTHING so re-registration is idempotent.
|
||||
ensure_tags_exist(
|
||||
session, {row["tag_name"] for row in tag_rows}, tag_type="user"
|
||||
)
|
||||
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows)
|
||||
|
||||
return BulkInsertResult(
|
||||
|
||||
@ -3,6 +3,7 @@ from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import folder_paths
|
||||
from app.assets.helpers import normalize_tags
|
||||
|
||||
|
||||
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
|
||||
@ -26,51 +27,27 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs).
|
||||
|
||||
Accepts both the legacy one-tag-per-directory shape
|
||||
(``["models", "diffusers", "Kolors", "text_encoder"]``) and the
|
||||
slash-joined shape emitted by :func:`get_name_and_tags_from_asset_path`
|
||||
(``["models", "diffusers/Kolors/text_encoder"]``). Hybrid shapes that
|
||||
mix the two within a single call (e.g.
|
||||
``["models", "diffusers", "Kolors/text_encoder"]``) are also
|
||||
accepted: each entry after ``tags[0]`` is split on ``/`` and
|
||||
concatenated, so the two shapes — and any mix of them — resolve to
|
||||
the same destination. The same safety checks are applied to each
|
||||
component after expansion.
|
||||
"""
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
if not tags:
|
||||
raise ValueError("tags must not be empty")
|
||||
root = tags[0].lower()
|
||||
|
||||
# Expand any slash-joined entries into individual path components so
|
||||
# the rest of the function can treat both tag shapes uniformly. Each
|
||||
# component is also stripped, so " a / b " behaves like ["a", "b"].
|
||||
expanded: list[str] = []
|
||||
for t in tags[1:]:
|
||||
for part in str(t).split("/"):
|
||||
part = part.strip()
|
||||
if part:
|
||||
expanded.append(part)
|
||||
|
||||
if root == "models":
|
||||
if not expanded:
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
category = expanded[0]
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[category][0]
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{category}'")
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{category}'")
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = expanded[1:]
|
||||
raw_subdirs = tags[2:]
|
||||
elif root == "input":
|
||||
base_dir = os.path.abspath(folder_paths.get_input_directory())
|
||||
raw_subdirs = expanded
|
||||
raw_subdirs = tags[1:]
|
||||
elif root == "output":
|
||||
base_dir = os.path.abspath(folder_paths.get_output_directory())
|
||||
raw_subdirs = expanded
|
||||
raw_subdirs = tags[1:]
|
||||
else:
|
||||
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
|
||||
_sep_chars = frozenset(("/", "\\", os.sep))
|
||||
@ -183,21 +160,7 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return (name, tags) derived from a filesystem path.
|
||||
|
||||
- name: base filename with extension
|
||||
- tags: [root_category] for paths with no parent subdirectories,
|
||||
[root_category, slash_joined_subpath] otherwise. The parent subpath
|
||||
(everything between the root category and the filename) is collapsed
|
||||
into a single tag rather than emitted as one tag per directory, so
|
||||
consumers can use ``tags[1]`` as a stable category identifier that
|
||||
survives nested directory layouts (e.g. diffusers components).
|
||||
|
||||
The subpath is lowercased to match the canonicalization applied by
|
||||
:func:`ensure_tags_exist`; without that, the
|
||||
``asset_reference_tags.tag_name`` FK to the lowercased ``tags.name``
|
||||
would fail for any path containing uppercase letters. The root
|
||||
category is lowercase by construction in
|
||||
:func:`get_asset_category_and_relative_path`, so no separate cast
|
||||
is applied here. Consumers that need to look up providers keyed on
|
||||
original-case paths should normalize their lookup key to lowercase.
|
||||
- tags: [root_category] + parent folder names in order
|
||||
|
||||
Raises:
|
||||
ValueError: path does not belong to any known root.
|
||||
@ -207,7 +170,4 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
parent_parts = [
|
||||
part for part in p.parent.parts if part not in (".", "..", p.anchor)
|
||||
]
|
||||
tags = [root_category]
|
||||
if parent_parts:
|
||||
tags.append("/".join(parent_parts).lower())
|
||||
return p.name, list(dict.fromkeys(t.strip() for t in tags if t.strip()))
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
|
||||
@ -543,7 +543,7 @@ class AudioConcat(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="AudioConcat",
|
||||
search_aliases=["join audio", "combine audio", "append audio"],
|
||||
display_name="Audio Concat",
|
||||
display_name="Concatenate Audio",
|
||||
description="Concatenates the audio1 to audio2 in the specified direction.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -597,7 +597,7 @@ class AudioMerge(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="AudioMerge",
|
||||
search_aliases=["mix audio", "overlay audio", "layer audio"],
|
||||
display_name="Audio Merge",
|
||||
display_name="Merge Audio",
|
||||
description="Combine two audio tracks by overlaying their waveforms.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -667,8 +667,9 @@ class AudioAdjustVolume(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="AudioAdjustVolume",
|
||||
search_aliases=["audio gain", "loudness", "audio level"],
|
||||
display_name="Audio Adjust Volume",
|
||||
display_name="Adjust Audio Volume",
|
||||
category="audio",
|
||||
description="Adjust the volume of the audio by a specified amount in decibels (dB).",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.Int.Input(
|
||||
|
||||
@ -47,8 +47,10 @@ class LoadImageDataSetFromFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadImageDataSetFromFolder",
|
||||
display_name="Load Image Dataset from Folder",
|
||||
category="dataset",
|
||||
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
|
||||
display_name="Load Image (from Folder)",
|
||||
category="image",
|
||||
description="Load a dataset of images from a specified folder and return a list of images. Supported formats: PNG, JPG, JPEG, WEBP.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
@ -84,14 +86,16 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadImageTextDataSetFromFolder",
|
||||
display_name="Load Image and Text Dataset from Folder",
|
||||
category="dataset",
|
||||
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
|
||||
display_name="Load Image-Text (from Folder)",
|
||||
category="image",
|
||||
description="Load a dataset of pairs of images and text captions from a specified folder and return them as a list. Supported formats: PNG, JPG, JPEG, WEBP.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"folder",
|
||||
options=folder_paths.get_input_subfolders(),
|
||||
tooltip="The folder to load images from.",
|
||||
tooltip="The folder to load images and text captions from.",
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
@ -206,8 +210,10 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveImageDataSetToFolder",
|
||||
display_name="Save Image Dataset to Folder",
|
||||
category="dataset",
|
||||
search_aliases=["save folder", "save to folder", "save dataset", "save images", "export dataset"],
|
||||
display_name="Save Image (to Folder) (DEPRECATED)",
|
||||
category="image",
|
||||
description="Save a dataset of images to a specified folder. Supported formats: PNG.",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
is_input_list=True, # Receive images as list
|
||||
@ -226,6 +232,7 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[],
|
||||
is_deprecated=True, # This node is redundant and superseded by existing Save Image nodes where the target folder can be specified in the filename_prefix
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -246,14 +253,20 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveImageTextDataSetToFolder",
|
||||
display_name="Save Image and Text Dataset to Folder",
|
||||
category="dataset",
|
||||
search_aliases=["save folder", "save to folder", "save dataset", "save images", "save text", "export dataset"],
|
||||
display_name="Save Image-Text (to Folder)",
|
||||
category="image",
|
||||
description="Save a dataset of pairs of images and text captions to a specified folder. Images are saved as PNG files and captions are saved as TXT files with the same filename_prefix.",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
is_input_list=True, # Receive both images and texts as lists
|
||||
inputs=[
|
||||
io.Image.Input("images", tooltip="List of images to save."),
|
||||
io.String.Input("texts", tooltip="List of text captions to save."),
|
||||
io.String.Input("texts",
|
||||
optional=True,
|
||||
force_input=True,
|
||||
tooltip="List of text captions to save."
|
||||
),
|
||||
io.String.Input(
|
||||
"folder_name",
|
||||
default="dataset",
|
||||
@ -270,7 +283,7 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, texts, folder_name, filename_prefix):
|
||||
def execute(cls, images, folder_name, filename_prefix, texts=None):
|
||||
# Extract scalar values
|
||||
folder_name = folder_name[0]
|
||||
filename_prefix = filename_prefix[0]
|
||||
@ -279,11 +292,12 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
||||
saved_files = save_images_to_folder(images, output_dir, filename_prefix)
|
||||
|
||||
# Save captions
|
||||
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
|
||||
caption_filename = filename.replace(".png", ".txt")
|
||||
caption_path = os.path.join(output_dir, caption_filename)
|
||||
with open(caption_path, "w", encoding="utf-8") as f:
|
||||
f.write(caption)
|
||||
if texts:
|
||||
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
|
||||
caption_filename = filename.replace(".png", ".txt")
|
||||
caption_path = os.path.join(output_dir, caption_filename)
|
||||
with open(caption_path, "w", encoding="utf-8") as f:
|
||||
f.write(caption)
|
||||
|
||||
logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.")
|
||||
return io.NodeOutput()
|
||||
@ -314,11 +328,13 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
|
||||
Child classes should set:
|
||||
node_id: Unique node identifier (required)
|
||||
search_aliases: List of search aliases (optional)
|
||||
display_name: Display name (optional, defaults to node_id)
|
||||
description: Node description (optional)
|
||||
extra_inputs: List of additional io.Input objects beyond "images" (optional)
|
||||
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
||||
is_output_list: True (list output) or False (single output) (optional, default True)
|
||||
is_deprecated: True if the node is deprecated (optional, default False)
|
||||
|
||||
Child classes must implement ONE of:
|
||||
_process(cls, image, **kwargs) -> tensor (for single-item processing)
|
||||
@ -326,12 +342,13 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
"""
|
||||
|
||||
node_id = None
|
||||
search_aliases = []
|
||||
display_name = None
|
||||
description = None
|
||||
extra_inputs = []
|
||||
is_group_process = None # None = auto-detect, True/False = explicit
|
||||
is_output_list = None # None = auto-detect based on processing mode
|
||||
|
||||
is_deprecated = False
|
||||
@classmethod
|
||||
def _detect_processing_mode(cls):
|
||||
"""Detect whether this node uses group or individual processing.
|
||||
@ -402,8 +419,10 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
|
||||
return io.Schema(
|
||||
node_id=cls.node_id,
|
||||
search_aliases=cls.search_aliases,
|
||||
display_name=cls.display_name or cls.node_id,
|
||||
category="dataset/image",
|
||||
category=cls.category,
|
||||
description=cls.description,
|
||||
is_experimental=True,
|
||||
is_input_list=is_group, # True for group, False for individual
|
||||
inputs=inputs,
|
||||
@ -472,11 +491,13 @@ class TextProcessingNode(io.ComfyNode):
|
||||
|
||||
Child classes should set:
|
||||
node_id: Unique node identifier (required)
|
||||
search_aliases: List of search aliases (optional)
|
||||
display_name: Display name (optional, defaults to node_id)
|
||||
description: Node description (optional)
|
||||
extra_inputs: List of additional io.Input objects beyond "texts" (optional)
|
||||
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
||||
is_output_list: True (list output) or False (single output) (optional, default True)
|
||||
is_deprecated: True if the node is deprecated (optional, default False)
|
||||
|
||||
Child classes must implement ONE of:
|
||||
_process(cls, text, **kwargs) -> str (for single-item processing)
|
||||
@ -484,12 +505,13 @@ class TextProcessingNode(io.ComfyNode):
|
||||
"""
|
||||
|
||||
node_id = None
|
||||
search_aliases = []
|
||||
display_name = None
|
||||
description = None
|
||||
extra_inputs = []
|
||||
is_group_process = None # None = auto-detect, True/False = explicit
|
||||
is_output_list = None # None = auto-detect based on processing mode
|
||||
|
||||
is_deprecated = False
|
||||
@classmethod
|
||||
def _detect_processing_mode(cls):
|
||||
"""Detect whether this node uses group or individual processing.
|
||||
@ -627,15 +649,17 @@ class TextProcessingNode(io.ComfyNode):
|
||||
|
||||
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesByShorterEdge"
|
||||
display_name = "Resize Images by Shorter Edge"
|
||||
description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio."
|
||||
display_name = "Resize Images by Shorter Edge (DEPRECATED)"
|
||||
category = "image/transform"
|
||||
description = "Resize images so that the shorter edge matches the specified dimension while preserving aspect ratio."
|
||||
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale shorter dimension
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"shorter_edge",
|
||||
default=512,
|
||||
min=1,
|
||||
max=8192,
|
||||
tooltip="Target length for the shorter edge.",
|
||||
tooltip="Target dimension for the shorter edge.",
|
||||
),
|
||||
]
|
||||
|
||||
@ -655,15 +679,17 @@ class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||
|
||||
class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesByLongerEdge"
|
||||
display_name = "Resize Images by Longer Edge"
|
||||
description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio."
|
||||
display_name = "Resize Images by Longer Edge (DEPRECATED)"
|
||||
category = "image/transform"
|
||||
description = "Resize images so that the longer edge matches the specified dimension while preserving aspect ratio."
|
||||
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale longer dimension
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"longer_edge",
|
||||
default=1024,
|
||||
min=1,
|
||||
max=8192,
|
||||
tooltip="Target length for the longer edge.",
|
||||
tooltip="Target dimension for the longer edge.",
|
||||
),
|
||||
]
|
||||
|
||||
@ -686,8 +712,10 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
||||
|
||||
class CenterCropImagesNode(ImageProcessingNode):
|
||||
node_id = "CenterCropImages"
|
||||
display_name = "Center Crop Images"
|
||||
description = "Center crop all images to the specified dimensions."
|
||||
search_aliases=["crop", "cut", "trim"]
|
||||
display_name="Crop Image (Center)"
|
||||
category="image/transform"
|
||||
description = "Center crop an image to the specified dimensions."
|
||||
extra_inputs = [
|
||||
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
|
||||
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
|
||||
@ -706,10 +734,11 @@ class CenterCropImagesNode(ImageProcessingNode):
|
||||
|
||||
class RandomCropImagesNode(ImageProcessingNode):
|
||||
node_id = "RandomCropImages"
|
||||
display_name = "Random Crop Images"
|
||||
description = (
|
||||
"Randomly crop all images to the specified dimensions (for data augmentation)."
|
||||
)
|
||||
search_aliases=["crop", "cut", "trim"]
|
||||
display_name = "Crop Image (Random)"
|
||||
category="image/transform"
|
||||
description = "Randomly crop an image to the specified dimensions."
|
||||
|
||||
extra_inputs = [
|
||||
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
|
||||
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
|
||||
@ -734,7 +763,9 @@ class RandomCropImagesNode(ImageProcessingNode):
|
||||
|
||||
class NormalizeImagesNode(ImageProcessingNode):
|
||||
node_id = "NormalizeImages"
|
||||
display_name = "Normalize Images"
|
||||
search_aliases=["normalize", "normalize colors"]
|
||||
display_name = "Normalize Image Colors"
|
||||
category = "image/color"
|
||||
description = "Normalize images using mean and standard deviation."
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
@ -762,8 +793,10 @@ class NormalizeImagesNode(ImageProcessingNode):
|
||||
|
||||
class AdjustBrightnessNode(ImageProcessingNode):
|
||||
node_id = "AdjustBrightness"
|
||||
search_aliases=["brightness"]
|
||||
display_name = "Adjust Brightness"
|
||||
description = "Adjust brightness of all images."
|
||||
category="image/adjustments"
|
||||
description = "Adjust the brightness of an image."
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
"factor",
|
||||
@ -781,8 +814,10 @@ class AdjustBrightnessNode(ImageProcessingNode):
|
||||
|
||||
class AdjustContrastNode(ImageProcessingNode):
|
||||
node_id = "AdjustContrast"
|
||||
search_aliases=["contrast"]
|
||||
display_name = "Adjust Contrast"
|
||||
description = "Adjust contrast of all images."
|
||||
category="image/adjustments"
|
||||
description = "Adjust the contrast of an image."
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
"factor",
|
||||
@ -800,8 +835,10 @@ class AdjustContrastNode(ImageProcessingNode):
|
||||
|
||||
class ShuffleDatasetNode(ImageProcessingNode):
|
||||
node_id = "ShuffleDataset"
|
||||
display_name = "Shuffle Image Dataset"
|
||||
description = "Randomly shuffle the order of images in the dataset."
|
||||
search_aliases=["shuffle", "randomize", "mix"]
|
||||
display_name = "Shuffle Images List"
|
||||
category = "image/batch"
|
||||
description = "Randomly shuffle the order of images in a list."
|
||||
is_group_process = True # Requires full list to shuffle
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
@ -823,13 +860,15 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ShuffleImageTextDataset",
|
||||
display_name="Shuffle Image-Text Dataset",
|
||||
category="dataset/image",
|
||||
search_aliases=["shuffle", "randomize", "mix"],
|
||||
display_name = "Shuffle Pairs of Image-Text",
|
||||
category = "image/batch",
|
||||
description = "Randomly shuffle the order of pairs of image-text in a list.",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
io.Image.Input("images", tooltip="List of images to shuffle."),
|
||||
io.String.Input("texts", tooltip="List of texts to shuffle."),
|
||||
io.String.Input("texts", tooltip="List of texts to shuffle.", force_input=True),
|
||||
io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
@ -865,8 +904,11 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
|
||||
|
||||
class TextToLowercaseNode(TextProcessingNode):
|
||||
node_id = "TextToLowercase"
|
||||
display_name = "Text to Lowercase"
|
||||
description = "Convert all texts to lowercase."
|
||||
search_aliases=["lowercase"]
|
||||
display_name = "Convert Text to Lowercase (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Convert text to lowercase."
|
||||
is_deprecated = True # This node is superseded by the Convert Text Case node
|
||||
|
||||
@classmethod
|
||||
def _process(cls, text):
|
||||
@ -875,8 +917,11 @@ class TextToLowercaseNode(TextProcessingNode):
|
||||
|
||||
class TextToUppercaseNode(TextProcessingNode):
|
||||
node_id = "TextToUppercase"
|
||||
display_name = "Text to Uppercase"
|
||||
description = "Convert all texts to uppercase."
|
||||
search_aliases=["uppercase"]
|
||||
display_name = "Convert Text to Uppercase (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Convert text to uppercase."
|
||||
is_deprecated = True # This node is superseded by the Convert Text Case node
|
||||
|
||||
@classmethod
|
||||
def _process(cls, text):
|
||||
@ -885,8 +930,10 @@ class TextToUppercaseNode(TextProcessingNode):
|
||||
|
||||
class TruncateTextNode(TextProcessingNode):
|
||||
node_id = "TruncateText"
|
||||
search_aliases=["truncate", "cut", "shorten"]
|
||||
display_name = "Truncate Text"
|
||||
description = "Truncate all texts to a maximum length."
|
||||
category = "text"
|
||||
description = "Truncate text to a maximum length."
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"max_length", default=77, min=1, max=10000, tooltip="Maximum text length."
|
||||
@ -900,8 +947,10 @@ class TruncateTextNode(TextProcessingNode):
|
||||
|
||||
class AddTextPrefixNode(TextProcessingNode):
|
||||
node_id = "AddTextPrefix"
|
||||
display_name = "Add Text Prefix"
|
||||
display_name = "Add Text Prefix (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Add a prefix to all texts."
|
||||
is_deprecated = True # This node is superseded by the Concatenate Text node
|
||||
extra_inputs = [
|
||||
io.String.Input("prefix", default="", tooltip="Prefix to add."),
|
||||
]
|
||||
@ -913,8 +962,10 @@ class AddTextPrefixNode(TextProcessingNode):
|
||||
|
||||
class AddTextSuffixNode(TextProcessingNode):
|
||||
node_id = "AddTextSuffix"
|
||||
display_name = "Add Text Suffix"
|
||||
display_name = "Add Text Suffix (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Add a suffix to all texts."
|
||||
is_deprecated = True # This node is superseded by the Concatenate Text node
|
||||
extra_inputs = [
|
||||
io.String.Input("suffix", default="", tooltip="Suffix to add."),
|
||||
]
|
||||
@ -926,8 +977,10 @@ class AddTextSuffixNode(TextProcessingNode):
|
||||
|
||||
class ReplaceTextNode(TextProcessingNode):
|
||||
node_id = "ReplaceText"
|
||||
display_name = "Replace Text"
|
||||
display_name = "Replace Text (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Replace text in all texts."
|
||||
is_deprecated = True # This node is superseded by the other Replace Text node
|
||||
extra_inputs = [
|
||||
io.String.Input("find", default="", tooltip="Text to find."),
|
||||
io.String.Input("replace", default="", tooltip="Text to replace with."),
|
||||
@ -940,8 +993,10 @@ class ReplaceTextNode(TextProcessingNode):
|
||||
|
||||
class StripWhitespaceNode(TextProcessingNode):
|
||||
node_id = "StripWhitespace"
|
||||
display_name = "Strip Whitespace"
|
||||
display_name = "Strip Whitespace (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Strip leading and trailing whitespace from all texts."
|
||||
is_deprecated = True # This node is superseded by the Trim Text node
|
||||
|
||||
@classmethod
|
||||
def _process(cls, text):
|
||||
@ -952,11 +1007,13 @@ class StripWhitespaceNode(TextProcessingNode):
|
||||
|
||||
|
||||
class ImageDeduplicationNode(ImageProcessingNode):
|
||||
"""Remove duplicate or very similar images from the dataset using perceptual hashing."""
|
||||
"""Remove duplicate or very similar images from a list using perceptual hashing."""
|
||||
|
||||
node_id = "ImageDeduplication"
|
||||
display_name = "Image Deduplication"
|
||||
description = "Remove duplicate or very similar images from the dataset."
|
||||
search_aliases=["deduplicate", "remove duplicates", "similarity filter"]
|
||||
display_name = "Deduplicate Images"
|
||||
category = "image/batch"
|
||||
description = "Remove duplicate or very similar images from a list."
|
||||
is_group_process = True # Requires full list to compare images
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
@ -1026,7 +1083,9 @@ class ImageGridNode(ImageProcessingNode):
|
||||
"""Combine multiple images into a single grid/collage."""
|
||||
|
||||
node_id = "ImageGrid"
|
||||
display_name = "Image Grid"
|
||||
search_aliases=["grid", "collage", "combine"]
|
||||
display_name = "Make Image Grid"
|
||||
category="image/batch"
|
||||
description = "Arrange multiple images into a grid layout."
|
||||
is_group_process = True # Requires full list to create grid
|
||||
is_output_list = False # Outputs single grid image
|
||||
@ -1102,9 +1161,12 @@ class MergeImageListsNode(ImageProcessingNode):
|
||||
"""Merge multiple image lists into a single list."""
|
||||
|
||||
node_id = "MergeImageLists"
|
||||
display_name = "Merge Image Lists"
|
||||
search_aliases=["list", "merge list", "make list"]
|
||||
display_name = "Merge Image Lists (DEPRECATED)"
|
||||
category = "image/batch"
|
||||
description = "Concatenate multiple image lists into one."
|
||||
is_group_process = True # Receives images as list
|
||||
is_deprecated = True # This node is superseded by the Create List node
|
||||
|
||||
@classmethod
|
||||
def _group_process(cls, images):
|
||||
@ -1119,9 +1181,11 @@ class MergeTextListsNode(TextProcessingNode):
|
||||
"""Merge multiple text lists into a single list."""
|
||||
|
||||
node_id = "MergeTextLists"
|
||||
display_name = "Merge Text Lists"
|
||||
display_name = "Merge Text Lists (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Concatenate multiple text lists into one."
|
||||
is_group_process = True # Receives texts as list
|
||||
is_deprecated = True # This node is superseded by the Create List node
|
||||
|
||||
@classmethod
|
||||
def _group_process(cls, texts):
|
||||
@ -1142,8 +1206,10 @@ class ResolutionBucket(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ResolutionBucket",
|
||||
search_aliases=["bucket by resolution", "group by resolution", "batch by resolution"],
|
||||
display_name="Resolution Bucket",
|
||||
category="dataset",
|
||||
category="training",
|
||||
description="Group latents and conditionings into buckets",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
@ -1236,7 +1302,8 @@ class MakeTrainingDataset(io.ComfyNode):
|
||||
node_id="MakeTrainingDataset",
|
||||
search_aliases=["encode dataset"],
|
||||
display_name="Make Training Dataset",
|
||||
category="dataset",
|
||||
category="training",
|
||||
description="Encode images with VAE and texts with CLIP to create a training dataset of latents and conditionings.",
|
||||
is_experimental=True,
|
||||
is_input_list=True, # images and texts as lists
|
||||
inputs=[
|
||||
@ -1251,6 +1318,7 @@ class MakeTrainingDataset(io.ComfyNode):
|
||||
"texts",
|
||||
optional=True,
|
||||
tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).",
|
||||
force_input=True
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
@ -1320,9 +1388,10 @@ class SaveTrainingDataset(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveTrainingDataset",
|
||||
search_aliases=["export training data"],
|
||||
search_aliases=["export dataset", "save dataset"],
|
||||
display_name="Save Training Dataset",
|
||||
category="dataset",
|
||||
category="training",
|
||||
description="Save encoded training dataset (latents + conditioning) to disk for efficient loading during training.",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
is_input_list=True, # Receive lists
|
||||
@ -1424,7 +1493,8 @@ class LoadTrainingDataset(io.ComfyNode):
|
||||
node_id="LoadTrainingDataset",
|
||||
search_aliases=["import dataset", "training data"],
|
||||
display_name="Load Training Dataset",
|
||||
category="dataset",
|
||||
category="training",
|
||||
description="Load encoded training dataset (latents + conditioning) from disk for use in training.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
|
||||
@ -419,15 +419,17 @@ class VoxelToMeshBasic(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VoxelToMeshBasic",
|
||||
display_name="Voxel to Mesh (Basic)",
|
||||
display_name="Voxel to Mesh (Basic) (DEPRECATED)",
|
||||
category="3d",
|
||||
description="Converts a voxel grid to a mesh.",
|
||||
is_deprecated=True, # This node is superseded by the Voxel To Mesh node
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output(),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -453,9 +455,10 @@ class VoxelToMesh(IO.ComfyNode):
|
||||
node_id="VoxelToMesh",
|
||||
display_name="Voxel to Mesh",
|
||||
category="3d",
|
||||
description="Converts a voxel grid to a mesh.",
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
IO.Combo.Input("algorithm", options=["surface net", "basic"], advanced=True),
|
||||
IO.Combo.Input("algorithm", options=["surface net", "basic"]),
|
||||
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
|
||||
@ -55,9 +55,10 @@ class ImageCropV2(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCropV2",
|
||||
search_aliases=["trim"],
|
||||
search_aliases=["crop", "cut", "trim"],
|
||||
display_name="Crop Image",
|
||||
category="image/transform",
|
||||
description = "Crop an image to the specified dimensions.",
|
||||
essentials_category="Image Tools",
|
||||
has_intermediate_output=True,
|
||||
inputs=[
|
||||
|
||||
@ -15,7 +15,7 @@ class SwitchNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfySwitchNode",
|
||||
display_name="Switch",
|
||||
category="logic",
|
||||
category="utils/logic",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Boolean.Input("switch"),
|
||||
@ -46,7 +46,7 @@ class SoftSwitchNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfySoftSwitchNode",
|
||||
display_name="Soft Switch",
|
||||
category="logic",
|
||||
category="utils/logic",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Boolean.Input("switch"),
|
||||
@ -136,7 +136,7 @@ class DCTestNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="DCTestNode",
|
||||
display_name="DCTest",
|
||||
category="logic",
|
||||
category="utils/logic",
|
||||
is_output_node=True,
|
||||
inputs=[io.DynamicCombo.Input("combo", options=[
|
||||
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
|
||||
@ -174,7 +174,7 @@ class AutogrowNamesTestNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="AutogrowNamesTestNode",
|
||||
display_name="AutogrowNamesTest",
|
||||
category="logic",
|
||||
category="utils/logic",
|
||||
inputs=[
|
||||
_io.Autogrow.Input("autogrow", template=template)
|
||||
],
|
||||
@ -194,7 +194,7 @@ class AutogrowPrefixTestNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="AutogrowPrefixTestNode",
|
||||
display_name="AutogrowPrefixTest",
|
||||
category="logic",
|
||||
category="utils/logic",
|
||||
inputs=[
|
||||
_io.Autogrow.Input("autogrow", template=template)
|
||||
],
|
||||
@ -213,7 +213,7 @@ class ComboOutputTestNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComboOptionTestNode",
|
||||
display_name="ComboOptionTest",
|
||||
category="logic",
|
||||
category="utils/logic",
|
||||
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
|
||||
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
|
||||
outputs=[io.Combo.Output(), io.Combo.Output()],
|
||||
@ -230,7 +230,7 @@ class ConvertStringToComboNode(io.ComfyNode):
|
||||
node_id="ConvertStringToComboNode",
|
||||
search_aliases=["string to dropdown", "text to combo"],
|
||||
display_name="Convert String to Combo",
|
||||
category="logic",
|
||||
category="utils/logic",
|
||||
inputs=[io.String.Input("string")],
|
||||
outputs=[io.Combo.Output()],
|
||||
)
|
||||
@ -246,7 +246,7 @@ class InvertBooleanNode(io.ComfyNode):
|
||||
node_id="InvertBooleanNode",
|
||||
search_aliases=["not", "toggle", "negate", "flip boolean"],
|
||||
display_name="Invert Boolean",
|
||||
category="logic",
|
||||
category="utils/logic",
|
||||
inputs=[io.Boolean.Input("boolean")],
|
||||
outputs=[io.Boolean.Output()],
|
||||
)
|
||||
|
||||
@ -11,8 +11,8 @@ class LTXVAudioVAELoader(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAELoader",
|
||||
display_name="LTXV Audio VAE Loader",
|
||||
category="audio",
|
||||
display_name="Load LTXV Audio VAE",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"ckpt_name",
|
||||
@ -40,7 +40,7 @@ class LTXVAudioVAEEncode(VAEEncodeAudio):
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEEncode",
|
||||
display_name="LTXV Audio VAE Encode",
|
||||
category="audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Audio.Input("audio", tooltip="The audio to be encoded."),
|
||||
io.Vae.Input(
|
||||
@ -63,7 +63,7 @@ class LTXVAudioVAEDecode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEDecode",
|
||||
display_name="LTXV Audio VAE Decode",
|
||||
category="audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Latent.Input("samples", tooltip="The latent to be decoded."),
|
||||
io.Vae.Input(
|
||||
|
||||
@ -70,7 +70,7 @@ class MathExpressionNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfyMathExpression",
|
||||
display_name="Math Expression",
|
||||
category="logic",
|
||||
category="utils",
|
||||
search_aliases=[
|
||||
"expression", "formula", "calculate", "calculator",
|
||||
"eval", "math",
|
||||
|
||||
@ -28,7 +28,7 @@ from comfy_extras.mediapipe.face_landmarker import FaceLandmarker
|
||||
from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection
|
||||
|
||||
|
||||
FaceLandmarkerType = io.Custom("FACE_LANDMARKER")
|
||||
FaceDetectionType = io.Custom("FACE_DETECTION_MODEL")
|
||||
FaceLandmarksType = io.Custom("FACE_LANDMARKS")
|
||||
|
||||
_CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights")
|
||||
@ -204,18 +204,19 @@ class LoadMediaPipeFaceLandmarker(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadMediaPipeFaceLandmarker",
|
||||
display_name="Load MediaPipe Face Landmarker",
|
||||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
|
||||
display_name="Load Face Detection Model (MediaPipe)",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("mediapipe"),
|
||||
tooltip="Face Landmarker safetensors from models/mediapipe/."),
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("detection"),
|
||||
tooltip="Face detection model from models/detection/."),
|
||||
],
|
||||
outputs=[FaceLandmarkerType.Output()],
|
||||
outputs=[FaceDetectionType.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name) -> io.NodeOutput:
|
||||
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("mediapipe", model_name), safe_load=True)
|
||||
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("detection", model_name), safe_load=True)
|
||||
wrapper = FaceLandmarkerModel(sd)
|
||||
return io.NodeOutput(wrapper)
|
||||
|
||||
@ -234,10 +235,12 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MediaPipeFaceLandmarker",
|
||||
display_name="MediaPipe Face Landmarker",
|
||||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
|
||||
display_name="Detect Face Landmarks (MediaPipe)",
|
||||
category="image/detection",
|
||||
description="Detects facial landmarks using MediaPipe model.",
|
||||
inputs=[
|
||||
FaceLandmarkerType.Input("face_landmarker"),
|
||||
FaceDetectionType.Input("face_detection_model"),
|
||||
io.Image.Input("image"),
|
||||
io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short",
|
||||
tooltip="Face detector range. 'short' is tuned for close-up faces "
|
||||
@ -261,9 +264,9 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, face_landmarker, image, detector_variant, num_faces, min_confidence,
|
||||
def execute(cls, face_detection_model, image, detector_variant, num_faces, min_confidence,
|
||||
missing_frame_fallback) -> io.NodeOutput:
|
||||
canonical = face_landmarker.canonical_data
|
||||
canonical = face_detection_model.canonical_data
|
||||
img_np = _image_to_uint8(image)
|
||||
B, H, W = img_np.shape[:3]
|
||||
chunk = 16
|
||||
@ -276,7 +279,7 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq:
|
||||
for i in range(0, B, chunk):
|
||||
end = min(i + chunk, B)
|
||||
res.extend(face_landmarker.detect_batch(
|
||||
res.extend(face_detection_model.detect_batch(
|
||||
[img_np[bi] for bi in range(i, end)],
|
||||
num_faces=int(num_faces),
|
||||
score_thresh=float(min_confidence),
|
||||
@ -306,7 +309,7 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])})
|
||||
bboxes.append(per_bb)
|
||||
return io.NodeOutput({"frames": frames, "image_size": (H, W),
|
||||
"connection_sets": face_landmarker.connection_sets}, bboxes)
|
||||
"connection_sets": face_detection_model.connection_sets}, bboxes)
|
||||
|
||||
|
||||
# Topology keys unioned by the 'all' connections preset (contour parts + irises + nose).
|
||||
@ -332,8 +335,10 @@ class MediaPipeFaceMeshVisualize(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MediaPipeFaceMeshVisualize",
|
||||
display_name="MediaPipe Face Mesh Visualize",
|
||||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection", "visualize"],
|
||||
display_name="Visualize Face Landmarks (MediaPipe)",
|
||||
category="image/detection",
|
||||
description="Draws face landmarks mesh on the input image.",
|
||||
inputs=[
|
||||
FaceLandmarksType.Input("face_landmarks"),
|
||||
io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."),
|
||||
@ -443,8 +448,10 @@ class MediaPipeFaceMask(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MediaPipeFaceMask",
|
||||
display_name="MediaPipe Face Mask",
|
||||
search_aliases=["face", "facial", "mediapipe", "face mask", "blazeface", "face detection", "visualize"],
|
||||
display_name="Draw Face Mask (MediaPipe)",
|
||||
category="image/detection",
|
||||
description="Draws a mask from face landmarks.",
|
||||
inputs=[
|
||||
FaceLandmarksType.Input("face_landmarks"),
|
||||
io.DynamicCombo.Input(
|
||||
|
||||
@ -103,8 +103,10 @@ class MoGePanoramaInference(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGePanoramaInference",
|
||||
display_name="MoGe Panorama Inference",
|
||||
search_aliases=["moge", "panorama", "depth", "geometry", "depth estimation", "geometry estimation"],
|
||||
display_name="Run MoGe Panorama Inference",
|
||||
category="image/geometry_estimation",
|
||||
description="Run MoGe on an equirectangular panorama by splitting it into 12 perspective views, running inference on each, and merging the results into a single depth map.",
|
||||
inputs=[
|
||||
MoGeModelType.Input("moge_model"),
|
||||
io.Image.Input("image", tooltip="Equirectangular panorama (any aspect)."),
|
||||
@ -222,7 +224,9 @@ class MoGeInference(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGeInference",
|
||||
display_name="MoGe Inference",
|
||||
search_aliases=["moge", "depth", "geometry", "depth estimation", "geometry estimation"],
|
||||
display_name="Run MoGe Inference",
|
||||
description="Run MoGe on a single image to estimate depth and geometry.",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeModelType.Input("moge_model"),
|
||||
@ -277,7 +281,9 @@ class MoGeRender(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGeRender",
|
||||
display_name="MoGe Render",
|
||||
search_aliases=["moge", "render", "geometry", "depth", "normal"],
|
||||
display_name="Render MoGe Geometry",
|
||||
description="Render a depth map or normal map from geometry data",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeGeometry.Input("moge_geometry"),
|
||||
@ -342,7 +348,9 @@ class MoGePointMapToMesh(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGePointMapToMesh",
|
||||
display_name="MoGe Point Map to Mesh",
|
||||
search_aliases=["moge", "mesh", "geometry", "point map"],
|
||||
display_name="Convert MoGe Point Map to Mesh",
|
||||
description="Convert a MoGe point map into a 3D mesh.",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeGeometry.Input("moge_geometry"),
|
||||
|
||||
@ -14,7 +14,7 @@ class CreateList(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="CreateList",
|
||||
display_name="Create List",
|
||||
category="logic",
|
||||
category="utils",
|
||||
is_input_list=True,
|
||||
search_aliases=["Image Iterator", "Text Iterator", "Iterator"],
|
||||
inputs=[io.Autogrow.Input("inputs", template=template_autogrow)],
|
||||
|
||||
@ -60,7 +60,7 @@ folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geom
|
||||
|
||||
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["mediapipe"] = ([os.path.join(models_dir, "mediapipe")], supported_pt_extensions)
|
||||
folder_names_and_paths["detection"] = ([os.path.join(models_dir, "detection")], supported_pt_extensions)
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
|
||||
673
openapi.yaml
673
openapi.yaml
@ -1556,12 +1556,6 @@ paths:
|
||||
type: string
|
||||
enum: [asc, desc]
|
||||
description: Sort direction
|
||||
- name: job_ids
|
||||
in: query
|
||||
schema:
|
||||
type: string
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Comma-separated UUIDs to filter assets by associated job."
|
||||
- name: include_public
|
||||
in: query
|
||||
schema:
|
||||
@ -2514,37 +2508,25 @@ paths:
|
||||
|
||||
/api/assets/import:
|
||||
post:
|
||||
operationId: importAssets
|
||||
operationId: importPublishedAssets
|
||||
tags: [assets]
|
||||
summary: Import assets from external URLs
|
||||
description: "[cloud-only] Imports one or more assets from external URLs into the cloud asset store."
|
||||
summary: "[cloud-only] Import published assets into the caller's library"
|
||||
description: |
|
||||
[cloud-only] Imports the specified published assets into the caller's asset library. New DB records reference the same storage objects; no file copying occurs. Assets the caller already owns (by hash) are deduplicated. The `id` field on each returned `AssetInfo` is the caller's newly-created private asset ID, not the published asset ID supplied in the request.
|
||||
x-runtime: [cloud]
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
required:
|
||||
- imports
|
||||
properties:
|
||||
imports:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/AssetImportRequest"
|
||||
description: Assets to import
|
||||
$ref: "#/components/schemas/ImportPublishedAssetsRequest"
|
||||
responses:
|
||||
"200":
|
||||
description: Import initiated
|
||||
description: Successfully imported assets
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
assets:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/Asset"
|
||||
$ref: "#/components/schemas/ImportPublishedAssetsResponse"
|
||||
"400":
|
||||
description: Bad request
|
||||
content:
|
||||
@ -3790,6 +3772,295 @@ paths:
|
||||
schema:
|
||||
$ref: "#/components/schemas/JwksResponse"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OAuth 2.1 / RFC 7591 Dynamic Client Registration (cloud)
|
||||
# ---------------------------------------------------------------------------
|
||||
/.well-known/oauth-authorization-server:
|
||||
get:
|
||||
operationId: getOAuthAuthorizationServer
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] OAuth 2.1 authorization-server metadata (RFC 8414)"
|
||||
description: "[cloud-only] Public metadata document for OAuth 2.1 clients. Cached 5 minutes."
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
responses:
|
||||
"200":
|
||||
description: Authorization-server metadata
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthAuthorizationServerMetadata"
|
||||
"404":
|
||||
description: OAuth disabled
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
|
||||
/.well-known/oauth-protected-resource:
|
||||
get:
|
||||
operationId: getOAuthProtectedResource
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] OAuth 2.1 protected-resource metadata (RFC 9728)"
|
||||
description: "[cloud-only] Public metadata describing the currently advertised protected resource. Cached 5 minutes."
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
responses:
|
||||
"200":
|
||||
description: Protected-resource metadata
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthProtectedResourceMetadata"
|
||||
"404":
|
||||
description: OAuth disabled or no active resource configured
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
|
||||
/oauth/authorize:
|
||||
get:
|
||||
operationId: getOAuthAuthorize
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] Begin or resume an OAuth 2.1 authorization request"
|
||||
description: |
|
||||
[cloud-only] Two modes:
|
||||
- **Initial entry** (OAuth params present): validates client/redirect/resource/scopes, persists a server-side authorization-request row, and either redirects (no session / unverified email) to the configured frontend login URL carrying only the opaque `oauth_request_id`, or returns the JSON consent challenge for the frontend to render.
|
||||
- **Resume** (`oauth_request_id` present): loads the server-side row, fails closed if expired/consumed/unknown, returns the JSON consent challenge. Browser-replayed OAuth params are intentionally ignored.
|
||||
|
||||
The frontend renders the consent UI from the JSON payload and POSTs the user's decision back to this endpoint.
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
parameters:
|
||||
- { name: response_type, in: query, required: false, schema: { type: string } }
|
||||
- { name: client_id, in: query, required: false, schema: { type: string } }
|
||||
- { name: redirect_uri, in: query, required: false, schema: { type: string } }
|
||||
- { name: scope, in: query, required: false, schema: { type: string } }
|
||||
- name: state
|
||||
in: query
|
||||
required: false
|
||||
schema: { type: string }
|
||||
description: |
|
||||
RFC 6749 §10.12 marks `state` as RECOMMENDED. Cloud hardening makes it REQUIRED on the initial-entry path (omitted only on the resume path where `oauth_request_id` is supplied instead). This parameter is `required: false` at the spec level only because the operation is dual-mode (initial entry vs. resume); the runtime rejects empty `state` on the initial-entry path with a stable `invalid_request` 400.
|
||||
- { name: code_challenge, in: query, required: false, schema: { type: string } }
|
||||
- { name: code_challenge_method, in: query, required: false, schema: { type: string } }
|
||||
- { name: resource, in: query, required: false, schema: { type: string } }
|
||||
- { name: oauth_request_id, in: query, required: false, schema: { type: string } }
|
||||
responses:
|
||||
"200":
|
||||
description: Consent challenge payload (session present, email verified). Frontend renders the consent UI from this payload and POSTs back to /oauth/authorize.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthConsentChallenge"
|
||||
"302":
|
||||
description: Redirect to login (no session / unverified email) or to registered redirect_uri (pre-validated client error)
|
||||
headers:
|
||||
Location:
|
||||
schema:
|
||||
type: string
|
||||
"400":
|
||||
description: Invalid authorize request (pre-redirect failure — unknown client, redirect mismatch, malformed params)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
"404":
|
||||
description: OAuth disabled
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
post:
|
||||
operationId: postOAuthAuthorize
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] Submit OAuth consent decision"
|
||||
description: |
|
||||
[cloud-only] JSON-only consent submission. The handler verifies the per-row CSRF token, atomically marks the authorization request consumed (single-use covers both allow and deny paths), then returns the redirect URL the browser must navigate to. The URL contains either `code` + original `state` for allow, or the RFC 6749 §5.2 error and `state` for deny.
|
||||
|
||||
Workspace membership is re-checked at submission time. Consent is persisted keyed by `(user_id, client_id, resource_id, workspace_id)`; broadening the previously approved scope set requires a fresh consent flow.
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
required: [oauth_request_id, csrf_token, decision, workspace_id]
|
||||
properties:
|
||||
oauth_request_id: { type: string, format: uuid }
|
||||
csrf_token: { type: string }
|
||||
decision: { type: string, enum: [allow, deny] }
|
||||
workspace_id: { type: string }
|
||||
responses:
|
||||
"200":
|
||||
description: Redirect URL for the frontend to navigate to (allow → with code+state; deny → with error+state)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthAuthorizeRedirectResponse"
|
||||
"400":
|
||||
description: Bad request (CSRF mismatch, expired/consumed request, inaccessible workspace)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
"403":
|
||||
description: Scope broadening on consent re-grant — fresh consent flow required
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
"404":
|
||||
description: OAuth disabled
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
|
||||
/oauth/token:
|
||||
post:
|
||||
operationId: postOAuthToken
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] Exchange authorization code or refresh token for a resource-bound access token"
|
||||
description: |
|
||||
[cloud-only] OAuth 2.1 token endpoint (RFC 6749 §3.2). Public clients only — `client_secret` is rejected.
|
||||
|
||||
Two grant types are supported:
|
||||
- `authorization_code` — exchanges the code minted by `/oauth/authorize` (with PKCE verifier) for an access token + first refresh token. Single-use; reuse fails closed.
|
||||
- `refresh_token` — rotates the refresh token. Old token immediately invalid; presenting an already-rotated token revokes the entire token family and emits a security metric.
|
||||
|
||||
Both grant types re-validate canonical user state, current workspace membership, and the resource's active flag at every mint. A code or refresh token bound to a deactivated resource fails closed.
|
||||
|
||||
Errors follow RFC 6749 §5.2. Logs never contain raw codes, refresh tokens, or minted tokens.
|
||||
|
||||
Per RFC 6749 §5.1, every 200 and 400 response carries `Cache-Control: no-store` and `Pragma: no-cache` so intermediaries cannot cache token-bearing or state-change-reason responses.
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/x-www-form-urlencoded:
|
||||
schema:
|
||||
type: object
|
||||
required: [grant_type, client_id]
|
||||
properties:
|
||||
grant_type: { type: string, enum: [authorization_code, refresh_token] }
|
||||
client_id: { type: string }
|
||||
code: { type: string }
|
||||
redirect_uri: { type: string }
|
||||
code_verifier: { type: string }
|
||||
refresh_token: { type: string }
|
||||
scope: { type: string }
|
||||
client_secret: { type: string }
|
||||
responses:
|
||||
"200":
|
||||
description: New token pair
|
||||
headers:
|
||||
Cache-Control:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-store" per RFC 6749 §5.1'
|
||||
Pragma:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-cache" per RFC 6749 §5.1'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthTokenResponse"
|
||||
"400":
|
||||
description: RFC 6749 §5.2 error
|
||||
headers:
|
||||
Cache-Control:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-store" per RFC 6749 §5.1'
|
||||
Pragma:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-cache" per RFC 6749 §5.1'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthTokenError"
|
||||
"404":
|
||||
description: OAuth disabled
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
|
||||
/oauth/register:
|
||||
post:
|
||||
operationId: postOAuthRegister
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] Dynamic Client Registration (RFC 7591)"
|
||||
description: |
|
||||
[cloud-only] Public, unauthenticated, insert-only RFC 7591 §3.1 client registration. Used by MCP-spec-compliant clients to self-register a public OAuth client without operator involvement.
|
||||
|
||||
Policy:
|
||||
|
||||
- Public clients only — `token_endpoint_auth_method` is forced to `none`. Confidential-client registration is out of scope this phase.
|
||||
- Server-owned `resource_grants`. Caller-supplied `scope` or `resource_grants` is rejected as `invalid_client_metadata` (would be a privilege-escalation surface). Dynamic clients receive the same scopes the active resource publishes.
|
||||
- Application-type-aware redirect URI policy. `application_type=native` accepts loopback (`127.0.0.1`, `::1`, `localhost`) and reverse-DNS-shaped custom schemes; `application_type=web` accepts HTTPS to hosts in an operator-controlled allowlist only. `application_type` is REQUIRED on the request — missing or empty rejects with `invalid_client_metadata`.
|
||||
- Anti-impersonation: reserved client names are rejected from third parties via NFKC-folded compare.
|
||||
- Generated `client_id` carries a stable prefix to distinguish dynamic from seeded clients in audit logs.
|
||||
- Cache-Control: `no-store` on every 201 and 400 response (the response carries fresh credentials and rejection reasons).
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthRegisterRequest"
|
||||
responses:
|
||||
"201":
|
||||
description: Registered. Body echoes the metadata RFC 7591 §3.2.1 requires.
|
||||
headers:
|
||||
Cache-Control:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-store"'
|
||||
Pragma:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-cache"'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthRegisterResponse"
|
||||
"400":
|
||||
description: RFC 7591 §3.2.2 invalid client metadata
|
||||
headers:
|
||||
Cache-Control:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-store"'
|
||||
Pragma:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-cache"'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthRegisterError"
|
||||
"404":
|
||||
description: OAuth disabled
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
"503":
|
||||
description: No active resource is configured — DCR cannot mint a usable client until an active resource row is seeded.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Billing (cloud)
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -7090,24 +7361,35 @@ components:
|
||||
type: string
|
||||
description: Target path on the runtime filesystem
|
||||
|
||||
AssetImportRequest:
|
||||
ImportPublishedAssetsRequest:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] A single asset to import from an external URL."
|
||||
description: "[cloud-only] Request body for importing published assets into the caller's library."
|
||||
required:
|
||||
- url
|
||||
- published_asset_ids
|
||||
properties:
|
||||
url:
|
||||
type: string
|
||||
format: uri
|
||||
description: URL of the asset to import
|
||||
name:
|
||||
type: string
|
||||
description: Display name for the imported asset
|
||||
tags:
|
||||
published_asset_ids:
|
||||
type: array
|
||||
description: IDs of published assets (inputs and models) to import.
|
||||
items:
|
||||
type: string
|
||||
share_id:
|
||||
type: string
|
||||
nullable: true
|
||||
description: |
|
||||
Optional. Share ID of the published workflow these assets belong to. When provided (non-null, non-empty): all `published_asset_ids` must belong to this share's workflow version; returns 400 if the share is not found or any asset does not belong to it. When omitted, null, or empty string: no share-scoped validation is performed and the assets are validated only against global rules (preserved for clients that have not yet adopted `share_id`).
|
||||
|
||||
ImportPublishedAssetsResponse:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Response after importing published assets. Each returned `AssetInfo.id` is the caller's newly-created private asset ID, not the published asset ID supplied in the request."
|
||||
required:
|
||||
- assets
|
||||
properties:
|
||||
assets:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/AssetInfo"
|
||||
|
||||
RemoteAssetMetadata:
|
||||
type: object
|
||||
@ -7424,6 +7706,325 @@ components:
|
||||
description: RSA exponent (base64url)
|
||||
additionalProperties: true
|
||||
|
||||
OAuthAuthorizationServerMetadata:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] OAuth 2.1 authorization-server metadata (RFC 8414)."
|
||||
required:
|
||||
- issuer
|
||||
- authorization_endpoint
|
||||
- token_endpoint
|
||||
- jwks_uri
|
||||
- response_types_supported
|
||||
- grant_types_supported
|
||||
- code_challenge_methods_supported
|
||||
- token_endpoint_auth_methods_supported
|
||||
properties:
|
||||
issuer:
|
||||
type: string
|
||||
format: uri
|
||||
authorization_endpoint:
|
||||
type: string
|
||||
format: uri
|
||||
token_endpoint:
|
||||
type: string
|
||||
format: uri
|
||||
jwks_uri:
|
||||
type: string
|
||||
format: uri
|
||||
registration_endpoint:
|
||||
type: string
|
||||
format: uri
|
||||
description: "[cloud-only] RFC 7591 §3.1 Dynamic Client Registration endpoint. Advertised so MCP-spec-compliant clients can auto-discover and self-register without operator involvement. Present only when DCR is enabled."
|
||||
response_types_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
grant_types_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
code_challenge_methods_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
token_endpoint_auth_methods_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
scopes_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
||||
OAuthProtectedResourceMetadata:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] OAuth 2.1 protected-resource metadata (RFC 9728)."
|
||||
required:
|
||||
- resource
|
||||
- authorization_servers
|
||||
- scopes_supported
|
||||
properties:
|
||||
resource:
|
||||
type: string
|
||||
format: uri
|
||||
authorization_servers:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
format: uri
|
||||
scopes_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
bearer_methods_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
||||
OAuthConsentChallenge:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Server-side state describing the OAuth consent decision the user is being asked to make. Returned by GET /oauth/authorize when a valid session exists; the frontend renders the consent UI from this payload and POSTs the decision back. Browser never sees the original OAuth params on resume."
|
||||
required:
|
||||
- oauth_request_id
|
||||
- csrf_token
|
||||
- client_display_name
|
||||
- resource_display_name
|
||||
- scopes
|
||||
- workspaces
|
||||
properties:
|
||||
oauth_request_id:
|
||||
type: string
|
||||
format: uuid
|
||||
description: Opaque server-side identifier for the authorization-request row. Carried back unchanged in the consent submission.
|
||||
csrf_token:
|
||||
type: string
|
||||
description: Per-row CSRF token bound to this authorization request (not to the session). Must be echoed back on POST.
|
||||
client_display_name:
|
||||
type: string
|
||||
description: Human-readable name of the OAuth client requesting authorization.
|
||||
resource_display_name:
|
||||
type: string
|
||||
description: Human-readable name of the protected resource.
|
||||
scopes:
|
||||
type: array
|
||||
description: Scopes the client is requesting for this resource. The frontend should present these for the user to approve.
|
||||
items:
|
||||
type: string
|
||||
workspaces:
|
||||
type: array
|
||||
description: Workspaces the user can select from. Membership is re-checked on POST.
|
||||
items:
|
||||
$ref: "#/components/schemas/OAuthConsentChallengeWorkspace"
|
||||
|
||||
OAuthConsentChallengeWorkspace:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] One workspace option presented in the OAuth consent challenge."
|
||||
required: [id, name, type, role]
|
||||
properties:
|
||||
id: { type: string }
|
||||
name: { type: string }
|
||||
type: { type: string, enum: [personal, team] }
|
||||
role: { type: string, enum: [owner, member] }
|
||||
|
||||
OAuthAuthorizeRedirectResponse:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Redirect target produced after a JSON consent submission. The frontend must navigate the browser to this URL so custom-scheme client callbacks work without relying on fetch-visible 302 headers."
|
||||
required:
|
||||
- redirect_url
|
||||
properties:
|
||||
redirect_url:
|
||||
type: string
|
||||
format: uri
|
||||
description: OAuth client redirect URI with either code+state for allow, or error+state for deny.
|
||||
|
||||
OAuthTokenResponse:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] RFC 6749 §5.1 successful token response."
|
||||
required: [access_token, token_type, expires_in, refresh_token, scope]
|
||||
properties:
|
||||
access_token:
|
||||
type: string
|
||||
description: Resource-bound access token (audience matches the protected resource).
|
||||
token_type:
|
||||
type: string
|
||||
enum: [Bearer]
|
||||
expires_in:
|
||||
type: integer
|
||||
description: Access token lifetime in seconds.
|
||||
refresh_token:
|
||||
type: string
|
||||
description: Opaque refresh token. Rotates on every successful refresh; presenting an already-rotated token revokes the entire family.
|
||||
scope:
|
||||
type: string
|
||||
description: Space-delimited scopes granted with this token.
|
||||
|
||||
OAuthTokenError:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] RFC 6749 §5.2 error response."
|
||||
required: [error]
|
||||
properties:
|
||||
error:
|
||||
type: string
|
||||
description: 'RFC 6749 §5.2 error code: invalid_request, invalid_client, invalid_grant, unauthorized_client, unsupported_grant_type, invalid_scope.'
|
||||
error_description:
|
||||
type: string
|
||||
description: Human-readable, no leak of internal storage state.
|
||||
|
||||
OAuthRegisterRequest:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
additionalProperties: false
|
||||
description: "[cloud-only] RFC 7591 §2 client metadata document. Only the fields the server honors are listed; presence of `scope` or `resource_grants` in the request is rejected (`invalid_client_metadata`) because those are server-owned for dynamic clients."
|
||||
required:
|
||||
- redirect_uris
|
||||
- application_type
|
||||
properties:
|
||||
redirect_uris:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
minItems: 1
|
||||
maxItems: 5
|
||||
description: 1–5 redirect URIs. Validated against `application_type` policy.
|
||||
client_name:
|
||||
type: string
|
||||
maxLength: 100
|
||||
description: Human-readable name shown in the consent UI. Reserved-name list rejects impersonation of major clients.
|
||||
application_type:
|
||||
type: string
|
||||
enum: [native, web]
|
||||
description: |
|
||||
RFC 7591 §2 application_type. **REQUIRED** — clients MUST declare intent; the server does not default this field. `native` for desktop / CLI / MCP-spec-strict clients (loopback redirects); `web` for hosted clients (HTTPS only, host must be allowlisted). A missing or explicitly empty `application_type` rejects with `invalid_client_metadata`.
|
||||
token_endpoint_auth_method:
|
||||
type: string
|
||||
enum: [none]
|
||||
description: 'Public clients only this phase — must be `none` if present. The server forces `none` regardless.'
|
||||
grant_types:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
enum: [authorization_code, refresh_token]
|
||||
description: Optional. Defaults to `["authorization_code","refresh_token"]`.
|
||||
response_types:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
enum: [code]
|
||||
description: Optional. Defaults to `["code"]`.
|
||||
scope:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Dynamic clients do not pick scopes — the server assigns scopes from the active resource's published list. Sending `scope` in the registration body is treated as a privilege-escalation attempt and returns `invalid_client_metadata`."
|
||||
resource_grants:
|
||||
type: object
|
||||
nullable: true
|
||||
additionalProperties:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: "**REJECTED IF PRESENT.** Same reason as `scope`. The set of resources and scopes a dynamic client may request is server-policy, not request-driven."
|
||||
client_uri:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
logo_uri:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
tos_uri:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
policy_uri:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
software_id:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
software_version:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
contacts:
|
||||
type: array
|
||||
nullable: true
|
||||
items:
|
||||
type: string
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
jwks:
|
||||
type: object
|
||||
nullable: true
|
||||
additionalProperties: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
jwks_uri:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
|
||||
OAuthRegisterResponse:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] RFC 7591 §3.2.1 successful registration response."
|
||||
required:
|
||||
- client_id
|
||||
- client_id_issued_at
|
||||
- redirect_uris
|
||||
- grant_types
|
||||
- response_types
|
||||
- token_endpoint_auth_method
|
||||
- application_type
|
||||
properties:
|
||||
client_id:
|
||||
type: string
|
||||
description: Server-generated client_id.
|
||||
client_id_issued_at:
|
||||
type: integer
|
||||
format: int64
|
||||
description: Unix timestamp (seconds) when the client was registered.
|
||||
client_name:
|
||||
type: string
|
||||
redirect_uris:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
grant_types:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
response_types:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
token_endpoint_auth_method:
|
||||
type: string
|
||||
enum: [none]
|
||||
application_type:
|
||||
type: string
|
||||
enum: [native, web]
|
||||
|
||||
OAuthRegisterError:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] RFC 7591 §3.2.2 error response."
|
||||
required:
|
||||
- error
|
||||
properties:
|
||||
error:
|
||||
type: string
|
||||
enum: [invalid_redirect_uri, invalid_client_metadata]
|
||||
error_description:
|
||||
type: string
|
||||
nullable: true
|
||||
|
||||
BillingBalance:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
|
||||
@ -21,7 +21,6 @@ from app.assets.database.queries import (
|
||||
get_reference_ids_by_ids,
|
||||
ensure_tags_exist,
|
||||
add_tags_to_reference,
|
||||
set_reference_tags,
|
||||
)
|
||||
from app.assets.helpers import get_utc_now
|
||||
|
||||
@ -160,153 +159,6 @@ class TestListReferencesPage:
|
||||
assert refs[0].name == "large"
|
||||
|
||||
|
||||
class TestTagRetrievalOrder:
|
||||
"""End-to-end check: tags written through the public write paths come
|
||||
back from the public read paths in insertion order rather than the
|
||||
composite-PK alphabetical order SQLite would otherwise impose.
|
||||
|
||||
Each test deliberately picks tag names that would sort differently
|
||||
under alphabetical vs insertion order, so an alphabetical regression
|
||||
fails loudly.
|
||||
"""
|
||||
|
||||
def _make_ref(self, session: Session) -> AssetReference:
|
||||
asset = _make_asset(session, "h1")
|
||||
return _make_reference(session, asset, name="x.bin")
|
||||
|
||||
def test_set_reference_tags_preserves_input_order_in_list(self, session: Session):
|
||||
ref = self._make_ref(session)
|
||||
# "checkpoints" < "models" alphabetically; if added_at stagger
|
||||
# works, list_references_page returns insertion order.
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
assert tag_map[ref.id] == ["models", "checkpoints"]
|
||||
|
||||
def test_set_reference_tags_preserves_input_order_in_fetch(self, session: Session):
|
||||
ref = self._make_ref(session)
|
||||
# Subpath tag sorts before "models" alphabetically.
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "diffusers/kolors/text_encoder"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
result = fetch_reference_asset_and_tags(session, ref.id)
|
||||
assert result is not None
|
||||
_, _, tags = result
|
||||
# Bucket-prefix expansion appends the standalone `diffusers` token
|
||||
# at path-tier (microsecond stagger) so FE set-membership filters
|
||||
# match nested category paths.
|
||||
assert tags == ["models", "diffusers/kolors/text_encoder", "diffusers"]
|
||||
|
||||
def test_add_tags_to_reference_lands_after_path_tags(self, session: Session):
|
||||
ref = self._make_ref(session)
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
|
||||
session.commit()
|
||||
|
||||
# "aaa-..." sorts before both path tags alphabetically. If added_at
|
||||
# stagger is missing, alphabetic tiebreak would hoist it to tags[0].
|
||||
add_tags_to_reference(
|
||||
session, reference_id=ref.id, tags=["aaa-user-tag"], origin="manual"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
assert tag_map[ref.id] == ["models", "checkpoints", "aaa-user-tag"]
|
||||
|
||||
def test_multi_tag_batch_lands_after_path_tags(self, session: Session):
|
||||
ref = self._make_ref(session)
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
|
||||
session.commit()
|
||||
|
||||
# Three user tags inserted in non-alphabetical input order. Per-tag
|
||||
# microsecond stagger should preserve at least the "user batch is
|
||||
# after path tags" property; within the user batch insertion order
|
||||
# is also preserved.
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["zzz-z", "favorite", "experiment-q4"],
|
||||
origin="manual",
|
||||
)
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
tags = tag_map[ref.id]
|
||||
assert tags[0:2] == ["models", "checkpoints"]
|
||||
assert set(tags[2:]) == {"zzz-z", "favorite", "experiment-q4"}
|
||||
|
||||
def test_user_batch_lands_after_path_batch_under_clock_collision(
|
||||
self, session: Session, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Windows-specific race: when two back-to-back commits share the
|
||||
same datetime.now() microsecond, the path-tier and user-tier
|
||||
added_at values used to collide and alphabetic tiebreak would
|
||||
hoist user tags ahead of path tags. The fix reads
|
||||
max(existing_added_at) for the reference and seeds the next batch
|
||||
past it, deterministically restoring insertion order.
|
||||
|
||||
This test simulates the collision by pinning get_utc_now() so the
|
||||
platform-dependent race becomes a platform-independent failure.
|
||||
"""
|
||||
ref = self._make_ref(session)
|
||||
|
||||
from datetime import datetime
|
||||
from app.assets.database import queries as queries_pkg
|
||||
from app.assets.database.queries import tags as tags_module
|
||||
|
||||
frozen = datetime(2026, 1, 1, 0, 0, 0)
|
||||
monkeypatch.setattr(tags_module, "get_utc_now", lambda: frozen)
|
||||
monkeypatch.setattr(queries_pkg, "get_utc_now", lambda: frozen, raising=False)
|
||||
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
|
||||
session.commit()
|
||||
|
||||
# Same frozen timestamp — without the max(existing) seed, the
|
||||
# user batch would share added_at with the path batch and
|
||||
# `aaa-user-tag` would sort to position 0 via the alphabetic
|
||||
# tiebreaker.
|
||||
add_tags_to_reference(
|
||||
session, reference_id=ref.id, tags=["aaa-user-tag"], origin="manual"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
assert tag_map[ref.id] == ["models", "checkpoints", "aaa-user-tag"]
|
||||
|
||||
def test_remove_then_add_does_not_disrupt_path_tag_positions(
|
||||
self, session: Session
|
||||
):
|
||||
ref = self._make_ref(session)
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "loras/my/custom/path"],
|
||||
)
|
||||
session.commit()
|
||||
add_tags_to_reference(session, reference_id=ref.id, tags=["temp-tag"])
|
||||
session.commit()
|
||||
from app.assets.database.queries import remove_tags_from_reference
|
||||
|
||||
remove_tags_from_reference(session, reference_id=ref.id, tags=["temp-tag"])
|
||||
session.commit()
|
||||
add_tags_to_reference(session, reference_id=ref.id, tags=["second-tag"])
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
# `loras` is expanded from the nested category path; user-added
|
||||
# tags trail behind it via the microsecond stagger.
|
||||
assert tag_map[ref.id] == [
|
||||
"models",
|
||||
"loras/my/custom/path",
|
||||
"loras",
|
||||
"second-tag",
|
||||
]
|
||||
|
||||
|
||||
class TestFetchReferenceAssetAndTags:
|
||||
def test_returns_none_for_nonexistent(self, session: Session):
|
||||
result = fetch_reference_asset_and_tags(session, "nonexistent")
|
||||
|
||||
@ -160,120 +160,6 @@ class TestAddTagsToReference:
|
||||
add_tags_to_reference(session, reference_id="nonexistent", tags=["x"])
|
||||
|
||||
|
||||
class TestBucketPrefixExpansion:
|
||||
"""The standalone bucket token must appear in the asset's tag set for
|
||||
nested category paths so FE filters like
|
||||
`include_tags=models,checkpoints` continue to match.
|
||||
"""
|
||||
|
||||
def test_set_reference_tags_inserts_bucket_for_nested_path(
|
||||
self, session: Session
|
||||
):
|
||||
asset = _make_asset(session, "hash-nested")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
result = set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "checkpoints/flux"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert set(result.total) == {"models", "checkpoints/flux", "checkpoints"}
|
||||
stored = get_reference_tags(session, reference_id=ref.id)
|
||||
# tag[1] keeps the slash-joined positional contract; the standalone
|
||||
# bucket lands after it via path-tier microsecond stagger so user
|
||||
# tags remain at the tail.
|
||||
assert stored[:3] == ["models", "checkpoints/flux", "checkpoints"]
|
||||
|
||||
def test_set_reference_tags_idempotent_on_replay(self, session: Session):
|
||||
asset = _make_asset(session, "hash-replay")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "checkpoints/flux"],
|
||||
)
|
||||
# Replay with the same caller-supplied set; expansion is already
|
||||
# baked in, so nothing should be added or removed.
|
||||
result = set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "checkpoints/flux"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert result.added == []
|
||||
assert result.removed == []
|
||||
assert set(result.total) == {"models", "checkpoints/flux", "checkpoints"}
|
||||
|
||||
def test_add_tags_to_reference_expands_bucket(self, session: Session):
|
||||
asset = _make_asset(session, "hash-add")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
result = add_tags_to_reference(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["loras/style/v2"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert set(result.added) == {"loras/style/v2", "loras"}
|
||||
stored = get_reference_tags(session, reference_id=ref.id)
|
||||
assert "loras" in stored
|
||||
assert "loras/style/v2" in stored
|
||||
|
||||
def test_add_tags_does_not_duplicate_existing_bucket(self, session: Session):
|
||||
asset = _make_asset(session, "hash-dedupe")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
add_tags_to_reference(
|
||||
session, reference_id=ref.id, tags=["models", "checkpoints"]
|
||||
)
|
||||
result = add_tags_to_reference(
|
||||
session, reference_id=ref.id, tags=["checkpoints/flux"]
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# `checkpoints` was already there from the first add; only the
|
||||
# slash-joined token is genuinely new.
|
||||
assert result.added == ["checkpoints/flux"]
|
||||
assert "checkpoints" in result.already_present
|
||||
|
||||
def test_flat_category_is_unaffected(self, session: Session):
|
||||
asset = _make_asset(session, "hash-flat")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
result = set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "checkpoints"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert set(result.total) == {"models", "checkpoints"}
|
||||
assert get_reference_tags(session, reference_id=ref.id) == [
|
||||
"models",
|
||||
"checkpoints",
|
||||
]
|
||||
|
||||
def test_unknown_prefix_passes_through(self, session: Session):
|
||||
asset = _make_asset(session, "hash-user")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
# `my-org` isn't a registered bucket — the slash-joined user tag
|
||||
# should not trigger bucket expansion.
|
||||
result = set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["my-org/team-a"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert result.total == ["my-org/team-a"]
|
||||
|
||||
|
||||
class TestRemoveTagsFromReference:
|
||||
def test_removes_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
|
||||
@ -4,7 +4,7 @@ from pathlib import Path
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets
|
||||
|
||||
|
||||
@ -102,82 +102,6 @@ class TestBatchInsertSeedAssets:
|
||||
assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}"
|
||||
|
||||
|
||||
class TestBucketPrefixExpansionOnIngest:
|
||||
"""Path-scanning ingest must persist the standalone bucket token for
|
||||
nested category paths so the FE set-membership filter
|
||||
(`include_tags=models,checkpoints`) matches assets organized into
|
||||
subfolders (`models/checkpoints/flux/foo.safetensors`).
|
||||
"""
|
||||
|
||||
def test_nested_path_inserts_standalone_bucket(
|
||||
self, session: Session, temp_dir: Path
|
||||
):
|
||||
file_path = temp_dir / "flux.safetensors"
|
||||
file_path.write_bytes(b"content")
|
||||
|
||||
specs: list[SeedAssetSpec] = [
|
||||
{
|
||||
"abs_path": str(file_path),
|
||||
"size_bytes": 7,
|
||||
"mtime_ns": 1234567890000000000,
|
||||
"info_name": "flux",
|
||||
# Shape emitted by get_name_and_tags_from_asset_path for a
|
||||
# nested model path.
|
||||
"tags": ["models", "checkpoints/flux"],
|
||||
"fname": "flux.safetensors",
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": "application/safetensors",
|
||||
}
|
||||
]
|
||||
|
||||
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||
|
||||
assert result.inserted_refs == 1
|
||||
ref = session.query(AssetReference).filter_by(name="flux").one()
|
||||
stored = [
|
||||
row.tag_name
|
||||
for row in session.query(AssetReferenceTag)
|
||||
.filter_by(asset_reference_id=ref.id)
|
||||
.order_by(AssetReferenceTag.added_at.asc())
|
||||
.all()
|
||||
]
|
||||
assert stored == ["models", "checkpoints/flux", "checkpoints"]
|
||||
|
||||
def test_flat_path_remains_two_tags(
|
||||
self, session: Session, temp_dir: Path
|
||||
):
|
||||
file_path = temp_dir / "vanilla.safetensors"
|
||||
file_path.write_bytes(b"content")
|
||||
|
||||
specs: list[SeedAssetSpec] = [
|
||||
{
|
||||
"abs_path": str(file_path),
|
||||
"size_bytes": 7,
|
||||
"mtime_ns": 1234567890000000000,
|
||||
"info_name": "vanilla",
|
||||
"tags": ["models", "checkpoints"],
|
||||
"fname": "vanilla.safetensors",
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": "application/safetensors",
|
||||
}
|
||||
]
|
||||
|
||||
batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||
|
||||
ref = session.query(AssetReference).filter_by(name="vanilla").one()
|
||||
stored = {
|
||||
row.tag_name
|
||||
for row in session.query(AssetReferenceTag)
|
||||
.filter_by(asset_reference_id=ref.id)
|
||||
.all()
|
||||
}
|
||||
# Dedupe means flat layouts don't pick up a redundant `checkpoints`
|
||||
# row — tag[1] already serves both positional and set-membership.
|
||||
assert stored == {"models", "checkpoints"}
|
||||
|
||||
|
||||
class TestMetadataExtraction:
|
||||
def test_extracts_mime_type_for_model_files(self, temp_dir: Path):
|
||||
"""Verify metadata extraction returns correct mime_type for model files."""
|
||||
|
||||
@ -6,11 +6,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.services.path_utils import (
|
||||
get_asset_category_and_relative_path,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
)
|
||||
from app.assets.services.path_utils import get_asset_category_and_relative_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -42,50 +38,6 @@ def fake_dirs():
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_dirs_multi_bucket():
|
||||
"""Variant fixture with multiple model buckets (checkpoints + diffusers + loras)."""
|
||||
with tempfile.TemporaryDirectory() as root:
|
||||
root_path = Path(root)
|
||||
input_dir = root_path / "input"
|
||||
output_dir = root_path / "output"
|
||||
temp_dir = root_path / "temp"
|
||||
checkpoints_dir = root_path / "models" / "checkpoints"
|
||||
diffusers_dir = root_path / "models" / "diffusers"
|
||||
loras_dir = root_path / "models" / "loras"
|
||||
for d in (
|
||||
input_dir,
|
||||
output_dir,
|
||||
temp_dir,
|
||||
checkpoints_dir,
|
||||
diffusers_dir,
|
||||
loras_dir,
|
||||
):
|
||||
d.mkdir(parents=True)
|
||||
|
||||
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
|
||||
mock_fp.get_input_directory.return_value = str(input_dir)
|
||||
mock_fp.get_output_directory.return_value = str(output_dir)
|
||||
mock_fp.get_temp_directory.return_value = str(temp_dir)
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[
|
||||
("checkpoints", [str(checkpoints_dir)]),
|
||||
("diffusers", [str(diffusers_dir)]),
|
||||
("loras", [str(loras_dir)]),
|
||||
],
|
||||
):
|
||||
yield {
|
||||
"input": input_dir,
|
||||
"output": output_dir,
|
||||
"temp": temp_dir,
|
||||
"checkpoints": checkpoints_dir,
|
||||
"diffusers": diffusers_dir,
|
||||
"loras": loras_dir,
|
||||
}
|
||||
|
||||
|
||||
class TestGetAssetCategoryAndRelativePath:
|
||||
def test_input_file(self, fake_dirs):
|
||||
f = fake_dirs["input"] / "photo.png"
|
||||
@ -127,161 +79,3 @@ class TestGetAssetCategoryAndRelativePath:
|
||||
def test_unknown_path_raises(self, fake_dirs):
|
||||
with pytest.raises(ValueError, match="not within"):
|
||||
get_asset_category_and_relative_path("/some/random/path.png")
|
||||
|
||||
|
||||
class TestGetNameAndTagsFromAssetPath:
|
||||
"""tags collapse the parent subpath into a single slash-joined tag.
|
||||
|
||||
Consumers should be able to read ``tags[1]`` as a stable category
|
||||
identifier regardless of how deep the file lives in the bucket.
|
||||
"""
|
||||
|
||||
def test_flat_input(self, fake_dirs_multi_bucket):
|
||||
f = fake_dirs_multi_bucket["input"] / "photo.png"
|
||||
f.touch()
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "photo.png"
|
||||
assert tags == ["input"]
|
||||
|
||||
def test_flat_output(self, fake_dirs_multi_bucket):
|
||||
f = fake_dirs_multi_bucket["output"] / "result_00001.png"
|
||||
f.touch()
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "result_00001.png"
|
||||
assert tags == ["output"]
|
||||
|
||||
def test_flat_models_checkpoint(self, fake_dirs_multi_bucket):
|
||||
f = fake_dirs_multi_bucket["checkpoints"] / "flux.safetensors"
|
||||
f.touch()
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "flux.safetensors"
|
||||
assert tags == ["models", "checkpoints"]
|
||||
|
||||
def test_diffusers_nested_subpath_slash_joined(self, fake_dirs_multi_bucket):
|
||||
"""Diffusers components live in nested directories — the full subpath
|
||||
must collapse into one tag so consumers can look up the model category
|
||||
via tags[1] regardless of nesting depth.
|
||||
|
||||
The subpath is lowercased to match the canonicalization
|
||||
:func:`ensure_tags_exist` applies on the write side; without that,
|
||||
the asset_reference_tags.tag_name FK to tags.name would fail for
|
||||
any path containing uppercase letters.
|
||||
"""
|
||||
nested = (
|
||||
fake_dirs_multi_bucket["diffusers"]
|
||||
/ "Kolors"
|
||||
/ "text_encoder"
|
||||
)
|
||||
nested.mkdir(parents=True)
|
||||
f = nested / "model.safetensors"
|
||||
f.touch()
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "model.safetensors"
|
||||
assert tags == ["models", "diffusers/kolors/text_encoder"]
|
||||
|
||||
def test_deep_lora_user_subpath_slash_joined(self, fake_dirs_multi_bucket):
|
||||
"""User-created subdirectories under a model bucket also collapse to a
|
||||
single tag rather than one tag per directory."""
|
||||
nested = (
|
||||
fake_dirs_multi_bucket["loras"]
|
||||
/ "my"
|
||||
/ "custom"
|
||||
/ "path"
|
||||
)
|
||||
nested.mkdir(parents=True)
|
||||
f = nested / "v0001.safetensors"
|
||||
f.touch()
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "v0001.safetensors"
|
||||
assert tags == ["models", "loras/my/custom/path"]
|
||||
|
||||
|
||||
class TestResolveDestinationFromTags:
|
||||
"""resolve_destination_from_tags must accept both the legacy
|
||||
one-tag-per-directory shape and the new slash-joined shape so that an
|
||||
upload using the tags it just read back from /api/assets round-trips
|
||||
to the right on-disk destination.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def resolve_dirs(self):
|
||||
with tempfile.TemporaryDirectory() as root:
|
||||
root_path = Path(root)
|
||||
input_dir = root_path / "input"
|
||||
output_dir = root_path / "output"
|
||||
checkpoints_dir = root_path / "models" / "checkpoints"
|
||||
diffusers_dir = root_path / "models" / "diffusers"
|
||||
loras_dir = root_path / "models" / "loras"
|
||||
for d in (input_dir, output_dir, checkpoints_dir, diffusers_dir, loras_dir):
|
||||
d.mkdir(parents=True)
|
||||
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
|
||||
mock_fp.get_input_directory.return_value = str(input_dir)
|
||||
mock_fp.get_output_directory.return_value = str(output_dir)
|
||||
mock_fp.folder_names_and_paths = {
|
||||
"checkpoints": ([str(checkpoints_dir)], None),
|
||||
"diffusers": ([str(diffusers_dir)], None),
|
||||
"loras": ([str(loras_dir)], None),
|
||||
}
|
||||
yield {
|
||||
"input": input_dir,
|
||||
"output": output_dir,
|
||||
"checkpoints": checkpoints_dir,
|
||||
"diffusers": diffusers_dir,
|
||||
"loras": loras_dir,
|
||||
}
|
||||
|
||||
def test_models_flat_category(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(["models", "checkpoints"])
|
||||
assert base == str(resolve_dirs["checkpoints"])
|
||||
assert subdirs == []
|
||||
|
||||
def test_models_slash_joined_new_shape(self, resolve_dirs):
|
||||
# The shape get_name_and_tags_from_asset_path now emits.
|
||||
base, subdirs = resolve_destination_from_tags(
|
||||
["models", "diffusers/kolors/text_encoder"]
|
||||
)
|
||||
assert base == str(resolve_dirs["diffusers"])
|
||||
assert subdirs == ["kolors", "text_encoder"]
|
||||
|
||||
def test_models_legacy_one_tag_per_dir(self, resolve_dirs):
|
||||
# The legacy shape must still resolve identically.
|
||||
base, subdirs = resolve_destination_from_tags(
|
||||
["models", "diffusers", "kolors", "text_encoder"]
|
||||
)
|
||||
assert base == str(resolve_dirs["diffusers"])
|
||||
assert subdirs == ["kolors", "text_encoder"]
|
||||
|
||||
def test_models_loras_slash_joined(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(
|
||||
["models", "loras/my/custom/path"]
|
||||
)
|
||||
assert base == str(resolve_dirs["loras"])
|
||||
assert subdirs == ["my", "custom", "path"]
|
||||
|
||||
def test_input_no_subdir(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(["input"])
|
||||
assert base == str(resolve_dirs["input"])
|
||||
assert subdirs == []
|
||||
|
||||
def test_input_slash_joined_subdir(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(["input", "portraits/2026"])
|
||||
assert base == str(resolve_dirs["input"])
|
||||
assert subdirs == ["portraits", "2026"]
|
||||
|
||||
def test_output_slash_joined_subdir(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(["output", "runs/abc"])
|
||||
assert base == str(resolve_dirs["output"])
|
||||
assert subdirs == ["runs", "abc"]
|
||||
|
||||
def test_unknown_category_rejected(self, resolve_dirs):
|
||||
with pytest.raises(ValueError, match="unknown model category"):
|
||||
resolve_destination_from_tags(["models", "not_a_real_category"])
|
||||
|
||||
def test_unknown_category_via_slash_joined(self, resolve_dirs):
|
||||
# First segment of a slash-joined tag must still match a registered category.
|
||||
with pytest.raises(ValueError, match="unknown model category 'bogus'"):
|
||||
resolve_destination_from_tags(["models", "bogus/sub/path"])
|
||||
|
||||
def test_traversal_in_subdir_rejected(self, resolve_dirs):
|
||||
with pytest.raises(ValueError, match="invalid path component"):
|
||||
resolve_destination_from_tags(["models", "checkpoints/..", "evil"])
|
||||
|
||||
@ -32,7 +32,7 @@ def test_seed_asset_removed_when_file_is_deleted(
|
||||
# Verify it is visible via API and carries no hash (seed)
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests/syncseed", "name_contains": name},
|
||||
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body1 = r1.json()
|
||||
@ -52,7 +52,7 @@ def test_seed_asset_removed_when_file_is_deleted(
|
||||
# It should disappear (AssetInfo and seed Asset gone)
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests/syncseed", "name_contains": name},
|
||||
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body2 = r2.json()
|
||||
@ -332,7 +332,7 @@ def test_fastpass_removes_stale_state_row_no_missing(
|
||||
|
||||
rl = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests/{scope}"},
|
||||
params={"include_tags": f"unit-tests,{scope}"},
|
||||
timeout=120,
|
||||
)
|
||||
bl = rl.json()
|
||||
|
||||
@ -280,15 +280,9 @@ def test_metadata_filename_is_set_for_seed_asset_without_hash(
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Scanner emits tags as ``[root, "<dir1>/<dir2>/..."]`` — the second tag
|
||||
# is the slash-joined parent subpath. For ``<root>/unit-tests/<scope>/a/b/<name>``
|
||||
# the second tag is ``"unit-tests/<scope>/a/b"``.
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": f"unit-tests/{scope}/a/b",
|
||||
"name_contains": name,
|
||||
},
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body = r1.json()
|
||||
|
||||
@ -1,69 +0,0 @@
|
||||
"""Unit tests for app.assets.helpers."""
|
||||
|
||||
from app.assets.helpers import expand_bucket_prefixes
|
||||
|
||||
|
||||
class TestExpandBucketPrefixes:
|
||||
def test_flat_category_unchanged(self):
|
||||
# `checkpoints` is already a standalone token, no expansion needed.
|
||||
assert expand_bucket_prefixes(["models", "checkpoints"]) == [
|
||||
"models",
|
||||
"checkpoints",
|
||||
]
|
||||
|
||||
def test_nested_category_inserts_bucket(self):
|
||||
# Path-derived shape for `models/checkpoints/flux/foo.safetensors` —
|
||||
# the standalone bucket has to be present so the FE set-membership
|
||||
# filter (`include_tags=models,checkpoints`) matches the asset.
|
||||
assert expand_bucket_prefixes(["models", "checkpoints/flux"]) == [
|
||||
"models",
|
||||
"checkpoints/flux",
|
||||
"checkpoints",
|
||||
]
|
||||
|
||||
def test_deeply_nested_only_first_segment_expands(self):
|
||||
# Only the FIRST slash segment ever gets emitted as a standalone —
|
||||
# intermediate path segments don't have routing significance.
|
||||
assert expand_bucket_prefixes(
|
||||
["models", "diffusers/kolors/text_encoder"]
|
||||
) == ["models", "diffusers/kolors/text_encoder", "diffusers"]
|
||||
|
||||
def test_unknown_prefix_does_not_expand(self):
|
||||
# Free-form user labels with slashes whose first segment is not a
|
||||
# registered bucket pass through opaquely.
|
||||
assert expand_bucket_prefixes(["models", "my-org/team-a"]) == [
|
||||
"models",
|
||||
"my-org/team-a",
|
||||
]
|
||||
|
||||
def test_idempotent(self):
|
||||
# Re-applying the helper is a no-op once the bucket is in the set.
|
||||
expanded = expand_bucket_prefixes(["models", "checkpoints/flux"])
|
||||
assert expand_bucket_prefixes(expanded) == expanded
|
||||
|
||||
def test_does_not_duplicate_existing_bucket(self):
|
||||
# If the caller already supplied the standalone bucket, don't add a
|
||||
# second copy.
|
||||
assert expand_bucket_prefixes(
|
||||
["models", "checkpoints/flux", "checkpoints"]
|
||||
) == ["models", "checkpoints/flux", "checkpoints"]
|
||||
|
||||
def test_preserves_caller_order(self):
|
||||
# User tags after path tags must stay after; the inserted bucket
|
||||
# token slots in immediately after its slash-joined parent so the
|
||||
# microsecond stagger lands it at path-tier before user-tier.
|
||||
assert expand_bucket_prefixes(
|
||||
["models", "loras/style", "favorite", "v2"]
|
||||
) == ["models", "loras/style", "loras", "favorite", "v2"]
|
||||
|
||||
def test_empty_input(self):
|
||||
assert expand_bucket_prefixes([]) == []
|
||||
|
||||
def test_input_root_with_subpath_no_expansion(self):
|
||||
# `portraits` isn't a registered model category, so the input
|
||||
# subpath stays opaque (FE filter doesn't have a checkpoint-loader
|
||||
# analogue for input subfolders).
|
||||
assert expand_bucket_prefixes(["input", "portraits/2026"]) == [
|
||||
"input",
|
||||
"portraits/2026",
|
||||
]
|
||||
@ -29,10 +29,7 @@ def create_seed_file(comfy_tmp_base_dir: Path):
|
||||
def find_asset(http: requests.Session, api_base: str):
|
||||
"""Query API for assets matching scope and optional name."""
|
||||
def _find(scope: str, name: str | None = None) -> list[dict]:
|
||||
# Scanner now emits tags as ``[root, "<dir1>/<dir2>/..."]`` rather than
|
||||
# one tag per directory. For files at ``<root>/unit-tests/<scope>/...``
|
||||
# the second tag is exactly ``"unit-tests/<scope>"``.
|
||||
params = {"include_tags": f"unit-tests/{scope}"}
|
||||
params = {"include_tags": f"unit-tests,{scope}"}
|
||||
if name:
|
||||
params["name_contains"] = name
|
||||
r = http.get(f"{api_base}/api/assets", params=params, timeout=120)
|
||||
@ -141,7 +138,4 @@ def test_special_chars_in_path_escaped_correctly(
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Scanner emits the full parent subpath as a single slash-joined tag, so
|
||||
# the lookup tag is ``unit-tests/<scope>`` even when <scope> itself
|
||||
# contains a slash (parent + special-char dirname).
|
||||
assert find_asset(scope, fp.name), "Asset with special chars should survive"
|
||||
assert find_asset(scope.split("/")[0], fp.name), "Asset with special chars should survive"
|
||||
|
||||
@ -1,135 +0,0 @@
|
||||
"""HTTP-layer smoke test: user-added tags via POST /api/assets/{id}/tags
|
||||
land after path tags when read back via GET /api/assets.
|
||||
|
||||
Exercises the full route handler -> service -> query path that the unit
|
||||
tests at tests-unit/assets_test/queries/test_asset_info.py only cover at
|
||||
the service layer.
|
||||
"""
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def smoke_asset(http: requests.Session, api_base: str):
|
||||
"""Upload a single asset into models/checkpoints/unit-tests/smoke
|
||||
and delete it on teardown."""
|
||||
name = "smoke_user_tag.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "smoke"]
|
||||
files = {"file": (name, b"S" * 4096, "application/octet-stream")}
|
||||
form_data = {
|
||||
"tags": json.dumps(tags),
|
||||
"name": name,
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
|
||||
assert r.status_code == 201, r.text
|
||||
body = r.json()
|
||||
yield body
|
||||
http.delete(
|
||||
f"{api_base}/api/assets/{body['id']}?delete_content=true", timeout=30
|
||||
)
|
||||
|
||||
|
||||
def _fetch_asset_tags(http, api_base, ref_id):
|
||||
r = http.get(f"{api_base}/api/assets/{ref_id}", timeout=30)
|
||||
assert r.status_code == 200, r.text
|
||||
return r.json()["tags"]
|
||||
|
||||
|
||||
def test_user_tag_lands_after_path_tags_via_http(
|
||||
http: requests.Session, api_base: str, smoke_asset: dict
|
||||
):
|
||||
ref_id = smoke_asset["id"]
|
||||
|
||||
initial_tags = _fetch_asset_tags(http, api_base, ref_id)
|
||||
# Path tags should already be at the front in upload order.
|
||||
assert initial_tags[:2] == ["models", "checkpoints"]
|
||||
|
||||
# Add a user tag that would jump to position 0 under alphabetical sort.
|
||||
r = http.post(
|
||||
f"{api_base}/api/assets/{ref_id}/tags",
|
||||
json={"tags": ["aaa-user-tag"]},
|
||||
timeout=30,
|
||||
)
|
||||
assert r.status_code in (200, 201), r.text
|
||||
|
||||
tags_after = _fetch_asset_tags(http, api_base, ref_id)
|
||||
# Path tags must still be at the front; user tag goes to the end.
|
||||
assert tags_after[0] == "models"
|
||||
assert tags_after[1] == "checkpoints"
|
||||
assert "aaa-user-tag" in tags_after
|
||||
assert tags_after[-1] == "aaa-user-tag"
|
||||
|
||||
|
||||
def test_user_tag_batch_lands_after_path_tags_via_http(
|
||||
http: requests.Session, api_base: str, smoke_asset: dict
|
||||
):
|
||||
ref_id = smoke_asset["id"]
|
||||
|
||||
# Add three user tags in a single request, in non-alphabetical input
|
||||
# order. They should all land after the path tags (microsecond stagger
|
||||
# in set_reference_tags / add_tags_to_reference is what makes this
|
||||
# work — without it, "aaa" would jump to position 0).
|
||||
r = http.post(
|
||||
f"{api_base}/api/assets/{ref_id}/tags",
|
||||
json={"tags": ["zzz-z", "favorite", "aaa-experiment"]},
|
||||
timeout=30,
|
||||
)
|
||||
assert r.status_code in (200, 201), r.text
|
||||
|
||||
tags_after = _fetch_asset_tags(http, api_base, ref_id)
|
||||
assert tags_after[0] == "models"
|
||||
assert tags_after[1] == "checkpoints"
|
||||
user_tail = tags_after[len({"models", "checkpoints", "unit-tests", "smoke"}):]
|
||||
assert set(user_tail) >= {"zzz-z", "favorite", "aaa-experiment"}
|
||||
# Critically: alphabetical sort would put 'aaa-experiment' at position 0.
|
||||
assert tags_after.index("aaa-experiment") > tags_after.index("models")
|
||||
assert tags_after.index("aaa-experiment") > tags_after.index("checkpoints")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nested_checkpoint_asset(http: requests.Session, api_base: str):
|
||||
"""Upload a checkpoint at the slash-joined path shape cloud emits
|
||||
(`models/checkpoints/flux/...`), then delete it on teardown.
|
||||
"""
|
||||
name = "nested_checkpoint.safetensors"
|
||||
tags = ["models", "checkpoints/flux"]
|
||||
files = {"file": (name, b"S" * 4096, "application/octet-stream")}
|
||||
form_data = {
|
||||
"tags": json.dumps(tags),
|
||||
"name": name,
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
|
||||
assert r.status_code == 201, r.text
|
||||
body = r.json()
|
||||
yield body
|
||||
http.delete(
|
||||
f"{api_base}/api/assets/{body['id']}?delete_content=true", timeout=30
|
||||
)
|
||||
|
||||
|
||||
def test_nested_checkpoint_satisfies_fe_set_filter(
|
||||
http: requests.Session, api_base: str, nested_checkpoint_asset: dict
|
||||
):
|
||||
"""The case Simon flagged: a nested-path checkpoint must still match
|
||||
`include_tags=models,checkpoints` — the FE combo-widget filter.
|
||||
"""
|
||||
ref_id = nested_checkpoint_asset["id"]
|
||||
|
||||
stored = _fetch_asset_tags(http, api_base, ref_id)
|
||||
# tag[1] keeps cloud's slash-joined positional contract; tag[2] holds
|
||||
# the standalone bucket the FE filter looks for.
|
||||
assert stored[:3] == ["models", "checkpoints/flux", "checkpoints"]
|
||||
|
||||
# The actual FE query — exact set-membership across both tokens.
|
||||
r = http.get(
|
||||
f"{api_base}/api/assets",
|
||||
params=[("include_tags", "models"), ("include_tags", "checkpoints")],
|
||||
timeout=30,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
returned_ids = {a["id"] for a in r.json()["assets"]}
|
||||
assert ref_id in returned_ids
|
||||
Reference in New Issue
Block a user