mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 09:07:02 +08:00
Compare commits
4 Commits
assets-red
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 8aabe2403e | |||
| 0167653781 | |||
| 0a7993729c | |||
| bbe2c13a70 |
30
.github/workflows/test-assets.yml
vendored
30
.github/workflows/test-assets.yml
vendored
@ -1,30 +0,0 @@
|
||||
name: Assets Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master, release/** ]
|
||||
pull_request:
|
||||
branches: [ main, master, release/** ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
- name: Run Assets Tests
|
||||
run: |
|
||||
pip install -r tests-assets/requirements.txt
|
||||
python -m pytest tests-assets -v
|
||||
@ -1,8 +1,5 @@
|
||||
import logging
|
||||
import uuid
|
||||
import urllib.parse
|
||||
import os
|
||||
import contextlib
|
||||
from aiohttp import web
|
||||
|
||||
from pydantic import ValidationError
|
||||
@ -11,9 +8,6 @@ import app.assets.manager as manager
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in
|
||||
from app.assets.helpers import get_query_dict
|
||||
from app.assets.scanner import seed_assets
|
||||
|
||||
import folder_paths
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
@ -21,9 +15,6 @@ USER_MANAGER: user_manager.UserManager | None = None
|
||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
# Note to any custom node developers reading this code:
|
||||
# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same.
|
||||
|
||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||
global USER_MANAGER
|
||||
USER_MANAGER = user_manager_instance
|
||||
@ -37,18 +28,6 @@ def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||
if not hash_str or ":" not in hash_str:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = hash_str.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
exists = manager.asset_exists(asset_hash=hash_str)
|
||||
return web.Response(status=200 if exists else 404)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@ -71,7 +50,7 @@ async def list_assets(request: web.Request) -> web.Response:
|
||||
order=q.order,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
@ -97,314 +76,6 @@ async def get_asset(request: web.Request) -> web.Response:
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||
async def download_asset_content(request: web.Request) -> web.Response:
|
||||
# question: do we need disposition? could we just stick with one of these?
|
||||
disposition = request.query.get("disposition", "attachment").lower().strip()
|
||||
if disposition not in {"inline", "attachment"}:
|
||||
disposition = "attachment"
|
||||
|
||||
try:
|
||||
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
|
||||
asset_info_id=str(uuid.UUID(request.match_info["id"])),
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||
except NotImplementedError as nie:
|
||||
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
|
||||
except FileNotFoundError:
|
||||
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
|
||||
|
||||
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
|
||||
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
|
||||
|
||||
file_size = os.path.getsize(abs_path)
|
||||
logging.info(
|
||||
"download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s",
|
||||
abs_path,
|
||||
file_size,
|
||||
file_size / (1024 * 1024),
|
||||
content_type,
|
||||
filename,
|
||||
)
|
||||
|
||||
async def file_sender():
|
||||
chunk_size = 64 * 1024
|
||||
with open(abs_path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
return web.Response(
|
||||
body=file_sender(),
|
||||
content_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": cd,
|
||||
"Content-Length": str(file_size),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/from-hash")
|
||||
async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
body = schemas_in.CreateFromHashBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
result = manager.create_asset_from_hash(
|
||||
hash_str=body.hash,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||
return web.json_response(result.model_dump(mode="json"), status=201)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
async def upload_asset(request: web.Request) -> web.Response:
|
||||
"""Multipart/form-data endpoint for Asset uploads."""
|
||||
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
file_present = False
|
||||
file_client_name: str | None = None
|
||||
tags_raw: list[str] = []
|
||||
provided_name: str | None = None
|
||||
user_metadata_raw: str | None = None
|
||||
provided_hash: str | None = None
|
||||
provided_hash_exists: bool | None = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: str | None = None
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
fname = getattr(field, "name", "") or ""
|
||||
|
||||
if fname == "hash":
|
||||
try:
|
||||
s = ((await field.text()) or "").strip().lower()
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
|
||||
if s:
|
||||
if ":" not in s:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
provided_hash = f"{algo}:{digest}"
|
||||
try:
|
||||
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
|
||||
except Exception:
|
||||
provided_hash_exists = None # do not fail the whole request here
|
||||
|
||||
elif fname == "file":
|
||||
file_present = True
|
||||
file_client_name = (field.filename or "").strip()
|
||||
|
||||
if provided_hash and provided_hash_exists is True:
|
||||
# If client supplied a hash that we know exists, drain but do not write to disk
|
||||
try:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
|
||||
continue # Do not create temp file; we will create AssetInfo from the existing content
|
||||
|
||||
# Otherwise, store to temp for hashing/ingest
|
||||
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||
os.makedirs(unique_dir, exist_ok=True)
|
||||
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
try:
|
||||
if os.path.exists(tmp_path or ""):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
|
||||
elif fname == "tags":
|
||||
tags_raw.append((await field.text()) or "")
|
||||
elif fname == "name":
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
|
||||
# If client did not send file, and we are not doing a from-hash fast path -> error
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
|
||||
|
||||
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
|
||||
# Empty upload is only acceptable if we are fast-pathing from existing hash
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||
|
||||
try:
|
||||
spec = schemas_in.UploadAssetSpec.model_validate({
|
||||
"tags": tags_raw,
|
||||
"name": provided_name,
|
||||
"user_metadata": user_metadata_raw,
|
||||
"hash": provided_hash,
|
||||
})
|
||||
except ValidationError as ve:
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
# Validate models category against configured folders (consistent with previous behavior)
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return _error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
|
||||
)
|
||||
|
||||
owner_id = USER_MANAGER.get_request_user_id(request)
|
||||
|
||||
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
|
||||
if spec.hash and provided_hash_exists is True:
|
||||
try:
|
||||
result = manager.create_asset_from_hash(
|
||||
hash_str=spec.hash,
|
||||
name=spec.name or (spec.hash.split(":", 1)[1]),
|
||||
tags=spec.tags,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
owner_id=owner_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
|
||||
|
||||
# Drain temp if we accidentally saved (e.g., hash field came after file)
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
with contextlib.suppress(Exception):
|
||||
os.remove(tmp_path)
|
||||
|
||||
status = 200 if (not result.created_new) else 201
|
||||
return web.json_response(result.model_dump(mode="json"), status=status)
|
||||
|
||||
# Otherwise, we must have a temp file path to ingest
|
||||
if not tmp_path or not os.path.exists(tmp_path):
|
||||
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
|
||||
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
|
||||
|
||||
try:
|
||||
created = manager.upload_asset_from_temp_path(
|
||||
spec,
|
||||
temp_path=tmp_path,
|
||||
client_filename=file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_asset_hash=spec.hash,
|
||||
)
|
||||
status = 201 if created.created_new else 200
|
||||
return web.json_response(created.model_dump(mode="json"), status=status)
|
||||
except ValueError as e:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
msg = str(e)
|
||||
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
|
||||
return _error_response(
|
||||
400,
|
||||
"HASH_MISMATCH",
|
||||
"Uploaded file hash does not match provided hash.",
|
||||
)
|
||||
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
|
||||
except Exception:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def update_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.update_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
name=body.name,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"update_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def delete_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
delete_content = request.query.get("delete_content")
|
||||
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
|
||||
|
||||
try:
|
||||
deleted = manager.delete_asset_reference(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
delete_content_if_orphan=delete_content,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if not deleted:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
|
||||
return web.Response(status=204)
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@ -429,86 +100,3 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsAdd.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.add_tags_to_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
origin="manual",
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsRemove.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.remove_tags_from_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/seed")
|
||||
async def seed_assets_endpoint(request: web.Request) -> web.Response:
|
||||
"""Trigger asset seeding for specified roots (models, input, output)."""
|
||||
try:
|
||||
payload = await request.json()
|
||||
roots = payload.get("roots", ["models", "input", "output"])
|
||||
except Exception:
|
||||
roots = ["models", "input", "output"]
|
||||
|
||||
valid_roots = [r for r in roots if r in ("models", "input", "output")]
|
||||
if not valid_roots:
|
||||
return _error_response(400, "INVALID_BODY", "No valid roots specified")
|
||||
|
||||
try:
|
||||
seed_assets(tuple(valid_roots))
|
||||
except Exception:
|
||||
logging.exception("seed_assets failed for roots=%s", valid_roots)
|
||||
return _error_response(500, "INTERNAL", "Seed operation failed")
|
||||
|
||||
return web.json_response({"seeded": valid_roots}, status=200)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
@ -7,9 +8,9 @@ from pydantic import (
|
||||
Field,
|
||||
conint,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
@ -56,57 +57,6 @@ class ListAssetsQuery(BaseModel):
|
||||
return None
|
||||
|
||||
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: str | None = None
|
||||
user_metadata: dict[str, Any] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
if self.name is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, user_metadata.")
|
||||
return self
|
||||
|
||||
|
||||
class CreateFromHashBody(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
hash: str
|
||||
name: str
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
def _require_blake3(cls, v):
|
||||
s = (v or "").strip().lower()
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return s
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _tags_norm(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||
seen = set()
|
||||
dedup = []
|
||||
for t in out:
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
dedup.append(t)
|
||||
return dedup
|
||||
if isinstance(v, str):
|
||||
return [t.strip().lower() for t in v.split(",") if t.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
@ -125,140 +75,20 @@ class TagsListQuery(BaseModel):
|
||||
return v.lower() or None
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
preview_id: str | None = None
|
||||
|
||||
@field_validator("tags")
|
||||
@field_validator("preview_id", mode="before")
|
||||
@classmethod
|
||||
def normalize_tags(cls, v: list[str]) -> list[str]:
|
||||
out = []
|
||||
for t in v:
|
||||
if not isinstance(t, str):
|
||||
raise TypeError("tags must be strings")
|
||||
tnorm = t.strip().lower()
|
||||
if tnorm:
|
||||
out.append(tnorm)
|
||||
seen = set()
|
||||
deduplicated = []
|
||||
for x in out:
|
||||
if x not in seen:
|
||||
seen.add(x)
|
||||
deduplicated.append(x)
|
||||
return deduplicated
|
||||
|
||||
|
||||
class TagsRemove(TagsAdd):
|
||||
pass
|
||||
|
||||
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
|
||||
|
||||
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||
and the original extension is preserved when available.
|
||||
"""
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
hash: str | None = Field(default=None)
|
||||
|
||||
@field_validator("hash", mode="before")
|
||||
@classmethod
|
||||
def _parse_hash(cls, v):
|
||||
def _norm_uuid(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip().lower()
|
||||
s = str(v).strip()
|
||||
if not s:
|
||||
return None
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return f"{algo}:{digest}"
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _parse_tags(cls, v):
|
||||
"""
|
||||
Accepts a list of strings (possibly multiple form fields),
|
||||
where each string can be:
|
||||
- JSON array (e.g., '["models","loras","foo"]')
|
||||
- comma-separated ('models, loras, foo')
|
||||
- single token ('models')
|
||||
Returns a normalized, deduplicated, ordered list.
|
||||
"""
|
||||
items: list[str] = []
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
v = [v]
|
||||
|
||||
if isinstance(v, list):
|
||||
for item in v:
|
||||
if item is None:
|
||||
continue
|
||||
s = str(item).strip()
|
||||
if not s:
|
||||
continue
|
||||
if s.startswith("["):
|
||||
try:
|
||||
arr = json.loads(s)
|
||||
if isinstance(arr, list):
|
||||
items.extend(str(x) for x in arr)
|
||||
continue
|
||||
except Exception:
|
||||
pass # fallback to CSV parse below
|
||||
items.extend([p for p in s.split(",") if p.strip()])
|
||||
else:
|
||||
return []
|
||||
|
||||
# normalize + dedupe
|
||||
norm = []
|
||||
seen = set()
|
||||
for t in items:
|
||||
tnorm = str(t).strip().lower()
|
||||
if tnorm and tnorm not in seen:
|
||||
seen.add(tnorm)
|
||||
norm.append(tnorm)
|
||||
return norm
|
||||
|
||||
@field_validator("user_metadata", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v or {}
|
||||
if isinstance(v, str):
|
||||
s = v.strip()
|
||||
if not s:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
except Exception as e:
|
||||
raise ValueError(f"user_metadata must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("user_metadata must be a JSON object")
|
||||
return parsed
|
||||
return {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("tags must be provided and non-empty")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError("models uploads require a category tag as the second tag")
|
||||
return self
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
|
||||
@ -29,21 +29,6 @@ class AssetsList(BaseModel):
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetUpdated(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _ser_updated(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -63,10 +48,6 @@ class AssetDetail(BaseModel):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(AssetDetail):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
@ -77,17 +58,3 @@ class TagsList(BaseModel):
|
||||
tags: list[TagUsage] = Field(default_factory=list)
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
added: list[str] = Field(default_factory=list)
|
||||
already_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagsRemove(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
@ -1,17 +1,9 @@
|
||||
import os
|
||||
import logging
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Iterable, Any
|
||||
from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy import select, exists, func
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import (
|
||||
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
|
||||
)
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import escape_like_prefix, normalize_tags
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
@ -23,22 +15,6 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
|
||||
"""
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
@ -66,7 +42,6 @@ def apply_tag_filters(
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
@ -119,11 +94,7 @@ def apply_metadata_filter(
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_exists_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> bool:
|
||||
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
@ -134,39 +105,9 @@ def asset_exists_by_hash(
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
def asset_info_exists_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (session.execute(q)).first() is not None
|
||||
|
||||
|
||||
def get_asset_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> Asset | None:
|
||||
return (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
|
||||
def get_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
) -> AssetInfo | None:
|
||||
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
@ -230,14 +171,12 @@ def list_asset_infos_page(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
.order_by(AssetInfoTag.added_at)
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
@ -269,494 +208,6 @@ def fetch_asset_info_asset_and_tags(
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
def fetch_asset_info_and_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
def list_cache_states_by_asset_id(
|
||||
session: Session, *, asset_id: str
|
||||
) -> Sequence[AssetCacheState]:
|
||||
return (
|
||||
session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
def touch_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
ts: datetime | None = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
def create_asset_info_for_existing_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||
now = utcnow()
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
try:
|
||||
with session.begin_nested():
|
||||
session.add(info)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# metadata["filename"] hack
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
if new_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def set_asset_info_tags(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
session.flush()
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
def replace_asset_info_metadata_projection(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
user_metadata: dict | None = None,
|
||||
) -> None:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def ingest_fs_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: str | None = None,
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
|
||||
if preview_id:
|
||||
if not session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
out: dict[str, Any] = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
|
||||
# 1) Asset by hash
|
||||
asset = (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
vals = {
|
||||
"hash": asset_hash,
|
||||
"size_bytes": int(size_bytes),
|
||||
"mime_type": mime_type,
|
||||
"created_at": now,
|
||||
}
|
||||
res = session.execute(
|
||||
sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["asset_created"] = True
|
||||
asset = (
|
||||
session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
else:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
out["asset_updated"] = True
|
||||
|
||||
# 2) AssetCacheState upsert by file_path (unique)
|
||||
vals = {
|
||||
"asset_id": asset.id,
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
ins = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
|
||||
res = session.execute(ins)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["state_created"] = True
|
||||
else:
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset.id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||
)
|
||||
res2 = session.execute(upd)
|
||||
if int(res2.rowcount or 0) > 0:
|
||||
out["state_updated"] = True
|
||||
|
||||
# 3) Optional AssetInfo + tags + metadata
|
||||
if info_name:
|
||||
try:
|
||||
with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
out["asset_info_id"] = info.id
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
existing_info = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == info_name,
|
||||
(AssetInfo.owner_id == owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not existing_info:
|
||||
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||
|
||||
if preview_id and existing_info.preview_id != preview_id:
|
||||
existing_info.preview_id = preview_id
|
||||
|
||||
existing_info.updated_at = now
|
||||
if existing_info.last_access_time < now:
|
||||
existing_info.last_access_time = now
|
||||
session.flush()
|
||||
out["asset_info_id"] = existing_info.id
|
||||
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
if not require_existing_tags:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_at=now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata is not None:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
return out
|
||||
|
||||
|
||||
def update_asset_info_full(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tag_origin: str = "manual",
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
else:
|
||||
if computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def delete_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str,
|
||||
) -> bool:
|
||||
stmt = sa.delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
return int((session.execute(stmt)).rowcount or 0) > 0
|
||||
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
@ -814,163 +265,3 @@ def list_tags_with_usage(
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
session.execute(ins)
|
||||
|
||||
|
||||
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def add_tags_to_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=utcnow(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
nested.rollback()
|
||||
|
||||
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
def remove_tags_from_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sa.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_asset_info_preview(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
# validate preview asset exists
|
||||
if not session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import contextlib
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from aiohttp import web
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
@ -88,40 +87,6 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
root = tags[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
else:
|
||||
base_dir = os.path.abspath(
|
||||
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||
)
|
||||
raw_subdirs = tags[1:]
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
def ensure_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = os.path.abspath(candidate)
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
@ -148,6 +113,7 @@ def compute_relative_filename(file_path: str) -> str | None:
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
@ -249,64 +215,3 @@ def collect_models_files() -> list[str]:
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def project_kv(key: str, value):
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
@ -1,33 +1,13 @@
|
||||
import os
|
||||
import mimetypes
|
||||
import contextlib
|
||||
from typing import Sequence
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.assets.api import schemas_out, schemas_in
|
||||
from app.assets.api import schemas_out
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
fetch_asset_info_and_asset,
|
||||
create_asset_info_for_existing_asset,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
delete_asset_info_by_id,
|
||||
list_cache_states_by_asset_id,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
get_asset_tags,
|
||||
add_tags_to_asset_info,
|
||||
remove_tags_from_asset_info,
|
||||
pick_best_live_path,
|
||||
ingest_fs_asset,
|
||||
set_asset_info_preview,
|
||||
)
|
||||
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
|
||||
from app.assets.database.models import Asset
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
@ -39,28 +19,11 @@ def _safe_sort_field(requested: str | None) -> str:
|
||||
return "created_at"
|
||||
|
||||
|
||||
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
def _safe_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
if n:
|
||||
return n
|
||||
return fallback
|
||||
|
||||
|
||||
def asset_exists(*, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
def asset_exists(asset_hash: str) -> bool:
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def list_assets(
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
@ -100,6 +63,7 @@ def list_assets(
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
preview_url=f"/api/assets/{info.id}/content",
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
@ -112,12 +76,7 @@ def list_assets(
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
|
||||
def get_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
@ -138,358 +97,6 @@ def get_asset(
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def resolve_asset_content_for_download(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
with create_session() as session:
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = pick_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError
|
||||
|
||||
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
def upload_asset_from_temp_path(
|
||||
spec: schemas_in.UploadAssetSpec,
|
||||
*,
|
||||
temp_path: str,
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_asset_hash: str | None = None,
|
||||
) -> schemas_out.AssetCreated:
|
||||
"""
|
||||
Create new asset or update existing asset from a temporary file path.
|
||||
"""
|
||||
try:
|
||||
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
|
||||
import app.assets.hashing as hashing
|
||||
digest = hashing.blake3_hash(temp_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||
raise ValueError("HASH_MISMATCH")
|
||||
|
||||
with create_session() as session:
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=existing.hash,
|
||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||
mime_type=existing.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or spec.name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
ensure_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
with create_session() as session:
|
||||
result = ingest_fs_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result["asset_info_id"]
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
created_result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return created_result
|
||||
|
||||
|
||||
def update_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
info = update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
result = schemas_out.AssetUpdated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset.hash if info.asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
updated_at=info.updated_at,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def set_asset_preview(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
info, asset, tags = res
|
||||
result = schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
session.delete(asset_row)
|
||||
|
||||
session.commit()
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
return True
|
||||
|
||||
|
||||
def create_asset_from_hash(
|
||||
*,
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetCreated | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=canonical,
|
||||
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def add_tags_to_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
data = add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsAdd(**data)
|
||||
|
||||
|
||||
def remove_tags_from_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsRemove(**data)
|
||||
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
|
||||
@ -27,7 +27,6 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
t_start = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
orphans_pruned = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
@ -39,11 +38,6 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||
|
||||
try:
|
||||
orphans_pruned = _prune_orphaned_assets(roots)
|
||||
except Exception as e:
|
||||
logging.exception("orphan pruning failed: %s", e)
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
@ -91,43 +85,15 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
finally:
|
||||
if enable_logging:
|
||||
logging.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_start,
|
||||
created,
|
||||
skipped_existing,
|
||||
orphans_pruned,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
|
||||
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
|
||||
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
|
||||
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
|
||||
if not all_prefixes:
|
||||
return 0
|
||||
|
||||
def make_prefix_condition(prefix: str):
|
||||
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
|
||||
|
||||
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
|
||||
|
||||
orphan_subq = (
|
||||
sqlalchemy.select(Asset.id)
|
||||
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
|
||||
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
|
||||
).scalar_subquery()
|
||||
|
||||
with create_session() as sess:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
|
||||
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
|
||||
sess.commit()
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def _fast_db_consistency_pass(
|
||||
root: RootType,
|
||||
*,
|
||||
|
||||
@ -1146,6 +1146,20 @@ class ImageCompare(ComfyTypeI):
|
||||
def as_dict(self):
|
||||
return super().as_dict()
|
||||
|
||||
|
||||
@comfytype(io_type="COLOR")
|
||||
class Color(ComfyTypeIO):
|
||||
Type = str
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, advanced: bool=None, default: str="#ffffff"):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
self.default: str
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict()
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||
@ -1252,23 +1266,6 @@ class NodeInfoV1:
|
||||
price_badge: dict | None = None
|
||||
search_aliases: list[str]=None
|
||||
|
||||
@dataclass
|
||||
class NodeInfoV3:
|
||||
input: dict=None
|
||||
output: dict=None
|
||||
hidden: list[str]=None
|
||||
name: str=None
|
||||
display_name: str=None
|
||||
description: str=None
|
||||
python_module: Any = None
|
||||
category: str=None
|
||||
output_node: bool=None
|
||||
deprecated: bool=None
|
||||
experimental: bool=None
|
||||
dev_only: bool=None
|
||||
api_node: bool=None
|
||||
price_badge: dict | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriceBadgeDepends:
|
||||
@ -1497,40 +1494,6 @@ class Schema:
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def get_v3_info(self, cls) -> NodeInfoV3:
|
||||
input_dict = {}
|
||||
output_dict = {}
|
||||
hidden_list = []
|
||||
# TODO: make sure dynamic types will be handled correctly
|
||||
if self.inputs:
|
||||
for input in self.inputs:
|
||||
add_to_dict_v3(input, input_dict)
|
||||
if self.outputs:
|
||||
for output in self.outputs:
|
||||
add_to_dict_v3(output, output_dict)
|
||||
if self.hidden:
|
||||
for hidden in self.hidden:
|
||||
hidden_list.append(hidden.value)
|
||||
|
||||
info = NodeInfoV3(
|
||||
input=input_dict,
|
||||
output=output_dict,
|
||||
hidden=hidden_list,
|
||||
name=self.node_id,
|
||||
display_name=self.display_name,
|
||||
description=self.description,
|
||||
category=self.category,
|
||||
output_node=self.is_output_node,
|
||||
deprecated=self.is_deprecated,
|
||||
experimental=self.is_experimental,
|
||||
dev_only=self.is_dev_only,
|
||||
api_node=self.is_api_node,
|
||||
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
|
||||
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
|
||||
)
|
||||
return info
|
||||
|
||||
def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]:
|
||||
out_dict = {
|
||||
"required": {},
|
||||
@ -1585,9 +1548,6 @@ def add_to_dict_v1(i: Input, d: dict):
|
||||
as_dict.pop("optional", None)
|
||||
d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
|
||||
|
||||
def add_to_dict_v3(io: Input | Output, d: dict):
|
||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||
|
||||
class DynamicPathsDefaultValue:
|
||||
EMPTY_DICT = "empty_dict"
|
||||
|
||||
@ -1748,13 +1708,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
# set hidden
|
||||
type_clone.hidden = HiddenHolder.from_v3_data(v3_data)
|
||||
return type_clone
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
|
||||
schema = cls.GET_SCHEMA()
|
||||
info = schema.get_v3_info(cls)
|
||||
return asdict(info)
|
||||
#############################################
|
||||
# V1 Backwards Compatibility code
|
||||
#--------------------------------------------
|
||||
@ -2099,6 +2052,7 @@ __all__ = [
|
||||
"AnyType",
|
||||
"MultiType",
|
||||
"Tracks",
|
||||
"Color",
|
||||
# Dynamic Types
|
||||
"MatchType",
|
||||
"DynamicCombo",
|
||||
@ -2107,12 +2061,10 @@ __all__ = [
|
||||
"HiddenHolder",
|
||||
"Hidden",
|
||||
"NodeInfoV1",
|
||||
"NodeInfoV3",
|
||||
"Schema",
|
||||
"ComfyNode",
|
||||
"NodeOutput",
|
||||
"add_to_dict_v1",
|
||||
"add_to_dict_v3",
|
||||
"V3Data",
|
||||
"ImageCompare",
|
||||
"PriceBadgeDepends",
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, conint, confloat
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RecraftColor:
|
||||
@ -229,24 +226,24 @@ class RecraftColorObject(BaseModel):
|
||||
|
||||
|
||||
class RecraftControlsObject(BaseModel):
|
||||
colors: Optional[list[RecraftColorObject]] = Field(None, description='An array of preferable colors')
|
||||
background_color: Optional[RecraftColorObject] = Field(None, description='Use given color as a desired background color')
|
||||
no_text: Optional[bool] = Field(None, description='Do not embed text layouts')
|
||||
artistic_level: Optional[conint(ge=0, le=5)] = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].')
|
||||
colors: list[RecraftColorObject] | None = Field(None, description='An array of preferable colors')
|
||||
background_color: RecraftColorObject | None = Field(None, description='Use given color as a desired background color')
|
||||
no_text: bool | None = Field(None, description='Do not embed text layouts')
|
||||
artistic_level: int | None = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].')
|
||||
|
||||
|
||||
class RecraftImageGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt describing the image to generate')
|
||||
size: Optional[RecraftImageSize] = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
||||
n: conint(ge=1, le=6) = Field(..., description='The number of images to generate')
|
||||
negative_prompt: Optional[str] = Field(None, description='A text description of undesired elements on an image')
|
||||
model: Optional[RecraftModel] = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
|
||||
style: Optional[str] = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
|
||||
substyle: Optional[str] = Field(None, description='The substyle to apply to the generated image, depending on the style input')
|
||||
controls: Optional[RecraftControlsObject] = Field(None, description='A set of custom parameters to tweak generation process')
|
||||
style_id: Optional[str] = Field(None, description='Use a previously uploaded style as a reference; UUID')
|
||||
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
|
||||
random_seed: Optional[int] = Field(None, description="Seed for video generation")
|
||||
size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
||||
n: int = Field(..., description='The number of images to generate')
|
||||
negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image')
|
||||
model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
|
||||
style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
|
||||
substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input')
|
||||
controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process')
|
||||
style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID')
|
||||
strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
|
||||
random_seed: int | None = Field(None, description="Seed for video generation")
|
||||
# text_layout
|
||||
|
||||
|
||||
@ -258,5 +255,13 @@ class RecraftReturnedObject(BaseModel):
|
||||
class RecraftImageGenerationResponse(BaseModel):
|
||||
created: int = Field(..., description='Unix timestamp when the generation was created')
|
||||
credits: int = Field(..., description='Number of credits used for the generation')
|
||||
data: Optional[list[RecraftReturnedObject]] = Field(None, description='Array of generated image information')
|
||||
image: Optional[RecraftReturnedObject] = Field(None, description='Single generated image')
|
||||
data: list[RecraftReturnedObject] | None = Field(None, description='Array of generated image information')
|
||||
image: RecraftReturnedObject | None = Field(None, description='Single generated image')
|
||||
|
||||
|
||||
class RecraftCreateStyleRequest(BaseModel):
|
||||
style: str = Field(..., description="realistic_image, digital_illustration, vector_illustration, or icon")
|
||||
|
||||
|
||||
class RecraftCreateStyleResponse(BaseModel):
|
||||
id: str = Field(..., description="UUID of the created style")
|
||||
|
||||
@ -12,6 +12,8 @@ from comfy_api_nodes.apis.recraft import (
|
||||
RecraftColor,
|
||||
RecraftColorChain,
|
||||
RecraftControls,
|
||||
RecraftCreateStyleRequest,
|
||||
RecraftCreateStyleResponse,
|
||||
RecraftImageGenerationRequest,
|
||||
RecraftImageGenerationResponse,
|
||||
RecraftImageSize,
|
||||
@ -323,6 +325,75 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
|
||||
return IO.NodeOutput(RecraftStyle(style_id=style_id))
|
||||
|
||||
|
||||
class RecraftCreateStyleNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RecraftCreateStyleNode",
|
||||
display_name="Recraft Create Style",
|
||||
category="api node/image/Recraft",
|
||||
description="Create a custom style from reference images. "
|
||||
"Upload 1-5 images to use as style references. "
|
||||
"Total size of all images is limited to 5 MB.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"style",
|
||||
options=["realistic_image", "digital_illustration"],
|
||||
tooltip="The base style of the generated images.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplatePrefix(
|
||||
IO.Image.Input("image"),
|
||||
prefix="image",
|
||||
min=1,
|
||||
max=5,
|
||||
),
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="style_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.04}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
style: str,
|
||||
images: IO.Autogrow.Type,
|
||||
) -> IO.NodeOutput:
|
||||
files = []
|
||||
total_size = 0
|
||||
max_total_size = 5 * 1024 * 1024 # 5 MB limit
|
||||
for i, img in enumerate(list(images.values())):
|
||||
file_bytes = tensor_to_bytesio(img, total_pixels=2048 * 2048, mime_type="image/webp").read()
|
||||
total_size += len(file_bytes)
|
||||
if total_size > max_total_size:
|
||||
raise Exception("Total size of all images exceeds 5 MB limit.")
|
||||
files.append((f"file{i + 1}", file_bytes))
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/recraft/styles", method="POST"),
|
||||
response_model=RecraftCreateStyleResponse,
|
||||
files=files,
|
||||
data=RecraftCreateStyleRequest(style=style),
|
||||
content_type="multipart/form-data",
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
return IO.NodeOutput(response.id)
|
||||
|
||||
|
||||
class RecraftTextToImageNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -395,7 +466,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
|
||||
negative_prompt: str = None,
|
||||
recraft_controls: RecraftControls = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, max_length=1000)
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=1000)
|
||||
default_style = RecraftStyle(RecraftStyleV3.realistic_image)
|
||||
if recraft_style is None:
|
||||
recraft_style = default_style
|
||||
@ -1024,6 +1095,7 @@ class RecraftExtension(ComfyExtension):
|
||||
RecraftStyleV3DigitalIllustrationNode,
|
||||
RecraftStyleV3LogoRasterNode,
|
||||
RecraftStyleInfiniteStyleLibrary,
|
||||
RecraftCreateStyleNode,
|
||||
RecraftColorRGBNode,
|
||||
RecraftControlsNode,
|
||||
]
|
||||
|
||||
42
comfy_extras/nodes_color.py
Normal file
42
comfy_extras/nodes_color.py
Normal file
@ -0,0 +1,42 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class ColorToRGBInt(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="ColorToRGBInt",
|
||||
display_name="Color to RGB Int",
|
||||
category="utils",
|
||||
description="Convert a color to a RGB integer value.",
|
||||
inputs=[
|
||||
io.Color.Input("color"),
|
||||
],
|
||||
outputs=[
|
||||
io.Int.Output(display_name="rgb_int"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
color: str,
|
||||
) -> io.NodeOutput:
|
||||
# expect format #RRGGBB
|
||||
if len(color) != 7 or color[0] != "#":
|
||||
raise ValueError("Color must be in format #RRGGBB")
|
||||
r = int(color[1:3], 16)
|
||||
g = int(color[3:5], 16)
|
||||
b = int(color[5:7], 16)
|
||||
return io.NodeOutput(r * 256 * 256 + g * 256 + b)
|
||||
|
||||
|
||||
class ColorExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [ColorToRGBInt]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ColorExtension:
|
||||
return ColorExtension()
|
||||
@ -56,7 +56,7 @@ class EmptyHunyuanLatentVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples":latent})
|
||||
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
|
||||
@ -73,7 +73,7 @@ class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo):
|
||||
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
|
||||
# Using scale factor of 16 instead of 8
|
||||
latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent})
|
||||
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 16})
|
||||
|
||||
|
||||
class HunyuanVideo15ImageToVideo(io.ComfyNode):
|
||||
|
||||
3
nodes.py
3
nodes.py
@ -2432,7 +2432,8 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_wanmove.py",
|
||||
"nodes_image_compare.py",
|
||||
"nodes_zimage.py",
|
||||
"nodes_lora_debug.py"
|
||||
"nodes_lora_debug.py",
|
||||
"nodes_color.py"
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@ -1,271 +0,0 @@
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator, Optional
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
"""
|
||||
Allow overriding the database URL used by the spawned ComfyUI process.
|
||||
Priority:
|
||||
1) --db-url command line option
|
||||
2) ASSETS_TEST_DB_URL environment variable (used by CI)
|
||||
3) default: None (will use file-backed sqlite in temp dir)
|
||||
"""
|
||||
parser.addoption(
|
||||
"--db-url",
|
||||
action="store",
|
||||
default=os.environ.get("ASSETS_TEST_DB_URL"),
|
||||
help="SQLAlchemy DB URL (e.g. sqlite:///path/to/db.sqlite3)",
|
||||
)
|
||||
|
||||
|
||||
def _free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _make_base_dirs(root: Path) -> None:
|
||||
for sub in ("models", "custom_nodes", "input", "output", "temp", "user"):
|
||||
(root / sub).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _wait_http_ready(base: str, session: requests.Session, timeout: float = 90.0) -> None:
|
||||
start = time.time()
|
||||
last_err = None
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
r = session.get(base + "/api/assets", timeout=5)
|
||||
if r.status_code in (200, 400):
|
||||
return
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
time.sleep(0.25)
|
||||
raise RuntimeError(f"ComfyUI HTTP did not become ready: {last_err}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def comfy_tmp_base_dir() -> Path:
|
||||
env_base = os.environ.get("ASSETS_TEST_BASE_DIR")
|
||||
created_by_fixture = False
|
||||
if env_base:
|
||||
tmp = Path(env_base)
|
||||
tmp.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-"))
|
||||
created_by_fixture = True
|
||||
_make_base_dirs(tmp)
|
||||
yield tmp
|
||||
if created_by_fixture:
|
||||
with contextlib.suppress(Exception):
|
||||
for p in sorted(tmp.rglob("*"), reverse=True):
|
||||
if p.is_file() or p.is_symlink():
|
||||
p.unlink(missing_ok=True)
|
||||
for p in sorted(tmp.glob("**/*"), reverse=True):
|
||||
with contextlib.suppress(Exception):
|
||||
p.rmdir()
|
||||
tmp.rmdir()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest):
|
||||
"""
|
||||
Boot ComfyUI subprocess with:
|
||||
- sandbox base dir
|
||||
- file-backed sqlite DB in temp dir
|
||||
- autoscan disabled
|
||||
Returns (base_url, process, port)
|
||||
"""
|
||||
port = _free_port()
|
||||
db_url = request.config.getoption("--db-url")
|
||||
if not db_url:
|
||||
# Use a file-backed sqlite database in the temp directory
|
||||
db_path = comfy_tmp_base_dir / "assets-test.sqlite3"
|
||||
db_url = f"sqlite:///{db_path}"
|
||||
|
||||
logs_dir = comfy_tmp_base_dir / "logs"
|
||||
logs_dir.mkdir(exist_ok=True)
|
||||
out_log = open(logs_dir / "stdout.log", "w", buffering=1)
|
||||
err_log = open(logs_dir / "stderr.log", "w", buffering=1)
|
||||
|
||||
comfy_root = Path(__file__).resolve().parent.parent
|
||||
if not (comfy_root / "main.py").is_file():
|
||||
raise FileNotFoundError(f"main.py not found under {comfy_root}")
|
||||
|
||||
proc = subprocess.Popen(
|
||||
args=[
|
||||
sys.executable,
|
||||
"main.py",
|
||||
f"--base-directory={str(comfy_tmp_base_dir)}",
|
||||
f"--database-url={db_url}",
|
||||
"--disable-assets-autoscan",
|
||||
"--listen",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
str(port),
|
||||
"--cpu",
|
||||
],
|
||||
stdout=out_log,
|
||||
stderr=err_log,
|
||||
cwd=str(comfy_root),
|
||||
env={**os.environ},
|
||||
)
|
||||
|
||||
for _ in range(50):
|
||||
if proc.poll() is not None:
|
||||
out_log.flush()
|
||||
err_log.flush()
|
||||
raise RuntimeError(f"ComfyUI exited early with code {proc.returncode}")
|
||||
time.sleep(0.1)
|
||||
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
try:
|
||||
with requests.Session() as s:
|
||||
_wait_http_ready(base_url, s, timeout=90.0)
|
||||
yield base_url, proc, port
|
||||
except Exception as e:
|
||||
with contextlib.suppress(Exception):
|
||||
proc.terminate()
|
||||
proc.wait(timeout=10)
|
||||
with contextlib.suppress(Exception):
|
||||
out_log.flush()
|
||||
err_log.flush()
|
||||
raise RuntimeError(f"ComfyUI did not become ready: {e}")
|
||||
|
||||
if proc and proc.poll() is None:
|
||||
with contextlib.suppress(Exception):
|
||||
proc.terminate()
|
||||
proc.wait(timeout=15)
|
||||
out_log.close()
|
||||
err_log.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def http() -> Iterator[requests.Session]:
|
||||
with requests.Session() as s:
|
||||
s.timeout = 120
|
||||
yield s
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_base(comfy_url_and_proc) -> str:
|
||||
base_url, _proc, _port = comfy_url_and_proc
|
||||
return base_url
|
||||
|
||||
|
||||
def _post_multipart_asset(
|
||||
session: requests.Session,
|
||||
base: str,
|
||||
*,
|
||||
name: str,
|
||||
tags: list[str],
|
||||
meta: dict,
|
||||
data: bytes,
|
||||
extra_fields: Optional[dict] = None,
|
||||
) -> tuple[int, dict]:
|
||||
files = {"file": (name, data, "application/octet-stream")}
|
||||
form_data = {
|
||||
"tags": json.dumps(tags),
|
||||
"name": name,
|
||||
"user_metadata": json.dumps(meta),
|
||||
}
|
||||
if extra_fields:
|
||||
for k, v in extra_fields.items():
|
||||
form_data[k] = v
|
||||
r = session.post(base + "/api/assets", files=files, data=form_data, timeout=120)
|
||||
return r.status_code, r.json()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_asset_bytes() -> Callable[[str, int], bytes]:
|
||||
def _make(name: str, size: int = 8192) -> bytes:
|
||||
seed = sum(ord(c) for c in name) % 251
|
||||
return bytes((i * 31 + seed) % 256 for i in range(size))
|
||||
return _make
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def asset_factory(http: requests.Session, api_base: str):
|
||||
"""
|
||||
Returns create(name, tags, meta, data) -> response dict
|
||||
Tracks created ids and deletes them after the test.
|
||||
"""
|
||||
created: list[str] = []
|
||||
|
||||
def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict:
|
||||
status, body = _post_multipart_asset(http, api_base, name=name, tags=tags, meta=meta, data=data)
|
||||
assert status in (200, 201), body
|
||||
created.append(body["id"])
|
||||
return body
|
||||
|
||||
yield create
|
||||
|
||||
for aid in created:
|
||||
with contextlib.suppress(Exception):
|
||||
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_base: str) -> dict:
|
||||
"""
|
||||
Upload one asset with ".safetensors" extension into models/checkpoints/unit-tests/<name>.
|
||||
Returns response dict with id, asset_hash, tags, etc.
|
||||
"""
|
||||
name = "unit_1_example.safetensors"
|
||||
p = getattr(request, "param", {}) or {}
|
||||
tags: Optional[list[str]] = p.get("tags")
|
||||
if tags is None:
|
||||
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||
meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None}
|
||||
files = {"file": (name, b"A" * 4096, "application/octet-stream")}
|
||||
form_data = {
|
||||
"tags": json.dumps(tags),
|
||||
"name": name,
|
||||
"user_metadata": json.dumps(meta),
|
||||
}
|
||||
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 201, body
|
||||
return body
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def autoclean_unit_test_assets(http: requests.Session, api_base: str):
|
||||
"""Ensure isolation by removing all AssetInfo rows tagged with 'unit-tests' after each test."""
|
||||
yield
|
||||
|
||||
while True:
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests", "limit": "500", "sort": "name"},
|
||||
timeout=30,
|
||||
)
|
||||
if r.status_code != 200:
|
||||
break
|
||||
body = r.json()
|
||||
ids = [a["id"] for a in body.get("assets", [])]
|
||||
if not ids:
|
||||
break
|
||||
for aid in ids:
|
||||
with contextlib.suppress(Exception):
|
||||
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
|
||||
|
||||
|
||||
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
|
||||
"""Force a fast sync/seed pass by calling the seed endpoint."""
|
||||
session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30)
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
def get_asset_filename(asset_hash: str, extension: str) -> str:
|
||||
return asset_hash.removeprefix("blake3:") + extension
|
||||
@ -1,2 +0,0 @@
|
||||
pytest>=7.8.0
|
||||
blake3
|
||||
@ -1,348 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_seed_asset_removed_when_file_is_deleted(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
"""Asset without hash (seed) whose file disappears:
|
||||
after triggering sync_seed_assets, Asset + AssetInfo disappear.
|
||||
"""
|
||||
# Create a file directly under input/unit-tests/<case> so tags include "unit-tests"
|
||||
case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed"
|
||||
case_dir.mkdir(parents=True, exist_ok=True)
|
||||
name = f"seed_{uuid.uuid4().hex[:8]}.bin"
|
||||
fp = case_dir / name
|
||||
fp.write_bytes(b"Z" * 2048)
|
||||
|
||||
# Trigger a seed sync so DB sees this path (seed asset => hash is NULL)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Verify it is visible via API and carries no hash (seed)
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
# there should be exactly one with that name
|
||||
matches = [a for a in body1.get("assets", []) if a.get("name") == name]
|
||||
assert matches
|
||||
assert matches[0].get("asset_hash") is None
|
||||
asset_info_id = matches[0]["id"]
|
||||
|
||||
# Remove the underlying file and sync again
|
||||
if fp.exists():
|
||||
fp.unlink()
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# It should disappear (AssetInfo and seed Asset gone)
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
matches2 = [a for a in body2.get("assets", []) if a.get("name") == name]
|
||||
assert not matches2, f"Seed asset {asset_info_id} should be gone after sync"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires computing hashes of files in directories to verify and clear missing tags")
|
||||
def test_hashed_asset_missing_tag_added_then_removed_after_scan(
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""Hashed asset with a single cache_state:
|
||||
1. delete its file -> sync adds 'missing'
|
||||
2. restore file -> sync removes 'missing'
|
||||
"""
|
||||
name = "missing_tag_test.png"
|
||||
tags = ["input", "unit-tests", "msync2"]
|
||||
data = make_asset_bytes(name, 4096)
|
||||
a = asset_factory(name, tags, {}, data)
|
||||
|
||||
# Compute its on-disk path and remove it
|
||||
dest = comfy_tmp_base_dir / "input" / "unit-tests" / "msync2" / get_asset_filename(a["asset_hash"], ".png")
|
||||
assert dest.exists(), f"Expected asset file at {dest}"
|
||||
dest.unlink()
|
||||
|
||||
# Fast sync should add 'missing' to the AssetInfo
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
g1 = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120)
|
||||
d1 = g1.json()
|
||||
assert g1.status_code == 200, d1
|
||||
assert "missing" in set(d1.get("tags", [])), "Expected 'missing' tag after deletion"
|
||||
|
||||
# Restore the file with the exact same content and sync again
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_bytes(data)
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
g2 = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120)
|
||||
d2 = g2.json()
|
||||
assert g2.status_code == 200, d2
|
||||
assert "missing" not in set(d2.get("tags", [])), "Missing tag should be cleared after verify"
|
||||
|
||||
|
||||
def test_hashed_asset_two_asset_infos_both_get_missing(
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
):
|
||||
"""Hashed asset with a single cache_state, but two AssetInfo rows:
|
||||
deleting the single file then syncing should add 'missing' to both infos.
|
||||
"""
|
||||
# Upload one hashed asset
|
||||
name = "two_infos_one_path.png"
|
||||
base_tags = ["input", "unit-tests", "multiinfo"]
|
||||
created = asset_factory(name, base_tags, {}, b"A" * 2048)
|
||||
|
||||
# Create second AssetInfo for the same Asset via from-hash
|
||||
payload = {
|
||||
"hash": created["asset_hash"],
|
||||
"name": "two_infos_one_path_copy.png",
|
||||
"tags": base_tags, # keep it in our unit-tests scope for cleanup
|
||||
"user_metadata": {"k": "v"},
|
||||
}
|
||||
r2 = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 201, b2
|
||||
second_id = b2["id"]
|
||||
|
||||
# Remove the single underlying file
|
||||
p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png")
|
||||
assert p.exists()
|
||||
p.unlink()
|
||||
|
||||
r0 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120)
|
||||
tags0 = r0.json()
|
||||
assert r0.status_code == 200, tags0
|
||||
byname0 = {t["name"]: t for t in tags0.get("tags", [])}
|
||||
old_missing = int(byname0.get("missing", {}).get("count", 0))
|
||||
|
||||
# Sync -> both AssetInfos for this asset must receive 'missing'
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
ga = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
|
||||
da = ga.json()
|
||||
assert ga.status_code == 200, da
|
||||
assert "missing" in set(da.get("tags", []))
|
||||
|
||||
gb = http.get(f"{api_base}/api/assets/{second_id}", timeout=120)
|
||||
db = gb.json()
|
||||
assert gb.status_code == 200, db
|
||||
assert "missing" in set(db.get("tags", []))
|
||||
|
||||
# Tag usage for 'missing' increased by exactly 2 (two AssetInfos)
|
||||
r1 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120)
|
||||
tags1 = r1.json()
|
||||
assert r1.status_code == 200, tags1
|
||||
byname1 = {t["name"]: t for t in tags1.get("tags", [])}
|
||||
new_missing = int(byname1.get("missing", {}).get("count", 0))
|
||||
assert new_missing == old_missing + 2
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states")
|
||||
def test_hashed_asset_two_cache_states_partial_delete_then_full_delete(
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""Hashed asset with two cache_state rows:
|
||||
1. delete one file -> sync should NOT add 'missing'
|
||||
2. delete second file -> sync should add 'missing'
|
||||
"""
|
||||
name = "two_cache_states_partial_delete.png"
|
||||
tags = ["input", "unit-tests", "dual"]
|
||||
data = make_asset_bytes(name, 3072)
|
||||
|
||||
created = asset_factory(name, tags, {}, data)
|
||||
path1 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual" / get_asset_filename(created["asset_hash"], ".png")
|
||||
assert path1.exists()
|
||||
|
||||
# Create a second on-disk copy under the same root but different subfolder
|
||||
path2 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual_copy" / name
|
||||
path2.parent.mkdir(parents=True, exist_ok=True)
|
||||
path2.write_bytes(data)
|
||||
|
||||
# Fast seed so the second path appears (as a seed initially)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Deduplication of AssetInfo-s will not happen as first AssetInfo has owner='default' and second has empty owner.
|
||||
run_scan_and_wait("input")
|
||||
|
||||
# Remove only one file and sync -> asset should still be healthy (no 'missing')
|
||||
path1.unlink()
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
g1 = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
|
||||
d1 = g1.json()
|
||||
assert g1.status_code == 200, d1
|
||||
assert "missing" not in set(d1.get("tags", [])), "Should not be missing while one valid path remains"
|
||||
|
||||
# Baseline 'missing' usage count just before last file removal
|
||||
r0 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120)
|
||||
tags0 = r0.json()
|
||||
assert r0.status_code == 200, tags0
|
||||
old_missing = int({t["name"]: t for t in tags0.get("tags", [])}.get("missing", {}).get("count", 0))
|
||||
|
||||
# Remove the second (last) file and sync -> now we expect 'missing' on this AssetInfo
|
||||
path2.unlink()
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
g2 = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
|
||||
d2 = g2.json()
|
||||
assert g2.status_code == 200, d2
|
||||
assert "missing" in set(d2.get("tags", [])), "Missing must be set once no valid paths remain"
|
||||
|
||||
# Tag usage for 'missing' increased by exactly 2 (two AssetInfo for one Asset)
|
||||
r1 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120)
|
||||
tags1 = r1.json()
|
||||
assert r1.status_code == 200, tags1
|
||||
new_missing = int({t["name"]: t for t in tags1.get("tags", [])}.get("missing", {}).get("count", 0))
|
||||
assert new_missing == old_missing + 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""
|
||||
Fast pass alone clears 'missing' when size and mtime match exactly:
|
||||
1) upload (hashed), record original mtime_ns
|
||||
2) delete -> fast pass adds 'missing'
|
||||
3) restore same bytes and set mtime back to the original value
|
||||
4) run fast pass again -> 'missing' is removed (no slow scan)
|
||||
"""
|
||||
scope = f"fastclear-{uuid.uuid4().hex[:6]}"
|
||||
name = "fastpass_clear.bin"
|
||||
data = make_asset_bytes(name, 3072)
|
||||
|
||||
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = a["id"]
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
p = base / get_asset_filename(a["asset_hash"], ".bin")
|
||||
st0 = p.stat()
|
||||
orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000))
|
||||
|
||||
# Delete -> fast pass adds 'missing'
|
||||
p.unlink()
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
d1 = g1.json()
|
||||
assert g1.status_code == 200, d1
|
||||
assert "missing" in set(d1.get("tags", []))
|
||||
|
||||
# Restore same bytes and revert mtime to the original value
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
p.write_bytes(data)
|
||||
# set both atime and mtime in ns to ensure exact match
|
||||
os.utime(p, ns=(orig_mtime_ns, orig_mtime_ns))
|
||||
|
||||
# Fast pass should clear 'missing' without a scan
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
g2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
d2 = g2.json()
|
||||
assert g2.status_code == 200, d2
|
||||
assert "missing" not in set(d2.get("tags", [])), "Fast pass should clear 'missing' when size+mtime match"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states")
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_fastpass_removes_stale_state_row_no_missing(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
Hashed asset with two states:
|
||||
- delete one file
|
||||
- run fast pass only
|
||||
Expect:
|
||||
- asset stays healthy (no 'missing')
|
||||
- stale AssetCacheState row for the deleted path is removed.
|
||||
We verify this behaviorally by recreating the deleted path and running fast pass again:
|
||||
a new *seed* AssetInfo is created, which proves the old state row was not reused.
|
||||
"""
|
||||
scope = f"stale-{uuid.uuid4().hex[:6]}"
|
||||
name = "two_states.bin"
|
||||
data = make_asset_bytes(name, 2048)
|
||||
|
||||
# Upload hashed asset at path1
|
||||
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
a1_filename = get_asset_filename(a["asset_hash"], ".bin")
|
||||
p1 = base / a1_filename
|
||||
assert p1.exists()
|
||||
|
||||
aid = a["id"]
|
||||
h = a["asset_hash"]
|
||||
|
||||
# Create second state path2, seed+scan to dedupe into the same Asset
|
||||
p2 = base / "copy" / name
|
||||
p2.parent.mkdir(parents=True, exist_ok=True)
|
||||
p2.write_bytes(data)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
run_scan_and_wait(root)
|
||||
|
||||
# Delete path1 and run fast pass -> no 'missing' and stale state row should be removed
|
||||
p1.unlink()
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
d1 = g1.json()
|
||||
assert g1.status_code == 200, d1
|
||||
assert "missing" not in set(d1.get("tags", []))
|
||||
|
||||
# Recreate path1 and run fast pass again.
|
||||
# If the stale state row was removed, a NEW seed AssetInfo will appear for this path.
|
||||
p1.write_bytes(data)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
rl = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}"},
|
||||
timeout=120,
|
||||
)
|
||||
bl = rl.json()
|
||||
assert rl.status_code == 200, bl
|
||||
items = bl.get("assets", [])
|
||||
# one hashed AssetInfo (asset_hash == h) + one seed AssetInfo (asset_hash == null)
|
||||
hashes = [it.get("asset_hash") for it in items if it.get("name") in (name, a1_filename)]
|
||||
assert h in hashes
|
||||
assert any(x is None for x in hashes), "Expected a new seed AssetInfo for the recreated path"
|
||||
|
||||
# Asset identity still healthy
|
||||
rh = http.head(f"{api_base}/api/assets/hash/{h}", timeout=120)
|
||||
assert rh.status_code == 200
|
||||
@ -1,306 +0,0 @@
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
def test_create_from_hash_success(
|
||||
http: requests.Session, api_base: str, seeded_asset: dict
|
||||
):
|
||||
h = seeded_asset["asset_hash"]
|
||||
payload = {
|
||||
"hash": h,
|
||||
"name": "from_hash_ok.safetensors",
|
||||
"tags": ["models", "checkpoints", "unit-tests", "from-hash"],
|
||||
"user_metadata": {"k": "v"},
|
||||
}
|
||||
r1 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 201, b1
|
||||
assert b1["asset_hash"] == h
|
||||
assert b1["created_new"] is False
|
||||
aid = b1["id"]
|
||||
|
||||
# Calling again with the same name should return the same AssetInfo id
|
||||
r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 201, b2
|
||||
assert b2["id"] == aid
|
||||
|
||||
|
||||
def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# GET detail
|
||||
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
detail = rg.json()
|
||||
assert rg.status_code == 200, detail
|
||||
assert detail["id"] == aid
|
||||
assert "user_metadata" in detail
|
||||
assert "filename" in detail["user_metadata"]
|
||||
|
||||
# DELETE
|
||||
rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
assert rd.status_code == 204
|
||||
|
||||
# GET again -> 404
|
||||
rg2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
body = rg2.json()
|
||||
assert rg2.status_code == 404
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
|
||||
|
||||
def test_delete_upon_reference_count(
|
||||
http: requests.Session, api_base: str, seeded_asset: dict
|
||||
):
|
||||
# Create a second reference to the same asset via from-hash
|
||||
src_hash = seeded_asset["asset_hash"]
|
||||
payload = {
|
||||
"hash": src_hash,
|
||||
"name": "unit_ref_copy.safetensors",
|
||||
"tags": ["models", "checkpoints", "unit-tests", "del-flow"],
|
||||
"user_metadata": {"note": "copy"},
|
||||
}
|
||||
r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
|
||||
copy = r2.json()
|
||||
assert r2.status_code == 201, copy
|
||||
assert copy["asset_hash"] == src_hash
|
||||
assert copy["created_new"] is False
|
||||
|
||||
# Delete original reference -> asset identity must remain
|
||||
aid1 = seeded_asset["id"]
|
||||
rd1 = http.delete(f"{api_base}/api/assets/{aid1}", timeout=120)
|
||||
assert rd1.status_code == 204
|
||||
|
||||
rh1 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120)
|
||||
assert rh1.status_code == 200 # identity still present
|
||||
|
||||
# Delete the last reference with default semantics -> identity and cached files removed
|
||||
aid2 = copy["id"]
|
||||
rd2 = http.delete(f"{api_base}/api/assets/{aid2}", timeout=120)
|
||||
assert rd2.status_code == 204
|
||||
|
||||
rh2 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120)
|
||||
assert rh2.status_code == 404 # orphan content removed
|
||||
|
||||
|
||||
def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
original_tags = seeded_asset["tags"]
|
||||
|
||||
payload = {
|
||||
"name": "unit_1_renamed.safetensors",
|
||||
"user_metadata": {"purpose": "updated", "epoch": 2},
|
||||
}
|
||||
ru = http.put(f"{api_base}/api/assets/{aid}", json=payload, timeout=120)
|
||||
body = ru.json()
|
||||
assert ru.status_code == 200, body
|
||||
assert body["name"] == payload["name"]
|
||||
assert body["tags"] == original_tags # tags unchanged
|
||||
assert body["user_metadata"]["purpose"] == "updated"
|
||||
# filename should still be present and normalized by server
|
||||
assert "filename" in body["user_metadata"]
|
||||
|
||||
|
||||
def test_head_asset_by_hash(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
h = seeded_asset["asset_hash"]
|
||||
|
||||
# Existing
|
||||
rh1 = http.head(f"{api_base}/api/assets/hash/{h}", timeout=120)
|
||||
assert rh1.status_code == 200
|
||||
|
||||
# Non-existent
|
||||
rh2 = http.head(f"{api_base}/api/assets/hash/blake3:{'0'*64}", timeout=120)
|
||||
assert rh2.status_code == 404
|
||||
|
||||
|
||||
def test_head_asset_bad_hash_returns_400_and_no_body(http: requests.Session, api_base: str):
|
||||
# Invalid format; handler returns a JSON error, but HEAD responses must not carry a payload.
|
||||
# requests exposes an empty body for HEAD, so validate status and that there is no payload.
|
||||
rh = http.head(f"{api_base}/api/assets/hash/not_a_hash", timeout=120)
|
||||
assert rh.status_code == 400
|
||||
body = rh.content
|
||||
assert body == b""
|
||||
|
||||
|
||||
def test_delete_nonexistent_returns_404(http: requests.Session, api_base: str):
|
||||
bogus = str(uuid.uuid4())
|
||||
r = http.delete(f"{api_base}/api/assets/{bogus}", timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 404
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
|
||||
|
||||
def test_create_from_hash_invalids(http: requests.Session, api_base: str):
|
||||
# Bad hash algorithm
|
||||
bad = {
|
||||
"hash": "sha256:" + "0" * 64,
|
||||
"name": "x.bin",
|
||||
"tags": ["models", "checkpoints", "unit-tests"],
|
||||
}
|
||||
r1 = http.post(f"{api_base}/api/assets/from-hash", json=bad, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 400
|
||||
assert b1["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
# Invalid JSON body
|
||||
r2 = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 400
|
||||
assert b2["error"]["code"] == "INVALID_JSON"
|
||||
|
||||
|
||||
def test_get_update_download_bad_ids(http: requests.Session, api_base: str):
|
||||
# All endpoints should be not found, as we UUID regex directly in the route definition.
|
||||
bad_id = "not-a-uuid"
|
||||
|
||||
r1 = http.get(f"{api_base}/api/assets/{bad_id}", timeout=120)
|
||||
assert r1.status_code == 404
|
||||
|
||||
r3 = http.get(f"{api_base}/api/assets/{bad_id}/content", timeout=120)
|
||||
assert r3.status_code == 404
|
||||
|
||||
|
||||
def test_update_requires_at_least_one_field(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
r = http.put(f"{api_base}/api/assets/{aid}", json={}, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_concurrent_delete_same_asset_info_single_204(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""
|
||||
Many concurrent DELETE for the same AssetInfo should result in:
|
||||
- exactly one 204 No Content (the one that actually deleted)
|
||||
- all others 404 Not Found (row already gone)
|
||||
"""
|
||||
scope = f"conc-del-{uuid.uuid4().hex[:6]}"
|
||||
name = "to_delete.bin"
|
||||
data = make_asset_bytes(name, 1536)
|
||||
|
||||
created = asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = created["id"]
|
||||
|
||||
# Hit the same endpoint N times in parallel.
|
||||
n_tests = 4
|
||||
url = f"{api_base}/api/assets/{aid}?delete_content=false"
|
||||
|
||||
def _do_delete(delete_url):
|
||||
with requests.Session() as s:
|
||||
return s.delete(delete_url, timeout=120).status_code
|
||||
|
||||
with ThreadPoolExecutor(max_workers=n_tests) as ex:
|
||||
statuses = list(ex.map(_do_delete, [url] * n_tests))
|
||||
|
||||
# Exactly one actual delete, the rest must be 404
|
||||
assert statuses.count(204) == 1, f"Expected exactly one 204; got: {statuses}"
|
||||
assert statuses.count(404) == n_tests - 1, f"Expected {n_tests-1} 404; got: {statuses}"
|
||||
|
||||
# The resource must be gone.
|
||||
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
assert rg.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_metadata_filename_is_set_for_seed_asset_without_hash(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
"""Seed ingest (no hash yet) must compute user_metadata['filename'] immediately."""
|
||||
scope = f"seedmeta-{uuid.uuid4().hex[:6]}"
|
||||
name = "seed_filename.bin"
|
||||
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope / "a" / "b"
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
fp = base / name
|
||||
fp.write_bytes(b"Z" * 2048)
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body = r1.json()
|
||||
assert r1.status_code == 200, body
|
||||
matches = [a for a in body.get("assets", []) if a.get("name") == name]
|
||||
assert matches, "Seed asset should be visible after sync"
|
||||
assert matches[0].get("asset_hash") is None # still a seed
|
||||
aid = matches[0]["id"]
|
||||
|
||||
r2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
detail = r2.json()
|
||||
assert r2.status_code == 200, detail
|
||||
filename = (detail.get("user_metadata") or {}).get("filename")
|
||||
expected = str(fp.relative_to(comfy_tmp_base_dir / root)).replace("\\", "/")
|
||||
assert filename == expected, f"expected filename={expected}, got {filename!r}"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires computing hashes of files in directories to retarget cache states")
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_metadata_filename_computed_and_updated_on_retarget(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
1) Ingest under {root}/unit-tests/<scope>/a/b/<name> -> filename reflects relative path.
|
||||
2) Retarget by copying to {root}/unit-tests/<scope>/x/<new_name>, remove old file,
|
||||
run fast pass + scan -> filename updates to new relative path.
|
||||
"""
|
||||
scope = f"meta-fn-{uuid.uuid4().hex[:6]}"
|
||||
name1 = "compute_metadata_filename.png"
|
||||
name2 = "compute_changed_metadata_filename.png"
|
||||
data = make_asset_bytes(name1, 2100)
|
||||
|
||||
# Upload into nested path a/b
|
||||
a = asset_factory(name1, [root, "unit-tests", scope, "a", "b"], {}, data)
|
||||
aid = a["id"]
|
||||
|
||||
root_base = comfy_tmp_base_dir / root
|
||||
p1 = (root_base / "unit-tests" / scope / "a" / "b" / get_asset_filename(a["asset_hash"], ".png"))
|
||||
assert p1.exists()
|
||||
|
||||
# filename at ingest should be the path relative to root
|
||||
rel1 = str(p1.relative_to(root_base)).replace("\\", "/")
|
||||
g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
d1 = g1.json()
|
||||
assert g1.status_code == 200, d1
|
||||
fn1 = d1["user_metadata"].get("filename")
|
||||
assert fn1 == rel1
|
||||
|
||||
# Retarget: copy to x/<name2>, remove old, then sync+scan
|
||||
p2 = root_base / "unit-tests" / scope / "x" / name2
|
||||
p2.parent.mkdir(parents=True, exist_ok=True)
|
||||
p2.write_bytes(data)
|
||||
if p1.exists():
|
||||
p1.unlink()
|
||||
|
||||
trigger_sync_seed_assets(http, api_base) # seed the new path
|
||||
run_scan_and_wait(root) # verify/hash and reconcile
|
||||
|
||||
# filename should now point at x/<name2>
|
||||
rel2 = str(p2.relative_to(root_base)).replace("\\", "/")
|
||||
g2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
d2 = g2.json()
|
||||
assert g2.status_code == 200, d2
|
||||
fn2 = d2["user_metadata"].get("filename")
|
||||
assert fn2 == rel2
|
||||
@ -1,166 +0,0 @@
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# default attachment
|
||||
r1 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
|
||||
data = r1.content
|
||||
assert r1.status_code == 200
|
||||
cd = r1.headers.get("Content-Disposition", "")
|
||||
assert "attachment" in cd
|
||||
assert data and len(data) == 4096
|
||||
|
||||
# inline requested
|
||||
r2 = http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline", timeout=120)
|
||||
r2.content
|
||||
assert r2.status_code == 200
|
||||
cd2 = r2.headers.get("Content-Disposition", "")
|
||||
assert "inline" in cd2
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states")
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_download_chooses_existing_state_and_updates_access_time(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
Hashed asset with two state paths: if the first one disappears,
|
||||
GET /content still serves from the remaining path and bumps last_access_time.
|
||||
"""
|
||||
scope = f"dl-first-{uuid.uuid4().hex[:6]}"
|
||||
name = "first_existing_state.bin"
|
||||
data = make_asset_bytes(name, 3072)
|
||||
|
||||
# Upload -> path1
|
||||
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = a["id"]
|
||||
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
path1 = base / get_asset_filename(a["asset_hash"], ".bin")
|
||||
assert path1.exists()
|
||||
|
||||
# Seed path2 by copying, then scan to dedupe into a second state
|
||||
path2 = base / "alt" / name
|
||||
path2.parent.mkdir(parents=True, exist_ok=True)
|
||||
path2.write_bytes(data)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
run_scan_and_wait(root)
|
||||
|
||||
# Remove path1 so server must fall back to path2
|
||||
path1.unlink()
|
||||
|
||||
# last_access_time before
|
||||
rg0 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
d0 = rg0.json()
|
||||
assert rg0.status_code == 200, d0
|
||||
ts0 = d0.get("last_access_time")
|
||||
|
||||
time.sleep(0.05)
|
||||
r = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
|
||||
blob = r.content
|
||||
assert r.status_code == 200
|
||||
assert blob == data # must serve from the surviving state (same bytes)
|
||||
|
||||
rg1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
d1 = rg1.json()
|
||||
assert rg1.status_code == 200, d1
|
||||
ts1 = d1.get("last_access_time")
|
||||
|
||||
def _parse_iso8601(s: Optional[str]) -> Optional[float]:
|
||||
if not s:
|
||||
return None
|
||||
s = s[:-1] if s.endswith("Z") else s
|
||||
return datetime.fromisoformat(s).timestamp()
|
||||
|
||||
t0 = _parse_iso8601(ts0)
|
||||
t1 = _parse_iso8601(ts1)
|
||||
assert t1 is not None
|
||||
if t0 is not None:
|
||||
assert t1 > t0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True)
|
||||
def test_download_missing_file_returns_404(
|
||||
http: requests.Session, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict
|
||||
):
|
||||
# Remove the underlying file then attempt download.
|
||||
# We initialize fixture without additional tags to know exactly the asset file path.
|
||||
try:
|
||||
aid = seeded_asset["id"]
|
||||
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
detail = rg.json()
|
||||
assert rg.status_code == 200
|
||||
asset_filename = get_asset_filename(detail["asset_hash"], ".safetensors")
|
||||
abs_path = comfy_tmp_base_dir / "models" / "checkpoints" / asset_filename
|
||||
assert abs_path.exists()
|
||||
abs_path.unlink()
|
||||
|
||||
r2 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
|
||||
assert r2.status_code == 404
|
||||
body = r2.json()
|
||||
assert body["error"]["code"] == "FILE_NOT_FOUND"
|
||||
finally:
|
||||
# We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually.
|
||||
dr = http.delete(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
dr.content
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states")
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_download_404_if_all_states_missing(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""Multi-state asset: after the last remaining on-disk file is removed, download must return 404."""
|
||||
scope = f"dl-404-{uuid.uuid4().hex[:6]}"
|
||||
name = "missing_all_states.bin"
|
||||
data = make_asset_bytes(name, 2048)
|
||||
|
||||
# Upload -> path1
|
||||
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = a["id"]
|
||||
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
p1 = base / get_asset_filename(a["asset_hash"], ".bin")
|
||||
assert p1.exists()
|
||||
|
||||
# Seed a second state and dedupe
|
||||
p2 = base / "copy" / name
|
||||
p2.parent.mkdir(parents=True, exist_ok=True)
|
||||
p2.write_bytes(data)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
run_scan_and_wait(root)
|
||||
|
||||
# Remove first file -> download should still work via the second state
|
||||
p1.unlink()
|
||||
ok1 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
|
||||
b1 = ok1.content
|
||||
assert ok1.status_code == 200 and b1 == data
|
||||
|
||||
# Remove the last file -> download must 404
|
||||
p2.unlink()
|
||||
r2 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
|
||||
body = r2.json()
|
||||
assert r2.status_code == 404
|
||||
assert body["error"]["code"] == "FILE_NOT_FOUND"
|
||||
@ -1,342 +0,0 @@
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
names = ["a1_u.safetensors", "a2_u.safetensors", "a3_u.safetensors"]
|
||||
for n in names:
|
||||
asset_factory(
|
||||
n,
|
||||
["models", "checkpoints", "unit-tests", "paging"],
|
||||
{"epoch": 1},
|
||||
make_asset_bytes(n, size=2048),
|
||||
)
|
||||
|
||||
# name ascending for stable order
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "0"},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
got1 = [a["name"] for a in b1["assets"]]
|
||||
assert got1 == sorted(names)[:2]
|
||||
assert b1["has_more"] is True
|
||||
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "2"},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
got2 = [a["name"] for a in b2["assets"]]
|
||||
assert got2 == sorted(names)[2:]
|
||||
assert b2["has_more"] is False
|
||||
|
||||
|
||||
def test_list_assets_include_exclude_and_name_contains(http: requests.Session, api_base: str, asset_factory):
|
||||
a = asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024)
|
||||
b = asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024)
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,alpha", "exclude_tags": "beta", "limit": "50"},
|
||||
timeout=120,
|
||||
)
|
||||
body = r.json()
|
||||
assert r.status_code == 200
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert a["name"] in names
|
||||
assert b["name"] not in names
|
||||
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests", "name_contains": "inc_"},
|
||||
timeout=120,
|
||||
)
|
||||
body2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
names2 = [x["name"] for x in body2["assets"]]
|
||||
assert a["name"] in names2
|
||||
assert b["name"] in names2
|
||||
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "non-existing-tag"},
|
||||
timeout=120,
|
||||
)
|
||||
body3 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
assert not body3["assets"]
|
||||
|
||||
|
||||
def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-size"]
|
||||
n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors"
|
||||
asset_factory(n1, t, {}, make_asset_bytes(n1, 1024))
|
||||
asset_factory(n2, t, {}, make_asset_bytes(n2, 2048))
|
||||
asset_factory(n3, t, {}, make_asset_bytes(n3, 3072))
|
||||
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "asc"},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
names = [a["name"] for a in b1["assets"]]
|
||||
assert names[:3] == [n1, n2, n3]
|
||||
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "desc"},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
names2 = [a["name"] for a in b2["assets"]]
|
||||
assert names2[:3] == [n3, n2, n1]
|
||||
|
||||
|
||||
|
||||
def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-upd"]
|
||||
a1 = asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200))
|
||||
a2 = asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200))
|
||||
|
||||
# Rename the second asset to bump updated_at
|
||||
rp = http.put(f"{api_base}/api/assets/{a2['id']}", json={"name": "upd_b_renamed.safetensors"}, timeout=120)
|
||||
upd = rp.json()
|
||||
assert rp.status_code == 200, upd
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-upd", "sort": "updated_at", "order": "desc"},
|
||||
timeout=120,
|
||||
)
|
||||
body = r.json()
|
||||
assert r.status_code == 200
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert names[0] == "upd_b_renamed.safetensors"
|
||||
assert a1["name"] in names
|
||||
|
||||
|
||||
|
||||
def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-access"]
|
||||
asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100))
|
||||
time.sleep(0.02)
|
||||
a2 = asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100))
|
||||
|
||||
# Touch last_access_time of b by downloading its content
|
||||
time.sleep(0.02)
|
||||
dl = http.get(f"{api_base}/api/assets/{a2['id']}/content", timeout=120)
|
||||
assert dl.status_code == 200
|
||||
dl.content
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-access", "sort": "last_access_time", "order": "desc"},
|
||||
timeout=120,
|
||||
)
|
||||
body = r.json()
|
||||
assert r.status_code == 200
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert names[0] == a2["name"]
|
||||
|
||||
|
||||
def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-include"]
|
||||
a = asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva"))
|
||||
asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb"))
|
||||
|
||||
# CSV + case-insensitive
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names1 = [x["name"] for x in b1["assets"]]
|
||||
assert a["name"] in names1
|
||||
assert not any("beta" in x for x in names1)
|
||||
|
||||
# Repeated query params for include_tags
|
||||
params_multi = [
|
||||
("include_tags", "unit-tests"),
|
||||
("include_tags", "lf-include"),
|
||||
("include_tags", "alpha"),
|
||||
]
|
||||
r2 = http.get(api_base + "/api/assets", params=params_multi, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
names2 = [x["name"] for x in b2["assets"]]
|
||||
assert a["name"] in names2
|
||||
assert not any("beta" in x for x in names2)
|
||||
|
||||
# Duplicates and spaces in CSV
|
||||
r3 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": " unit-tests , lf-include , alpha , alpha "},
|
||||
timeout=120,
|
||||
)
|
||||
b3 = r3.json()
|
||||
assert r3.status_code == 200
|
||||
names3 = [x["name"] for x in b3["assets"]]
|
||||
assert a["name"] in names3
|
||||
|
||||
|
||||
def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-exclude"]
|
||||
a = asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900))
|
||||
asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900))
|
||||
|
||||
# Exclude uppercase should work
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names1 = [x["name"] for x in b1["assets"]]
|
||||
assert a["name"] in names1
|
||||
# Repeated excludes with duplicates
|
||||
params_multi = [
|
||||
("include_tags", "unit-tests"),
|
||||
("include_tags", "lf-exclude"),
|
||||
("exclude_tags", "beta"),
|
||||
("exclude_tags", "beta"),
|
||||
]
|
||||
r2 = http.get(api_base + "/api/assets", params=params_multi, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
names2 = [x["name"] for x in b2["assets"]]
|
||||
assert all("beta" not in x for x in names2)
|
||||
|
||||
|
||||
def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-name"]
|
||||
a1 = asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800))
|
||||
a2 = asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800))
|
||||
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-name", "name_contains": "casemix"},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names1 = [x["name"] for x in b1["assets"]]
|
||||
assert a1["name"] in names1
|
||||
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-name", "name_contains": ".SAFE"},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
names2 = [x["name"] for x in b2["assets"]]
|
||||
assert a1["name"] in names2
|
||||
|
||||
r3 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-name", "name_contains": "case-"},
|
||||
timeout=120,
|
||||
)
|
||||
b3 = r3.json()
|
||||
assert r3.status_code == 200
|
||||
names3 = [x["name"] for x in b3["assets"]]
|
||||
assert a2["name"] in names3
|
||||
|
||||
|
||||
def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"]
|
||||
asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600))
|
||||
asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600))
|
||||
asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600))
|
||||
|
||||
# Offset far beyond total
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-pagelimits", "limit": "2", "offset": "10"},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
assert not b1["assets"]
|
||||
assert b1["has_more"] is False
|
||||
|
||||
# Boundary large limit (<=500 is valid)
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-pagelimits", "limit": "500"},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
assert len(b2["assets"]) == 3
|
||||
assert b2["has_more"] is False
|
||||
|
||||
|
||||
def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base):
|
||||
r1 = http.get(api_base + "/api/assets", params={"offset": "-1"}, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 400
|
||||
assert b1["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
r2 = http.get(api_base + "/api/assets", params={"limit": "abc"}, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 400
|
||||
assert b2["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str):
|
||||
# limit too small
|
||||
r1 = http.get(api_base + "/api/assets", params={"limit": "0"}, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 400
|
||||
assert b1["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
# bad metadata JSON
|
||||
r2 = http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 400
|
||||
assert b2["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
def test_list_assets_name_contains_literal_underscore(
|
||||
http,
|
||||
api_base,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""'name_contains' must treat '_' literally, not as a SQL wildcard.
|
||||
We create:
|
||||
- foo_bar.safetensors (should match)
|
||||
- fooxbar.safetensors (must NOT match if '_' is escaped)
|
||||
- foobar.safetensors (must NOT match)
|
||||
"""
|
||||
scope = f"lf-underscore-{uuid.uuid4().hex[:6]}"
|
||||
tags = ["models", "checkpoints", "unit-tests", scope]
|
||||
|
||||
a = asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700))
|
||||
b = asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700))
|
||||
c = asset_factory("foobar.safetensors", tags, {}, make_asset_bytes("c", 700))
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": "foo_bar"},
|
||||
timeout=120,
|
||||
)
|
||||
body = r.json()
|
||||
assert r.status_code == 200, body
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert a["name"] in names, f"Expected literal underscore match to include {a['name']}"
|
||||
assert b["name"] not in names, "Underscore must be escaped — should not match 'fooxbar'"
|
||||
assert c["name"] not in names, "Underscore must be escaped — should not match 'foobar'"
|
||||
assert body["total"] == 1
|
||||
@ -1,395 +0,0 @@
|
||||
import json
|
||||
|
||||
|
||||
def test_meta_and_across_keys_and_types(
|
||||
http, api_base: str, asset_factory, make_asset_bytes
|
||||
):
|
||||
name = "mf_and_mix.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-and"]
|
||||
meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
|
||||
asset_factory(name, tags, meta, make_asset_bytes(name, 4096))
|
||||
|
||||
# All keys must match (AND semantics)
|
||||
f_ok = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-and",
|
||||
"metadata_filter": json.dumps(f_ok),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names = [a["name"] for a in b1["assets"]]
|
||||
assert name in names
|
||||
|
||||
# One key mismatched -> no result
|
||||
f_bad = {"purpose": "mix", "epoch": 2, "active": True}
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-and",
|
||||
"metadata_filter": json.dumps(f_bad),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
assert not b2["assets"]
|
||||
|
||||
|
||||
def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_types.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-types"]
|
||||
meta = {"epoch": 1, "active": True}
|
||||
asset_factory(name, tags, meta, make_asset_bytes(name))
|
||||
|
||||
# int filter matches numeric
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-types",
|
||||
"metadata_filter": json.dumps({"epoch": 1}),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# string "1" must NOT match numeric 1
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-types",
|
||||
"metadata_filter": json.dumps({"epoch": "1"}),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200 and not b2["assets"]
|
||||
|
||||
# bool True matches, string "true" must NOT match
|
||||
r3 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-types",
|
||||
"metadata_filter": json.dumps({"active": True}),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b3 = r3.json()
|
||||
assert r3.status_code == 200 and any(a["name"] == name for a in b3["assets"])
|
||||
|
||||
r4 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-types",
|
||||
"metadata_filter": json.dumps({"active": "true"}),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b4 = r4.json()
|
||||
assert r4.status_code == 200 and not b4["assets"]
|
||||
|
||||
|
||||
def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_list_scalars.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-list"]
|
||||
meta = {"flags": ["red", "green"]}
|
||||
asset_factory(name, tags, meta, make_asset_bytes(name, 3000))
|
||||
|
||||
# Any-of should match because "green" is present
|
||||
filt_ok = {"flags": ["blue", "green"]}
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_ok)},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# None of provided flags present -> no match
|
||||
filt_miss = {"flags": ["blue", "yellow"]}
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_miss)},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200 and not b2["assets"]
|
||||
|
||||
# Duplicates in list should not break matching
|
||||
filt_dup = {"flags": ["green", "green", "green"]}
|
||||
r3 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_dup)},
|
||||
timeout=120,
|
||||
)
|
||||
b3 = r3.json()
|
||||
assert r3.status_code == 200 and any(a["name"] == name for a in b3["assets"])
|
||||
|
||||
|
||||
def test_meta_none_semantics_missing_or_null_and_any_of_with_none(
|
||||
http, api_base, asset_factory, make_asset_bytes
|
||||
):
|
||||
# a1: key missing; a2: explicit null; a3: concrete value
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-none"]
|
||||
a1 = asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1"))
|
||||
a2 = asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2"))
|
||||
a3 = asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3"))
|
||||
|
||||
# Filter {maybe: None} must match a1 and a2, not a3
|
||||
filt = {"maybe": None}
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt), "sort": "name"},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
got = [a["name"] for a in b1["assets"]]
|
||||
assert a1["name"] in got and a2["name"] in got and a3["name"] not in got
|
||||
|
||||
# Any-of with None should include missing/null plus value matches
|
||||
filt_any = {"maybe": [None, "x"]}
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt_any), "sort": "name"},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
got2 = [a["name"] for a in b2["assets"]]
|
||||
assert a1["name"] in got2 and a2["name"] in got2 and a3["name"] in got2
|
||||
|
||||
|
||||
def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_nested_json.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-nested"]
|
||||
cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}}
|
||||
asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200))
|
||||
|
||||
# Exact JSON object equality (same structure)
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-nested",
|
||||
"metadata_filter": json.dumps({"config": cfg}),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# Different JSON object should not match
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-nested",
|
||||
"metadata_filter": json.dumps({"config": {"optimizer": "sgd"}}),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200 and not b2["assets"]
|
||||
|
||||
|
||||
def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_list_objects.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-objlist"]
|
||||
transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}]
|
||||
asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048))
|
||||
|
||||
# Any-of for list of objects should match when one element equals the filter object
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-objlist",
|
||||
"metadata_filter": json.dumps({"transforms": {"type": "flip", "p": 0.5}}),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# Non-matching object -> no match
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-objlist",
|
||||
"metadata_filter": json.dumps({"transforms": {"type": "rotate", "deg": 90}}),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200 and not b2["assets"]
|
||||
|
||||
|
||||
def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_keys_unicode.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-keys"]
|
||||
meta = {
|
||||
"weird.key": "v1",
|
||||
"path/like": 7,
|
||||
"with:colon": True,
|
||||
"ключ": "значение",
|
||||
"emoji": "🐍",
|
||||
}
|
||||
asset_factory(name, tags, meta, make_asset_bytes(name, 1500))
|
||||
|
||||
# Match all the special keys
|
||||
filt = {"weird.key": "v1", "path/like": 7, "with:colon": True, "emoji": "🐍"}
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps(filt)},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# Unicode key match
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps({"ключ": "значение"})},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200 and any(a["name"] == name for a in b2["assets"])
|
||||
|
||||
|
||||
def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"]
|
||||
a0 = asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025))
|
||||
a1 = asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026))
|
||||
|
||||
# count == 0 must match only a0
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"count": 0})},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names1 = [a["name"] for a in b1["assets"]]
|
||||
assert a0["name"] in names1 and a1["name"] not in names1
|
||||
|
||||
# Any-of list of booleans: True matches second asset
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"choices": True})},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200 and any(a["name"] == a1["name"] for a in b2["assets"])
|
||||
|
||||
|
||||
def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_mixed_list.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-mixed"]
|
||||
meta = {"mix": ["1", 1, True, None]}
|
||||
asset_factory(name, tags, meta, make_asset_bytes(name, 1999))
|
||||
|
||||
# Should match because 1 is present
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": [2, 1]})},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# Should NOT match for False
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": False})},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200 and not b2["assets"]
|
||||
|
||||
|
||||
def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes):
|
||||
# Use a unique scope tag to avoid interference
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"]
|
||||
x = asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua"))
|
||||
y = asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub"))
|
||||
|
||||
# Filtering by unknown key with None should return both (missing key OR null)
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": None})},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names = {a["name"] for a in b1["assets"]}
|
||||
assert x["name"] in names and y["name"] in names
|
||||
|
||||
# Filtering by unknown key with concrete value should return none
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": "x"})},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200 and not b2["assets"]
|
||||
|
||||
|
||||
def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_factory, make_asset_bytes):
|
||||
# alpha matches epoch=1; beta has epoch=2
|
||||
a = asset_factory(
|
||||
"mf_tag_alpha.safetensors",
|
||||
["models", "checkpoints", "unit-tests", "mf-tag", "alpha"],
|
||||
{"epoch": 1},
|
||||
make_asset_bytes("alpha"),
|
||||
)
|
||||
b = asset_factory(
|
||||
"mf_tag_beta.safetensors",
|
||||
["models", "checkpoints", "unit-tests", "mf-tag", "beta"],
|
||||
{"epoch": 2},
|
||||
make_asset_bytes("beta"),
|
||||
)
|
||||
|
||||
params = {
|
||||
"include_tags": "unit-tests,mf-tag,alpha",
|
||||
"exclude_tags": "beta",
|
||||
"name_contains": "mf_tag_",
|
||||
"metadata_filter": json.dumps({"epoch": 1}),
|
||||
}
|
||||
r = http.get(api_base + "/api/assets", params=params, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 200
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert a["name"] in names
|
||||
assert b["name"] not in names
|
||||
|
||||
|
||||
def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes):
|
||||
# Three assets in same scope with different sizes and a common filter key
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-sort"]
|
||||
n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors"
|
||||
asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024))
|
||||
asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048))
|
||||
asset_factory(n3, t, {"group": "g"}, make_asset_bytes(n3, 3072))
|
||||
|
||||
# Sort by size ascending with paging
|
||||
q = {
|
||||
"include_tags": "unit-tests,mf-sort",
|
||||
"metadata_filter": json.dumps({"group": "g"}),
|
||||
"sort": "size", "order": "asc", "limit": "2",
|
||||
}
|
||||
r1 = http.get(api_base + "/api/assets", params=q, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
got1 = [a["name"] for a in b1["assets"]]
|
||||
assert got1 == [n1, n2]
|
||||
assert b1["has_more"] is True
|
||||
|
||||
q2 = {**q, "offset": "2"}
|
||||
r2 = http.get(api_base + "/api/assets", params=q2, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
got2 = [a["name"] for a in b2["assets"]]
|
||||
assert got2 == [n3]
|
||||
assert b2["has_more"] is False
|
||||
@ -1,141 +0,0 @@
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_seed_file(comfy_tmp_base_dir: Path):
|
||||
"""Create a file on disk that will become a seed asset after sync."""
|
||||
created: list[Path] = []
|
||||
|
||||
def _create(root: str, scope: str, name: str | None = None, data: bytes = b"TEST") -> Path:
|
||||
name = name or f"seed_{uuid.uuid4().hex[:8]}.bin"
|
||||
path = comfy_tmp_base_dir / root / "unit-tests" / scope / name
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(data)
|
||||
created.append(path)
|
||||
return path
|
||||
|
||||
yield _create
|
||||
|
||||
for p in created:
|
||||
p.unlink(missing_ok=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def find_asset(http: requests.Session, api_base: str):
|
||||
"""Query API for assets matching scope and optional name."""
|
||||
def _find(scope: str, name: str | None = None) -> list[dict]:
|
||||
params = {"include_tags": f"unit-tests,{scope}"}
|
||||
if name:
|
||||
params["name_contains"] = name
|
||||
r = http.get(f"{api_base}/api/assets", params=params, timeout=120)
|
||||
assert r.status_code == 200
|
||||
assets = r.json().get("assets", [])
|
||||
if name:
|
||||
return [a for a in assets if a.get("name") == name]
|
||||
return assets
|
||||
|
||||
return _find
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_orphaned_seed_asset_is_pruned(
|
||||
root: str,
|
||||
create_seed_file,
|
||||
find_asset,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
):
|
||||
"""Seed asset with deleted file is removed; with file present, it survives."""
|
||||
scope = f"prune-{uuid.uuid4().hex[:6]}"
|
||||
fp = create_seed_file(root, scope)
|
||||
name = fp.name
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
assert find_asset(scope, name), "Seed asset should exist"
|
||||
|
||||
fp.unlink()
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
assert not find_asset(scope, name), "Orphaned seed should be pruned"
|
||||
|
||||
|
||||
def test_seed_asset_with_file_survives_prune(
|
||||
create_seed_file,
|
||||
find_asset,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
):
|
||||
"""Seed asset with file still on disk is NOT pruned."""
|
||||
scope = f"keep-{uuid.uuid4().hex[:6]}"
|
||||
fp = create_seed_file("input", scope)
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
assert find_asset(scope, fp.name), "Seed with valid file should survive"
|
||||
|
||||
|
||||
def test_hashed_asset_not_pruned_when_file_missing(
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""Hashed assets are never deleted by prune, even without file."""
|
||||
scope = f"hashed-{uuid.uuid4().hex[:6]}"
|
||||
data = make_asset_bytes("test", 2048)
|
||||
a = asset_factory("test.bin", ["input", "unit-tests", scope], {}, data)
|
||||
|
||||
path = comfy_tmp_base_dir / "input" / "unit-tests" / scope / get_asset_filename(a["asset_hash"], ".bin")
|
||||
path.unlink()
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
r = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120)
|
||||
assert r.status_code == 200, "Hashed asset should NOT be pruned"
|
||||
|
||||
|
||||
def test_prune_across_multiple_roots(
|
||||
create_seed_file,
|
||||
find_asset,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
):
|
||||
"""Prune correctly handles assets across input and output roots."""
|
||||
scope = f"multi-{uuid.uuid4().hex[:6]}"
|
||||
input_fp = create_seed_file("input", scope, "input.bin")
|
||||
create_seed_file("output", scope, "output.bin")
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
assert len(find_asset(scope)) == 2
|
||||
|
||||
input_fp.unlink()
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
remaining = find_asset(scope)
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0]["name"] == "output.bin"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dirname", ["100%_done", "my_folder_name", "has spaces"])
|
||||
def test_special_chars_in_path_escaped_correctly(
|
||||
dirname: str,
|
||||
create_seed_file,
|
||||
find_asset,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
"""SQL LIKE wildcards (%, _) and spaces in paths don't cause false matches."""
|
||||
scope = f"special-{uuid.uuid4().hex[:6]}/{dirname}"
|
||||
fp = create_seed_file("input", scope)
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
assert find_asset(scope.split("/")[0], fp.name), "Asset with special chars should survive"
|
||||
@ -1,225 +0,0 @@
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
# Include zero-usage tags by default
|
||||
r1 = http.get(api_base + "/api/tags", params={"limit": "50"}, timeout=120)
|
||||
body1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names = [t["name"] for t in body1["tags"]]
|
||||
# A few system tags from migration should exist:
|
||||
assert "models" in names
|
||||
assert "checkpoints" in names
|
||||
|
||||
# Only used tags before we add anything new from this test cycle
|
||||
r2 = http.get(api_base + "/api/tags", params={"include_zero": "false"}, timeout=120)
|
||||
body2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
# We already seeded one asset via fixture, so used tags must be non-empty
|
||||
used_names = [t["name"] for t in body2["tags"]]
|
||||
assert "models" in used_names
|
||||
assert "checkpoints" in used_names
|
||||
|
||||
# Prefix filter should refine the list
|
||||
r3 = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}, timeout=120)
|
||||
b3 = r3.json()
|
||||
assert r3.status_code == 200
|
||||
names3 = [t["name"] for t in b3["tags"]]
|
||||
assert "unit-tests" in names3
|
||||
assert "models" not in names3 # filtered out by prefix
|
||||
|
||||
# Order by name ascending should be stable
|
||||
r4 = http.get(api_base + "/api/tags", params={"include_zero": "false", "order": "name_asc"}, timeout=120)
|
||||
b4 = r4.json()
|
||||
assert r4.status_code == 200
|
||||
names4 = [t["name"] for t in b4["tags"]]
|
||||
assert names4 == sorted(names4)
|
||||
|
||||
|
||||
def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
# Baseline: system tags exist when include_zero (default) is true
|
||||
r1 = http.get(api_base + "/api/tags", params={"limit": "500"}, timeout=120)
|
||||
body1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names = [t["name"] for t in body1["tags"]]
|
||||
assert "models" in names and "checkpoints" in names
|
||||
|
||||
# Create a short-lived asset under input with a unique custom tag
|
||||
scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}"
|
||||
custom_tag = f"temp-{uuid.uuid4().hex[:8]}"
|
||||
name = "tag_seed.bin"
|
||||
_asset = asset_factory(
|
||||
name,
|
||||
["input", "unit-tests", scope, custom_tag],
|
||||
{},
|
||||
make_asset_bytes(name, 512),
|
||||
)
|
||||
|
||||
# While the asset exists, the custom tag must appear when include_zero=false
|
||||
r2 = http.get(
|
||||
api_base + "/api/tags",
|
||||
params={"include_zero": "false", "prefix": custom_tag, "limit": "50"},
|
||||
timeout=120,
|
||||
)
|
||||
body2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
used_names = [t["name"] for t in body2["tags"]]
|
||||
assert custom_tag in used_names
|
||||
|
||||
# Delete the asset so the tag usage drops to zero
|
||||
rd = http.delete(f"{api_base}/api/assets/{_asset['id']}", timeout=120)
|
||||
assert rd.status_code == 204
|
||||
|
||||
# Now the custom tag must not be returned when include_zero=false
|
||||
r3 = http.get(
|
||||
api_base + "/api/tags",
|
||||
params={"include_zero": "false", "prefix": custom_tag, "limit": "50"},
|
||||
timeout=120,
|
||||
)
|
||||
body3 = r3.json()
|
||||
assert r3.status_code == 200
|
||||
names_after = [t["name"] for t in body3["tags"]]
|
||||
assert custom_tag not in names_after
|
||||
assert not names_after # filtered view should be empty now
|
||||
|
||||
|
||||
def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# Add tags with duplicates and mixed case
|
||||
payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]}
|
||||
r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200, b1
|
||||
# normalized, deduplicated; 'unit-tests' was already present from the seed
|
||||
assert set(b1["added"]) == {"newtag", "beta"}
|
||||
assert set(b1["already_present"]) == {"unit-tests"}
|
||||
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"]
|
||||
|
||||
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
g = rg.json()
|
||||
assert rg.status_code == 200
|
||||
tags_now = set(g["tags"])
|
||||
assert {"newtag", "beta"}.issubset(tags_now)
|
||||
|
||||
# Remove a tag and a non-existent tag
|
||||
payload_del = {"tags": ["newtag", "does-not-exist"]}
|
||||
r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
assert set(b2["removed"]) == {"newtag"}
|
||||
assert set(b2["not_present"]) == {"does-not-exist"}
|
||||
|
||||
# Verify remaining tags after deletion
|
||||
rg2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
g2 = rg2.json()
|
||||
assert rg2.status_code == 200
|
||||
tags_later = set(g2["tags"])
|
||||
assert "newtag" not in tags_later
|
||||
assert "beta" in tags_later # still present
|
||||
|
||||
|
||||
def test_tags_list_order_and_prefix(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
h = seeded_asset["asset_hash"]
|
||||
|
||||
# Add both tags to the seeded asset (usage: orderaaa=1, orderbbb=1)
|
||||
r_add = http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": ["orderaaa", "orderbbb"]}, timeout=120)
|
||||
add_body = r_add.json()
|
||||
assert r_add.status_code == 200, add_body
|
||||
|
||||
# Create another AssetInfo from the same content but tagged ONLY with 'orderbbb'.
|
||||
payload = {
|
||||
"hash": h,
|
||||
"name": "order_only_bbb.safetensors",
|
||||
"tags": ["input", "unit-tests", "orderbbb"],
|
||||
"user_metadata": {},
|
||||
}
|
||||
r_copy = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
|
||||
copy_body = r_copy.json()
|
||||
assert r_copy.status_code == 201, copy_body
|
||||
|
||||
# 1) Default order (count_desc): 'orderbbb' should come before 'orderaaa'
|
||||
# because it has higher usage (2 vs 1).
|
||||
r1 = http.get(api_base + "/api/tags", params={"prefix": "order", "include_zero": "false"}, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200, b1
|
||||
names1 = [t["name"] for t in b1["tags"]]
|
||||
counts1 = {t["name"]: t["count"] for t in b1["tags"]}
|
||||
# Both must be present within the prefix subset
|
||||
assert "orderaaa" in names1 and "orderbbb" in names1
|
||||
# Usage of 'orderbbb' must be >= 'orderaaa'; in our setup it's 2 vs 1
|
||||
assert counts1["orderbbb"] >= counts1["orderaaa"]
|
||||
# And with count_desc, 'orderbbb' appears earlier than 'orderaaa'
|
||||
assert names1.index("orderbbb") < names1.index("orderaaa")
|
||||
|
||||
# 2) name_asc: lexical order should flip the relative order
|
||||
r2 = http.get(
|
||||
api_base + "/api/tags",
|
||||
params={"prefix": "order", "include_zero": "false", "order": "name_asc"},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200, b2
|
||||
names2 = [t["name"] for t in b2["tags"]]
|
||||
assert "orderaaa" in names2 and "orderbbb" in names2
|
||||
assert names2.index("orderaaa") < names2.index("orderbbb")
|
||||
|
||||
# 3) invalid limit rejected (existing negative case retained)
|
||||
r3 = http.get(api_base + "/api/tags", params={"limit": "1001"}, timeout=120)
|
||||
b3 = r3.json()
|
||||
assert r3.status_code == 400
|
||||
assert b3["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
def test_tags_endpoints_invalid_bodies(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# Add with empty list
|
||||
r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": []}, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 400
|
||||
assert b1["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
# Remove with wrong type
|
||||
r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json={"tags": [123]}, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 400
|
||||
assert b2["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
# metadata_filter provided as JSON array should be rejected (must be object)
|
||||
r3 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"metadata_filter": json.dumps([{"x": 1}])},
|
||||
timeout=120,
|
||||
)
|
||||
b3 = r3.json()
|
||||
assert r3.status_code == 400
|
||||
assert b3["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
def test_tags_prefix_treats_underscore_literal(
|
||||
http,
|
||||
api_base,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""'prefix' for /api/tags must treat '_' literally, not as a wildcard."""
|
||||
base = f"pref_{uuid.uuid4().hex[:6]}"
|
||||
tag_ok = f"{base}_ok" # should match prefix=f"{base}_"
|
||||
tag_bad = f"{base}xok" # must NOT match if '_' is escaped
|
||||
scope = f"tags-underscore-{uuid.uuid4().hex[:6]}"
|
||||
|
||||
asset_factory("t1.bin", ["input", "unit-tests", scope, tag_ok], {}, make_asset_bytes("t1", 512))
|
||||
asset_factory("t2.bin", ["input", "unit-tests", scope, tag_bad], {}, make_asset_bytes("t2", 512))
|
||||
|
||||
r = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": f"{base}_"}, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 200, body
|
||||
names = [t["name"] for t in body["tags"]]
|
||||
assert tag_ok in names, f"Expected {tag_ok} to be returned for prefix '{base}_'"
|
||||
assert tag_bad not in names, f"'{tag_bad}' must not match — '_' is not a wildcard"
|
||||
assert body["total"] == 1
|
||||
@ -1,281 +0,0 @@
|
||||
import json
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
import pytest
|
||||
|
||||
|
||||
def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes):
|
||||
name = "dup_a.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||
meta = {"purpose": "dup"}
|
||||
data = make_asset_bytes(name)
|
||||
files = {"file": (name, data, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
|
||||
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
a1 = r1.json()
|
||||
assert r1.status_code == 201, a1
|
||||
assert a1["created_new"] is True
|
||||
|
||||
# Second upload with the same data and name should return created_new == False and the same asset
|
||||
files = {"file": (name, data, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
|
||||
r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
a2 = r2.json()
|
||||
assert r2.status_code == 200, a2
|
||||
assert a2["created_new"] is False
|
||||
assert a2["asset_hash"] == a1["asset_hash"]
|
||||
assert a2["id"] == a1["id"] # old reference
|
||||
|
||||
# Third upload with the same data but new name should return created_new == False and the new AssetReference
|
||||
files = {"file": (name, data, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(tags), "name": name + "_d", "user_metadata": json.dumps(meta)}
|
||||
r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
a3 = r2.json()
|
||||
assert r2.status_code == 200, a3
|
||||
assert a3["created_new"] is False
|
||||
assert a3["asset_hash"] == a1["asset_hash"]
|
||||
assert a3["id"] != a1["id"] # old reference
|
||||
|
||||
|
||||
def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str):
|
||||
# Seed a small file first
|
||||
name = "fastpath_seed.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests"]
|
||||
meta = {}
|
||||
files = {"file": (name, b"B" * 1024, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
|
||||
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 201, b1
|
||||
h = b1["asset_hash"]
|
||||
|
||||
# Now POST /api/assets with only hash and no file
|
||||
files = [
|
||||
("hash", (None, h)),
|
||||
("tags", (None, json.dumps(tags))),
|
||||
("name", (None, "fastpath_copy.safetensors")),
|
||||
("user_metadata", (None, json.dumps({"purpose": "copy"}))),
|
||||
]
|
||||
r2 = http.post(api_base + "/api/assets", files=files, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200, b2 # fast path returns 200 with created_new == False
|
||||
assert b2["created_new"] is False
|
||||
assert b2["asset_hash"] == h
|
||||
|
||||
|
||||
def test_upload_fastpath_with_known_hash_and_file(
|
||||
http: requests.Session, api_base: str
|
||||
):
|
||||
# Seed
|
||||
files = {"file": ("seed.safetensors", b"C" * 128, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})}
|
||||
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 201, b1
|
||||
h = b1["asset_hash"]
|
||||
|
||||
# Send both file and hash of existing content -> server must drain file and create from hash (200)
|
||||
files = {"file": ("ignored.bin", b"ignored" * 10, "application/octet-stream")}
|
||||
form = {"hash": h, "tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "copy_from_hash.safetensors", "user_metadata": json.dumps({})}
|
||||
r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200, b2
|
||||
assert b2["created_new"] is False
|
||||
assert b2["asset_hash"] == h
|
||||
|
||||
|
||||
def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str):
|
||||
data = [
|
||||
("tags", "models,checkpoints"),
|
||||
("tags", json.dumps(["unit-tests", "alpha"])),
|
||||
("name", "merge.safetensors"),
|
||||
("user_metadata", json.dumps({"u": 1})),
|
||||
]
|
||||
files = {"file": ("merge.safetensors", b"B" * 256, "application/octet-stream")}
|
||||
r1 = http.post(api_base + "/api/assets", data=data, files=files, timeout=120)
|
||||
created = r1.json()
|
||||
assert r1.status_code in (200, 201), created
|
||||
aid = created["id"]
|
||||
|
||||
# Verify all tags are present on the resource
|
||||
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
detail = rg.json()
|
||||
assert rg.status_code == 200, detail
|
||||
tags = set(detail["tags"])
|
||||
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_concurrent_upload_identical_bytes_different_names(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""
|
||||
Two concurrent uploads of identical bytes but different names.
|
||||
Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True.
|
||||
"""
|
||||
scope = f"concupload-{uuid.uuid4().hex[:6]}"
|
||||
name1, name2 = "cu_a.bin", "cu_b.bin"
|
||||
data = make_asset_bytes("concurrent", 4096)
|
||||
tags = [root, "unit-tests", scope]
|
||||
|
||||
def _do_upload(args):
|
||||
url, form_data, files_data = args
|
||||
with requests.Session() as s:
|
||||
return s.post(url, data=form_data, files=files_data, timeout=120)
|
||||
|
||||
url = api_base + "/api/assets"
|
||||
form1 = {"tags": json.dumps(tags), "name": name1, "user_metadata": json.dumps({})}
|
||||
files1 = {"file": (name1, data, "application/octet-stream")}
|
||||
form2 = {"tags": json.dumps(tags), "name": name2, "user_metadata": json.dumps({})}
|
||||
files2 = {"file": (name2, data, "application/octet-stream")}
|
||||
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
futures = list(executor.map(_do_upload, [(url, form1, files1), (url, form2, files2)]))
|
||||
r1, r2 = futures
|
||||
|
||||
b1, b2 = r1.json(), r2.json()
|
||||
assert r1.status_code in (200, 201), b1
|
||||
assert r2.status_code in (200, 201), b2
|
||||
assert b1["asset_hash"] == b2["asset_hash"]
|
||||
assert b1["id"] != b2["id"]
|
||||
|
||||
created_flags = sorted([bool(b1.get("created_new")), bool(b2.get("created_new"))])
|
||||
assert created_flags == [False, True]
|
||||
|
||||
rl = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "sort": "name"},
|
||||
timeout=120,
|
||||
)
|
||||
bl = rl.json()
|
||||
assert rl.status_code == 200, bl
|
||||
names = [a["name"] for a in bl.get("assets", [])]
|
||||
assert set([name1, name2]).issubset(names)
|
||||
|
||||
|
||||
def test_create_from_hash_endpoint_404(http: requests.Session, api_base: str):
|
||||
payload = {
|
||||
"hash": "blake3:" + "0" * 64,
|
||||
"name": "nonexistent.bin",
|
||||
"tags": ["models", "checkpoints", "unit-tests"],
|
||||
}
|
||||
r = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 404
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
|
||||
|
||||
def test_upload_zero_byte_rejected(http: requests.Session, api_base: str):
|
||||
files = {"file": ("empty.safetensors", b"", "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "EMPTY_UPLOAD"
|
||||
|
||||
|
||||
def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str):
|
||||
files = {"file": ("badroot.bin", b"A" * 64, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["not-a-root", "whatever"]), "name": "badroot.bin", "user_metadata": json.dumps({})}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
def test_upload_user_metadata_must_be_json(http: requests.Session, api_base: str):
|
||||
files = {"file": ("badmeta.bin", b"A" * 128, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
def test_upload_requires_multipart(http: requests.Session, api_base: str):
|
||||
r = http.post(api_base + "/api/assets", json={"foo": "bar"}, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 415
|
||||
assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE"
|
||||
|
||||
|
||||
def test_upload_missing_file_and_hash(http: requests.Session, api_base: str):
|
||||
files = [
|
||||
("tags", (None, json.dumps(["models", "checkpoints", "unit-tests"]))),
|
||||
("name", (None, "x.safetensors")),
|
||||
]
|
||||
r = http.post(api_base + "/api/assets", files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "MISSING_FILE"
|
||||
|
||||
|
||||
def test_upload_models_unknown_category(http: requests.Session, api_base: str):
|
||||
files = {"file": ("m.safetensors", b"A" * 128, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "no_such_category", "unit-tests"]), "name": "m.safetensors"}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
assert body["error"]["message"].startswith("unknown models category")
|
||||
|
||||
|
||||
def test_upload_models_requires_category(http: requests.Session, api_base: str):
|
||||
files = {"file": ("nocat.safetensors", b"A" * 64, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models"]), "name": "nocat.safetensors", "user_metadata": json.dumps({})}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
|
||||
files = {"file": ("evil.safetensors", b"A" * 256, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_duplicate_upload_same_display_name_does_not_clobber(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""
|
||||
Two uploads use the same tags and the same display name but different bytes.
|
||||
With hash-based filenames, they must NOT overwrite each other. Both assets
|
||||
remain accessible and serve their original content.
|
||||
"""
|
||||
scope = f"dup-path-{uuid.uuid4().hex[:6]}"
|
||||
display_name = "same_display.bin"
|
||||
|
||||
d1 = make_asset_bytes(scope + "-v1", 1536)
|
||||
d2 = make_asset_bytes(scope + "-v2", 2048)
|
||||
tags = [root, "unit-tests", scope]
|
||||
|
||||
first = asset_factory(display_name, tags, {}, d1)
|
||||
second = asset_factory(display_name, tags, {}, d2)
|
||||
|
||||
assert first["id"] != second["id"]
|
||||
assert first["asset_hash"] != second["asset_hash"] # different content
|
||||
assert first["name"] == second["name"] == display_name
|
||||
|
||||
# Both must be independently retrievable
|
||||
r1 = http.get(f"{api_base}/api/assets/{first['id']}/content", timeout=120)
|
||||
b1 = r1.content
|
||||
assert r1.status_code == 200
|
||||
assert b1 == d1
|
||||
r2 = http.get(f"{api_base}/api/assets/{second['id']}/content", timeout=120)
|
||||
b2 = r2.content
|
||||
assert r2.status_code == 200
|
||||
assert b2 == d2
|
||||
Reference in New Issue
Block a user