Compare commits

..

4 Commits

Author SHA1 Message Date
8aabe2403e Add color type and Color to RGB Int node (#12145)
* add color type and color to rgb int node

* review fix for allowing output

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-01-30 15:01:33 -08:00
0167653781 feat(api-nodes): add RecraftCreateStyleNode node (#12055)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-01-30 14:04:43 -08:00
0a7993729c Remove NodeInfoV3-related code; we are almost 100% guaranteed to stick with NodeInfoV1 for the foreseable future (#12147)
Co-authored-by: guill <jacob.e.segal@gmail.com>
2026-01-30 10:21:48 -08:00
bbe2c13a70 Make empty hunyuan latent 1.0 work with the 1.5 model. (#12171) 2026-01-29 23:52:22 -05:00
24 changed files with 184 additions and 4465 deletions

View File

@ -1,30 +0,0 @@
name: Assets Tests
on:
push:
branches: [ main, master, release/** ]
pull_request:
branches: [ main, master, release/** ]
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
runs-on: ${{ matrix.os }}
continue-on-error: true
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Run Assets Tests
run: |
pip install -r tests-assets/requirements.txt
python -m pytest tests-assets -v

View File

@ -1,8 +1,5 @@
import logging
import uuid
import urllib.parse
import os
import contextlib
from aiohttp import web
from pydantic import ValidationError
@ -11,9 +8,6 @@ import app.assets.manager as manager
from app import user_manager
from app.assets.api import schemas_in
from app.assets.helpers import get_query_dict
from app.assets.scanner import seed_assets
import folder_paths
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
@ -21,9 +15,6 @@ USER_MANAGER: user_manager.UserManager | None = None
# UUID regex (canonical hyphenated form, case-insensitive)
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
# Note to any custom node developers reading this code:
# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same.
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
global USER_MANAGER
USER_MANAGER = user_manager_instance
@ -37,18 +28,6 @@ def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
@ROUTES.head("/api/assets/hash/{hash}")
async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = hash_str.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
exists = manager.asset_exists(asset_hash=hash_str)
return web.Response(status=200 if exists else 404)
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
"""
@ -71,7 +50,7 @@ async def list_assets(request: web.Request) -> web.Response:
order=q.order,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
return web.json_response(payload.model_dump(mode="json"))
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
@ -97,314 +76,6 @@ async def get_asset(request: web.Request) -> web.Response:
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
async def download_asset_content(request: web.Request) -> web.Response:
# question: do we need disposition? could we just stick with one of these?
disposition = request.query.get("disposition", "attachment").lower().strip()
if disposition not in {"inline", "attachment"}:
disposition = "attachment"
try:
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
asset_info_id=str(uuid.UUID(request.match_info["id"])),
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
except NotImplementedError as nie:
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
except FileNotFoundError:
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
file_size = os.path.getsize(abs_path)
logging.info(
"download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s",
abs_path,
file_size,
file_size / (1024 * 1024),
content_type,
filename,
)
async def file_sender():
chunk_size = 64 * 1024
with open(abs_path, "rb") as f:
while True:
chunk = f.read(chunk_size)
if not chunk:
break
yield chunk
return web.Response(
body=file_sender(),
content_type=content_type,
headers={
"Content-Disposition": cd,
"Content-Length": str(file_size),
},
)
@ROUTES.post("/api/assets/from-hash")
async def create_asset_from_hash(request: web.Request) -> web.Response:
try:
payload = await request.json()
body = schemas_in.CreateFromHashBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
result = manager.create_asset_from_hash(
hash_str=body.hash,
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
return web.json_response(result.model_dump(mode="json"), status=201)
@ROUTES.post("/api/assets")
async def upload_asset(request: web.Request) -> web.Response:
"""Multipart/form-data endpoint for Asset uploads."""
if not (request.content_type or "").lower().startswith("multipart/"):
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
reader = await request.multipart()
file_present = False
file_client_name: str | None = None
tags_raw: list[str] = []
provided_name: str | None = None
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
file_written = 0
tmp_path: str | None = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
if s:
if ":" not in s:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
provided_hash = f"{algo}:{digest}"
try:
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
except Exception:
provided_hash_exists = None # do not fail the whole request here
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# If client supplied a hash that we know exists, drain but do not write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
continue # Do not create temp file; we will create AssetInfo from the existing content
# Otherwise, store to temp for hashing/ingest
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
try:
if os.path.exists(tmp_path or ""):
os.remove(tmp_path)
finally:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
# If client did not send file, and we are not doing a from-hash fast path -> error
if not file_present and not (provided_hash and provided_hash_exists):
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
# Empty upload is only acceptable if we are fast-pathing from existing hash
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
try:
spec = schemas_in.UploadAssetSpec.model_validate({
"tags": tags_raw,
"name": provided_name,
"user_metadata": user_metadata_raw,
"hash": provided_hash,
})
except ValidationError as ve:
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _validation_error_response("INVALID_BODY", ve)
# Validate models category against configured folders (consistent with previous behavior)
if spec.tags and spec.tags[0] == "models":
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
)
owner_id = USER_MANAGER.get_request_user_id(request)
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
if spec.hash and provided_hash_exists is True:
try:
result = manager.create_asset_from_hash(
hash_str=spec.hash,
name=spec.name or (spec.hash.split(":", 1)[1]),
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
)
except Exception:
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
# Drain temp if we accidentally saved (e.g., hash field came after file)
if tmp_path and os.path.exists(tmp_path):
with contextlib.suppress(Exception):
os.remove(tmp_path)
status = 200 if (not result.created_new) else 201
return web.json_response(result.model_dump(mode="json"), status=status)
# Otherwise, we must have a temp file path to ingest
if not tmp_path or not os.path.exists(tmp_path):
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
try:
created = manager.upload_asset_from_temp_path(
spec,
temp_path=tmp_path,
client_filename=file_client_name,
owner_id=owner_id,
expected_asset_hash=spec.hash,
)
status = 201 if created.created_new else 200
return web.json_response(created.model_dump(mode="json"), status=status)
except ValueError as e:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
msg = str(e)
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
return _error_response(
400,
"HASH_MISMATCH",
"Uploaded file hash does not match provided hash.",
)
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
except Exception:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
async def update_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = manager.update_asset(
asset_info_id=asset_info_id,
name=body.name,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"update_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
async def delete_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
delete_content = request.query.get("delete_content")
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
try:
deleted = manager.delete_asset_reference(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
delete_content_if_orphan=delete_content,
)
except Exception:
logging.exception(
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if not deleted:
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
return web.Response(status=204)
@ROUTES.get("/api/tags")
async def get_tags(request: web.Request) -> web.Response:
"""
@ -429,86 +100,3 @@ async def get_tags(request: web.Request) -> web.Response:
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(result.model_dump(mode="json"))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def add_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsAdd.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = manager.add_tags_to_asset(
asset_info_id=asset_info_id,
tags=data.tags,
origin="manual",
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def delete_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsRemove.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = manager.remove_tags_from_asset(
asset_info_id=asset_info_id,
tags=data.tags,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.post("/api/assets/seed")
async def seed_assets_endpoint(request: web.Request) -> web.Response:
"""Trigger asset seeding for specified roots (models, input, output)."""
try:
payload = await request.json()
roots = payload.get("roots", ["models", "input", "output"])
except Exception:
roots = ["models", "input", "output"]
valid_roots = [r for r in roots if r in ("models", "input", "output")]
if not valid_roots:
return _error_response(400, "INVALID_BODY", "No valid roots specified")
try:
seed_assets(tuple(valid_roots))
except Exception:
logging.exception("seed_assets failed for roots=%s", valid_roots)
return _error_response(500, "INTERNAL", "Seed operation failed")
return web.json_response({"seeded": valid_roots}, status=200)

View File

@ -1,4 +1,5 @@
import json
import uuid
from typing import Any, Literal
from pydantic import (
@ -7,9 +8,9 @@ from pydantic import (
Field,
conint,
field_validator,
model_validator,
)
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
@ -56,57 +57,6 @@ class ListAssetsQuery(BaseModel):
return None
class UpdateAssetBody(BaseModel):
name: str | None = None
user_metadata: dict[str, Any] | None = None
@model_validator(mode="after")
def _at_least_one(self):
if self.name is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, user_metadata.")
return self
class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str
name: str
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
@field_validator("hash")
@classmethod
def _require_blake3(cls, v):
s = (v or "").strip().lower()
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return s
@field_validator("tags", mode="before")
@classmethod
def _tags_norm(cls, v):
if v is None:
return []
if isinstance(v, list):
out = [str(t).strip().lower() for t in v if str(t).strip()]
seen = set()
dedup = []
for t in out:
if t not in seen:
seen.add(t)
dedup.append(t)
return dedup
if isinstance(v, str):
return [t.strip().lower() for t in v.split(",") if t.strip()]
return []
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
@ -125,140 +75,20 @@ class TagsListQuery(BaseModel):
return v.lower() or None
class TagsAdd(BaseModel):
model_config = ConfigDict(extra="ignore")
tags: list[str] = Field(..., min_length=1)
class SetPreviewBody(BaseModel):
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
preview_id: str | None = None
@field_validator("tags")
@field_validator("preview_id", mode="before")
@classmethod
def normalize_tags(cls, v: list[str]) -> list[str]:
out = []
for t in v:
if not isinstance(t, str):
raise TypeError("tags must be strings")
tnorm = t.strip().lower()
if tnorm:
out.append(tnorm)
seen = set()
deduplicated = []
for x in out:
if x not in seen:
seen.add(x)
deduplicated.append(x)
return deduplicated
class TagsRemove(TagsAdd):
pass
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
and the original extension is preserved when available.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1)
name: str | None = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: str | None = Field(default=None)
@field_validator("hash", mode="before")
@classmethod
def _parse_hash(cls, v):
def _norm_uuid(cls, v):
if v is None:
return None
s = str(v).strip().lower()
s = str(v).strip()
if not s:
return None
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return f"{algo}:{digest}"
@field_validator("tags", mode="before")
@classmethod
def _parse_tags(cls, v):
"""
Accepts a list of strings (possibly multiple form fields),
where each string can be:
- JSON array (e.g., '["models","loras","foo"]')
- comma-separated ('models, loras, foo')
- single token ('models')
Returns a normalized, deduplicated, ordered list.
"""
items: list[str] = []
if v is None:
return []
if isinstance(v, str):
v = [v]
if isinstance(v, list):
for item in v:
if item is None:
continue
s = str(item).strip()
if not s:
continue
if s.startswith("["):
try:
arr = json.loads(s)
if isinstance(arr, list):
items.extend(str(x) for x in arr)
continue
except Exception:
pass # fallback to CSV parse below
items.extend([p for p in s.split(",") if p.strip()])
else:
return []
# normalize + dedupe
norm = []
seen = set()
for t in items:
tnorm = str(t).strip().lower()
if tnorm and tnorm not in seen:
seen.add(tnorm)
norm.append(tnorm)
return norm
@field_validator("user_metadata", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v or {}
if isinstance(v, str):
s = v.strip()
if not s:
return {}
try:
parsed = json.loads(s)
except Exception as e:
raise ValueError(f"user_metadata must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("user_metadata must be a JSON object")
return parsed
return {}
@model_validator(mode="after")
def _validate_order(self):
if not self.tags:
raise ValueError("tags must be provided and non-empty")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError("models uploads require a category tag as the second tag")
return self
try:
uuid.UUID(s)
except Exception:
raise ValueError("preview_id must be a UUID")
return s

View File

@ -29,21 +29,6 @@ class AssetsList(BaseModel):
has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _ser_updated(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
@ -63,10 +48,6 @@ class AssetDetail(BaseModel):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel):
name: str
count: int
@ -77,17 +58,3 @@ class TagsList(BaseModel):
tags: list[TagUsage] = Field(default_factory=list)
total: int
has_more: bool
class TagsAdd(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
added: list[str] = Field(default_factory=list)
already_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class TagsRemove(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)

View File

@ -1,17 +1,9 @@
import os
import logging
import sqlalchemy as sa
from collections import defaultdict
from datetime import datetime
from typing import Iterable, Any
from sqlalchemy import select, delete, exists, func
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy import select, exists, func
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import (
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
)
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import escape_like_prefix, normalize_tags
from typing import Sequence
@ -23,22 +15,6 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
return AssetInfo.owner_id.in_(["", owner_id])
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
@ -66,7 +42,6 @@ def apply_tag_filters(
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
@ -119,11 +94,7 @@ def apply_metadata_filter(
return stmt
def asset_exists_by_hash(
session: Session,
*,
asset_hash: str,
) -> bool:
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
@ -134,39 +105,9 @@ def asset_exists_by_hash(
).first()
return row is not None
def asset_info_exists_for_asset_id(
session: Session,
*,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (session.execute(q)).first() is not None
def get_asset_by_hash(
session: Session,
*,
asset_hash: str,
) -> Asset | None:
return (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
def get_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
) -> AssetInfo | None:
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def list_asset_infos_page(
session: Session,
owner_id: str = "",
@ -230,14 +171,12 @@ def list_asset_infos_page(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
.order_by(AssetInfoTag.added_at)
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
@ -269,494 +208,6 @@ def fetch_asset_info_asset_and_tags(
tags.append(tag_name)
return first_info, first_asset, tags
def fetch_asset_info_and_asset(
session: Session,
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset] | None:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def touch_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
session.execute(stmt.values(last_access_time=ts))
def create_asset_info_for_existing_asset(
session: Session,
*,
asset_hash: str,
name: str,
user_metadata: dict | None = None,
tags: Sequence[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
with session.begin_nested():
session.add(info)
session.flush()
except IntegrityError:
existing = (
session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
def set_asset_info_tags(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def replace_asset_info_metadata_projection(
session: Session,
*,
asset_info_id: str,
user_metadata: dict | None = None,
) -> None:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
session.flush()
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def ingest_fs_asset(
session: Session,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: dict | None = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
if preview_id:
if not session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
vals = {
"hash": asset_hash,
"size_bytes": int(size_bytes),
"mime_type": mime_type,
"created_at": now,
}
res = session.execute(
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
if int(res.rowcount or 0) > 0:
out["asset_created"] = True
asset = (
session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
if not asset:
raise RuntimeError("Asset row not found after upsert.")
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
def update_asset_info_full(
session: Session,
*,
asset_info_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: dict | None = None,
tag_origin: str = "manual",
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
session.flush()
return info
def delete_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
@ -814,163 +265,3 @@ def list_tags_with_usage(
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
return [
tag_name for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def add_tags_to_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)
def set_asset_info_preview(
session: Session,
*,
asset_info_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
session.flush()

View File

@ -1,6 +1,5 @@
import contextlib
import os
from decimal import Decimal
from aiohttp import web
from datetime import datetime, timezone
from pathlib import Path
@ -88,40 +87,6 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def ensure_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
@ -148,6 +113,7 @@ def compute_relative_filename(file_path: str) -> str | None:
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
@ -249,64 +215,3 @@ def collect_models_files() -> list[str]:
if allowed:
out.append(abs_path)
return out
def is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def project_kv(key: str, value):
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
rows: list[dict] = []
def _null_row(ordinal: int) -> dict:
return {
"key": key, "ordinal": ordinal,
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
}
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

View File

@ -1,33 +1,13 @@
import os
import mimetypes
import contextlib
from typing import Sequence
from app.database.db import create_session
from app.assets.api import schemas_out, schemas_in
from app.assets.api import schemas_out
from app.assets.database.queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
get_asset_by_hash,
get_asset_info_by_id,
fetch_asset_info_asset_and_tags,
fetch_asset_info_and_asset,
create_asset_info_for_existing_asset,
touch_asset_info_by_id,
update_asset_info_full,
delete_asset_info_by_id,
list_cache_states_by_asset_id,
list_asset_infos_page,
list_tags_with_usage,
get_asset_tags,
add_tags_to_asset_info,
remove_tags_from_asset_info,
pick_best_live_path,
ingest_fs_asset,
set_asset_info_preview,
)
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
from app.assets.database.models import Asset
def _safe_sort_field(requested: str | None) -> str:
@ -39,28 +19,11 @@ def _safe_sort_field(requested: str | None) -> str:
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n
return fallback
def asset_exists(*, asset_hash: str) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
def asset_exists(asset_hash: str) -> bool:
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def list_assets(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
@ -100,6 +63,7 @@ def list_assets(
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
preview_url=f"/api/assets/{info.id}/content",
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
@ -112,12 +76,7 @@ def list_assets(
has_more=(offset + len(summaries)) < total,
)
def get_asset(
*,
asset_info_id: str,
owner_id: str = "",
) -> schemas_out.AssetDetail:
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
with create_session() as session:
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
@ -138,358 +97,6 @@ def get_asset(
last_access_time=info.last_access_time,
)
def resolve_asset_content_for_download(
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[str, str, str]:
with create_session() as session:
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = pick_best_live_path(states)
if not abs_path:
raise FileNotFoundError
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
def upload_asset_from_temp_path(
spec: schemas_in.UploadAssetSpec,
*,
temp_path: str,
client_filename: str | None = None,
owner_id: str = "",
expected_asset_hash: str | None = None,
) -> schemas_out.AssetCreated:
"""
Create new asset or update existing asset from a temporary file path.
"""
try:
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
import app.assets.hashing as hashing
digest = hashing.blake3_hash(temp_path)
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
raise ValueError("HASH_MISMATCH")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
info = create_asset_info_for_existing_asset(
session,
asset_hash=asset_hash,
name=display_name,
user_metadata=spec.user_metadata or {},
tags=spec.tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=existing.hash,
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
mime_type=existing.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or spec.name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
ensure_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
with create_session() as session:
result = ingest_fs_asset(
session,
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = result["asset_info_id"]
if not info_id:
raise RuntimeError("failed to create asset metadata")
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = get_asset_tags(session, asset_info_id=info.id)
created_result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
)
session.commit()
return created_result
def update_asset(
*,
asset_info_id: str,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
tag_origin="manual",
asset_info_row=info_row,
)
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
result = schemas_out.AssetUpdated(
id=info.id,
name=info.name,
asset_hash=info.asset.hash if info.asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
updated_at=info.updated_at,
)
session.commit()
return result
def set_asset_preview(
*,
asset_info_id: str,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise RuntimeError("State changed during preview update")
info, asset, tags = res
result = schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
session.commit()
return result
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
session.commit()
return False
if not delete_content_if_orphan or not asset_id:
session.commit()
return True
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def create_asset_from_hash(
*,
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetCreated | None:
canonical = hash_str.strip().lower()
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
info = create_asset_info_for_existing_asset(
session,
asset_hash=canonical,
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
session.commit()
return result
def add_tags_to_asset(
*,
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> schemas_out.TagsAdd:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
session.commit()
return schemas_out.TagsAdd(**data)
def remove_tags_from_asset(
*,
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
session.commit()
return schemas_out.TagsRemove(**data)
def list_tags(
prefix: str | None = None,
limit: int = 100,

View File

@ -27,7 +27,6 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
t_start = time.perf_counter()
created = 0
skipped_existing = 0
orphans_pruned = 0
paths: list[str] = []
try:
existing_paths: set[str] = set()
@ -39,11 +38,6 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", r, e)
try:
orphans_pruned = _prune_orphaned_assets(roots)
except Exception as e:
logging.exception("orphan pruning failed: %s", e)
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
@ -91,43 +85,15 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
finally:
if enable_logging:
logging.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
roots,
time.perf_counter() - t_start,
created,
skipped_existing,
orphans_pruned,
len(paths),
)
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
if not all_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_like_prefix(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
orphan_subq = (
sqlalchemy.select(Asset.id)
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
).scalar_subquery()
with create_session() as sess:
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
sess.commit()
return result.rowcount
def _fast_db_consistency_pass(
root: RootType,
*,

View File

@ -1146,6 +1146,20 @@ class ImageCompare(ComfyTypeI):
def as_dict(self):
return super().as_dict()
@comfytype(io_type="COLOR")
class Color(ComfyTypeIO):
Type = str
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, advanced: bool=None, default: str="#ffffff"):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
self.default: str
def as_dict(self):
return super().as_dict()
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
@ -1252,23 +1266,6 @@ class NodeInfoV1:
price_badge: dict | None = None
search_aliases: list[str]=None
@dataclass
class NodeInfoV3:
input: dict=None
output: dict=None
hidden: list[str]=None
name: str=None
display_name: str=None
description: str=None
python_module: Any = None
category: str=None
output_node: bool=None
deprecated: bool=None
experimental: bool=None
dev_only: bool=None
api_node: bool=None
price_badge: dict | None = None
@dataclass
class PriceBadgeDepends:
@ -1497,40 +1494,6 @@ class Schema:
)
return info
def get_v3_info(self, cls) -> NodeInfoV3:
input_dict = {}
output_dict = {}
hidden_list = []
# TODO: make sure dynamic types will be handled correctly
if self.inputs:
for input in self.inputs:
add_to_dict_v3(input, input_dict)
if self.outputs:
for output in self.outputs:
add_to_dict_v3(output, output_dict)
if self.hidden:
for hidden in self.hidden:
hidden_list.append(hidden.value)
info = NodeInfoV3(
input=input_dict,
output=output_dict,
hidden=hidden_list,
name=self.node_id,
display_name=self.display_name,
description=self.description,
category=self.category,
output_node=self.is_output_node,
deprecated=self.is_deprecated,
experimental=self.is_experimental,
dev_only=self.is_dev_only,
api_node=self.is_api_node,
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
)
return info
def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]:
out_dict = {
"required": {},
@ -1585,9 +1548,6 @@ def add_to_dict_v1(i: Input, d: dict):
as_dict.pop("optional", None)
d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict())
class DynamicPathsDefaultValue:
EMPTY_DICT = "empty_dict"
@ -1748,13 +1708,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
# set hidden
type_clone.hidden = HiddenHolder.from_v3_data(v3_data)
return type_clone
@final
@classmethod
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
schema = cls.GET_SCHEMA()
info = schema.get_v3_info(cls)
return asdict(info)
#############################################
# V1 Backwards Compatibility code
#--------------------------------------------
@ -2099,6 +2052,7 @@ __all__ = [
"AnyType",
"MultiType",
"Tracks",
"Color",
# Dynamic Types
"MatchType",
"DynamicCombo",
@ -2107,12 +2061,10 @@ __all__ = [
"HiddenHolder",
"Hidden",
"NodeInfoV1",
"NodeInfoV3",
"Schema",
"ComfyNode",
"NodeOutput",
"add_to_dict_v1",
"add_to_dict_v3",
"V3Data",
"ImageCompare",
"PriceBadgeDepends",

View File

@ -1,11 +1,8 @@
from __future__ import annotations
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field, conint, confloat
from pydantic import BaseModel, Field
class RecraftColor:
@ -229,24 +226,24 @@ class RecraftColorObject(BaseModel):
class RecraftControlsObject(BaseModel):
colors: Optional[list[RecraftColorObject]] = Field(None, description='An array of preferable colors')
background_color: Optional[RecraftColorObject] = Field(None, description='Use given color as a desired background color')
no_text: Optional[bool] = Field(None, description='Do not embed text layouts')
artistic_level: Optional[conint(ge=0, le=5)] = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].')
colors: list[RecraftColorObject] | None = Field(None, description='An array of preferable colors')
background_color: RecraftColorObject | None = Field(None, description='Use given color as a desired background color')
no_text: bool | None = Field(None, description='Do not embed text layouts')
artistic_level: int | None = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].')
class RecraftImageGenerationRequest(BaseModel):
prompt: str = Field(..., description='The text prompt describing the image to generate')
size: Optional[RecraftImageSize] = Field(None, description='The size of the generated image (e.g., "1024x1024")')
n: conint(ge=1, le=6) = Field(..., description='The number of images to generate')
negative_prompt: Optional[str] = Field(None, description='A text description of undesired elements on an image')
model: Optional[RecraftModel] = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
style: Optional[str] = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
substyle: Optional[str] = Field(None, description='The substyle to apply to the generated image, depending on the style input')
controls: Optional[RecraftControlsObject] = Field(None, description='A set of custom parameters to tweak generation process')
style_id: Optional[str] = Field(None, description='Use a previously uploaded style as a reference; UUID')
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
random_seed: Optional[int] = Field(None, description="Seed for video generation")
size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
n: int = Field(..., description='The number of images to generate')
negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image')
model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input')
controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process')
style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID')
strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
random_seed: int | None = Field(None, description="Seed for video generation")
# text_layout
@ -258,5 +255,13 @@ class RecraftReturnedObject(BaseModel):
class RecraftImageGenerationResponse(BaseModel):
created: int = Field(..., description='Unix timestamp when the generation was created')
credits: int = Field(..., description='Number of credits used for the generation')
data: Optional[list[RecraftReturnedObject]] = Field(None, description='Array of generated image information')
image: Optional[RecraftReturnedObject] = Field(None, description='Single generated image')
data: list[RecraftReturnedObject] | None = Field(None, description='Array of generated image information')
image: RecraftReturnedObject | None = Field(None, description='Single generated image')
class RecraftCreateStyleRequest(BaseModel):
style: str = Field(..., description="realistic_image, digital_illustration, vector_illustration, or icon")
class RecraftCreateStyleResponse(BaseModel):
id: str = Field(..., description="UUID of the created style")

View File

@ -12,6 +12,8 @@ from comfy_api_nodes.apis.recraft import (
RecraftColor,
RecraftColorChain,
RecraftControls,
RecraftCreateStyleRequest,
RecraftCreateStyleResponse,
RecraftImageGenerationRequest,
RecraftImageGenerationResponse,
RecraftImageSize,
@ -323,6 +325,75 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
return IO.NodeOutput(RecraftStyle(style_id=style_id))
class RecraftCreateStyleNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RecraftCreateStyleNode",
display_name="Recraft Create Style",
category="api node/image/Recraft",
description="Create a custom style from reference images. "
"Upload 1-5 images to use as style references. "
"Total size of all images is limited to 5 MB.",
inputs=[
IO.Combo.Input(
"style",
options=["realistic_image", "digital_illustration"],
tooltip="The base style of the generated images.",
),
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="image",
min=1,
max=5,
),
),
],
outputs=[
IO.String.Output(display_name="style_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd": 0.04}""",
),
)
@classmethod
async def execute(
cls,
style: str,
images: IO.Autogrow.Type,
) -> IO.NodeOutput:
files = []
total_size = 0
max_total_size = 5 * 1024 * 1024 # 5 MB limit
for i, img in enumerate(list(images.values())):
file_bytes = tensor_to_bytesio(img, total_pixels=2048 * 2048, mime_type="image/webp").read()
total_size += len(file_bytes)
if total_size > max_total_size:
raise Exception("Total size of all images exceeds 5 MB limit.")
files.append((f"file{i + 1}", file_bytes))
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/recraft/styles", method="POST"),
response_model=RecraftCreateStyleResponse,
files=files,
data=RecraftCreateStyleRequest(style=style),
content_type="multipart/form-data",
max_retries=1,
)
return IO.NodeOutput(response.id)
class RecraftTextToImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
@ -395,7 +466,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
negative_prompt: str = None,
recraft_controls: RecraftControls = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, max_length=1000)
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=1000)
default_style = RecraftStyle(RecraftStyleV3.realistic_image)
if recraft_style is None:
recraft_style = default_style
@ -1024,6 +1095,7 @@ class RecraftExtension(ComfyExtension):
RecraftStyleV3DigitalIllustrationNode,
RecraftStyleV3LogoRasterNode,
RecraftStyleInfiniteStyleLibrary,
RecraftCreateStyleNode,
RecraftColorRGBNode,
RecraftControlsNode,
]

View File

@ -0,0 +1,42 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class ColorToRGBInt(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="ColorToRGBInt",
display_name="Color to RGB Int",
category="utils",
description="Convert a color to a RGB integer value.",
inputs=[
io.Color.Input("color"),
],
outputs=[
io.Int.Output(display_name="rgb_int"),
],
)
@classmethod
def execute(
cls,
color: str,
) -> io.NodeOutput:
# expect format #RRGGBB
if len(color) != 7 or color[0] != "#":
raise ValueError("Color must be in format #RRGGBB")
r = int(color[1:3], 16)
g = int(color[3:5], 16)
b = int(color[5:7], 16)
return io.NodeOutput(r * 256 * 256 + g * 256 + b)
class ColorExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [ColorToRGBInt]
async def comfy_entrypoint() -> ColorExtension:
return ColorExtension()

View File

@ -56,7 +56,7 @@ class EmptyHunyuanLatentVideo(io.ComfyNode):
@classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples":latent})
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})
generate = execute # TODO: remove
@ -73,7 +73,7 @@ class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo):
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
# Using scale factor of 16 instead of 8
latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent})
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 16})
class HunyuanVideo15ImageToVideo(io.ComfyNode):

View File

@ -2432,7 +2432,8 @@ async def init_builtin_extra_nodes():
"nodes_wanmove.py",
"nodes_image_compare.py",
"nodes_zimage.py",
"nodes_lora_debug.py"
"nodes_lora_debug.py",
"nodes_color.py"
]
import_failed = []

View File

@ -1,271 +0,0 @@
import contextlib
import json
import os
import socket
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from typing import Callable, Iterator, Optional
import pytest
import requests
def pytest_addoption(parser: pytest.Parser) -> None:
"""
Allow overriding the database URL used by the spawned ComfyUI process.
Priority:
1) --db-url command line option
2) ASSETS_TEST_DB_URL environment variable (used by CI)
3) default: None (will use file-backed sqlite in temp dir)
"""
parser.addoption(
"--db-url",
action="store",
default=os.environ.get("ASSETS_TEST_DB_URL"),
help="SQLAlchemy DB URL (e.g. sqlite:///path/to/db.sqlite3)",
)
def _free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def _make_base_dirs(root: Path) -> None:
for sub in ("models", "custom_nodes", "input", "output", "temp", "user"):
(root / sub).mkdir(parents=True, exist_ok=True)
def _wait_http_ready(base: str, session: requests.Session, timeout: float = 90.0) -> None:
start = time.time()
last_err = None
while time.time() - start < timeout:
try:
r = session.get(base + "/api/assets", timeout=5)
if r.status_code in (200, 400):
return
except Exception as e:
last_err = e
time.sleep(0.25)
raise RuntimeError(f"ComfyUI HTTP did not become ready: {last_err}")
@pytest.fixture(scope="session")
def comfy_tmp_base_dir() -> Path:
env_base = os.environ.get("ASSETS_TEST_BASE_DIR")
created_by_fixture = False
if env_base:
tmp = Path(env_base)
tmp.mkdir(parents=True, exist_ok=True)
else:
tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-"))
created_by_fixture = True
_make_base_dirs(tmp)
yield tmp
if created_by_fixture:
with contextlib.suppress(Exception):
for p in sorted(tmp.rglob("*"), reverse=True):
if p.is_file() or p.is_symlink():
p.unlink(missing_ok=True)
for p in sorted(tmp.glob("**/*"), reverse=True):
with contextlib.suppress(Exception):
p.rmdir()
tmp.rmdir()
@pytest.fixture(scope="session")
def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest):
"""
Boot ComfyUI subprocess with:
- sandbox base dir
- file-backed sqlite DB in temp dir
- autoscan disabled
Returns (base_url, process, port)
"""
port = _free_port()
db_url = request.config.getoption("--db-url")
if not db_url:
# Use a file-backed sqlite database in the temp directory
db_path = comfy_tmp_base_dir / "assets-test.sqlite3"
db_url = f"sqlite:///{db_path}"
logs_dir = comfy_tmp_base_dir / "logs"
logs_dir.mkdir(exist_ok=True)
out_log = open(logs_dir / "stdout.log", "w", buffering=1)
err_log = open(logs_dir / "stderr.log", "w", buffering=1)
comfy_root = Path(__file__).resolve().parent.parent
if not (comfy_root / "main.py").is_file():
raise FileNotFoundError(f"main.py not found under {comfy_root}")
proc = subprocess.Popen(
args=[
sys.executable,
"main.py",
f"--base-directory={str(comfy_tmp_base_dir)}",
f"--database-url={db_url}",
"--disable-assets-autoscan",
"--listen",
"127.0.0.1",
"--port",
str(port),
"--cpu",
],
stdout=out_log,
stderr=err_log,
cwd=str(comfy_root),
env={**os.environ},
)
for _ in range(50):
if proc.poll() is not None:
out_log.flush()
err_log.flush()
raise RuntimeError(f"ComfyUI exited early with code {proc.returncode}")
time.sleep(0.1)
base_url = f"http://127.0.0.1:{port}"
try:
with requests.Session() as s:
_wait_http_ready(base_url, s, timeout=90.0)
yield base_url, proc, port
except Exception as e:
with contextlib.suppress(Exception):
proc.terminate()
proc.wait(timeout=10)
with contextlib.suppress(Exception):
out_log.flush()
err_log.flush()
raise RuntimeError(f"ComfyUI did not become ready: {e}")
if proc and proc.poll() is None:
with contextlib.suppress(Exception):
proc.terminate()
proc.wait(timeout=15)
out_log.close()
err_log.close()
@pytest.fixture
def http() -> Iterator[requests.Session]:
with requests.Session() as s:
s.timeout = 120
yield s
@pytest.fixture
def api_base(comfy_url_and_proc) -> str:
base_url, _proc, _port = comfy_url_and_proc
return base_url
def _post_multipart_asset(
session: requests.Session,
base: str,
*,
name: str,
tags: list[str],
meta: dict,
data: bytes,
extra_fields: Optional[dict] = None,
) -> tuple[int, dict]:
files = {"file": (name, data, "application/octet-stream")}
form_data = {
"tags": json.dumps(tags),
"name": name,
"user_metadata": json.dumps(meta),
}
if extra_fields:
for k, v in extra_fields.items():
form_data[k] = v
r = session.post(base + "/api/assets", files=files, data=form_data, timeout=120)
return r.status_code, r.json()
@pytest.fixture
def make_asset_bytes() -> Callable[[str, int], bytes]:
def _make(name: str, size: int = 8192) -> bytes:
seed = sum(ord(c) for c in name) % 251
return bytes((i * 31 + seed) % 256 for i in range(size))
return _make
@pytest.fixture
def asset_factory(http: requests.Session, api_base: str):
"""
Returns create(name, tags, meta, data) -> response dict
Tracks created ids and deletes them after the test.
"""
created: list[str] = []
def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict:
status, body = _post_multipart_asset(http, api_base, name=name, tags=tags, meta=meta, data=data)
assert status in (200, 201), body
created.append(body["id"])
return body
yield create
for aid in created:
with contextlib.suppress(Exception):
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
@pytest.fixture
def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_base: str) -> dict:
"""
Upload one asset with ".safetensors" extension into models/checkpoints/unit-tests/<name>.
Returns response dict with id, asset_hash, tags, etc.
"""
name = "unit_1_example.safetensors"
p = getattr(request, "param", {}) or {}
tags: Optional[list[str]] = p.get("tags")
if tags is None:
tags = ["models", "checkpoints", "unit-tests", "alpha"]
meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None}
files = {"file": (name, b"A" * 4096, "application/octet-stream")}
form_data = {
"tags": json.dumps(tags),
"name": name,
"user_metadata": json.dumps(meta),
}
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
body = r.json()
assert r.status_code == 201, body
return body
@pytest.fixture(autouse=True)
def autoclean_unit_test_assets(http: requests.Session, api_base: str):
"""Ensure isolation by removing all AssetInfo rows tagged with 'unit-tests' after each test."""
yield
while True:
r = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests", "limit": "500", "sort": "name"},
timeout=30,
)
if r.status_code != 200:
break
body = r.json()
ids = [a["id"] for a in body.get("assets", [])]
if not ids:
break
for aid in ids:
with contextlib.suppress(Exception):
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
"""Force a fast sync/seed pass by calling the seed endpoint."""
session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30)
time.sleep(0.2)
def get_asset_filename(asset_hash: str, extension: str) -> str:
return asset_hash.removeprefix("blake3:") + extension

View File

@ -1,2 +0,0 @@
pytest>=7.8.0
blake3

View File

@ -1,348 +0,0 @@
import os
import uuid
from pathlib import Path
import pytest
import requests
from conftest import get_asset_filename, trigger_sync_seed_assets
@pytest.mark.parametrize("root", ["input", "output"])
def test_seed_asset_removed_when_file_is_deleted(
root: str,
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
):
"""Asset without hash (seed) whose file disappears:
after triggering sync_seed_assets, Asset + AssetInfo disappear.
"""
# Create a file directly under input/unit-tests/<case> so tags include "unit-tests"
case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed"
case_dir.mkdir(parents=True, exist_ok=True)
name = f"seed_{uuid.uuid4().hex[:8]}.bin"
fp = case_dir / name
fp.write_bytes(b"Z" * 2048)
# Trigger a seed sync so DB sees this path (seed asset => hash is NULL)
trigger_sync_seed_assets(http, api_base)
# Verify it is visible via API and carries no hash (seed)
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
timeout=120,
)
body1 = r1.json()
assert r1.status_code == 200
# there should be exactly one with that name
matches = [a for a in body1.get("assets", []) if a.get("name") == name]
assert matches
assert matches[0].get("asset_hash") is None
asset_info_id = matches[0]["id"]
# Remove the underlying file and sync again
if fp.exists():
fp.unlink()
trigger_sync_seed_assets(http, api_base)
# It should disappear (AssetInfo and seed Asset gone)
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
timeout=120,
)
body2 = r2.json()
assert r2.status_code == 200
matches2 = [a for a in body2.get("assets", []) if a.get("name") == name]
assert not matches2, f"Seed asset {asset_info_id} should be gone after sync"
@pytest.mark.skip(reason="Requires computing hashes of files in directories to verify and clear missing tags")
def test_hashed_asset_missing_tag_added_then_removed_after_scan(
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
make_asset_bytes,
):
"""Hashed asset with a single cache_state:
1. delete its file -> sync adds 'missing'
2. restore file -> sync removes 'missing'
"""
name = "missing_tag_test.png"
tags = ["input", "unit-tests", "msync2"]
data = make_asset_bytes(name, 4096)
a = asset_factory(name, tags, {}, data)
# Compute its on-disk path and remove it
dest = comfy_tmp_base_dir / "input" / "unit-tests" / "msync2" / get_asset_filename(a["asset_hash"], ".png")
assert dest.exists(), f"Expected asset file at {dest}"
dest.unlink()
# Fast sync should add 'missing' to the AssetInfo
trigger_sync_seed_assets(http, api_base)
g1 = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120)
d1 = g1.json()
assert g1.status_code == 200, d1
assert "missing" in set(d1.get("tags", [])), "Expected 'missing' tag after deletion"
# Restore the file with the exact same content and sync again
dest.parent.mkdir(parents=True, exist_ok=True)
dest.write_bytes(data)
trigger_sync_seed_assets(http, api_base)
g2 = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120)
d2 = g2.json()
assert g2.status_code == 200, d2
assert "missing" not in set(d2.get("tags", [])), "Missing tag should be cleared after verify"
def test_hashed_asset_two_asset_infos_both_get_missing(
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
):
"""Hashed asset with a single cache_state, but two AssetInfo rows:
deleting the single file then syncing should add 'missing' to both infos.
"""
# Upload one hashed asset
name = "two_infos_one_path.png"
base_tags = ["input", "unit-tests", "multiinfo"]
created = asset_factory(name, base_tags, {}, b"A" * 2048)
# Create second AssetInfo for the same Asset via from-hash
payload = {
"hash": created["asset_hash"],
"name": "two_infos_one_path_copy.png",
"tags": base_tags, # keep it in our unit-tests scope for cleanup
"user_metadata": {"k": "v"},
}
r2 = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120)
b2 = r2.json()
assert r2.status_code == 201, b2
second_id = b2["id"]
# Remove the single underlying file
p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png")
assert p.exists()
p.unlink()
r0 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120)
tags0 = r0.json()
assert r0.status_code == 200, tags0
byname0 = {t["name"]: t for t in tags0.get("tags", [])}
old_missing = int(byname0.get("missing", {}).get("count", 0))
# Sync -> both AssetInfos for this asset must receive 'missing'
trigger_sync_seed_assets(http, api_base)
ga = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
da = ga.json()
assert ga.status_code == 200, da
assert "missing" in set(da.get("tags", []))
gb = http.get(f"{api_base}/api/assets/{second_id}", timeout=120)
db = gb.json()
assert gb.status_code == 200, db
assert "missing" in set(db.get("tags", []))
# Tag usage for 'missing' increased by exactly 2 (two AssetInfos)
r1 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120)
tags1 = r1.json()
assert r1.status_code == 200, tags1
byname1 = {t["name"]: t for t in tags1.get("tags", [])}
new_missing = int(byname1.get("missing", {}).get("count", 0))
assert new_missing == old_missing + 2
@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states")
def test_hashed_asset_two_cache_states_partial_delete_then_full_delete(
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
make_asset_bytes,
run_scan_and_wait,
):
"""Hashed asset with two cache_state rows:
1. delete one file -> sync should NOT add 'missing'
2. delete second file -> sync should add 'missing'
"""
name = "two_cache_states_partial_delete.png"
tags = ["input", "unit-tests", "dual"]
data = make_asset_bytes(name, 3072)
created = asset_factory(name, tags, {}, data)
path1 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual" / get_asset_filename(created["asset_hash"], ".png")
assert path1.exists()
# Create a second on-disk copy under the same root but different subfolder
path2 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual_copy" / name
path2.parent.mkdir(parents=True, exist_ok=True)
path2.write_bytes(data)
# Fast seed so the second path appears (as a seed initially)
trigger_sync_seed_assets(http, api_base)
# Deduplication of AssetInfo-s will not happen as first AssetInfo has owner='default' and second has empty owner.
run_scan_and_wait("input")
# Remove only one file and sync -> asset should still be healthy (no 'missing')
path1.unlink()
trigger_sync_seed_assets(http, api_base)
g1 = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
d1 = g1.json()
assert g1.status_code == 200, d1
assert "missing" not in set(d1.get("tags", [])), "Should not be missing while one valid path remains"
# Baseline 'missing' usage count just before last file removal
r0 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120)
tags0 = r0.json()
assert r0.status_code == 200, tags0
old_missing = int({t["name"]: t for t in tags0.get("tags", [])}.get("missing", {}).get("count", 0))
# Remove the second (last) file and sync -> now we expect 'missing' on this AssetInfo
path2.unlink()
trigger_sync_seed_assets(http, api_base)
g2 = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
d2 = g2.json()
assert g2.status_code == 200, d2
assert "missing" in set(d2.get("tags", [])), "Missing must be set once no valid paths remain"
# Tag usage for 'missing' increased by exactly 2 (two AssetInfo for one Asset)
r1 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120)
tags1 = r1.json()
assert r1.status_code == 200, tags1
new_missing = int({t["name"]: t for t in tags1.get("tags", [])}.get("missing", {}).get("count", 0))
assert new_missing == old_missing + 2
@pytest.mark.parametrize("root", ["input", "output"])
def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match(
root: str,
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
make_asset_bytes,
):
"""
Fast pass alone clears 'missing' when size and mtime match exactly:
1) upload (hashed), record original mtime_ns
2) delete -> fast pass adds 'missing'
3) restore same bytes and set mtime back to the original value
4) run fast pass again -> 'missing' is removed (no slow scan)
"""
scope = f"fastclear-{uuid.uuid4().hex[:6]}"
name = "fastpass_clear.bin"
data = make_asset_bytes(name, 3072)
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
aid = a["id"]
base = comfy_tmp_base_dir / root / "unit-tests" / scope
p = base / get_asset_filename(a["asset_hash"], ".bin")
st0 = p.stat()
orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000))
# Delete -> fast pass adds 'missing'
p.unlink()
trigger_sync_seed_assets(http, api_base)
g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
d1 = g1.json()
assert g1.status_code == 200, d1
assert "missing" in set(d1.get("tags", []))
# Restore same bytes and revert mtime to the original value
p.parent.mkdir(parents=True, exist_ok=True)
p.write_bytes(data)
# set both atime and mtime in ns to ensure exact match
os.utime(p, ns=(orig_mtime_ns, orig_mtime_ns))
# Fast pass should clear 'missing' without a scan
trigger_sync_seed_assets(http, api_base)
g2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
d2 = g2.json()
assert g2.status_code == 200, d2
assert "missing" not in set(d2.get("tags", [])), "Fast pass should clear 'missing' when size+mtime match"
@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states")
@pytest.mark.parametrize("root", ["input", "output"])
def test_fastpass_removes_stale_state_row_no_missing(
root: str,
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
make_asset_bytes,
run_scan_and_wait,
):
"""
Hashed asset with two states:
- delete one file
- run fast pass only
Expect:
- asset stays healthy (no 'missing')
- stale AssetCacheState row for the deleted path is removed.
We verify this behaviorally by recreating the deleted path and running fast pass again:
a new *seed* AssetInfo is created, which proves the old state row was not reused.
"""
scope = f"stale-{uuid.uuid4().hex[:6]}"
name = "two_states.bin"
data = make_asset_bytes(name, 2048)
# Upload hashed asset at path1
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
base = comfy_tmp_base_dir / root / "unit-tests" / scope
a1_filename = get_asset_filename(a["asset_hash"], ".bin")
p1 = base / a1_filename
assert p1.exists()
aid = a["id"]
h = a["asset_hash"]
# Create second state path2, seed+scan to dedupe into the same Asset
p2 = base / "copy" / name
p2.parent.mkdir(parents=True, exist_ok=True)
p2.write_bytes(data)
trigger_sync_seed_assets(http, api_base)
run_scan_and_wait(root)
# Delete path1 and run fast pass -> no 'missing' and stale state row should be removed
p1.unlink()
trigger_sync_seed_assets(http, api_base)
g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
d1 = g1.json()
assert g1.status_code == 200, d1
assert "missing" not in set(d1.get("tags", []))
# Recreate path1 and run fast pass again.
# If the stale state row was removed, a NEW seed AssetInfo will appear for this path.
p1.write_bytes(data)
trigger_sync_seed_assets(http, api_base)
rl = http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}"},
timeout=120,
)
bl = rl.json()
assert rl.status_code == 200, bl
items = bl.get("assets", [])
# one hashed AssetInfo (asset_hash == h) + one seed AssetInfo (asset_hash == null)
hashes = [it.get("asset_hash") for it in items if it.get("name") in (name, a1_filename)]
assert h in hashes
assert any(x is None for x in hashes), "Expected a new seed AssetInfo for the recreated path"
# Asset identity still healthy
rh = http.head(f"{api_base}/api/assets/hash/{h}", timeout=120)
assert rh.status_code == 200

View File

@ -1,306 +0,0 @@
import uuid
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import pytest
import requests
from conftest import get_asset_filename, trigger_sync_seed_assets
def test_create_from_hash_success(
http: requests.Session, api_base: str, seeded_asset: dict
):
h = seeded_asset["asset_hash"]
payload = {
"hash": h,
"name": "from_hash_ok.safetensors",
"tags": ["models", "checkpoints", "unit-tests", "from-hash"],
"user_metadata": {"k": "v"},
}
r1 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
b1 = r1.json()
assert r1.status_code == 201, b1
assert b1["asset_hash"] == h
assert b1["created_new"] is False
aid = b1["id"]
# Calling again with the same name should return the same AssetInfo id
r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
b2 = r2.json()
assert r2.status_code == 201, b2
assert b2["id"] == aid
def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]
# GET detail
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
detail = rg.json()
assert rg.status_code == 200, detail
assert detail["id"] == aid
assert "user_metadata" in detail
assert "filename" in detail["user_metadata"]
# DELETE
rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120)
assert rd.status_code == 204
# GET again -> 404
rg2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
body = rg2.json()
assert rg2.status_code == 404
assert body["error"]["code"] == "ASSET_NOT_FOUND"
def test_delete_upon_reference_count(
http: requests.Session, api_base: str, seeded_asset: dict
):
# Create a second reference to the same asset via from-hash
src_hash = seeded_asset["asset_hash"]
payload = {
"hash": src_hash,
"name": "unit_ref_copy.safetensors",
"tags": ["models", "checkpoints", "unit-tests", "del-flow"],
"user_metadata": {"note": "copy"},
}
r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
copy = r2.json()
assert r2.status_code == 201, copy
assert copy["asset_hash"] == src_hash
assert copy["created_new"] is False
# Delete original reference -> asset identity must remain
aid1 = seeded_asset["id"]
rd1 = http.delete(f"{api_base}/api/assets/{aid1}", timeout=120)
assert rd1.status_code == 204
rh1 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120)
assert rh1.status_code == 200 # identity still present
# Delete the last reference with default semantics -> identity and cached files removed
aid2 = copy["id"]
rd2 = http.delete(f"{api_base}/api/assets/{aid2}", timeout=120)
assert rd2.status_code == 204
rh2 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120)
assert rh2.status_code == 404 # orphan content removed
def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]
original_tags = seeded_asset["tags"]
payload = {
"name": "unit_1_renamed.safetensors",
"user_metadata": {"purpose": "updated", "epoch": 2},
}
ru = http.put(f"{api_base}/api/assets/{aid}", json=payload, timeout=120)
body = ru.json()
assert ru.status_code == 200, body
assert body["name"] == payload["name"]
assert body["tags"] == original_tags # tags unchanged
assert body["user_metadata"]["purpose"] == "updated"
# filename should still be present and normalized by server
assert "filename" in body["user_metadata"]
def test_head_asset_by_hash(http: requests.Session, api_base: str, seeded_asset: dict):
h = seeded_asset["asset_hash"]
# Existing
rh1 = http.head(f"{api_base}/api/assets/hash/{h}", timeout=120)
assert rh1.status_code == 200
# Non-existent
rh2 = http.head(f"{api_base}/api/assets/hash/blake3:{'0'*64}", timeout=120)
assert rh2.status_code == 404
def test_head_asset_bad_hash_returns_400_and_no_body(http: requests.Session, api_base: str):
# Invalid format; handler returns a JSON error, but HEAD responses must not carry a payload.
# requests exposes an empty body for HEAD, so validate status and that there is no payload.
rh = http.head(f"{api_base}/api/assets/hash/not_a_hash", timeout=120)
assert rh.status_code == 400
body = rh.content
assert body == b""
def test_delete_nonexistent_returns_404(http: requests.Session, api_base: str):
bogus = str(uuid.uuid4())
r = http.delete(f"{api_base}/api/assets/{bogus}", timeout=120)
body = r.json()
assert r.status_code == 404
assert body["error"]["code"] == "ASSET_NOT_FOUND"
def test_create_from_hash_invalids(http: requests.Session, api_base: str):
# Bad hash algorithm
bad = {
"hash": "sha256:" + "0" * 64,
"name": "x.bin",
"tags": ["models", "checkpoints", "unit-tests"],
}
r1 = http.post(f"{api_base}/api/assets/from-hash", json=bad, timeout=120)
b1 = r1.json()
assert r1.status_code == 400
assert b1["error"]["code"] == "INVALID_BODY"
# Invalid JSON body
r2 = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120)
b2 = r2.json()
assert r2.status_code == 400
assert b2["error"]["code"] == "INVALID_JSON"
def test_get_update_download_bad_ids(http: requests.Session, api_base: str):
# All endpoints should be not found, as we UUID regex directly in the route definition.
bad_id = "not-a-uuid"
r1 = http.get(f"{api_base}/api/assets/{bad_id}", timeout=120)
assert r1.status_code == 404
r3 = http.get(f"{api_base}/api/assets/{bad_id}/content", timeout=120)
assert r3.status_code == 404
def test_update_requires_at_least_one_field(http: requests.Session, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]
r = http.put(f"{api_base}/api/assets/{aid}", json={}, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == "INVALID_BODY"
@pytest.mark.parametrize("root", ["input", "output"])
def test_concurrent_delete_same_asset_info_single_204(
root: str,
http: requests.Session,
api_base: str,
asset_factory,
make_asset_bytes,
):
"""
Many concurrent DELETE for the same AssetInfo should result in:
- exactly one 204 No Content (the one that actually deleted)
- all others 404 Not Found (row already gone)
"""
scope = f"conc-del-{uuid.uuid4().hex[:6]}"
name = "to_delete.bin"
data = make_asset_bytes(name, 1536)
created = asset_factory(name, [root, "unit-tests", scope], {}, data)
aid = created["id"]
# Hit the same endpoint N times in parallel.
n_tests = 4
url = f"{api_base}/api/assets/{aid}?delete_content=false"
def _do_delete(delete_url):
with requests.Session() as s:
return s.delete(delete_url, timeout=120).status_code
with ThreadPoolExecutor(max_workers=n_tests) as ex:
statuses = list(ex.map(_do_delete, [url] * n_tests))
# Exactly one actual delete, the rest must be 404
assert statuses.count(204) == 1, f"Expected exactly one 204; got: {statuses}"
assert statuses.count(404) == n_tests - 1, f"Expected {n_tests-1} 404; got: {statuses}"
# The resource must be gone.
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
assert rg.status_code == 404
@pytest.mark.parametrize("root", ["input", "output"])
def test_metadata_filename_is_set_for_seed_asset_without_hash(
root: str,
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
):
"""Seed ingest (no hash yet) must compute user_metadata['filename'] immediately."""
scope = f"seedmeta-{uuid.uuid4().hex[:6]}"
name = "seed_filename.bin"
base = comfy_tmp_base_dir / root / "unit-tests" / scope / "a" / "b"
base.mkdir(parents=True, exist_ok=True)
fp = base / name
fp.write_bytes(b"Z" * 2048)
trigger_sync_seed_assets(http, api_base)
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
timeout=120,
)
body = r1.json()
assert r1.status_code == 200, body
matches = [a for a in body.get("assets", []) if a.get("name") == name]
assert matches, "Seed asset should be visible after sync"
assert matches[0].get("asset_hash") is None # still a seed
aid = matches[0]["id"]
r2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
detail = r2.json()
assert r2.status_code == 200, detail
filename = (detail.get("user_metadata") or {}).get("filename")
expected = str(fp.relative_to(comfy_tmp_base_dir / root)).replace("\\", "/")
assert filename == expected, f"expected filename={expected}, got {filename!r}"
@pytest.mark.skip(reason="Requires computing hashes of files in directories to retarget cache states")
@pytest.mark.parametrize("root", ["input", "output"])
def test_metadata_filename_computed_and_updated_on_retarget(
root: str,
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
make_asset_bytes,
run_scan_and_wait,
):
"""
1) Ingest under {root}/unit-tests/<scope>/a/b/<name> -> filename reflects relative path.
2) Retarget by copying to {root}/unit-tests/<scope>/x/<new_name>, remove old file,
run fast pass + scan -> filename updates to new relative path.
"""
scope = f"meta-fn-{uuid.uuid4().hex[:6]}"
name1 = "compute_metadata_filename.png"
name2 = "compute_changed_metadata_filename.png"
data = make_asset_bytes(name1, 2100)
# Upload into nested path a/b
a = asset_factory(name1, [root, "unit-tests", scope, "a", "b"], {}, data)
aid = a["id"]
root_base = comfy_tmp_base_dir / root
p1 = (root_base / "unit-tests" / scope / "a" / "b" / get_asset_filename(a["asset_hash"], ".png"))
assert p1.exists()
# filename at ingest should be the path relative to root
rel1 = str(p1.relative_to(root_base)).replace("\\", "/")
g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
d1 = g1.json()
assert g1.status_code == 200, d1
fn1 = d1["user_metadata"].get("filename")
assert fn1 == rel1
# Retarget: copy to x/<name2>, remove old, then sync+scan
p2 = root_base / "unit-tests" / scope / "x" / name2
p2.parent.mkdir(parents=True, exist_ok=True)
p2.write_bytes(data)
if p1.exists():
p1.unlink()
trigger_sync_seed_assets(http, api_base) # seed the new path
run_scan_and_wait(root) # verify/hash and reconcile
# filename should now point at x/<name2>
rel2 = str(p2.relative_to(root_base)).replace("\\", "/")
g2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
d2 = g2.json()
assert g2.status_code == 200, d2
fn2 = d2["user_metadata"].get("filename")
assert fn2 == rel2

View File

@ -1,166 +0,0 @@
import time
import uuid
from datetime import datetime
from pathlib import Path
from typing import Optional
import pytest
import requests
from conftest import get_asset_filename, trigger_sync_seed_assets
def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]
# default attachment
r1 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
data = r1.content
assert r1.status_code == 200
cd = r1.headers.get("Content-Disposition", "")
assert "attachment" in cd
assert data and len(data) == 4096
# inline requested
r2 = http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline", timeout=120)
r2.content
assert r2.status_code == 200
cd2 = r2.headers.get("Content-Disposition", "")
assert "inline" in cd2
@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states")
@pytest.mark.parametrize("root", ["input", "output"])
def test_download_chooses_existing_state_and_updates_access_time(
root: str,
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
make_asset_bytes,
run_scan_and_wait,
):
"""
Hashed asset with two state paths: if the first one disappears,
GET /content still serves from the remaining path and bumps last_access_time.
"""
scope = f"dl-first-{uuid.uuid4().hex[:6]}"
name = "first_existing_state.bin"
data = make_asset_bytes(name, 3072)
# Upload -> path1
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
aid = a["id"]
base = comfy_tmp_base_dir / root / "unit-tests" / scope
path1 = base / get_asset_filename(a["asset_hash"], ".bin")
assert path1.exists()
# Seed path2 by copying, then scan to dedupe into a second state
path2 = base / "alt" / name
path2.parent.mkdir(parents=True, exist_ok=True)
path2.write_bytes(data)
trigger_sync_seed_assets(http, api_base)
run_scan_and_wait(root)
# Remove path1 so server must fall back to path2
path1.unlink()
# last_access_time before
rg0 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
d0 = rg0.json()
assert rg0.status_code == 200, d0
ts0 = d0.get("last_access_time")
time.sleep(0.05)
r = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
blob = r.content
assert r.status_code == 200
assert blob == data # must serve from the surviving state (same bytes)
rg1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
d1 = rg1.json()
assert rg1.status_code == 200, d1
ts1 = d1.get("last_access_time")
def _parse_iso8601(s: Optional[str]) -> Optional[float]:
if not s:
return None
s = s[:-1] if s.endswith("Z") else s
return datetime.fromisoformat(s).timestamp()
t0 = _parse_iso8601(ts0)
t1 = _parse_iso8601(ts1)
assert t1 is not None
if t0 is not None:
assert t1 > t0
@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True)
def test_download_missing_file_returns_404(
http: requests.Session, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict
):
# Remove the underlying file then attempt download.
# We initialize fixture without additional tags to know exactly the asset file path.
try:
aid = seeded_asset["id"]
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
detail = rg.json()
assert rg.status_code == 200
asset_filename = get_asset_filename(detail["asset_hash"], ".safetensors")
abs_path = comfy_tmp_base_dir / "models" / "checkpoints" / asset_filename
assert abs_path.exists()
abs_path.unlink()
r2 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
assert r2.status_code == 404
body = r2.json()
assert body["error"]["code"] == "FILE_NOT_FOUND"
finally:
# We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually.
dr = http.delete(f"{api_base}/api/assets/{aid}", timeout=120)
dr.content
@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states")
@pytest.mark.parametrize("root", ["input", "output"])
def test_download_404_if_all_states_missing(
root: str,
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
make_asset_bytes,
run_scan_and_wait,
):
"""Multi-state asset: after the last remaining on-disk file is removed, download must return 404."""
scope = f"dl-404-{uuid.uuid4().hex[:6]}"
name = "missing_all_states.bin"
data = make_asset_bytes(name, 2048)
# Upload -> path1
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
aid = a["id"]
base = comfy_tmp_base_dir / root / "unit-tests" / scope
p1 = base / get_asset_filename(a["asset_hash"], ".bin")
assert p1.exists()
# Seed a second state and dedupe
p2 = base / "copy" / name
p2.parent.mkdir(parents=True, exist_ok=True)
p2.write_bytes(data)
trigger_sync_seed_assets(http, api_base)
run_scan_and_wait(root)
# Remove first file -> download should still work via the second state
p1.unlink()
ok1 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
b1 = ok1.content
assert ok1.status_code == 200 and b1 == data
# Remove the last file -> download must 404
p2.unlink()
r2 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
body = r2.json()
assert r2.status_code == 404
assert body["error"]["code"] == "FILE_NOT_FOUND"

View File

@ -1,342 +0,0 @@
import time
import uuid
import requests
def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
names = ["a1_u.safetensors", "a2_u.safetensors", "a3_u.safetensors"]
for n in names:
asset_factory(
n,
["models", "checkpoints", "unit-tests", "paging"],
{"epoch": 1},
make_asset_bytes(n, size=2048),
)
# name ascending for stable order
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "0"},
timeout=120,
)
b1 = r1.json()
assert r1.status_code == 200
got1 = [a["name"] for a in b1["assets"]]
assert got1 == sorted(names)[:2]
assert b1["has_more"] is True
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "2"},
timeout=120,
)
b2 = r2.json()
assert r2.status_code == 200
got2 = [a["name"] for a in b2["assets"]]
assert got2 == sorted(names)[2:]
assert b2["has_more"] is False
def test_list_assets_include_exclude_and_name_contains(http: requests.Session, api_base: str, asset_factory):
a = asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024)
b = asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024)
r = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,alpha", "exclude_tags": "beta", "limit": "50"},
timeout=120,
)
body = r.json()
assert r.status_code == 200
names = [x["name"] for x in body["assets"]]
assert a["name"] in names
assert b["name"] not in names
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests", "name_contains": "inc_"},
timeout=120,
)
body2 = r2.json()
assert r2.status_code == 200
names2 = [x["name"] for x in body2["assets"]]
assert a["name"] in names2
assert b["name"] in names2
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "non-existing-tag"},
timeout=120,
)
body3 = r2.json()
assert r2.status_code == 200
assert not body3["assets"]
def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-size"]
n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors"
asset_factory(n1, t, {}, make_asset_bytes(n1, 1024))
asset_factory(n2, t, {}, make_asset_bytes(n2, 2048))
asset_factory(n3, t, {}, make_asset_bytes(n3, 3072))
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "asc"},
timeout=120,
)
b1 = r1.json()
names = [a["name"] for a in b1["assets"]]
assert names[:3] == [n1, n2, n3]
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "desc"},
timeout=120,
)
b2 = r2.json()
names2 = [a["name"] for a in b2["assets"]]
assert names2[:3] == [n3, n2, n1]
def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-upd"]
a1 = asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200))
a2 = asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200))
# Rename the second asset to bump updated_at
rp = http.put(f"{api_base}/api/assets/{a2['id']}", json={"name": "upd_b_renamed.safetensors"}, timeout=120)
upd = rp.json()
assert rp.status_code == 200, upd
r = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-upd", "sort": "updated_at", "order": "desc"},
timeout=120,
)
body = r.json()
assert r.status_code == 200
names = [x["name"] for x in body["assets"]]
assert names[0] == "upd_b_renamed.safetensors"
assert a1["name"] in names
def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-access"]
asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100))
time.sleep(0.02)
a2 = asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100))
# Touch last_access_time of b by downloading its content
time.sleep(0.02)
dl = http.get(f"{api_base}/api/assets/{a2['id']}/content", timeout=120)
assert dl.status_code == 200
dl.content
r = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-access", "sort": "last_access_time", "order": "desc"},
timeout=120,
)
body = r.json()
assert r.status_code == 200
names = [x["name"] for x in body["assets"]]
assert names[0] == a2["name"]
def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-include"]
a = asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva"))
asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb"))
# CSV + case-insensitive
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"},
timeout=120,
)
b1 = r1.json()
assert r1.status_code == 200
names1 = [x["name"] for x in b1["assets"]]
assert a["name"] in names1
assert not any("beta" in x for x in names1)
# Repeated query params for include_tags
params_multi = [
("include_tags", "unit-tests"),
("include_tags", "lf-include"),
("include_tags", "alpha"),
]
r2 = http.get(api_base + "/api/assets", params=params_multi, timeout=120)
b2 = r2.json()
assert r2.status_code == 200
names2 = [x["name"] for x in b2["assets"]]
assert a["name"] in names2
assert not any("beta" in x for x in names2)
# Duplicates and spaces in CSV
r3 = http.get(
api_base + "/api/assets",
params={"include_tags": " unit-tests , lf-include , alpha , alpha "},
timeout=120,
)
b3 = r3.json()
assert r3.status_code == 200
names3 = [x["name"] for x in b3["assets"]]
assert a["name"] in names3
def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-exclude"]
a = asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900))
asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900))
# Exclude uppercase should work
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"},
timeout=120,
)
b1 = r1.json()
assert r1.status_code == 200
names1 = [x["name"] for x in b1["assets"]]
assert a["name"] in names1
# Repeated excludes with duplicates
params_multi = [
("include_tags", "unit-tests"),
("include_tags", "lf-exclude"),
("exclude_tags", "beta"),
("exclude_tags", "beta"),
]
r2 = http.get(api_base + "/api/assets", params=params_multi, timeout=120)
b2 = r2.json()
assert r2.status_code == 200
names2 = [x["name"] for x in b2["assets"]]
assert all("beta" not in x for x in names2)
def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-name"]
a1 = asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800))
a2 = asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800))
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-name", "name_contains": "casemix"},
timeout=120,
)
b1 = r1.json()
assert r1.status_code == 200
names1 = [x["name"] for x in b1["assets"]]
assert a1["name"] in names1
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-name", "name_contains": ".SAFE"},
timeout=120,
)
b2 = r2.json()
assert r2.status_code == 200
names2 = [x["name"] for x in b2["assets"]]
assert a1["name"] in names2
r3 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-name", "name_contains": "case-"},
timeout=120,
)
b3 = r3.json()
assert r3.status_code == 200
names3 = [x["name"] for x in b3["assets"]]
assert a2["name"] in names3
def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"]
asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600))
asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600))
asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600))
# Offset far beyond total
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-pagelimits", "limit": "2", "offset": "10"},
timeout=120,
)
b1 = r1.json()
assert r1.status_code == 200
assert not b1["assets"]
assert b1["has_more"] is False
# Boundary large limit (<=500 is valid)
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-pagelimits", "limit": "500"},
timeout=120,
)
b2 = r2.json()
assert r2.status_code == 200
assert len(b2["assets"]) == 3
assert b2["has_more"] is False
def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base):
r1 = http.get(api_base + "/api/assets", params={"offset": "-1"}, timeout=120)
b1 = r1.json()
assert r1.status_code == 400
assert b1["error"]["code"] == "INVALID_QUERY"
r2 = http.get(api_base + "/api/assets", params={"limit": "abc"}, timeout=120)
b2 = r2.json()
assert r2.status_code == 400
assert b2["error"]["code"] == "INVALID_QUERY"
def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str):
# limit too small
r1 = http.get(api_base + "/api/assets", params={"limit": "0"}, timeout=120)
b1 = r1.json()
assert r1.status_code == 400
assert b1["error"]["code"] == "INVALID_QUERY"
# bad metadata JSON
r2 = http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}, timeout=120)
b2 = r2.json()
assert r2.status_code == 400
assert b2["error"]["code"] == "INVALID_QUERY"
def test_list_assets_name_contains_literal_underscore(
http,
api_base,
asset_factory,
make_asset_bytes,
):
"""'name_contains' must treat '_' literally, not as a SQL wildcard.
We create:
- foo_bar.safetensors (should match)
- fooxbar.safetensors (must NOT match if '_' is escaped)
- foobar.safetensors (must NOT match)
"""
scope = f"lf-underscore-{uuid.uuid4().hex[:6]}"
tags = ["models", "checkpoints", "unit-tests", scope]
a = asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700))
b = asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700))
c = asset_factory("foobar.safetensors", tags, {}, make_asset_bytes("c", 700))
r = http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "name_contains": "foo_bar"},
timeout=120,
)
body = r.json()
assert r.status_code == 200, body
names = [x["name"] for x in body["assets"]]
assert a["name"] in names, f"Expected literal underscore match to include {a['name']}"
assert b["name"] not in names, "Underscore must be escaped — should not match 'fooxbar'"
assert c["name"] not in names, "Underscore must be escaped — should not match 'foobar'"
assert body["total"] == 1

View File

@ -1,395 +0,0 @@
import json
def test_meta_and_across_keys_and_types(
http, api_base: str, asset_factory, make_asset_bytes
):
name = "mf_and_mix.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-and"]
meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
asset_factory(name, tags, meta, make_asset_bytes(name, 4096))
# All keys must match (AND semantics)
f_ok = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
r1 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-and",
"metadata_filter": json.dumps(f_ok),
},
timeout=120,
)
b1 = r1.json()
assert r1.status_code == 200
names = [a["name"] for a in b1["assets"]]
assert name in names
# One key mismatched -> no result
f_bad = {"purpose": "mix", "epoch": 2, "active": True}
r2 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-and",
"metadata_filter": json.dumps(f_bad),
},
timeout=120,
)
b2 = r2.json()
assert r2.status_code == 200
assert not b2["assets"]
def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes):
name = "mf_types.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-types"]
meta = {"epoch": 1, "active": True}
asset_factory(name, tags, meta, make_asset_bytes(name))
# int filter matches numeric
r1 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-types",
"metadata_filter": json.dumps({"epoch": 1}),
},
timeout=120,
)
b1 = r1.json()
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
# string "1" must NOT match numeric 1
r2 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-types",
"metadata_filter": json.dumps({"epoch": "1"}),
},
timeout=120,
)
b2 = r2.json()
assert r2.status_code == 200 and not b2["assets"]
# bool True matches, string "true" must NOT match
r3 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-types",
"metadata_filter": json.dumps({"active": True}),
},
timeout=120,
)
b3 = r3.json()
assert r3.status_code == 200 and any(a["name"] == name for a in b3["assets"])
r4 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-types",
"metadata_filter": json.dumps({"active": "true"}),
},
timeout=120,
)
b4 = r4.json()
assert r4.status_code == 200 and not b4["assets"]
def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes):
name = "mf_list_scalars.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-list"]
meta = {"flags": ["red", "green"]}
asset_factory(name, tags, meta, make_asset_bytes(name, 3000))
# Any-of should match because "green" is present
filt_ok = {"flags": ["blue", "green"]}
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_ok)},
timeout=120,
)
b1 = r1.json()
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
# None of provided flags present -> no match
filt_miss = {"flags": ["blue", "yellow"]}
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_miss)},
timeout=120,
)
b2 = r2.json()
assert r2.status_code == 200 and not b2["assets"]
# Duplicates in list should not break matching
filt_dup = {"flags": ["green", "green", "green"]}
r3 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_dup)},
timeout=120,
)
b3 = r3.json()
assert r3.status_code == 200 and any(a["name"] == name for a in b3["assets"])
def test_meta_none_semantics_missing_or_null_and_any_of_with_none(
http, api_base, asset_factory, make_asset_bytes
):
# a1: key missing; a2: explicit null; a3: concrete value
t = ["models", "checkpoints", "unit-tests", "mf-none"]
a1 = asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1"))
a2 = asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2"))
a3 = asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3"))
# Filter {maybe: None} must match a1 and a2, not a3
filt = {"maybe": None}
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt), "sort": "name"},
timeout=120,
)
b1 = r1.json()
assert r1.status_code == 200
got = [a["name"] for a in b1["assets"]]
assert a1["name"] in got and a2["name"] in got and a3["name"] not in got
# Any-of with None should include missing/null plus value matches
filt_any = {"maybe": [None, "x"]}
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt_any), "sort": "name"},
timeout=120,
)
b2 = r2.json()
assert r2.status_code == 200
got2 = [a["name"] for a in b2["assets"]]
assert a1["name"] in got2 and a2["name"] in got2 and a3["name"] in got2
def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes):
name = "mf_nested_json.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-nested"]
cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}}
asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200))
# Exact JSON object equality (same structure)
r1 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-nested",
"metadata_filter": json.dumps({"config": cfg}),
},
timeout=120,
)
b1 = r1.json()
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
# Different JSON object should not match
r2 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-nested",
"metadata_filter": json.dumps({"config": {"optimizer": "sgd"}}),
},
timeout=120,
)
b2 = r2.json()
assert r2.status_code == 200 and not b2["assets"]
def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes):
name = "mf_list_objects.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-objlist"]
transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}]
asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048))
# Any-of for list of objects should match when one element equals the filter object
r1 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-objlist",
"metadata_filter": json.dumps({"transforms": {"type": "flip", "p": 0.5}}),
},
timeout=120,
)
b1 = r1.json()
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
# Non-matching object -> no match
r2 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,mf-objlist",
"metadata_filter": json.dumps({"transforms": {"type": "rotate", "deg": 90}}),
},
timeout=120,
)
b2 = r2.json()
assert r2.status_code == 200 and not b2["assets"]
def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes):
name = "mf_keys_unicode.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-keys"]
meta = {
"weird.key": "v1",
"path/like": 7,
"with:colon": True,
"ключ": "значение",
"emoji": "🐍",
}
asset_factory(name, tags, meta, make_asset_bytes(name, 1500))
# Match all the special keys
filt = {"weird.key": "v1", "path/like": 7, "with:colon": True, "emoji": "🐍"}
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps(filt)},
timeout=120,
)
b1 = r1.json()
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
# Unicode key match
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps({"ключ": "значение"})},
timeout=120,
)
b2 = r2.json()
assert r2.status_code == 200 and any(a["name"] == name for a in b2["assets"])
def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"]
a0 = asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025))
a1 = asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026))
# count == 0 must match only a0
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"count": 0})},
timeout=120,
)
b1 = r1.json()
assert r1.status_code == 200
names1 = [a["name"] for a in b1["assets"]]
assert a0["name"] in names1 and a1["name"] not in names1
# Any-of list of booleans: True matches second asset
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"choices": True})},
timeout=120,
)
b2 = r2.json()
assert r2.status_code == 200 and any(a["name"] == a1["name"] for a in b2["assets"])
def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes):
name = "mf_mixed_list.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-mixed"]
meta = {"mix": ["1", 1, True, None]}
asset_factory(name, tags, meta, make_asset_bytes(name, 1999))
# Should match because 1 is present
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": [2, 1]})},
timeout=120,
)
b1 = r1.json()
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
# Should NOT match for False
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": False})},
timeout=120,
)
b2 = r2.json()
assert r2.status_code == 200 and not b2["assets"]
def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes):
# Use a unique scope tag to avoid interference
t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"]
x = asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua"))
y = asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub"))
# Filtering by unknown key with None should return both (missing key OR null)
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": None})},
timeout=120,
)
b1 = r1.json()
assert r1.status_code == 200
names = {a["name"] for a in b1["assets"]}
assert x["name"] in names and y["name"] in names
# Filtering by unknown key with concrete value should return none
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": "x"})},
timeout=120,
)
b2 = r2.json()
assert r2.status_code == 200 and not b2["assets"]
def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_factory, make_asset_bytes):
# alpha matches epoch=1; beta has epoch=2
a = asset_factory(
"mf_tag_alpha.safetensors",
["models", "checkpoints", "unit-tests", "mf-tag", "alpha"],
{"epoch": 1},
make_asset_bytes("alpha"),
)
b = asset_factory(
"mf_tag_beta.safetensors",
["models", "checkpoints", "unit-tests", "mf-tag", "beta"],
{"epoch": 2},
make_asset_bytes("beta"),
)
params = {
"include_tags": "unit-tests,mf-tag,alpha",
"exclude_tags": "beta",
"name_contains": "mf_tag_",
"metadata_filter": json.dumps({"epoch": 1}),
}
r = http.get(api_base + "/api/assets", params=params, timeout=120)
body = r.json()
assert r.status_code == 200
names = [x["name"] for x in body["assets"]]
assert a["name"] in names
assert b["name"] not in names
def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes):
# Three assets in same scope with different sizes and a common filter key
t = ["models", "checkpoints", "unit-tests", "mf-sort"]
n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors"
asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024))
asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048))
asset_factory(n3, t, {"group": "g"}, make_asset_bytes(n3, 3072))
# Sort by size ascending with paging
q = {
"include_tags": "unit-tests,mf-sort",
"metadata_filter": json.dumps({"group": "g"}),
"sort": "size", "order": "asc", "limit": "2",
}
r1 = http.get(api_base + "/api/assets", params=q, timeout=120)
b1 = r1.json()
assert r1.status_code == 200
got1 = [a["name"] for a in b1["assets"]]
assert got1 == [n1, n2]
assert b1["has_more"] is True
q2 = {**q, "offset": "2"}
r2 = http.get(api_base + "/api/assets", params=q2, timeout=120)
b2 = r2.json()
assert r2.status_code == 200
got2 = [a["name"] for a in b2["assets"]]
assert got2 == [n3]
assert b2["has_more"] is False

View File

@ -1,141 +0,0 @@
import uuid
from pathlib import Path
import pytest
import requests
from conftest import get_asset_filename, trigger_sync_seed_assets
@pytest.fixture
def create_seed_file(comfy_tmp_base_dir: Path):
"""Create a file on disk that will become a seed asset after sync."""
created: list[Path] = []
def _create(root: str, scope: str, name: str | None = None, data: bytes = b"TEST") -> Path:
name = name or f"seed_{uuid.uuid4().hex[:8]}.bin"
path = comfy_tmp_base_dir / root / "unit-tests" / scope / name
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(data)
created.append(path)
return path
yield _create
for p in created:
p.unlink(missing_ok=True)
@pytest.fixture
def find_asset(http: requests.Session, api_base: str):
"""Query API for assets matching scope and optional name."""
def _find(scope: str, name: str | None = None) -> list[dict]:
params = {"include_tags": f"unit-tests,{scope}"}
if name:
params["name_contains"] = name
r = http.get(f"{api_base}/api/assets", params=params, timeout=120)
assert r.status_code == 200
assets = r.json().get("assets", [])
if name:
return [a for a in assets if a.get("name") == name]
return assets
return _find
@pytest.mark.parametrize("root", ["input", "output"])
def test_orphaned_seed_asset_is_pruned(
root: str,
create_seed_file,
find_asset,
http: requests.Session,
api_base: str,
):
"""Seed asset with deleted file is removed; with file present, it survives."""
scope = f"prune-{uuid.uuid4().hex[:6]}"
fp = create_seed_file(root, scope)
name = fp.name
trigger_sync_seed_assets(http, api_base)
assert find_asset(scope, name), "Seed asset should exist"
fp.unlink()
trigger_sync_seed_assets(http, api_base)
assert not find_asset(scope, name), "Orphaned seed should be pruned"
def test_seed_asset_with_file_survives_prune(
create_seed_file,
find_asset,
http: requests.Session,
api_base: str,
):
"""Seed asset with file still on disk is NOT pruned."""
scope = f"keep-{uuid.uuid4().hex[:6]}"
fp = create_seed_file("input", scope)
trigger_sync_seed_assets(http, api_base)
trigger_sync_seed_assets(http, api_base)
assert find_asset(scope, fp.name), "Seed with valid file should survive"
def test_hashed_asset_not_pruned_when_file_missing(
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
asset_factory,
make_asset_bytes,
):
"""Hashed assets are never deleted by prune, even without file."""
scope = f"hashed-{uuid.uuid4().hex[:6]}"
data = make_asset_bytes("test", 2048)
a = asset_factory("test.bin", ["input", "unit-tests", scope], {}, data)
path = comfy_tmp_base_dir / "input" / "unit-tests" / scope / get_asset_filename(a["asset_hash"], ".bin")
path.unlink()
trigger_sync_seed_assets(http, api_base)
r = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120)
assert r.status_code == 200, "Hashed asset should NOT be pruned"
def test_prune_across_multiple_roots(
create_seed_file,
find_asset,
http: requests.Session,
api_base: str,
):
"""Prune correctly handles assets across input and output roots."""
scope = f"multi-{uuid.uuid4().hex[:6]}"
input_fp = create_seed_file("input", scope, "input.bin")
create_seed_file("output", scope, "output.bin")
trigger_sync_seed_assets(http, api_base)
assert len(find_asset(scope)) == 2
input_fp.unlink()
trigger_sync_seed_assets(http, api_base)
remaining = find_asset(scope)
assert len(remaining) == 1
assert remaining[0]["name"] == "output.bin"
@pytest.mark.parametrize("dirname", ["100%_done", "my_folder_name", "has spaces"])
def test_special_chars_in_path_escaped_correctly(
dirname: str,
create_seed_file,
find_asset,
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
):
"""SQL LIKE wildcards (%, _) and spaces in paths don't cause false matches."""
scope = f"special-{uuid.uuid4().hex[:6]}/{dirname}"
fp = create_seed_file("input", scope)
trigger_sync_seed_assets(http, api_base)
trigger_sync_seed_assets(http, api_base)
assert find_asset(scope.split("/")[0], fp.name), "Asset with special chars should survive"

View File

@ -1,225 +0,0 @@
import json
import uuid
import requests
def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict):
# Include zero-usage tags by default
r1 = http.get(api_base + "/api/tags", params={"limit": "50"}, timeout=120)
body1 = r1.json()
assert r1.status_code == 200
names = [t["name"] for t in body1["tags"]]
# A few system tags from migration should exist:
assert "models" in names
assert "checkpoints" in names
# Only used tags before we add anything new from this test cycle
r2 = http.get(api_base + "/api/tags", params={"include_zero": "false"}, timeout=120)
body2 = r2.json()
assert r2.status_code == 200
# We already seeded one asset via fixture, so used tags must be non-empty
used_names = [t["name"] for t in body2["tags"]]
assert "models" in used_names
assert "checkpoints" in used_names
# Prefix filter should refine the list
r3 = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}, timeout=120)
b3 = r3.json()
assert r3.status_code == 200
names3 = [t["name"] for t in b3["tags"]]
assert "unit-tests" in names3
assert "models" not in names3 # filtered out by prefix
# Order by name ascending should be stable
r4 = http.get(api_base + "/api/tags", params={"include_zero": "false", "order": "name_asc"}, timeout=120)
b4 = r4.json()
assert r4.status_code == 200
names4 = [t["name"] for t in b4["tags"]]
assert names4 == sorted(names4)
def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
# Baseline: system tags exist when include_zero (default) is true
r1 = http.get(api_base + "/api/tags", params={"limit": "500"}, timeout=120)
body1 = r1.json()
assert r1.status_code == 200
names = [t["name"] for t in body1["tags"]]
assert "models" in names and "checkpoints" in names
# Create a short-lived asset under input with a unique custom tag
scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}"
custom_tag = f"temp-{uuid.uuid4().hex[:8]}"
name = "tag_seed.bin"
_asset = asset_factory(
name,
["input", "unit-tests", scope, custom_tag],
{},
make_asset_bytes(name, 512),
)
# While the asset exists, the custom tag must appear when include_zero=false
r2 = http.get(
api_base + "/api/tags",
params={"include_zero": "false", "prefix": custom_tag, "limit": "50"},
timeout=120,
)
body2 = r2.json()
assert r2.status_code == 200
used_names = [t["name"] for t in body2["tags"]]
assert custom_tag in used_names
# Delete the asset so the tag usage drops to zero
rd = http.delete(f"{api_base}/api/assets/{_asset['id']}", timeout=120)
assert rd.status_code == 204
# Now the custom tag must not be returned when include_zero=false
r3 = http.get(
api_base + "/api/tags",
params={"include_zero": "false", "prefix": custom_tag, "limit": "50"},
timeout=120,
)
body3 = r3.json()
assert r3.status_code == 200
names_after = [t["name"] for t in body3["tags"]]
assert custom_tag not in names_after
assert not names_after # filtered view should be empty now
def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]
# Add tags with duplicates and mixed case
payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]}
r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add, timeout=120)
b1 = r1.json()
assert r1.status_code == 200, b1
# normalized, deduplicated; 'unit-tests' was already present from the seed
assert set(b1["added"]) == {"newtag", "beta"}
assert set(b1["already_present"]) == {"unit-tests"}
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"]
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
g = rg.json()
assert rg.status_code == 200
tags_now = set(g["tags"])
assert {"newtag", "beta"}.issubset(tags_now)
# Remove a tag and a non-existent tag
payload_del = {"tags": ["newtag", "does-not-exist"]}
r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del, timeout=120)
b2 = r2.json()
assert r2.status_code == 200
assert set(b2["removed"]) == {"newtag"}
assert set(b2["not_present"]) == {"does-not-exist"}
# Verify remaining tags after deletion
rg2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
g2 = rg2.json()
assert rg2.status_code == 200
tags_later = set(g2["tags"])
assert "newtag" not in tags_later
assert "beta" in tags_later # still present
def test_tags_list_order_and_prefix(http: requests.Session, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]
h = seeded_asset["asset_hash"]
# Add both tags to the seeded asset (usage: orderaaa=1, orderbbb=1)
r_add = http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": ["orderaaa", "orderbbb"]}, timeout=120)
add_body = r_add.json()
assert r_add.status_code == 200, add_body
# Create another AssetInfo from the same content but tagged ONLY with 'orderbbb'.
payload = {
"hash": h,
"name": "order_only_bbb.safetensors",
"tags": ["input", "unit-tests", "orderbbb"],
"user_metadata": {},
}
r_copy = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
copy_body = r_copy.json()
assert r_copy.status_code == 201, copy_body
# 1) Default order (count_desc): 'orderbbb' should come before 'orderaaa'
# because it has higher usage (2 vs 1).
r1 = http.get(api_base + "/api/tags", params={"prefix": "order", "include_zero": "false"}, timeout=120)
b1 = r1.json()
assert r1.status_code == 200, b1
names1 = [t["name"] for t in b1["tags"]]
counts1 = {t["name"]: t["count"] for t in b1["tags"]}
# Both must be present within the prefix subset
assert "orderaaa" in names1 and "orderbbb" in names1
# Usage of 'orderbbb' must be >= 'orderaaa'; in our setup it's 2 vs 1
assert counts1["orderbbb"] >= counts1["orderaaa"]
# And with count_desc, 'orderbbb' appears earlier than 'orderaaa'
assert names1.index("orderbbb") < names1.index("orderaaa")
# 2) name_asc: lexical order should flip the relative order
r2 = http.get(
api_base + "/api/tags",
params={"prefix": "order", "include_zero": "false", "order": "name_asc"},
timeout=120,
)
b2 = r2.json()
assert r2.status_code == 200, b2
names2 = [t["name"] for t in b2["tags"]]
assert "orderaaa" in names2 and "orderbbb" in names2
assert names2.index("orderaaa") < names2.index("orderbbb")
# 3) invalid limit rejected (existing negative case retained)
r3 = http.get(api_base + "/api/tags", params={"limit": "1001"}, timeout=120)
b3 = r3.json()
assert r3.status_code == 400
assert b3["error"]["code"] == "INVALID_QUERY"
def test_tags_endpoints_invalid_bodies(http: requests.Session, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]
# Add with empty list
r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": []}, timeout=120)
b1 = r1.json()
assert r1.status_code == 400
assert b1["error"]["code"] == "INVALID_BODY"
# Remove with wrong type
r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json={"tags": [123]}, timeout=120)
b2 = r2.json()
assert r2.status_code == 400
assert b2["error"]["code"] == "INVALID_BODY"
# metadata_filter provided as JSON array should be rejected (must be object)
r3 = http.get(
api_base + "/api/assets",
params={"metadata_filter": json.dumps([{"x": 1}])},
timeout=120,
)
b3 = r3.json()
assert r3.status_code == 400
assert b3["error"]["code"] == "INVALID_QUERY"
def test_tags_prefix_treats_underscore_literal(
http,
api_base,
asset_factory,
make_asset_bytes,
):
"""'prefix' for /api/tags must treat '_' literally, not as a wildcard."""
base = f"pref_{uuid.uuid4().hex[:6]}"
tag_ok = f"{base}_ok" # should match prefix=f"{base}_"
tag_bad = f"{base}xok" # must NOT match if '_' is escaped
scope = f"tags-underscore-{uuid.uuid4().hex[:6]}"
asset_factory("t1.bin", ["input", "unit-tests", scope, tag_ok], {}, make_asset_bytes("t1", 512))
asset_factory("t2.bin", ["input", "unit-tests", scope, tag_bad], {}, make_asset_bytes("t2", 512))
r = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": f"{base}_"}, timeout=120)
body = r.json()
assert r.status_code == 200, body
names = [t["name"] for t in body["tags"]]
assert tag_ok in names, f"Expected {tag_ok} to be returned for prefix '{base}_'"
assert tag_bad not in names, f"'{tag_bad}' must not match — '_' is not a wildcard"
assert body["total"] == 1

View File

@ -1,281 +0,0 @@
import json
import uuid
from concurrent.futures import ThreadPoolExecutor
import requests
import pytest
def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes):
name = "dup_a.safetensors"
tags = ["models", "checkpoints", "unit-tests", "alpha"]
meta = {"purpose": "dup"}
data = make_asset_bytes(name)
files = {"file": (name, data, "application/octet-stream")}
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
a1 = r1.json()
assert r1.status_code == 201, a1
assert a1["created_new"] is True
# Second upload with the same data and name should return created_new == False and the same asset
files = {"file": (name, data, "application/octet-stream")}
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
a2 = r2.json()
assert r2.status_code == 200, a2
assert a2["created_new"] is False
assert a2["asset_hash"] == a1["asset_hash"]
assert a2["id"] == a1["id"] # old reference
# Third upload with the same data but new name should return created_new == False and the new AssetReference
files = {"file": (name, data, "application/octet-stream")}
form = {"tags": json.dumps(tags), "name": name + "_d", "user_metadata": json.dumps(meta)}
r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
a3 = r2.json()
assert r2.status_code == 200, a3
assert a3["created_new"] is False
assert a3["asset_hash"] == a1["asset_hash"]
assert a3["id"] != a1["id"] # old reference
def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str):
# Seed a small file first
name = "fastpath_seed.safetensors"
tags = ["models", "checkpoints", "unit-tests"]
meta = {}
files = {"file": (name, b"B" * 1024, "application/octet-stream")}
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
b1 = r1.json()
assert r1.status_code == 201, b1
h = b1["asset_hash"]
# Now POST /api/assets with only hash and no file
files = [
("hash", (None, h)),
("tags", (None, json.dumps(tags))),
("name", (None, "fastpath_copy.safetensors")),
("user_metadata", (None, json.dumps({"purpose": "copy"}))),
]
r2 = http.post(api_base + "/api/assets", files=files, timeout=120)
b2 = r2.json()
assert r2.status_code == 200, b2 # fast path returns 200 with created_new == False
assert b2["created_new"] is False
assert b2["asset_hash"] == h
def test_upload_fastpath_with_known_hash_and_file(
http: requests.Session, api_base: str
):
# Seed
files = {"file": ("seed.safetensors", b"C" * 128, "application/octet-stream")}
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})}
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
b1 = r1.json()
assert r1.status_code == 201, b1
h = b1["asset_hash"]
# Send both file and hash of existing content -> server must drain file and create from hash (200)
files = {"file": ("ignored.bin", b"ignored" * 10, "application/octet-stream")}
form = {"hash": h, "tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "copy_from_hash.safetensors", "user_metadata": json.dumps({})}
r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
b2 = r2.json()
assert r2.status_code == 200, b2
assert b2["created_new"] is False
assert b2["asset_hash"] == h
def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str):
data = [
("tags", "models,checkpoints"),
("tags", json.dumps(["unit-tests", "alpha"])),
("name", "merge.safetensors"),
("user_metadata", json.dumps({"u": 1})),
]
files = {"file": ("merge.safetensors", b"B" * 256, "application/octet-stream")}
r1 = http.post(api_base + "/api/assets", data=data, files=files, timeout=120)
created = r1.json()
assert r1.status_code in (200, 201), created
aid = created["id"]
# Verify all tags are present on the resource
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
detail = rg.json()
assert rg.status_code == 200, detail
tags = set(detail["tags"])
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags)
@pytest.mark.parametrize("root", ["input", "output"])
def test_concurrent_upload_identical_bytes_different_names(
root: str,
http: requests.Session,
api_base: str,
make_asset_bytes,
):
"""
Two concurrent uploads of identical bytes but different names.
Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True.
"""
scope = f"concupload-{uuid.uuid4().hex[:6]}"
name1, name2 = "cu_a.bin", "cu_b.bin"
data = make_asset_bytes("concurrent", 4096)
tags = [root, "unit-tests", scope]
def _do_upload(args):
url, form_data, files_data = args
with requests.Session() as s:
return s.post(url, data=form_data, files=files_data, timeout=120)
url = api_base + "/api/assets"
form1 = {"tags": json.dumps(tags), "name": name1, "user_metadata": json.dumps({})}
files1 = {"file": (name1, data, "application/octet-stream")}
form2 = {"tags": json.dumps(tags), "name": name2, "user_metadata": json.dumps({})}
files2 = {"file": (name2, data, "application/octet-stream")}
with ThreadPoolExecutor(max_workers=2) as executor:
futures = list(executor.map(_do_upload, [(url, form1, files1), (url, form2, files2)]))
r1, r2 = futures
b1, b2 = r1.json(), r2.json()
assert r1.status_code in (200, 201), b1
assert r2.status_code in (200, 201), b2
assert b1["asset_hash"] == b2["asset_hash"]
assert b1["id"] != b2["id"]
created_flags = sorted([bool(b1.get("created_new")), bool(b2.get("created_new"))])
assert created_flags == [False, True]
rl = http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "sort": "name"},
timeout=120,
)
bl = rl.json()
assert rl.status_code == 200, bl
names = [a["name"] for a in bl.get("assets", [])]
assert set([name1, name2]).issubset(names)
def test_create_from_hash_endpoint_404(http: requests.Session, api_base: str):
payload = {
"hash": "blake3:" + "0" * 64,
"name": "nonexistent.bin",
"tags": ["models", "checkpoints", "unit-tests"],
}
r = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120)
body = r.json()
assert r.status_code == 404
assert body["error"]["code"] == "ASSET_NOT_FOUND"
def test_upload_zero_byte_rejected(http: requests.Session, api_base: str):
files = {"file": ("empty.safetensors", b"", "application/octet-stream")}
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == "EMPTY_UPLOAD"
def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str):
files = {"file": ("badroot.bin", b"A" * 64, "application/octet-stream")}
form = {"tags": json.dumps(["not-a-root", "whatever"]), "name": "badroot.bin", "user_metadata": json.dumps({})}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == "INVALID_BODY"
def test_upload_user_metadata_must_be_json(http: requests.Session, api_base: str):
files = {"file": ("badmeta.bin", b"A" * 128, "application/octet-stream")}
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == "INVALID_BODY"
def test_upload_requires_multipart(http: requests.Session, api_base: str):
r = http.post(api_base + "/api/assets", json={"foo": "bar"}, timeout=120)
body = r.json()
assert r.status_code == 415
assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE"
def test_upload_missing_file_and_hash(http: requests.Session, api_base: str):
files = [
("tags", (None, json.dumps(["models", "checkpoints", "unit-tests"]))),
("name", (None, "x.safetensors")),
]
r = http.post(api_base + "/api/assets", files=files, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == "MISSING_FILE"
def test_upload_models_unknown_category(http: requests.Session, api_base: str):
files = {"file": ("m.safetensors", b"A" * 128, "application/octet-stream")}
form = {"tags": json.dumps(["models", "no_such_category", "unit-tests"]), "name": "m.safetensors"}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == "INVALID_BODY"
assert body["error"]["message"].startswith("unknown models category")
def test_upload_models_requires_category(http: requests.Session, api_base: str):
files = {"file": ("nocat.safetensors", b"A" * 64, "application/octet-stream")}
form = {"tags": json.dumps(["models"]), "name": "nocat.safetensors", "user_metadata": json.dumps({})}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == "INVALID_BODY"
def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
files = {"file": ("evil.safetensors", b"A" * 256, "application/octet-stream")}
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
@pytest.mark.parametrize("root", ["input", "output"])
def test_duplicate_upload_same_display_name_does_not_clobber(
root: str,
http: requests.Session,
api_base: str,
asset_factory,
make_asset_bytes,
):
"""
Two uploads use the same tags and the same display name but different bytes.
With hash-based filenames, they must NOT overwrite each other. Both assets
remain accessible and serve their original content.
"""
scope = f"dup-path-{uuid.uuid4().hex[:6]}"
display_name = "same_display.bin"
d1 = make_asset_bytes(scope + "-v1", 1536)
d2 = make_asset_bytes(scope + "-v2", 2048)
tags = [root, "unit-tests", scope]
first = asset_factory(display_name, tags, {}, d1)
second = asset_factory(display_name, tags, {}, d2)
assert first["id"] != second["id"]
assert first["asset_hash"] != second["asset_hash"] # different content
assert first["name"] == second["name"] == display_name
# Both must be independently retrievable
r1 = http.get(f"{api_base}/api/assets/{first['id']}/content", timeout=120)
b1 = r1.content
assert r1.status_code == 200
assert b1 == d1
r2 = http.get(f"{api_base}/api/assets/{second['id']}/content", timeout=120)
b2 = r2.content
assert r2.status_code == 200
assert b2 == d2