mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-23 01:28:21 +08:00
Compare commits
91 Commits
v0.3.61
...
asset-mana
| Author | SHA1 | Date | |
|---|---|---|---|
| 917177e821 | |||
| fd6ac0a765 | |||
| 94941c50b3 | |||
| fbba2e59e5 | |||
| adccfb2dfd | |||
| 9f4c0f3afe | |||
| ca39552954 | |||
| 4dd843d36f | |||
| 46fdd636de | |||
| 283cd27bdc | |||
| 1a37d1476d | |||
| f9602457d6 | |||
| 85ef08449d | |||
| 5b6810a2c6 | |||
| 621faaa195 | |||
| d0aa64d57b | |||
| 677a0e2508 | |||
| 31ec744317 | |||
| a336c7c165 | |||
| 77332d3054 | |||
| 24a95f5ca4 | |||
| 0be513b213 | |||
| f1fb7432a0 | |||
| f3cf99d10c | |||
| 5f187fe6fb | |||
| 025fc49b4e | |||
| 7becb84341 | |||
| dda31de690 | |||
| 1d970382f0 | |||
| a2fc2bbae4 | |||
| a7f2546558 | |||
| 6cfa94ec58 | |||
| a2ec1f7637 | |||
| 0b795dc7a7 | |||
| 47f7c7ee8c | |||
| cdd8d16075 | |||
| 37b81e6658 | |||
| 975650060f | |||
| 4a713654cd | |||
| 9b8e88ba6e | |||
| bb9ed04758 | |||
| 934377ac1e | |||
| 3c9bf39c20 | |||
| 0df1ccac6f | |||
| 72548a8ac4 | |||
| 6eaed072c7 | |||
| a9096f6c97 | |||
| 964de8a8ad | |||
| 1886f10e19 | |||
| 357193f7b5 | |||
| 0ef73e95fd | |||
| faa1e4de17 | |||
| dfb5703d40 | |||
| 0e9de2b7c9 | |||
| e3311c9229 | |||
| 3fa0fc496c | |||
| 6282d495ca | |||
| b8ef9bb92c | |||
| 2d9be462d3 | |||
| 789a62ce35 | |||
| 84384ca0b4 | |||
| ce270ba090 | |||
| bf8363ec87 | |||
| 6b86be320a | |||
| bdf4ba24ce | |||
| 871e41aec6 | |||
| eb7008a4d3 | |||
| 0379eff0b5 | |||
| 026b7f209c | |||
| 7c1b0be496 | |||
| 6fade5da38 | |||
| a763cbd39d | |||
| 09dabf95bc | |||
| d7464e9e73 | |||
| a82577f64a | |||
| f2ea0bc22c | |||
| 0755e5320a | |||
| 8d46bec951 | |||
| 5c1b5973ac | |||
| f92307cd4c | |||
| c708d0a433 | |||
| 1aa089e0b6 | |||
| f032c1a50a | |||
| 3089936a2c | |||
| cd679129e3 | |||
| d7062277a7 | |||
| 54cf14cbbb | |||
| 7d5160f92c | |||
| 7f7b3f1695 | |||
| 9da6aca0d0 | |||
| 1cb3c98947 |
@ -1,24 +0,0 @@
|
||||
As of the time of writing this you need this preview driver for best results:
|
||||
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html
|
||||
|
||||
HOW TO RUN:
|
||||
|
||||
if you have a AMD gpu:
|
||||
|
||||
run_amd_gpu.bat
|
||||
|
||||
|
||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||
|
||||
You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors
|
||||
|
||||
|
||||
RECOMMENDED WAY TO UPDATE:
|
||||
To update the ComfyUI code: update\update_comfyui.bat
|
||||
|
||||
|
||||
TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
|
||||
In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
|
||||
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
|
||||
|
||||
@ -1,2 +0,0 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||
pause
|
||||
49
.github/workflows/release-stable-all.yml
vendored
49
.github/workflows/release-stable-all.yml
vendored
@ -1,49 +0,0 @@
|
||||
name: "Release Stable All Portable Versions"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_tag:
|
||||
description: 'Git tag'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
release_nvidia_default:
|
||||
name: "Release NVIDIA Default (cu129)"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu129"
|
||||
python_minor: "13"
|
||||
python_patch: "6"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: ""
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_nvidia_cu128:
|
||||
name: "Release NVIDIA cu128"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu128"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: "_cu128"
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_amd_rocm:
|
||||
name: "Release AMD ROCm 6.4.4"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "rocm644"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "amd"
|
||||
rel_extra_name: ""
|
||||
test_release: false
|
||||
secrets: inherit
|
||||
98
.github/workflows/stable-release.yml
vendored
98
.github/workflows/stable-release.yml
vendored
@ -2,53 +2,17 @@
|
||||
name: "Release Stable Version"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
git_tag:
|
||||
description: 'Git tag'
|
||||
required: true
|
||||
type: string
|
||||
cache_tag:
|
||||
description: 'Cached dependencies tag'
|
||||
required: true
|
||||
type: string
|
||||
default: "cu129"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "13"
|
||||
python_patch:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "6"
|
||||
rel_name:
|
||||
description: 'Release name'
|
||||
required: true
|
||||
type: string
|
||||
default: "nvidia"
|
||||
rel_extra_name:
|
||||
description: 'Release extra name'
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
test_release:
|
||||
description: 'Test Release'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_tag:
|
||||
description: 'Git tag'
|
||||
required: true
|
||||
type: string
|
||||
cache_tag:
|
||||
description: 'Cached dependencies tag'
|
||||
cu:
|
||||
description: 'CUDA version'
|
||||
required: true
|
||||
type: string
|
||||
default: "cu129"
|
||||
default: "129"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
@ -59,21 +23,7 @@ on:
|
||||
required: true
|
||||
type: string
|
||||
default: "6"
|
||||
rel_name:
|
||||
description: 'Release name'
|
||||
required: true
|
||||
type: string
|
||||
default: "nvidia"
|
||||
rel_extra_name:
|
||||
description: 'Release extra name'
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
test_release:
|
||||
description: 'Test Release'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
|
||||
jobs:
|
||||
package_comfy_windows:
|
||||
@ -92,15 +42,15 @@ jobs:
|
||||
id: cache
|
||||
with:
|
||||
path: |
|
||||
${{ inputs.cache_tag }}_python_deps.tar
|
||||
cu${{ inputs.cu }}_python_deps.tar
|
||||
update_comfyui_and_python_dependencies.bat
|
||||
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
|
||||
key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
|
||||
- shell: bash
|
||||
run: |
|
||||
mv ${{ inputs.cache_tag }}_python_deps.tar ../
|
||||
mv cu${{ inputs.cu }}_python_deps.tar ../
|
||||
mv update_comfyui_and_python_dependencies.bat ../
|
||||
cd ..
|
||||
tar xf ${{ inputs.cache_tag }}_python_deps.tar
|
||||
tar xf cu${{ inputs.cu }}_python_deps.tar
|
||||
pwd
|
||||
ls
|
||||
|
||||
@ -115,19 +65,12 @@ jobs:
|
||||
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
|
||||
|
||||
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
|
||||
./python.exe -s -m pip install -r requirements_comfyui.txt
|
||||
rm requirements_comfyui.txt
|
||||
|
||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
fi
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
|
||||
cd ..
|
||||
|
||||
@ -142,18 +85,14 @@ jobs:
|
||||
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp ../update_comfyui_and_python_dependencies.bat ./update/
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||
|
||||
- shell: bash
|
||||
if: ${{ inputs.test_release }}
|
||||
run: |
|
||||
cd ..
|
||||
cd ComfyUI_windows_portable
|
||||
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
||||
|
||||
@ -162,9 +101,10 @@ jobs:
|
||||
ls
|
||||
|
||||
- name: Upload binaries to release
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
|
||||
tag_name: ${{ inputs.git_tag }}
|
||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
file: ComfyUI_windows_portable_nvidia.7z
|
||||
tag: ${{ inputs.git_tag }}
|
||||
overwrite: true
|
||||
draft: true
|
||||
overwrite_files: true
|
||||
|
||||
173
.github/workflows/test-assets.yml
vendored
Normal file
173
.github/workflows/test-assets.yml
vendored
Normal file
@ -0,0 +1,173 @@
|
||||
name: Asset System Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'app/**'
|
||||
- 'tests-assets/**'
|
||||
- '.github/workflows/test-assets.yml'
|
||||
- 'requirements.txt'
|
||||
pull_request:
|
||||
branches: [master]
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
PIP_DISABLE_PIP_VERSION_CHECK: '1'
|
||||
PYTHONUNBUFFERED: '1'
|
||||
|
||||
jobs:
|
||||
sqlite:
|
||||
name: SQLite (${{ matrix.sqlite_mode }}) • Python ${{ matrix.python }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 40
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python: ['3.9', '3.12']
|
||||
sqlite_mode: ['memory', 'file']
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install -U pip wheel
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-aiohttp pytest-asyncio
|
||||
|
||||
- name: Set deterministic test base dir
|
||||
id: basedir
|
||||
shell: bash
|
||||
run: |
|
||||
BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}"
|
||||
echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV"
|
||||
echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV"
|
||||
mkdir -p "$BASE/logs"
|
||||
echo "ASSETS_TEST_BASE_DIR=$BASE"
|
||||
|
||||
- name: Set DB URL for SQLite
|
||||
id: setdb
|
||||
shell: bash
|
||||
run: |
|
||||
if [ "${{ matrix.sqlite_mode }}" = "memory" ]; then
|
||||
echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///:memory:" >> "$GITHUB_ENV"
|
||||
else
|
||||
DBFILE="$RUNNER_TEMP/assets-tests.sqlite"
|
||||
mkdir -p "$(dirname "$DBFILE")"
|
||||
echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///$DBFILE" >> "$GITHUB_ENV"
|
||||
fi
|
||||
|
||||
- name: Run tests
|
||||
run: python -m pytest tests-assets
|
||||
|
||||
- name: Show ComfyUI logs
|
||||
if: always()
|
||||
shell: bash
|
||||
run: |
|
||||
echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ===="
|
||||
echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ===="
|
||||
ls -la "$ASSETS_TEST_LOGS" || true
|
||||
for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do
|
||||
if [ -f "$f" ]; then
|
||||
echo "----- BEGIN $f -----"
|
||||
sed -n '1,400p' "$f"
|
||||
echo "----- END $f -----"
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Upload ComfyUI logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: asset-logs-sqlite-${{ matrix.sqlite_mode }}-py${{ matrix.python }}
|
||||
path: ${{ env.ASSETS_TEST_LOGS }}/*.log
|
||||
if-no-files-found: warn
|
||||
|
||||
postgres:
|
||||
name: PostgreSQL ${{ matrix.pgsql }} • Python ${{ matrix.python }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 40
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python: ['3.9', '3.12']
|
||||
pgsql: ['16', '18']
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:${{ matrix.pgsql }}
|
||||
env:
|
||||
POSTGRES_DB: assets
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd "pg_isready -U postgres -d assets"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 12
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install -U pip wheel
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-aiohttp pytest-asyncio
|
||||
pip install greenlet psycopg
|
||||
|
||||
- name: Set deterministic test base dir
|
||||
id: basedir
|
||||
shell: bash
|
||||
run: |
|
||||
BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}"
|
||||
echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV"
|
||||
echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV"
|
||||
mkdir -p "$BASE/logs"
|
||||
echo "ASSETS_TEST_BASE_DIR=$BASE"
|
||||
|
||||
- name: Set DB URL for PostgreSQL
|
||||
shell: bash
|
||||
run: |
|
||||
echo "ASSETS_TEST_DB_URL=postgresql+psycopg://postgres:postgres@localhost:5432/assets" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Run tests
|
||||
run: python -m pytest tests-assets
|
||||
|
||||
- name: Show ComfyUI logs
|
||||
if: always()
|
||||
shell: bash
|
||||
run: |
|
||||
echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ===="
|
||||
echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ===="
|
||||
ls -la "$ASSETS_TEST_LOGS" || true
|
||||
for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do
|
||||
if [ -f "$f" ]; then
|
||||
echo "----- BEGIN $f -----"
|
||||
sed -n '1,400p' "$f"
|
||||
echo "----- END $f -----"
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Upload ComfyUI logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: asset-logs-pgsql-${{ matrix.pgsql }}-py${{ matrix.python }}
|
||||
path: ${{ env.ASSETS_TEST_LOGS }}/*.log
|
||||
if-no-files-found: warn
|
||||
@ -56,8 +56,7 @@ jobs:
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
||||
pause" > update_comfyui_and_python_dependencies.bat
|
||||
|
||||
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
|
||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||
echo installed basic
|
||||
ls -lah temp_wheel_dir
|
||||
|
||||
@ -1,64 +0,0 @@
|
||||
name: "Windows Release dependencies Manual"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
torch_dependencies:
|
||||
description: 'torch dependencies'
|
||||
required: false
|
||||
type: string
|
||||
default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128"
|
||||
cache_tag:
|
||||
description: 'Cached dependencies tag'
|
||||
required: true
|
||||
type: string
|
||||
default: "cu128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
|
||||
jobs:
|
||||
build_dependencies:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
|
||||
|
||||
- shell: bash
|
||||
run: |
|
||||
echo "@echo off
|
||||
call update_comfyui.bat nopause
|
||||
echo -
|
||||
echo This will try to update pytorch and all python dependencies.
|
||||
echo -
|
||||
echo If you just want to update normally, close this and run update_comfyui.bat instead.
|
||||
echo -
|
||||
pause
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2
|
||||
pause" > update_comfyui_and_python_dependencies.bat
|
||||
|
||||
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
|
||||
python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||
echo installed basic
|
||||
ls -lah temp_wheel_dir
|
||||
mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps
|
||||
tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps
|
||||
|
||||
- uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
${{ inputs.cache_tag }}_python_deps.tar
|
||||
update_comfyui_and_python_dependencies.bat
|
||||
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
|
||||
@ -68,7 +68,7 @@ jobs:
|
||||
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
|
||||
|
||||
echo "call update_comfyui.bat nopause
|
||||
|
||||
@ -81,7 +81,7 @@ jobs:
|
||||
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp ../update_comfyui_and_python_dependencies.bat ./update/
|
||||
|
||||
cd ..
|
||||
|
||||
@ -233,7 +233,7 @@ Nvidia users should install stable pytorch using this command:
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
# Use forward slashes (/) also on windows to provide an os agnostic path
|
||||
script_location = alembic_db
|
||||
script_location = app/alembic_db
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
|
||||
@ -2,13 +2,12 @@ from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
from app.assets.database.models import Base
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
|
||||
from app.database.models import Base
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
175
app/alembic_db/versions/0001_assets.py
Normal file
175
app/alembic_db/versions/0001_assets.py
Normal file
@ -0,0 +1,175 @@
|
||||
"""initial assets schema
|
||||
|
||||
Revision ID: 0001_assets
|
||||
Revises:
|
||||
Create Date: 2025-08-20 00:00:00
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision = "0001_assets"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ASSETS: content identity
|
||||
op.create_table(
|
||||
"assets",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("hash", sa.String(length=256), nullable=True),
|
||||
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column("mime_type", sa.String(length=255), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
|
||||
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
|
||||
|
||||
# ASSETS_INFO: user-visible references
|
||||
op.create_table(
|
||||
"assets_info",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
||||
sa.Column("name", sa.String(length=512), nullable=False),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
)
|
||||
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
||||
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
|
||||
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
||||
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
||||
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
||||
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
|
||||
|
||||
# TAGS: normalized tag vocabulary
|
||||
op.create_table(
|
||||
"tags",
|
||||
sa.Column("name", sa.String(length=512), primary_key=True),
|
||||
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
|
||||
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
|
||||
)
|
||||
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
|
||||
|
||||
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
|
||||
op.create_table(
|
||||
"asset_info_tags",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
||||
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
|
||||
)
|
||||
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
||||
|
||||
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
|
||||
op.create_table(
|
||||
"asset_cache_state",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
|
||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
|
||||
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
|
||||
|
||||
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
|
||||
op.create_table(
|
||||
"asset_info_meta",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("key", sa.String(length=256), nullable=False),
|
||||
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("val_str", sa.String(length=2048), nullable=True),
|
||||
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
|
||||
sa.Column("val_bool", sa.Boolean(), nullable=True),
|
||||
sa.Column("val_json", sa.JSON().with_variant(postgresql.JSONB(), 'postgresql'), nullable=True),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
|
||||
)
|
||||
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
|
||||
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
|
||||
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
||||
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
||||
|
||||
# Tags vocabulary
|
||||
tags_table = sa.table(
|
||||
"tags",
|
||||
sa.column("name", sa.String(length=512)),
|
||||
sa.column("tag_type", sa.String()),
|
||||
)
|
||||
op.bulk_insert(
|
||||
tags_table,
|
||||
[
|
||||
{"name": "models", "tag_type": "system"},
|
||||
{"name": "input", "tag_type": "system"},
|
||||
{"name": "output", "tag_type": "system"},
|
||||
|
||||
{"name": "configs", "tag_type": "system"},
|
||||
{"name": "checkpoints", "tag_type": "system"},
|
||||
{"name": "loras", "tag_type": "system"},
|
||||
{"name": "vae", "tag_type": "system"},
|
||||
{"name": "text_encoders", "tag_type": "system"},
|
||||
{"name": "diffusion_models", "tag_type": "system"},
|
||||
{"name": "clip_vision", "tag_type": "system"},
|
||||
{"name": "style_models", "tag_type": "system"},
|
||||
{"name": "embeddings", "tag_type": "system"},
|
||||
{"name": "diffusers", "tag_type": "system"},
|
||||
{"name": "vae_approx", "tag_type": "system"},
|
||||
{"name": "controlnet", "tag_type": "system"},
|
||||
{"name": "gligen", "tag_type": "system"},
|
||||
{"name": "upscale_models", "tag_type": "system"},
|
||||
{"name": "hypernetworks", "tag_type": "system"},
|
||||
{"name": "photomaker", "tag_type": "system"},
|
||||
{"name": "classifiers", "tag_type": "system"},
|
||||
|
||||
{"name": "encoder", "tag_type": "system"},
|
||||
{"name": "decoder", "tag_type": "system"},
|
||||
|
||||
{"name": "missing", "tag_type": "system"},
|
||||
{"name": "rescan", "tag_type": "system"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
||||
op.drop_table("asset_info_meta")
|
||||
|
||||
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
|
||||
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_table("asset_cache_state")
|
||||
|
||||
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
||||
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
|
||||
op.drop_table("asset_info_tags")
|
||||
|
||||
op.drop_index("ix_tags_tag_type", table_name="tags")
|
||||
op.drop_table("tags")
|
||||
|
||||
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
||||
op.drop_table("assets_info")
|
||||
|
||||
op.drop_index("uq_assets_hash", table_name="assets")
|
||||
op.drop_index("ix_assets_mime_type", table_name="assets")
|
||||
op.drop_table("assets")
|
||||
4
app/assets/__init__.py
Normal file
4
app/assets/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .api.routes import register_assets_system
|
||||
from .scanner import sync_seed_assets
|
||||
|
||||
__all__ = ["sync_seed_assets", "register_assets_system"]
|
||||
225
app/assets/_helpers.py
Normal file
225
app/assets/_helpers.py
Normal file
@ -0,0 +1,225 @@
|
||||
import contextlib
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Sequence
|
||||
|
||||
import folder_paths
|
||||
|
||||
from .api import schemas_in
|
||||
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
||||
|
||||
We trust `folder_paths.folder_names_and_paths` and include a category if
|
||||
*any* of its base paths lies under the Comfy `models_dir`.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
models_root = os.path.abspath(folder_paths.models_dir)
|
||||
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
- 'output' if the file resides under `folder_paths.get_output_directory()`
|
||||
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
|
||||
|
||||
Returns:
|
||||
(root_category, relative_path_inside_that_root)
|
||||
For 'models', the relative path is prefixed with the category name:
|
||||
e.g. ('models', 'vae/test/sub/ae.safetensors')
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _is_within(child: str, parent: str) -> bool:
|
||||
try:
|
||||
return os.path.commonpath([child, parent]) == parent
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _rel(child: str, parent: str) -> str:
|
||||
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if _is_within(fp_abs, input_base):
|
||||
return "input", _rel(fp_abs, input_base)
|
||||
|
||||
# 2) output
|
||||
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||
if _is_within(fp_abs, output_base):
|
||||
return "output", _rel(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: Optional[tuple[int, str, str]] = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
|
||||
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return a tuple (name, tags) derived from a filesystem path.
|
||||
|
||||
Semantics:
|
||||
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
|
||||
- The returned `name` is the base filename with extension from the relative path.
|
||||
- The returned `tags` are:
|
||||
[root_category] + parent folders of the relative path (in order)
|
||||
For 'models', this means:
|
||||
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
|
||||
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
|
||||
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
|
||||
|
||||
def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
root = tags[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
else:
|
||||
base_dir = os.path.abspath(
|
||||
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||
)
|
||||
raw_subdirs = tags[1:]
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
|
||||
def ensure_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = os.path.abspath(candidate)
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
|
||||
def compute_relative_filename(file_path: str) -> Optional[str]:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
For non-model paths, returns None.
|
||||
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
|
||||
"""
|
||||
try:
|
||||
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
p = Path(rel_path)
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def list_tree(base_dir: str) -> list[str]:
|
||||
out: list[str] = []
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
return out
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
for name in filenames:
|
||||
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return out
|
||||
|
||||
|
||||
def prefixes_for_root(root: schemas_in.RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||
if root == "output":
|
||||
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||
return []
|
||||
|
||||
|
||||
def ts_to_iso(ts: Optional[float]) -> Optional[str]:
|
||||
if ts is None:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def new_scan_id(root: schemas_in.RootType) -> str:
|
||||
return f"scan-{root}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
allowed = False
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
with contextlib.suppress(Exception):
|
||||
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
544
app/assets/api/routes.py
Normal file
544
app/assets/api/routes.py
Normal file
@ -0,0 +1,544 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from aiohttp import web
|
||||
from pydantic import ValidationError
|
||||
|
||||
import folder_paths
|
||||
|
||||
from ... import user_manager
|
||||
from .. import manager, scanner
|
||||
from . import schemas_in, schemas_out
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: Optional[user_manager.UserManager] = None
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||
if not hash_str or ":" not in hash_str:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = hash_str.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
exists = await manager.asset_exists(asset_hash=hash_str)
|
||||
return web.Response(status=200 if exists else 404)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
qp = request.rel_url.query
|
||||
query_dict = {}
|
||||
if "include_tags" in qp:
|
||||
query_dict["include_tags"] = qp.getall("include_tags")
|
||||
if "exclude_tags" in qp:
|
||||
query_dict["exclude_tags"] = qp.getall("exclude_tags")
|
||||
for k in ("name_contains", "metadata_filter", "limit", "offset", "sort", "order"):
|
||||
v = qp.get(k)
|
||||
if v is not None:
|
||||
query_dict[k] = v
|
||||
|
||||
try:
|
||||
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_QUERY", ve)
|
||||
|
||||
payload = await manager.list_assets(
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=q.sort,
|
||||
order=q.order,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||
async def download_asset_content(request: web.Request) -> web.Response:
|
||||
disposition = request.query.get("disposition", "attachment").lower().strip()
|
||||
if disposition not in {"inline", "attachment"}:
|
||||
disposition = "attachment"
|
||||
|
||||
try:
|
||||
abs_path, content_type, filename = await manager.resolve_asset_content_for_download(
|
||||
asset_info_id=str(uuid.UUID(request.match_info["id"])),
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||
except NotImplementedError as nie:
|
||||
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
|
||||
except FileNotFoundError:
|
||||
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
|
||||
|
||||
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
|
||||
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
|
||||
|
||||
resp = web.FileResponse(abs_path)
|
||||
resp.content_type = content_type
|
||||
resp.headers["Content-Disposition"] = cd
|
||||
return resp
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/from-hash")
|
||||
async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
body = schemas_in.CreateFromHashBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
result = await manager.create_asset_from_hash(
|
||||
hash_str=body.hash,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||
return web.json_response(result.model_dump(mode="json"), status=201)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
async def upload_asset(request: web.Request) -> web.Response:
|
||||
"""Multipart/form-data endpoint for Asset uploads."""
|
||||
|
||||
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
file_present = False
|
||||
file_client_name: Optional[str] = None
|
||||
tags_raw: list[str] = []
|
||||
provided_name: Optional[str] = None
|
||||
user_metadata_raw: Optional[str] = None
|
||||
provided_hash: Optional[str] = None
|
||||
provided_hash_exists: Optional[bool] = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: Optional[str] = None
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
fname = getattr(field, "name", "") or ""
|
||||
|
||||
if fname == "hash":
|
||||
try:
|
||||
s = ((await field.text()) or "").strip().lower()
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
|
||||
if s:
|
||||
if ":" not in s:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
provided_hash = f"{algo}:{digest}"
|
||||
try:
|
||||
provided_hash_exists = await manager.asset_exists(asset_hash=provided_hash)
|
||||
except Exception:
|
||||
provided_hash_exists = None # do not fail the whole request here
|
||||
|
||||
elif fname == "file":
|
||||
file_present = True
|
||||
file_client_name = (field.filename or "").strip()
|
||||
|
||||
if provided_hash and provided_hash_exists is True:
|
||||
# If client supplied a hash that we know exists, drain but do not write to disk
|
||||
try:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
|
||||
continue # Do not create temp file; we will create AssetInfo from the existing content
|
||||
|
||||
# Otherwise, store to temp for hashing/ingest
|
||||
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||
os.makedirs(unique_dir, exist_ok=True)
|
||||
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
try:
|
||||
if os.path.exists(tmp_path or ""):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
|
||||
elif fname == "tags":
|
||||
tags_raw.append((await field.text()) or "")
|
||||
elif fname == "name":
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
|
||||
# If client did not send file, and we are not doing a from-hash fast path -> error
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
|
||||
|
||||
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
|
||||
# Empty upload is only acceptable if we are fast-pathing from existing hash
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||
|
||||
try:
|
||||
spec = schemas_in.UploadAssetSpec.model_validate({
|
||||
"tags": tags_raw,
|
||||
"name": provided_name,
|
||||
"user_metadata": user_metadata_raw,
|
||||
"hash": provided_hash,
|
||||
})
|
||||
except ValidationError as ve:
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
# Validate models category against configured folders (consistent with previous behavior)
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return _error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
|
||||
)
|
||||
|
||||
owner_id = USER_MANAGER.get_request_user_id(request)
|
||||
|
||||
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
|
||||
if spec.hash and provided_hash_exists is True:
|
||||
try:
|
||||
result = await manager.create_asset_from_hash(
|
||||
hash_str=spec.hash,
|
||||
name=spec.name or (spec.hash.split(":", 1)[1]),
|
||||
tags=spec.tags,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
owner_id=owner_id,
|
||||
)
|
||||
except Exception:
|
||||
LOGGER.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
|
||||
|
||||
# Drain temp if we accidentally saved (e.g., hash field came after file)
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
with contextlib.suppress(Exception):
|
||||
os.remove(tmp_path)
|
||||
|
||||
status = 200 if (not result.created_new) else 201
|
||||
return web.json_response(result.model_dump(mode="json"), status=status)
|
||||
|
||||
# Otherwise, we must have a temp file path to ingest
|
||||
if not tmp_path or not os.path.exists(tmp_path):
|
||||
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
|
||||
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
|
||||
|
||||
try:
|
||||
created = await manager.upload_asset_from_temp_path(
|
||||
spec,
|
||||
temp_path=tmp_path,
|
||||
client_filename=file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_asset_hash=spec.hash,
|
||||
)
|
||||
status = 201 if created.created_new else 200
|
||||
return web.json_response(created.model_dump(mode="json"), status=status)
|
||||
except ValueError as e:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
msg = str(e)
|
||||
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
|
||||
return _error_response(
|
||||
400,
|
||||
"HASH_MISMATCH",
|
||||
"Uploaded file hash does not match provided hash.",
|
||||
)
|
||||
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
|
||||
except Exception:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
LOGGER.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def get_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
result = await manager.get_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"get_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def update_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = await manager.update_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"update_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview")
|
||||
async def set_asset_preview(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.SetPreviewBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = await manager.set_asset_preview(
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=body.preview_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (PermissionError, ValueError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"set_asset_preview failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def delete_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
delete_content = request.query.get("delete_content")
|
||||
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
|
||||
|
||||
try:
|
||||
deleted = await manager.delete_asset_reference(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
delete_content_if_orphan=delete_content,
|
||||
)
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if not deleted:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
|
||||
return web.Response(status=204)
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
query_map = dict(request.rel_url.query)
|
||||
|
||||
try:
|
||||
query = schemas_in.TagsListQuery.model_validate(query_map)
|
||||
except ValidationError as ve:
|
||||
return web.json_response(
|
||||
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": ve.errors()}},
|
||||
status=400,
|
||||
)
|
||||
|
||||
result = await manager.list_tags(
|
||||
prefix=query.prefix,
|
||||
limit=query.limit,
|
||||
offset=query.offset,
|
||||
order=query.order,
|
||||
include_zero=query.include_zero,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsAdd.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = await manager.add_tags_to_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
origin="manual",
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsRemove.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = await manager.remove_tags_from_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/scan/seed")
|
||||
async def seed_assets(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
payload = {}
|
||||
|
||||
try:
|
||||
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
try:
|
||||
await scanner.sync_seed_assets(body.roots)
|
||||
except Exception:
|
||||
LOGGER.exception("sync_seed_assets failed for roots=%s", body.roots)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response({"synced": True, "roots": body.roots}, status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/scan/schedule")
|
||||
async def schedule_asset_scan(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
payload = {}
|
||||
|
||||
try:
|
||||
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
states = await scanner.schedule_scans(body.roots)
|
||||
return web.json_response(states.model_dump(mode="json"), status=202)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets/scan")
|
||||
async def get_asset_scan_status(request: web.Request) -> web.Response:
|
||||
root = request.query.get("root", "").strip().lower()
|
||||
states = scanner.current_statuses()
|
||||
if root in {"models", "input", "output"}:
|
||||
states = [s for s in states.scans if s.root == root] # type: ignore
|
||||
states = schemas_out.AssetScanStatusResponse(scans=states)
|
||||
return web.json_response(states.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||
global USER_MANAGER
|
||||
USER_MANAGER = user_manager_instance
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
|
||||
def _error_response(status: int, code: str, message: str, details: Optional[dict] = None) -> web.Response:
|
||||
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
|
||||
|
||||
|
||||
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||
297
app/assets/api/schemas_in.py
Normal file
297
app/assets/api/schemas_in.py
Normal file
@ -0,0 +1,297 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
conint,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
name_contains: Optional[str] = None
|
||||
|
||||
# Accept either a JSON string (query param) or a dict
|
||||
metadata_filter: Optional[dict[str, Any]] = None
|
||||
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
|
||||
order: Literal["asc", "desc"] = "desc"
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@classmethod
|
||||
def _split_csv_tags(cls, v):
|
||||
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [t.strip() for t in v.split(",") if t.strip()]
|
||||
if isinstance(v, list):
|
||||
out: list[str] = []
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
return out
|
||||
return v
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v
|
||||
if isinstance(v, str) and v.strip():
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
except Exception as e:
|
||||
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("metadata_filter must be a JSON object")
|
||||
return parsed
|
||||
return None
|
||||
|
||||
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
user_metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
if self.name is None and self.tags is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, tags, user_metadata.")
|
||||
if self.tags is not None:
|
||||
if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags):
|
||||
raise ValueError("Field 'tags' must be an array of strings.")
|
||||
return self
|
||||
|
||||
|
||||
class CreateFromHashBody(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
hash: str
|
||||
name: str
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
def _require_blake3(cls, v):
|
||||
s = (v or "").strip().lower()
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return s
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _tags_norm(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||
seen = set()
|
||||
dedup = []
|
||||
for t in out:
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
dedup.append(t)
|
||||
return dedup
|
||||
if isinstance(v, str):
|
||||
return [t.strip().lower() for t in v.split(",") if t.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
prefix: Optional[str] = Field(None, min_length=1, max_length=256)
|
||||
limit: int = Field(100, ge=1, le=1000)
|
||||
offset: int = Field(0, ge=0, le=10_000_000)
|
||||
order: Literal["count_desc", "name_asc"] = "count_desc"
|
||||
include_zero: bool = True
|
||||
|
||||
@field_validator("prefix")
|
||||
@classmethod
|
||||
def normalize_prefix(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
v = v.strip()
|
||||
return v.lower() or None
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
|
||||
@field_validator("tags")
|
||||
@classmethod
|
||||
def normalize_tags(cls, v: list[str]) -> list[str]:
|
||||
out = []
|
||||
for t in v:
|
||||
if not isinstance(t, str):
|
||||
raise TypeError("tags must be strings")
|
||||
tnorm = t.strip().lower()
|
||||
if tnorm:
|
||||
out.append(tnorm)
|
||||
seen = set()
|
||||
deduplicated = []
|
||||
for x in out:
|
||||
if x not in seen:
|
||||
seen.add(x)
|
||||
deduplicated.append(x)
|
||||
return deduplicated
|
||||
|
||||
|
||||
class TagsRemove(TagsAdd):
|
||||
pass
|
||||
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
||||
|
||||
|
||||
class ScheduleAssetScanBody(BaseModel):
|
||||
roots: list[RootType] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
|
||||
|
||||
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||
and the original extension is preserved when available.
|
||||
"""
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
name: Optional[str] = Field(default=None, max_length=512, description="Display Name")
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
hash: Optional[str] = Field(default=None)
|
||||
|
||||
@field_validator("hash", mode="before")
|
||||
@classmethod
|
||||
def _parse_hash(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip().lower()
|
||||
if not s:
|
||||
return None
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return f"{algo}:{digest}"
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _parse_tags(cls, v):
|
||||
"""
|
||||
Accepts a list of strings (possibly multiple form fields),
|
||||
where each string can be:
|
||||
- JSON array (e.g., '["models","loras","foo"]')
|
||||
- comma-separated ('models, loras, foo')
|
||||
- single token ('models')
|
||||
Returns a normalized, deduplicated, ordered list.
|
||||
"""
|
||||
items: list[str] = []
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
v = [v]
|
||||
|
||||
if isinstance(v, list):
|
||||
for item in v:
|
||||
if item is None:
|
||||
continue
|
||||
s = str(item).strip()
|
||||
if not s:
|
||||
continue
|
||||
if s.startswith("["):
|
||||
try:
|
||||
arr = json.loads(s)
|
||||
if isinstance(arr, list):
|
||||
items.extend(str(x) for x in arr)
|
||||
continue
|
||||
except Exception:
|
||||
pass # fallback to CSV parse below
|
||||
items.extend([p for p in s.split(",") if p.strip()])
|
||||
else:
|
||||
return []
|
||||
|
||||
# normalize + dedupe
|
||||
norm = []
|
||||
seen = set()
|
||||
for t in items:
|
||||
tnorm = str(t).strip().lower()
|
||||
if tnorm and tnorm not in seen:
|
||||
seen.add(tnorm)
|
||||
norm.append(tnorm)
|
||||
return norm
|
||||
|
||||
@field_validator("user_metadata", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v or {}
|
||||
if isinstance(v, str):
|
||||
s = v.strip()
|
||||
if not s:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
except Exception as e:
|
||||
raise ValueError(f"user_metadata must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("user_metadata must be a JSON object")
|
||||
return parsed
|
||||
return {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("tags must be provided and non-empty")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError("models uploads require a category tag as the second tag")
|
||||
return self
|
||||
|
||||
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
preview_id: Optional[str] = None
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
115
app/assets/api/schemas_out.py
Normal file
115
app/assets/api/schemas_out.py
Normal file
@ -0,0 +1,115 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
|
||||
|
||||
class AssetSummary(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: Optional[str]
|
||||
size: Optional[int] = None
|
||||
mime_type: Optional[str] = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
preview_url: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
last_access_time: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "updated_at", "last_access_time")
|
||||
def _ser_dt(self, v: Optional[datetime], _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetsList(BaseModel):
|
||||
assets: list[AssetSummary]
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetUpdated(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: Optional[str]
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _ser_updated(self, v: Optional[datetime], _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: Optional[str]
|
||||
size: Optional[int] = None
|
||||
mime_type: Optional[str] = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
preview_id: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
last_access_time: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "last_access_time")
|
||||
def _ser_dt(self, v: Optional[datetime], _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(AssetDetail):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
type: str
|
||||
|
||||
|
||||
class TagsList(BaseModel):
|
||||
tags: list[TagUsage] = Field(default_factory=list)
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
added: list[str] = Field(default_factory=list)
|
||||
already_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagsRemove(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AssetScanError(BaseModel):
|
||||
path: str
|
||||
message: str
|
||||
at: Optional[str] = Field(None, description="ISO timestamp")
|
||||
|
||||
|
||||
class AssetScanStatus(BaseModel):
|
||||
scan_id: str
|
||||
root: Literal["models", "input", "output"]
|
||||
status: Literal["scheduled", "running", "completed", "failed", "cancelled"]
|
||||
scheduled_at: Optional[str] = None
|
||||
started_at: Optional[str] = None
|
||||
finished_at: Optional[str] = None
|
||||
discovered: int = 0
|
||||
processed: int = 0
|
||||
file_errors: list[AssetScanError] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AssetScanStatusResponse(BaseModel):
|
||||
scans: list[AssetScanStatus] = Field(default_factory=list)
|
||||
0
app/assets/database/__init__.py
Normal file
0
app/assets/database/__init__.py
Normal file
25
app/assets/database/helpers/__init__.py
Normal file
25
app/assets/database/helpers/__init__.py
Normal file
@ -0,0 +1,25 @@
|
||||
from .bulk_ops import seed_from_paths_batch
|
||||
from .escape_like import escape_like_prefix
|
||||
from .fast_check import fast_asset_file_check
|
||||
from .filters import apply_metadata_filter, apply_tag_filters
|
||||
from .ownership import visible_owner_clause
|
||||
from .projection import is_scalar, project_kv
|
||||
from .tags import (
|
||||
add_missing_tag_for_asset_id,
|
||||
ensure_tags_exist,
|
||||
remove_missing_tag_for_asset_id,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"apply_tag_filters",
|
||||
"apply_metadata_filter",
|
||||
"escape_like_prefix",
|
||||
"fast_asset_file_check",
|
||||
"is_scalar",
|
||||
"project_kv",
|
||||
"ensure_tags_exist",
|
||||
"add_missing_tag_for_asset_id",
|
||||
"remove_missing_tag_for_asset_id",
|
||||
"seed_from_paths_batch",
|
||||
"visible_owner_clause",
|
||||
]
|
||||
230
app/assets/database/helpers/bulk_ops.py
Normal file
230
app/assets/database/helpers/bulk_ops.py
Normal file
@ -0,0 +1,230 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql as d_pg
|
||||
from sqlalchemy.dialects import sqlite as d_sqlite
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoMeta, AssetInfoTag
|
||||
from ..timeutil import utcnow
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
|
||||
async def seed_from_paths_batch(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
specs: Sequence[dict],
|
||||
owner_id: str = "",
|
||||
) -> dict:
|
||||
"""Each spec is a dict with keys:
|
||||
- abs_path: str
|
||||
- size_bytes: int
|
||||
- mtime_ns: int
|
||||
- info_name: str
|
||||
- tags: list[str]
|
||||
- fname: Optional[str]
|
||||
"""
|
||||
if not specs:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
|
||||
|
||||
now = utcnow()
|
||||
dialect = session.bind.dialect.name
|
||||
if dialect not in ("sqlite", "postgresql"):
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
|
||||
asset_rows: list[dict] = []
|
||||
state_rows: list[dict] = []
|
||||
path_to_asset: dict[str, str] = {}
|
||||
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
|
||||
path_list: list[str] = []
|
||||
|
||||
for sp in specs:
|
||||
ap = os.path.abspath(sp["abs_path"])
|
||||
aid = str(uuid.uuid4())
|
||||
iid = str(uuid.uuid4())
|
||||
path_list.append(ap)
|
||||
path_to_asset[ap] = aid
|
||||
|
||||
asset_rows.append(
|
||||
{
|
||||
"id": aid,
|
||||
"hash": None,
|
||||
"size_bytes": sp["size_bytes"],
|
||||
"mime_type": None,
|
||||
"created_at": now,
|
||||
}
|
||||
)
|
||||
state_rows.append(
|
||||
{
|
||||
"asset_id": aid,
|
||||
"file_path": ap,
|
||||
"mtime_ns": sp["mtime_ns"],
|
||||
}
|
||||
)
|
||||
asset_to_info[aid] = {
|
||||
"id": iid,
|
||||
"owner_id": owner_id,
|
||||
"name": sp["info_name"],
|
||||
"asset_id": aid,
|
||||
"preview_id": None,
|
||||
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
"_tags": sp["tags"],
|
||||
"_filename": sp["fname"],
|
||||
}
|
||||
|
||||
# insert all seed Assets (hash=NULL)
|
||||
ins_asset = d_sqlite.insert(Asset) if dialect == "sqlite" else d_pg.insert(Asset)
|
||||
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
|
||||
await session.execute(ins_asset, chunk)
|
||||
|
||||
# try to claim AssetCacheState (file_path)
|
||||
winners_by_path: set[str] = set()
|
||||
if dialect == "sqlite":
|
||||
ins_state = (
|
||||
d_sqlite.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
.returning(AssetCacheState.file_path)
|
||||
)
|
||||
else:
|
||||
ins_state = (
|
||||
d_pg.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
.returning(AssetCacheState.file_path)
|
||||
)
|
||||
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
|
||||
winners_by_path.update((await session.execute(ins_state, chunk)).scalars().all())
|
||||
|
||||
all_paths_set = set(path_list)
|
||||
losers_by_path = all_paths_set - winners_by_path
|
||||
lost_assets = [path_to_asset[p] for p in losers_by_path]
|
||||
if lost_assets: # losers get their Asset removed
|
||||
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
|
||||
await session.execute(sa.delete(Asset).where(Asset.id.in_(id_chunk)))
|
||||
|
||||
if not winners_by_path:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
|
||||
|
||||
# insert AssetInfo only for winners
|
||||
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
|
||||
if dialect == "sqlite":
|
||||
ins_info = (
|
||||
d_sqlite.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
.returning(AssetInfo.id)
|
||||
)
|
||||
else:
|
||||
ins_info = (
|
||||
d_pg.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
.returning(AssetInfo.id)
|
||||
)
|
||||
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
|
||||
inserted_info_ids.update((await session.execute(ins_info, chunk)).scalars().all())
|
||||
|
||||
# build and insert tag + meta rows for the AssetInfo
|
||||
tag_rows: list[dict] = []
|
||||
meta_rows: list[dict] = []
|
||||
if inserted_info_ids:
|
||||
for row in winner_info_rows:
|
||||
iid = row["id"]
|
||||
if iid not in inserted_info_ids:
|
||||
continue
|
||||
for t in row["_tags"]:
|
||||
tag_rows.append({
|
||||
"asset_info_id": iid,
|
||||
"tag_name": t,
|
||||
"origin": "automatic",
|
||||
"added_at": now,
|
||||
})
|
||||
if row["_filename"]:
|
||||
meta_rows.append(
|
||||
{
|
||||
"asset_info_id": iid,
|
||||
"key": "filename",
|
||||
"ordinal": 0,
|
||||
"val_str": row["_filename"],
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
)
|
||||
|
||||
await bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
|
||||
return {
|
||||
"inserted_infos": len(inserted_info_ids),
|
||||
"won_states": len(winners_by_path),
|
||||
"lost_states": len(losers_by_path),
|
||||
}
|
||||
|
||||
|
||||
async def bulk_insert_tags_and_meta(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
tag_rows: list[dict],
|
||||
meta_rows: list[dict],
|
||||
max_bind_params: int,
|
||||
) -> None:
|
||||
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
|
||||
- tag_rows keys: asset_info_id, tag_name, origin, added_at
|
||||
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
|
||||
"""
|
||||
dialect = session.bind.dialect.name
|
||||
if tag_rows:
|
||||
if dialect == "sqlite":
|
||||
ins_links = (
|
||||
d_sqlite.insert(AssetInfoTag)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins_links = (
|
||||
d_pg.insert(AssetInfoTag)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
|
||||
await session.execute(ins_links, chunk)
|
||||
if meta_rows:
|
||||
if dialect == "sqlite":
|
||||
ins_meta = (
|
||||
d_sqlite.insert(AssetInfoMeta)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
||||
)
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins_meta = (
|
||||
d_pg.insert(AssetInfoMeta)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
|
||||
await session.execute(ins_meta, chunk)
|
||||
|
||||
|
||||
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
|
||||
if not rows:
|
||||
return []
|
||||
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
|
||||
for i in range(0, len(rows), rows_per_stmt):
|
||||
yield rows[i:i + rows_per_stmt]
|
||||
|
||||
|
||||
def _iter_chunks(seq, n: int):
|
||||
for i in range(0, len(seq), n):
|
||||
yield seq[i:i + n]
|
||||
|
||||
|
||||
def _rows_per_stmt(cols: int) -> int:
|
||||
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
||||
7
app/assets/database/helpers/escape_like.py
Normal file
7
app/assets/database/helpers/escape_like.py
Normal file
@ -0,0 +1,7 @@
|
||||
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
|
||||
"""Escapes %, _ and the escape char itself in a LIKE prefix.
|
||||
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
|
||||
"""
|
||||
s = s.replace(escape, escape + escape) # escape the escape char first
|
||||
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
|
||||
return s, escape
|
||||
19
app/assets/database/helpers/fast_check.py
Normal file
19
app/assets/database/helpers/fast_check.py
Normal file
@ -0,0 +1,19 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def fast_asset_file_check(
|
||||
*,
|
||||
mtime_db: Optional[int],
|
||||
size_db: Optional[int],
|
||||
stat_result: os.stat_result,
|
||||
) -> bool:
|
||||
if mtime_db is None:
|
||||
return False
|
||||
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
|
||||
if int(mtime_db) != int(actual_mtime_ns):
|
||||
return False
|
||||
sz = int(size_db or 0)
|
||||
if sz > 0:
|
||||
return int(stat_result.st_size) == sz
|
||||
return True
|
||||
87
app/assets/database/helpers/filters.py
Normal file
87
app/assets/database/helpers/filters.py
Normal file
@ -0,0 +1,87 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exists
|
||||
|
||||
from ..._helpers import normalize_tags
|
||||
from ..models import AssetInfo, AssetInfoMeta, AssetInfoTag
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Optional[Sequence[str]],
|
||||
exclude_tags: Optional[Sequence[str]],
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: Optional[dict],
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_info_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetInfoMeta.val_json.is_(None),
|
||||
AssetInfoMeta.val_str.is_(None),
|
||||
AssetInfoMeta.val_num.is_(None),
|
||||
AssetInfoMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float)):
|
||||
from decimal import Decimal
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
12
app/assets/database/helpers/ownership.py
Normal file
12
app/assets/database/helpers/ownership.py
Normal file
@ -0,0 +1,12 @@
|
||||
import sqlalchemy as sa
|
||||
|
||||
from ..models import AssetInfo
|
||||
|
||||
|
||||
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
||||
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetInfo.owner_id == ""
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
64
app/assets/database/helpers/projection.py
Normal file
64
app/assets/database/helpers/projection.py
Normal file
@ -0,0 +1,64 @@
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def project_kv(key: str, value):
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
90
app/assets/database/helpers/tags.py
Normal file
90
app/assets/database/helpers/tags.py
Normal file
@ -0,0 +1,90 @@
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql as d_pg
|
||||
from sqlalchemy.dialects import sqlite as d_sqlite
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..._helpers import normalize_tags
|
||||
from ..models import AssetInfo, AssetInfoTag, Tag
|
||||
from ..timeutil import utcnow
|
||||
|
||||
|
||||
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
dialect = session.bind.dialect.name
|
||||
if dialect == "sqlite":
|
||||
ins = (
|
||||
d_sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins = (
|
||||
d_pg.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
await session.execute(ins)
|
||||
|
||||
|
||||
async def add_missing_tag_for_asset_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> None:
|
||||
select_rows = (
|
||||
sa.select(
|
||||
AssetInfo.id.label("asset_info_id"),
|
||||
sa.literal("missing").label("tag_name"),
|
||||
sa.literal(origin).label("origin"),
|
||||
sa.literal(utcnow()).label("added_at"),
|
||||
)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.where(
|
||||
sa.not_(
|
||||
sa.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
|
||||
)
|
||||
)
|
||||
)
|
||||
dialect = session.bind.dialect.name
|
||||
if dialect == "sqlite":
|
||||
ins = (
|
||||
d_sqlite.insert(AssetInfoTag)
|
||||
.from_select(
|
||||
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins = (
|
||||
d_pg.insert(AssetInfoTag)
|
||||
.from_select(
|
||||
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
await session.execute(ins)
|
||||
|
||||
|
||||
async def remove_missing_tag_for_asset_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
await session.execute(
|
||||
sa.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
251
app/assets/database/models.py
Normal file
251
app/assets/database/models.py
Normal file
@ -0,0 +1,251 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
CheckConstraint,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, foreign, mapped_column, relationship
|
||||
|
||||
from .timeutil import utcnow
|
||||
|
||||
JSONB_V = JSON(none_as_null=True).with_variant(JSONB(none_as_null=True), 'postgresql')
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||
fields = obj.__table__.columns.keys()
|
||||
out: dict[str, Any] = {}
|
||||
for field in fields:
|
||||
val = getattr(obj, field)
|
||||
if val is None and not include_none:
|
||||
continue
|
||||
if isinstance(val, datetime):
|
||||
out[field] = val.isoformat()
|
||||
else:
|
||||
out[field] = val
|
||||
return out
|
||||
|
||||
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
hash: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
mime_type: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
infos: Mapped[list["AssetInfo"]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
|
||||
foreign_keys=lambda: [AssetInfo.asset_id],
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
preview_of: Mapped[list["AssetInfo"]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
|
||||
foreign_keys=lambda: [AssetInfo.preview_id],
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
cache_states: Mapped[list["AssetCacheState"]] = relationship(
|
||||
back_populates="asset",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("uq_assets_hash", "hash", unique=True),
|
||||
Index("ix_assets_mime_type", "mime_type"),
|
||||
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
|
||||
|
||||
|
||||
class AssetCacheState(Base):
|
||||
__tablename__ = "asset_cache_state"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
mtime_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
|
||||
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
asset: Mapped["Asset"] = relationship(back_populates="cache_states")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_cache_state_file_path", "file_path"),
|
||||
Index("ix_asset_cache_state_asset_id", "asset_id"),
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
|
||||
|
||||
|
||||
class AssetInfo(Base):
|
||||
__tablename__ = "assets_info"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
|
||||
preview_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
|
||||
user_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON(none_as_null=True))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
|
||||
asset: Mapped[Asset] = relationship(
|
||||
"Asset",
|
||||
back_populates="infos",
|
||||
foreign_keys=[asset_id],
|
||||
lazy="selectin",
|
||||
)
|
||||
preview_asset: Mapped[Optional[Asset]] = relationship(
|
||||
"Asset",
|
||||
back_populates="preview_of",
|
||||
foreign_keys=[preview_id],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
tag_links: Mapped[list["AssetInfoTag"]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
overlaps="tags,asset_infos",
|
||||
)
|
||||
|
||||
tags: Mapped[list["Tag"]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="asset_infos",
|
||||
lazy="selectin",
|
||||
viewonly=True,
|
||||
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
Index("ix_assets_info_owner_name", "owner_id", "name"),
|
||||
Index("ix_assets_info_owner_id", "owner_id"),
|
||||
Index("ix_assets_info_asset_id", "asset_id"),
|
||||
Index("ix_assets_info_name", "name"),
|
||||
Index("ix_assets_info_created_at", "created_at"),
|
||||
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
data = to_dict(self, include_none=include_none)
|
||||
data["tags"] = [t.name for t in self.tags]
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
|
||||
|
||||
|
||||
class AssetInfoMeta(Base):
|
||||
__tablename__ = "asset_info_meta"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
key: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
|
||||
|
||||
val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
|
||||
val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True)
|
||||
val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
|
||||
val_json: Mapped[Optional[Any]] = mapped_column(JSONB_V, nullable=True)
|
||||
|
||||
asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_meta_key", "key"),
|
||||
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
|
||||
)
|
||||
|
||||
|
||||
class AssetInfoTag(Base):
|
||||
__tablename__ = "asset_info_tags"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
tag_name: Mapped[str] = mapped_column(
|
||||
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||
)
|
||||
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||
added_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links")
|
||||
tag: Mapped["Tag"] = relationship(back_populates="asset_info_links")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_tags_tag_name", "tag_name"),
|
||||
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
|
||||
)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(512), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(
|
||||
back_populates="tag",
|
||||
overlaps="asset_infos,tags",
|
||||
)
|
||||
asset_infos: Mapped[list["AssetInfo"]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="tags",
|
||||
viewonly=True,
|
||||
overlaps="asset_info_links,tag_links,tags,asset_info",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_tags_tag_type", "tag_type"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
57
app/assets/database/services/__init__.py
Normal file
57
app/assets/database/services/__init__.py
Normal file
@ -0,0 +1,57 @@
|
||||
from .content import (
|
||||
check_fs_asset_exists_quick,
|
||||
compute_hash_and_dedup_for_cache_state,
|
||||
ingest_fs_asset,
|
||||
list_cache_states_with_asset_under_prefixes,
|
||||
list_unhashed_candidates_under_prefixes,
|
||||
list_verify_candidates_under_prefixes,
|
||||
redirect_all_references_then_delete_asset,
|
||||
touch_asset_infos_by_fs_path,
|
||||
)
|
||||
from .info import (
|
||||
add_tags_to_asset_info,
|
||||
create_asset_info_for_existing_asset,
|
||||
delete_asset_info_by_id,
|
||||
fetch_asset_info_and_asset,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
get_asset_tags,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
remove_tags_from_asset_info,
|
||||
replace_asset_info_metadata_projection,
|
||||
set_asset_info_preview,
|
||||
set_asset_info_tags,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
)
|
||||
from .queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
get_cache_state_by_asset_id,
|
||||
list_cache_states_by_asset_id,
|
||||
pick_best_live_path,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# queries
|
||||
"asset_exists_by_hash", "get_asset_by_hash", "get_asset_info_by_id", "asset_info_exists_for_asset_id",
|
||||
"get_cache_state_by_asset_id",
|
||||
"list_cache_states_by_asset_id",
|
||||
"pick_best_live_path",
|
||||
# info
|
||||
"list_asset_infos_page", "create_asset_info_for_existing_asset", "set_asset_info_tags",
|
||||
"update_asset_info_full", "replace_asset_info_metadata_projection",
|
||||
"touch_asset_info_by_id", "delete_asset_info_by_id",
|
||||
"add_tags_to_asset_info", "remove_tags_from_asset_info",
|
||||
"get_asset_tags", "list_tags_with_usage", "set_asset_info_preview",
|
||||
"fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags",
|
||||
# content
|
||||
"check_fs_asset_exists_quick",
|
||||
"redirect_all_references_then_delete_asset",
|
||||
"compute_hash_and_dedup_for_cache_state",
|
||||
"list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes",
|
||||
"ingest_fs_asset", "touch_asset_infos_by_fs_path",
|
||||
"list_cache_states_with_asset_under_prefixes",
|
||||
]
|
||||
721
app/assets/database/services/content.py
Normal file
721
app/assets/database/services/content.py
Normal file
@ -0,0 +1,721 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects import postgresql as d_pg
|
||||
from sqlalchemy.dialects import sqlite as d_sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import noload
|
||||
|
||||
from ..._helpers import compute_relative_filename
|
||||
from ...storage import hashing as hashing_mod
|
||||
from ..helpers import (
|
||||
ensure_tags_exist,
|
||||
escape_like_prefix,
|
||||
remove_missing_tag_for_asset_id,
|
||||
)
|
||||
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag
|
||||
from ..timeutil import utcnow
|
||||
from .info import replace_asset_info_metadata_projection
|
||||
from .queries import list_cache_states_by_asset_id, pick_best_live_path
|
||||
|
||||
|
||||
async def check_fs_asset_exists_quick(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
file_path: str,
|
||||
size_bytes: Optional[int] = None,
|
||||
mtime_ns: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Returns True if we already track this absolute path with a HASHED asset and the cached mtime/size match."""
|
||||
locator = os.path.abspath(file_path)
|
||||
|
||||
stmt = (
|
||||
sa.select(sa.literal(True))
|
||||
.select_from(AssetCacheState)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(
|
||||
AssetCacheState.file_path == locator,
|
||||
Asset.hash.isnot(None),
|
||||
AssetCacheState.needs_verify.is_(False),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
conds = []
|
||||
if mtime_ns is not None:
|
||||
conds.append(AssetCacheState.mtime_ns == int(mtime_ns))
|
||||
if size_bytes is not None:
|
||||
conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
|
||||
if conds:
|
||||
stmt = stmt.where(*conds)
|
||||
return (await session.execute(stmt)).first() is not None
|
||||
|
||||
|
||||
async def redirect_all_references_then_delete_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
duplicate_asset_id: str,
|
||||
canonical_asset_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Safely migrate all references from duplicate_asset_id to canonical_asset_id.
|
||||
|
||||
- If an AssetInfo for (owner_id, name) already exists on the canonical asset,
|
||||
merge tags, metadata, times, and preview, then delete the duplicate AssetInfo.
|
||||
- Otherwise, simply repoint the AssetInfo.asset_id.
|
||||
- Always retarget AssetCacheState rows.
|
||||
- Finally delete the duplicate Asset row.
|
||||
"""
|
||||
if duplicate_asset_id == canonical_asset_id:
|
||||
return
|
||||
|
||||
# 1) Migrate AssetInfo rows one-by-one to avoid UNIQUE conflicts.
|
||||
dup_infos = (
|
||||
await session.execute(
|
||||
select(AssetInfo).options(noload(AssetInfo.tags)).where(AssetInfo.asset_id == duplicate_asset_id)
|
||||
)
|
||||
).unique().scalars().all()
|
||||
|
||||
for info in dup_infos:
|
||||
# Try to find an existing collision on canonical
|
||||
existing = (
|
||||
await session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == canonical_asset_id,
|
||||
AssetInfo.owner_id == info.owner_id,
|
||||
AssetInfo.name == info.name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
|
||||
if existing:
|
||||
merged_meta = dict(existing.user_metadata or {})
|
||||
other_meta = info.user_metadata or {}
|
||||
for k, v in other_meta.items():
|
||||
if k not in merged_meta:
|
||||
merged_meta[k] = v
|
||||
if merged_meta != (existing.user_metadata or {}):
|
||||
await replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=existing.id,
|
||||
user_metadata=merged_meta,
|
||||
)
|
||||
|
||||
existing_tags = {
|
||||
t for (t,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == existing.id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
from_tags = {
|
||||
t for (t,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == info.id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
to_add = sorted(from_tags - existing_tags)
|
||||
if to_add:
|
||||
await ensure_tags_exist(session, to_add, tag_type="user")
|
||||
now = utcnow()
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=existing.id, tag_name=t, origin="automatic", added_at=now)
|
||||
for t in to_add
|
||||
])
|
||||
await session.flush()
|
||||
|
||||
if existing.preview_id is None and info.preview_id is not None:
|
||||
existing.preview_id = info.preview_id
|
||||
if info.last_access_time and (
|
||||
existing.last_access_time is None or info.last_access_time > existing.last_access_time
|
||||
):
|
||||
existing.last_access_time = info.last_access_time
|
||||
existing.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
# Delete the duplicate AssetInfo (cascades will clean its tags/meta)
|
||||
await session.delete(info)
|
||||
await session.flush()
|
||||
else:
|
||||
# Simple retarget
|
||||
info.asset_id = canonical_asset_id
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
# 2) Repoint cache states and previews
|
||||
await session.execute(
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == duplicate_asset_id)
|
||||
.values(asset_id=canonical_asset_id)
|
||||
)
|
||||
await session.execute(
|
||||
sa.update(AssetInfo)
|
||||
.where(AssetInfo.preview_id == duplicate_asset_id)
|
||||
.values(preview_id=canonical_asset_id)
|
||||
)
|
||||
|
||||
# 3) Remove duplicate Asset
|
||||
dup = await session.get(Asset, duplicate_asset_id)
|
||||
if dup:
|
||||
await session.delete(dup)
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def compute_hash_and_dedup_for_cache_state(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
state_id: int,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Compute hash for the given cache state, deduplicate, and settle verify cases.
|
||||
|
||||
Returns the asset_id that this state ends up pointing to, or None if file disappeared.
|
||||
"""
|
||||
state = await session.get(AssetCacheState, state_id)
|
||||
if not state:
|
||||
return None
|
||||
|
||||
path = state.file_path
|
||||
try:
|
||||
if not os.path.isfile(path):
|
||||
# File vanished: drop the state. If the Asset has hash=NULL and has no other states, drop the Asset too.
|
||||
asset = await session.get(Asset, state.asset_id)
|
||||
await session.delete(state)
|
||||
await session.flush()
|
||||
|
||||
if asset and asset.hash is None:
|
||||
remaining = (
|
||||
await session.execute(
|
||||
sa.select(sa.func.count())
|
||||
.select_from(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset.id)
|
||||
)
|
||||
).scalar_one()
|
||||
if int(remaining or 0) == 0:
|
||||
await session.delete(asset)
|
||||
await session.flush()
|
||||
else:
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=asset.id)
|
||||
return None
|
||||
|
||||
digest = await hashing_mod.blake3_hash(path)
|
||||
new_hash = f"blake3:{digest}"
|
||||
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
new_size = int(st.st_size)
|
||||
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
# Current asset of this state
|
||||
this_asset = await session.get(Asset, state.asset_id)
|
||||
|
||||
# If the state got orphaned somehow (race), just reattach appropriately.
|
||||
if not this_asset:
|
||||
canonical = (
|
||||
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||
).scalars().first()
|
||||
if canonical:
|
||||
state.asset_id = canonical.id
|
||||
else:
|
||||
now = utcnow()
|
||||
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
|
||||
session.add(new_asset)
|
||||
await session.flush()
|
||||
state.asset_id = new_asset.id
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=state.asset_id)
|
||||
await session.flush()
|
||||
return state.asset_id
|
||||
|
||||
# 1) Seed asset case (hash is NULL): claim or merge into canonical
|
||||
if this_asset.hash is None:
|
||||
canonical = (
|
||||
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
if canonical and canonical.id != this_asset.id:
|
||||
# Merge seed asset into canonical (safe, collision-aware)
|
||||
await redirect_all_references_then_delete_asset(
|
||||
session,
|
||||
duplicate_asset_id=this_asset.id,
|
||||
canonical_asset_id=canonical.id,
|
||||
)
|
||||
state = await session.get(AssetCacheState, state_id)
|
||||
if state:
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id)
|
||||
await session.flush()
|
||||
return canonical.id
|
||||
|
||||
# No canonical: try to claim the hash; handle races with a SAVEPOINT
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
this_asset.hash = new_hash
|
||||
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
|
||||
this_asset.size_bytes = new_size
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
# Someone else claimed it concurrently; fetch canonical and merge
|
||||
canonical = (
|
||||
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||
).scalars().first()
|
||||
if canonical and canonical.id != this_asset.id:
|
||||
await redirect_all_references_then_delete_asset(
|
||||
session,
|
||||
duplicate_asset_id=this_asset.id,
|
||||
canonical_asset_id=canonical.id,
|
||||
)
|
||||
state = await session.get(AssetCacheState, state_id)
|
||||
if state:
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id)
|
||||
await session.flush()
|
||||
return canonical.id
|
||||
# If we got here, the integrity error was not about hash uniqueness
|
||||
raise
|
||||
|
||||
# Claimed successfully
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id)
|
||||
await session.flush()
|
||||
return this_asset.id
|
||||
|
||||
# 2) Verify case for hashed assets
|
||||
if this_asset.hash == new_hash:
|
||||
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
|
||||
this_asset.size_bytes = new_size
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id)
|
||||
await session.flush()
|
||||
return this_asset.id
|
||||
|
||||
# Content changed on this path only: retarget THIS state, do not move AssetInfo rows
|
||||
canonical = (
|
||||
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||
).scalars().first()
|
||||
if canonical:
|
||||
target_id = canonical.id
|
||||
else:
|
||||
now = utcnow()
|
||||
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
|
||||
session.add(new_asset)
|
||||
await session.flush()
|
||||
target_id = new_asset.id
|
||||
|
||||
state.asset_id = target_id
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=target_id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=target_id)
|
||||
await session.flush()
|
||||
return target_id
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
|
||||
async def list_unhashed_candidates_under_prefixes(session: AsyncSession, *, prefixes: list[str]) -> list[int]:
|
||||
if not prefixes:
|
||||
return []
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
path_filter = sa.or_(*conds) if len(conds) > 1 else conds[0]
|
||||
if session.bind.dialect.name == "postgresql":
|
||||
stmt = (
|
||||
sa.select(AssetCacheState.id)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(Asset.hash.is_(None), path_filter)
|
||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
.distinct(AssetCacheState.asset_id)
|
||||
)
|
||||
else:
|
||||
first_id = sa.func.min(AssetCacheState.id).label("first_id")
|
||||
stmt = (
|
||||
sa.select(first_id)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(Asset.hash.is_(None), path_filter)
|
||||
.group_by(AssetCacheState.asset_id)
|
||||
.order_by(first_id.asc())
|
||||
)
|
||||
return [int(x) for x in (await session.execute(stmt)).scalars().all()]
|
||||
|
||||
|
||||
async def list_verify_candidates_under_prefixes(
|
||||
session: AsyncSession, *, prefixes: Sequence[str]
|
||||
) -> Union[list[int], Sequence[int]]:
|
||||
if not prefixes:
|
||||
return []
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
return (
|
||||
await session.execute(
|
||||
sa.select(AssetCacheState.id)
|
||||
.where(AssetCacheState.needs_verify.is_(True))
|
||||
.where(sa.or_(*conds))
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
async def ingest_fs_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: Optional[str] = None,
|
||||
info_name: Optional[str] = None,
|
||||
owner_id: str = "",
|
||||
preview_id: Optional[str] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
dialect = session.bind.dialect.name
|
||||
|
||||
if preview_id:
|
||||
if not await session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
out: dict[str, Any] = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
|
||||
# 1) Asset by hash
|
||||
asset = (
|
||||
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
vals = {
|
||||
"hash": asset_hash,
|
||||
"size_bytes": int(size_bytes),
|
||||
"mime_type": mime_type,
|
||||
"created_at": now,
|
||||
}
|
||||
if dialect == "sqlite":
|
||||
res = await session.execute(
|
||||
d_sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["asset_created"] = True
|
||||
asset = (
|
||||
await session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
elif dialect == "postgresql":
|
||||
res = await session.execute(
|
||||
d_pg.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[Asset.hash],
|
||||
index_where=Asset.__table__.c.hash.isnot(None),
|
||||
)
|
||||
.returning(Asset.id)
|
||||
)
|
||||
inserted_id = res.scalar_one_or_none()
|
||||
if inserted_id:
|
||||
out["asset_created"] = True
|
||||
asset = await session.get(Asset, inserted_id)
|
||||
else:
|
||||
asset = (
|
||||
await session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
else:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
out["asset_updated"] = True
|
||||
|
||||
# 2) AssetCacheState upsert by file_path (unique)
|
||||
vals = {
|
||||
"asset_id": asset.id,
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
if dialect == "sqlite":
|
||||
ins = (
|
||||
d_sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins = (
|
||||
d_pg.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
|
||||
res = await session.execute(ins)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["state_created"] = True
|
||||
else:
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset.id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||
)
|
||||
res2 = await session.execute(upd)
|
||||
if int(res2.rowcount or 0) > 0:
|
||||
out["state_updated"] = True
|
||||
|
||||
# 3) Optional AssetInfo + tags + metadata
|
||||
if info_name:
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
await session.flush()
|
||||
out["asset_info_id"] = info.id
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
existing_info = (
|
||||
await session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == info_name,
|
||||
(AssetInfo.owner_id == owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not existing_info:
|
||||
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||
|
||||
if preview_id and existing_info.preview_id != preview_id:
|
||||
existing_info.preview_id = preview_id
|
||||
|
||||
existing_info.updated_at = now
|
||||
if existing_info.last_access_time < now:
|
||||
existing_info.last_access_time = now
|
||||
await session.flush()
|
||||
out["asset_info_id"] = existing_info.id
|
||||
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
if not require_existing_tags:
|
||||
await ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_at=now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata is not None:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
await replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
try:
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
return out
|
||||
|
||||
|
||||
async def touch_asset_infos_by_fs_path(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
file_path: str,
|
||||
ts: Optional[datetime] = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
locator = os.path.abspath(file_path)
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(
|
||||
sa.exists(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(AssetCacheState)
|
||||
.where(
|
||||
AssetCacheState.asset_id == AssetInfo.asset_id,
|
||||
AssetCacheState.file_path == locator,
|
||||
)
|
||||
)
|
||||
)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(
|
||||
AssetInfo.last_access_time.is_(None),
|
||||
AssetInfo.last_access_time < ts,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
async def list_cache_states_with_asset_under_prefixes(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
prefixes: Sequence[str],
|
||||
) -> list[tuple[AssetCacheState, Optional[str], int]]:
|
||||
"""Return (AssetCacheState, asset_hash, size_bytes) for rows under any prefix."""
|
||||
if not prefixes:
|
||||
return []
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
if not p:
|
||||
continue
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base = base + os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
if not conds:
|
||||
return []
|
||||
|
||||
rows = (
|
||||
await session.execute(
|
||||
select(AssetCacheState, Asset.hash, Asset.size_bytes)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sa.or_(*conds))
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).all()
|
||||
return [(r[0], r[1], int(r[2] or 0)) for r in rows]
|
||||
|
||||
|
||||
async def _recompute_and_apply_filename_for_asset(session: AsyncSession, *, asset_id: str) -> None:
|
||||
"""Compute filename from the first *existing* cache state path and apply it to all AssetInfo (if changed)."""
|
||||
try:
|
||||
primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset_id))
|
||||
if not primary_path:
|
||||
return
|
||||
new_filename = compute_relative_filename(primary_path)
|
||||
if not new_filename:
|
||||
return
|
||||
infos = (
|
||||
await session.execute(select(AssetInfo).where(AssetInfo.asset_id == asset_id))
|
||||
).scalars().all()
|
||||
for info in infos:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") == new_filename:
|
||||
continue
|
||||
updated = dict(current_meta)
|
||||
updated["filename"] = new_filename
|
||||
await replace_asset_info_metadata_projection(session, asset_info_id=info.id, user_metadata=updated)
|
||||
except Exception:
|
||||
logging.exception("Failed to recompute filename metadata for asset %s", asset_id)
|
||||
586
app/assets/database/services/info.py
Normal file
586
app/assets/database/services/info.py
Normal file
@ -0,0 +1,586 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import contains_eager, noload
|
||||
|
||||
from ..._helpers import compute_relative_filename, normalize_tags
|
||||
from ..helpers import (
|
||||
apply_metadata_filter,
|
||||
apply_tag_filters,
|
||||
ensure_tags_exist,
|
||||
escape_like_prefix,
|
||||
project_kv,
|
||||
visible_owner_clause,
|
||||
)
|
||||
from ..models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from ..timeutil import utcnow
|
||||
from .queries import (
|
||||
get_asset_by_hash,
|
||||
list_cache_states_by_asset_id,
|
||||
pick_best_live_path,
|
||||
)
|
||||
|
||||
|
||||
async def list_asset_infos_page(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
owner_id: str = "",
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
name_contains: Optional[str] = None,
|
||||
metadata_filter: Optional[dict] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
"name": AssetInfo.name,
|
||||
"created_at": AssetInfo.created_at,
|
||||
"updated_at": AssetInfo.updated_at,
|
||||
"last_access_time": AssetInfo.last_access_time,
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = int((await session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
infos = (await session.execute(base)).unique().scalars().all()
|
||||
|
||||
id_list: list[str] = [i.id for i in infos]
|
||||
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
rows = await session.execute(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
async def fetch_asset_info_and_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> Optional[tuple[AssetInfo, Asset]]:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = await session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
|
||||
async def fetch_asset_info_asset_and_tags(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> Optional[tuple[AssetInfo, Asset, list[str]]]:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset, Tag.name)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = (await session.execute(stmt)).all()
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
first_info, first_asset, _ = rows[0]
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for _info, _asset, tag_name in rows:
|
||||
if tag_name and tag_name not in seen:
|
||||
seen.add(tag_name)
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
async def create_asset_info_for_existing_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||
now = utcnow()
|
||||
asset = await get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
session.add(info)
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
await session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# metadata["filename"] hack
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
if new_meta:
|
||||
await replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
await set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
async def set_asset_info_tags(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
await ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
await session.flush()
|
||||
|
||||
if to_remove:
|
||||
await session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
async def update_asset_info_full(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tag_origin: str = "manual",
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
await replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
else:
|
||||
if computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
new_meta["filename"] = computed_filename
|
||||
await replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
await set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
async def replace_asset_info_metadata_projection(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
user_metadata: Optional[dict],
|
||||
) -> None:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
await session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def touch_asset_info_by_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
ts: Optional[datetime] = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
await session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool:
|
||||
stmt = sa.delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
return int((await session.execute(stmt)).rowcount or 0) > 0
|
||||
|
||||
|
||||
async def add_tags_to_asset_info(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
if not asset_info_row:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
await ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
async with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=utcnow(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
await nested.rollback()
|
||||
|
||||
after = set(await get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
async def remove_tags_from_asset_info(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
await session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
async def list_tags_with_usage(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
prefix: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetInfoTag.tag_name.label("tag_name"),
|
||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetInfoTag)
|
||||
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
.group_by(AssetInfoTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else:
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
total_q = total_q.where(
|
||||
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||
)
|
||||
|
||||
rows = (await session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (await session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
async def get_asset_tags(session: AsyncSession, *, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
async def set_asset_info_preview(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: Optional[str],
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
# validate preview asset exists
|
||||
if not await session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
76
app/assets/database/services/queries.py
Normal file
76
app/assets/database/services/queries.py
Normal file
@ -0,0 +1,76 @@
|
||||
import os
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import Asset, AssetCacheState, AssetInfo
|
||||
|
||||
|
||||
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
|
||||
row = (
|
||||
await session.execute(
|
||||
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]:
|
||||
return (
|
||||
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
|
||||
async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> Optional[AssetInfo]:
|
||||
return await session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
async def asset_info_exists_for_asset_id(session: AsyncSession, *, asset_id: str) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (await session.execute(q)).first() is not None
|
||||
|
||||
|
||||
async def get_cache_state_by_asset_id(session: AsyncSession, *, asset_id: str) -> Optional[AssetCacheState]:
|
||||
return (
|
||||
await session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
.limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
|
||||
|
||||
async def list_cache_states_by_asset_id(
|
||||
session: AsyncSession, *, asset_id: str
|
||||
) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]:
|
||||
return (
|
||||
await session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
def pick_best_live_path(states: Union[list[AssetCacheState], Sequence[AssetCacheState]]) -> str:
|
||||
"""
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
6
app/assets/database/timeutil.py
Normal file
6
app/assets/database/timeutil.py
Normal file
@ -0,0 +1,6 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
556
app/assets/manager.py
Normal file
556
app/assets/manager.py
Normal file
@ -0,0 +1,556 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from comfy_api.internal import async_to_sync
|
||||
|
||||
from ..db import create_session
|
||||
from ._helpers import (
|
||||
ensure_within_base,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
)
|
||||
from .api import schemas_in, schemas_out
|
||||
from .database.models import Asset
|
||||
from .database.services import (
|
||||
add_tags_to_asset_info,
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
check_fs_asset_exists_quick,
|
||||
create_asset_info_for_existing_asset,
|
||||
delete_asset_info_by_id,
|
||||
fetch_asset_info_and_asset,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
get_asset_tags,
|
||||
ingest_fs_asset,
|
||||
list_asset_infos_page,
|
||||
list_cache_states_by_asset_id,
|
||||
list_tags_with_usage,
|
||||
pick_best_live_path,
|
||||
remove_tags_from_asset_info,
|
||||
set_asset_info_preview,
|
||||
touch_asset_info_by_id,
|
||||
touch_asset_infos_by_fs_path,
|
||||
update_asset_info_full,
|
||||
)
|
||||
from .storage import hashing
|
||||
|
||||
|
||||
async def asset_exists(*, asset_hash: str) -> bool:
|
||||
async with await create_session() as session:
|
||||
return await asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None:
|
||||
if tags is None:
|
||||
tags = []
|
||||
try:
|
||||
asset_name, path_tags = get_name_and_tags_from_asset_path(file_path)
|
||||
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
|
||||
add_local_asset,
|
||||
tags=list(dict.fromkeys([*path_tags, *tags])),
|
||||
file_name=asset_name,
|
||||
file_path=file_path,
|
||||
)
|
||||
except ValueError as e:
|
||||
logging.warning("Skipping non-asset path %s: %s", file_path, e)
|
||||
|
||||
|
||||
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
|
||||
abs_path = os.path.abspath(file_path)
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(abs_path)
|
||||
if not size_bytes:
|
||||
return
|
||||
|
||||
async with await create_session() as session:
|
||||
if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns):
|
||||
await touch_asset_infos_by_fs_path(session, file_path=abs_path)
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
asset_hash = hashing.blake3_hash_sync(abs_path)
|
||||
|
||||
async with await create_session() as session:
|
||||
await ingest_fs_asset(
|
||||
session,
|
||||
asset_hash="blake3:" + asset_hash,
|
||||
abs_path=abs_path,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=None,
|
||||
info_name=file_name,
|
||||
tag_origin="automatic",
|
||||
tags=tags,
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def list_assets(
|
||||
*,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
name_contains: Optional[str] = None,
|
||||
metadata_filter: Optional[dict] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetsList:
|
||||
sort = _safe_sort_field(sort)
|
||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||
|
||||
async with await create_session() as session:
|
||||
infos, tag_map, total = await list_asset_infos_page(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
summaries: list[schemas_out.AssetSummary] = []
|
||||
for info in infos:
|
||||
asset = info.asset
|
||||
tags = tag_map.get(info.id, [])
|
||||
summaries.append(
|
||||
schemas_out.AssetSummary(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
preview_url=f"/api/assets/{info.id}/content",
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
)
|
||||
|
||||
return schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=total,
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
|
||||
async def get_asset(*, asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||
async with await create_session() as session:
|
||||
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
info, asset, tag_names = res
|
||||
preview_id = info.preview_id
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
async def resolve_asset_content_for_download(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
async with await create_session() as session:
|
||||
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = await list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = pick_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError
|
||||
|
||||
await touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
await session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
async def upload_asset_from_temp_path(
|
||||
spec: schemas_in.UploadAssetSpec,
|
||||
*,
|
||||
temp_path: str,
|
||||
client_filename: Optional[str] = None,
|
||||
owner_id: str = "",
|
||||
expected_asset_hash: Optional[str] = None,
|
||||
) -> schemas_out.AssetCreated:
|
||||
try:
|
||||
digest = await hashing.blake3_hash(temp_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||
raise ValueError("HASH_MISMATCH")
|
||||
|
||||
async with await create_session() as session:
|
||||
existing = await get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||
info = await create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = await get_asset_tags(session, asset_info_id=info.id)
|
||||
await session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=existing.hash,
|
||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||
mime_type=existing.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or spec.name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
ensure_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
async with await create_session() as session:
|
||||
result = await ingest_fs_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result["asset_info_id"]
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
pair = await fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
tag_names = await get_asset_tags(session, asset_info_id=info.id)
|
||||
await session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
)
|
||||
|
||||
|
||||
async def update_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
info = await update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
await session.commit()
|
||||
|
||||
return schemas_out.AssetUpdated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset.hash if info.asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
updated_at=info.updated_at,
|
||||
)
|
||||
|
||||
|
||||
async def set_asset_preview(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: Optional[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
await set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
info, asset, tags = res
|
||||
await session.commit()
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
await session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
still_exists = await asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
states = await list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = await session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
await session.delete(asset_row)
|
||||
|
||||
await session.commit()
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
return True
|
||||
|
||||
|
||||
async def create_asset_from_hash(
|
||||
*,
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: Optional[list[str]] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
owner_id: str = "",
|
||||
) -> Optional[schemas_out.AssetCreated]:
|
||||
canonical = hash_str.strip().lower()
|
||||
async with await create_session() as session:
|
||||
asset = await get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
info = await create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=canonical,
|
||||
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = await get_asset_tags(session, asset_info_id=info.id)
|
||||
await session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
|
||||
async def list_tags(
|
||||
*,
|
||||
prefix: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsList:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
async with await create_session() as session:
|
||||
rows, total = await list_tags_with_usage(
|
||||
session,
|
||||
prefix=prefix,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
include_zero=include_zero,
|
||||
order=order,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
|
||||
|
||||
|
||||
async def add_tags_to_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
data = await add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
await session.commit()
|
||||
return schemas_out.TagsAdd(**data)
|
||||
|
||||
|
||||
async def remove_tags_from_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = await remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
)
|
||||
await session.commit()
|
||||
return schemas_out.TagsRemove(**data)
|
||||
|
||||
|
||||
def _safe_sort_field(requested: Optional[str]) -> str:
|
||||
if not requested:
|
||||
return "created_at"
|
||||
v = requested.lower()
|
||||
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
|
||||
return v
|
||||
return "created_at"
|
||||
|
||||
|
||||
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
def _safe_filename(name: Optional[str], fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
if n:
|
||||
return n
|
||||
return fallback
|
||||
501
app/assets/scanner.py
Normal file
501
app/assets/scanner.py
Normal file
@ -0,0 +1,501 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
import folder_paths
|
||||
|
||||
from ..db import create_session
|
||||
from ._helpers import (
|
||||
collect_models_files,
|
||||
compute_relative_filename,
|
||||
get_comfy_models_folders,
|
||||
get_name_and_tags_from_asset_path,
|
||||
list_tree,
|
||||
new_scan_id,
|
||||
prefixes_for_root,
|
||||
ts_to_iso,
|
||||
)
|
||||
from .api import schemas_in, schemas_out
|
||||
from .database.helpers import (
|
||||
add_missing_tag_for_asset_id,
|
||||
ensure_tags_exist,
|
||||
escape_like_prefix,
|
||||
fast_asset_file_check,
|
||||
remove_missing_tag_for_asset_id,
|
||||
seed_from_paths_batch,
|
||||
)
|
||||
from .database.models import Asset, AssetCacheState, AssetInfo
|
||||
from .database.services import (
|
||||
compute_hash_and_dedup_for_cache_state,
|
||||
list_cache_states_by_asset_id,
|
||||
list_cache_states_with_asset_under_prefixes,
|
||||
list_unhashed_candidates_under_prefixes,
|
||||
list_verify_candidates_under_prefixes,
|
||||
)
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
SLOW_HASH_CONCURRENCY = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanProgress:
|
||||
scan_id: str
|
||||
root: schemas_in.RootType
|
||||
status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled"
|
||||
scheduled_at: float = field(default_factory=lambda: time.time())
|
||||
started_at: Optional[float] = None
|
||||
finished_at: Optional[float] = None
|
||||
discovered: int = 0
|
||||
processed: int = 0
|
||||
file_errors: list[dict] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlowQueueState:
|
||||
queue: asyncio.Queue
|
||||
workers: list[asyncio.Task] = field(default_factory=list)
|
||||
closed: bool = False
|
||||
|
||||
|
||||
RUNNING_TASKS: dict[schemas_in.RootType, asyncio.Task] = {}
|
||||
PROGRESS_BY_ROOT: dict[schemas_in.RootType, ScanProgress] = {}
|
||||
SLOW_STATE_BY_ROOT: dict[schemas_in.RootType, SlowQueueState] = {}
|
||||
|
||||
|
||||
def current_statuses() -> schemas_out.AssetScanStatusResponse:
|
||||
scans = []
|
||||
for root in schemas_in.ALLOWED_ROOTS:
|
||||
prog = PROGRESS_BY_ROOT.get(root)
|
||||
if not prog:
|
||||
continue
|
||||
scans.append(_scan_progress_to_scan_status_model(prog))
|
||||
return schemas_out.AssetScanStatusResponse(scans=scans)
|
||||
|
||||
|
||||
async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetScanStatusResponse:
|
||||
results: list[ScanProgress] = []
|
||||
for root in roots:
|
||||
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
|
||||
results.append(PROGRESS_BY_ROOT[root])
|
||||
continue
|
||||
|
||||
prog = ScanProgress(scan_id=new_scan_id(root), root=root, status="scheduled")
|
||||
PROGRESS_BY_ROOT[root] = prog
|
||||
state = SlowQueueState(queue=asyncio.Queue())
|
||||
SLOW_STATE_BY_ROOT[root] = state
|
||||
RUNNING_TASKS[root] = asyncio.create_task(
|
||||
_run_hash_verify_pipeline(root, prog, state),
|
||||
name=f"asset-scan:{root}",
|
||||
)
|
||||
results.append(prog)
|
||||
return _status_response_for(results)
|
||||
|
||||
|
||||
async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None:
|
||||
t_total = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
for r in roots:
|
||||
try:
|
||||
survivors = await _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
|
||||
if survivors:
|
||||
existing_paths.update(survivors)
|
||||
except Exception as ex:
|
||||
LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex)
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_input_directory()))
|
||||
if "output" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_output_directory()))
|
||||
|
||||
specs: list[dict] = []
|
||||
tag_pool: set[str] = set()
|
||||
for p in paths:
|
||||
ap = os.path.abspath(p)
|
||||
if ap in existing_paths:
|
||||
skipped_existing += 1
|
||||
continue
|
||||
try:
|
||||
st = os.stat(ap, follow_symlinks=True)
|
||||
except OSError:
|
||||
continue
|
||||
if not st.st_size:
|
||||
continue
|
||||
name, tags = get_name_and_tags_from_asset_path(ap)
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": ap,
|
||||
"size_bytes": st.st_size,
|
||||
"mtime_ns": getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)),
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": compute_relative_filename(ap),
|
||||
}
|
||||
)
|
||||
for t in tags:
|
||||
tag_pool.add(t)
|
||||
|
||||
if not specs:
|
||||
return
|
||||
async with await create_session() as sess:
|
||||
if tag_pool:
|
||||
await ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
|
||||
result = await seed_from_paths_batch(sess, specs=specs, owner_id="")
|
||||
created += result["inserted_infos"]
|
||||
await sess.commit()
|
||||
finally:
|
||||
LOGGER.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_total,
|
||||
created,
|
||||
skipped_existing,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
|
||||
def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse:
|
||||
return schemas_out.AssetScanStatusResponse(scans=[_scan_progress_to_scan_status_model(p) for p in progresses])
|
||||
|
||||
|
||||
def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.AssetScanStatus:
|
||||
return schemas_out.AssetScanStatus(
|
||||
scan_id=progress.scan_id,
|
||||
root=progress.root,
|
||||
status=progress.status,
|
||||
scheduled_at=ts_to_iso(progress.scheduled_at),
|
||||
started_at=ts_to_iso(progress.started_at),
|
||||
finished_at=ts_to_iso(progress.finished_at),
|
||||
discovered=progress.discovered,
|
||||
processed=progress.processed,
|
||||
file_errors=[
|
||||
schemas_out.AssetScanError(
|
||||
path=e.get("path", ""),
|
||||
message=e.get("message", ""),
|
||||
at=e.get("at"),
|
||||
)
|
||||
for e in (progress.file_errors or [])
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
|
||||
prog.status = "running"
|
||||
prog.started_at = time.time()
|
||||
try:
|
||||
prefixes = prefixes_for_root(root)
|
||||
|
||||
await _fast_db_consistency_pass(root)
|
||||
|
||||
# collect candidates from DB
|
||||
async with await create_session() as sess:
|
||||
verify_ids = await list_verify_candidates_under_prefixes(sess, prefixes=prefixes)
|
||||
unhashed_ids = await list_unhashed_candidates_under_prefixes(sess, prefixes=prefixes)
|
||||
# dedupe: prioritize verification first
|
||||
seen = set()
|
||||
ordered: list[int] = []
|
||||
for lst in (verify_ids, unhashed_ids):
|
||||
for sid in lst:
|
||||
if sid not in seen:
|
||||
seen.add(sid)
|
||||
ordered.append(sid)
|
||||
|
||||
prog.discovered = len(ordered)
|
||||
|
||||
# queue up work
|
||||
for sid in ordered:
|
||||
await state.queue.put(sid)
|
||||
state.closed = True
|
||||
_start_state_workers(root, prog, state)
|
||||
await _await_state_workers_then_finish(root, prog, state)
|
||||
except asyncio.CancelledError:
|
||||
prog.status = "cancelled"
|
||||
raise
|
||||
except Exception as exc:
|
||||
_append_error(prog, path="", message=str(exc))
|
||||
prog.status = "failed"
|
||||
prog.finished_at = time.time()
|
||||
LOGGER.exception("Asset scan failed for %s", root)
|
||||
finally:
|
||||
RUNNING_TASKS.pop(root, None)
|
||||
|
||||
|
||||
async def _reconcile_missing_tags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None:
|
||||
"""
|
||||
Detect missing files quickly and toggle 'missing' tag per asset_id.
|
||||
|
||||
Rules:
|
||||
- Only hashed assets (assets.hash != NULL) participate in missing tagging.
|
||||
- We consider ALL cache states of the asset (across roots) before tagging.
|
||||
"""
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
elif root == "input":
|
||||
bases = [folder_paths.get_input_directory()]
|
||||
else:
|
||||
bases = [folder_paths.get_output_directory()]
|
||||
|
||||
try:
|
||||
async with await create_session() as sess:
|
||||
# state + hash + size for the current root
|
||||
rows = await list_cache_states_with_asset_under_prefixes(sess, prefixes=bases)
|
||||
|
||||
# Track fast_ok within the scanned root and whether the asset is hashed
|
||||
by_asset: dict[str, dict[str, bool]] = {}
|
||||
for state, a_hash, size_db in rows:
|
||||
aid = state.asset_id
|
||||
acc = by_asset.get(aid)
|
||||
if acc is None:
|
||||
acc = {"any_fast_ok_here": False, "hashed": (a_hash is not None), "size_db": int(size_db or 0)}
|
||||
by_asset[aid] = acc
|
||||
try:
|
||||
if acc["hashed"]:
|
||||
st = os.stat(state.file_path, follow_symlinks=True)
|
||||
if fast_asset_file_check(mtime_db=state.mtime_ns, size_db=acc["size_db"], stat_result=st):
|
||||
acc["any_fast_ok_here"] = True
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError as e:
|
||||
_append_error(prog, path=state.file_path, message=str(e))
|
||||
|
||||
# Decide per asset, considering ALL its states (not just this root)
|
||||
for aid, acc in by_asset.items():
|
||||
try:
|
||||
if not acc["hashed"]:
|
||||
# Never tag seed assets as missing
|
||||
continue
|
||||
|
||||
any_fast_ok_global = acc["any_fast_ok_here"]
|
||||
if not any_fast_ok_global:
|
||||
# Check other states outside this root
|
||||
others = await list_cache_states_by_asset_id(sess, asset_id=aid)
|
||||
for st in others:
|
||||
try:
|
||||
any_fast_ok_global = fast_asset_file_check(
|
||||
mtime_db=st.mtime_ns,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(st.file_path, follow_symlinks=True),
|
||||
)
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
if any_fast_ok_global:
|
||||
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||
else:
|
||||
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||
except Exception as ex:
|
||||
_append_error(prog, path="", message=f"reconcile {aid[:8]}: {ex}")
|
||||
|
||||
await sess.commit()
|
||||
except Exception as e:
|
||||
_append_error(prog, path="", message=f"reconcile failed: {e}")
|
||||
|
||||
|
||||
def _start_state_workers(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
|
||||
if state.workers:
|
||||
return
|
||||
|
||||
async def _worker(_wid: int):
|
||||
while True:
|
||||
sid = await state.queue.get()
|
||||
try:
|
||||
if sid is None:
|
||||
return
|
||||
try:
|
||||
async with await create_session() as sess:
|
||||
# Optional: fetch path for better error messages
|
||||
st = await sess.get(AssetCacheState, sid)
|
||||
try:
|
||||
await compute_hash_and_dedup_for_cache_state(sess, state_id=sid)
|
||||
await sess.commit()
|
||||
except Exception as e:
|
||||
path = st.file_path if st else f"state:{sid}"
|
||||
_append_error(prog, path=path, message=str(e))
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
prog.processed += 1
|
||||
finally:
|
||||
state.queue.task_done()
|
||||
|
||||
state.workers = [
|
||||
asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}")
|
||||
for i in range(SLOW_HASH_CONCURRENCY)
|
||||
]
|
||||
|
||||
async def _close_when_ready():
|
||||
while not state.closed:
|
||||
await asyncio.sleep(0.05)
|
||||
for _ in range(SLOW_HASH_CONCURRENCY):
|
||||
await state.queue.put(None)
|
||||
|
||||
asyncio.create_task(_close_when_ready())
|
||||
|
||||
|
||||
async def _await_state_workers_then_finish(
|
||||
root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState
|
||||
) -> None:
|
||||
if state.workers:
|
||||
await asyncio.gather(*state.workers, return_exceptions=True)
|
||||
await _reconcile_missing_tags_for_root(root, prog)
|
||||
prog.finished_at = time.time()
|
||||
prog.status = "completed"
|
||||
|
||||
|
||||
def _append_error(prog: ScanProgress, *, path: str, message: str) -> None:
|
||||
prog.file_errors.append({
|
||||
"path": path,
|
||||
"message": message,
|
||||
"at": ts_to_iso(time.time()),
|
||||
})
|
||||
|
||||
|
||||
async def _fast_db_consistency_pass(
|
||||
root: schemas_in.RootType,
|
||||
*,
|
||||
collect_existing_paths: bool = False,
|
||||
update_missing_tags: bool = False,
|
||||
) -> Optional[set[str]]:
|
||||
"""Fast DB+FS pass for a root:
|
||||
- Toggle needs_verify per state using fast check
|
||||
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
|
||||
- For seed assets with all states missing: delete Asset and its AssetInfos
|
||||
- Optionally add/remove 'missing' tags based on fast-ok in this root
|
||||
- Optionally return surviving absolute paths
|
||||
"""
|
||||
prefixes = prefixes_for_root(root)
|
||||
if not prefixes:
|
||||
return set() if collect_existing_paths else None
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
async with await create_session() as sess:
|
||||
rows = (
|
||||
await sess.execute(
|
||||
sa.select(
|
||||
AssetCacheState.id,
|
||||
AssetCacheState.file_path,
|
||||
AssetCacheState.mtime_ns,
|
||||
AssetCacheState.needs_verify,
|
||||
AssetCacheState.asset_id,
|
||||
Asset.hash,
|
||||
Asset.size_bytes,
|
||||
)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sa.or_(*conds))
|
||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
)
|
||||
).all()
|
||||
|
||||
by_asset: dict[str, dict] = {}
|
||||
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
|
||||
acc = by_asset.get(aid)
|
||||
if acc is None:
|
||||
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
|
||||
by_asset[aid] = acc
|
||||
|
||||
fast_ok = False
|
||||
try:
|
||||
exists = True
|
||||
fast_ok = fast_asset_file_check(
|
||||
mtime_db=mtime_db,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(fp, follow_symlinks=True),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
exists = False
|
||||
except OSError:
|
||||
exists = False
|
||||
|
||||
acc["states"].append({
|
||||
"sid": sid,
|
||||
"fp": fp,
|
||||
"exists": exists,
|
||||
"fast_ok": fast_ok,
|
||||
"needs_verify": bool(needs_verify),
|
||||
})
|
||||
|
||||
to_set_verify: list[int] = []
|
||||
to_clear_verify: list[int] = []
|
||||
stale_state_ids: list[int] = []
|
||||
survivors: set[str] = set()
|
||||
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
states = acc["states"]
|
||||
any_fast_ok = any(s["fast_ok"] for s in states)
|
||||
all_missing = all(not s["exists"] for s in states)
|
||||
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
continue
|
||||
if s["fast_ok"] and s["needs_verify"]:
|
||||
to_clear_verify.append(s["sid"])
|
||||
if not s["fast_ok"] and not s["needs_verify"]:
|
||||
to_set_verify.append(s["sid"])
|
||||
|
||||
if a_hash is None:
|
||||
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
|
||||
await sess.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid))
|
||||
asset = await sess.get(Asset, aid)
|
||||
if asset:
|
||||
await sess.delete(asset)
|
||||
else:
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
continue
|
||||
|
||||
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
stale_state_ids.append(s["sid"])
|
||||
if update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||
elif update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
|
||||
if stale_state_ids:
|
||||
await sess.execute(sa.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
|
||||
if to_set_verify:
|
||||
await sess.execute(
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_set_verify))
|
||||
.values(needs_verify=True)
|
||||
)
|
||||
if to_clear_verify:
|
||||
await sess.execute(
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_clear_verify))
|
||||
.values(needs_verify=False)
|
||||
)
|
||||
await sess.commit()
|
||||
return survivors if collect_existing_paths else None
|
||||
0
app/assets/storage/__init__.py
Normal file
0
app/assets/storage/__init__.py
Normal file
72
app/assets/storage/hashing.py
Normal file
72
app/assets/storage/hashing.py
Normal file
@ -0,0 +1,72 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import IO, Union
|
||||
|
||||
from blake3 import blake3
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 * 1024 # 8 MiB
|
||||
|
||||
|
||||
def _hash_file_obj_sync(file_obj: IO[bytes], chunk_size: int) -> str:
|
||||
"""Hash an already-open binary file object by streaming in chunks.
|
||||
- Seeks to the beginning before reading (if supported).
|
||||
- Restores the original position afterward (if tell/seek are supported).
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
orig_pos = None
|
||||
if hasattr(file_obj, "tell"):
|
||||
orig_pos = file_obj.tell()
|
||||
|
||||
try:
|
||||
if hasattr(file_obj, "seek"):
|
||||
file_obj.seek(0)
|
||||
|
||||
h = blake3()
|
||||
while True:
|
||||
chunk = file_obj.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
finally:
|
||||
if hasattr(file_obj, "seek") and orig_pos is not None:
|
||||
file_obj.seek(orig_pos)
|
||||
|
||||
|
||||
def blake3_hash_sync(
|
||||
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Returns a BLAKE3 hex digest for ``fp``, which may be:
|
||||
- a filename (str/bytes) or PathLike
|
||||
- an open binary file object
|
||||
|
||||
If ``fp`` is a file object, it must be opened in **binary** mode and support
|
||||
``read``, ``seek``, and ``tell``. The function will seek to the start before
|
||||
reading and will attempt to restore the original position afterward.
|
||||
"""
|
||||
if hasattr(fp, "read"):
|
||||
return _hash_file_obj_sync(fp, chunk_size)
|
||||
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj_sync(f, chunk_size)
|
||||
|
||||
|
||||
async def blake3_hash(
|
||||
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Async wrapper for ``blake3_hash_sync``.
|
||||
Uses a worker thread so the event loop remains responsive.
|
||||
"""
|
||||
# If it is a path, open inside the worker thread to keep I/O off the loop.
|
||||
if hasattr(fp, "read"):
|
||||
return await asyncio.to_thread(blake3_hash_sync, fp, chunk_size)
|
||||
|
||||
def _worker() -> str:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj_sync(f, chunk_size)
|
||||
|
||||
return await asyncio.to_thread(_worker)
|
||||
@ -1,112 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from app.logger import log_startup_warning
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
from comfy.cli_args import args
|
||||
|
||||
_DB_AVAILABLE = False
|
||||
Session = None
|
||||
|
||||
|
||||
try:
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
log_startup_warning(
|
||||
f"""
|
||||
------------------------------------------------------------------------
|
||||
Error importing dependencies: {e}
|
||||
{get_missing_requirements_message()}
|
||||
This error is happening because ComfyUI now uses a local sqlite database.
|
||||
------------------------------------------------------------------------
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def dependencies_available():
|
||||
"""
|
||||
Temporary function to check if the dependencies are available
|
||||
"""
|
||||
return _DB_AVAILABLE
|
||||
|
||||
|
||||
def can_create_session():
|
||||
"""
|
||||
Temporary function to check if the database is available to create a session
|
||||
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
|
||||
"""
|
||||
return dependencies_available() and Session is not None
|
||||
|
||||
|
||||
def get_alembic_config():
|
||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||
|
||||
config = Config(config_path)
|
||||
config.set_main_option("script_location", scripts_path)
|
||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_db_path():
|
||||
url = args.database_url
|
||||
if url.startswith("sqlite:///"):
|
||||
return url.split("///")[1]
|
||||
else:
|
||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||
|
||||
|
||||
def init_db():
|
||||
db_url = args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
db_path = get_db_path()
|
||||
db_exists = os.path.exists(db_path)
|
||||
|
||||
config = get_alembic_config()
|
||||
|
||||
# Check if we need to upgrade
|
||||
engine = create_engine(db_url)
|
||||
conn = engine.connect()
|
||||
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
script = ScriptDirectory.from_config(config)
|
||||
target_rev = script.get_current_head()
|
||||
|
||||
if target_rev is None:
|
||||
logging.warning("No target revision found.")
|
||||
elif current_rev != target_rev:
|
||||
# Backup the database pre upgrade
|
||||
backup_path = db_path + ".bkp"
|
||||
if db_exists:
|
||||
shutil.copy(db_path, backup_path)
|
||||
else:
|
||||
backup_path = None
|
||||
|
||||
try:
|
||||
command.upgrade(config, target_rev)
|
||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
||||
except Exception as e:
|
||||
if backup_path:
|
||||
# Restore the database from backup if upgrade fails
|
||||
shutil.copy(backup_path, db_path)
|
||||
os.remove(backup_path)
|
||||
logging.exception("Error upgrading database: ")
|
||||
raise e
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
|
||||
|
||||
def create_session():
|
||||
return Session()
|
||||
@ -1,14 +0,0 @@
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
fields = obj.__table__.columns.keys()
|
||||
return {
|
||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||
for field in fields
|
||||
if (val := getattr(obj, field))
|
||||
}
|
||||
|
||||
# TODO: Define models here
|
||||
255
app/db.py
Normal file
255
app/db.py
Normal file
@ -0,0 +1,255 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
ENGINE: Optional[AsyncEngine] = None
|
||||
SESSION: Optional[async_sessionmaker] = None
|
||||
|
||||
|
||||
def _root_paths():
|
||||
"""Resolve alembic.ini and migrations script folder."""
|
||||
root_path = os.path.abspath(os.path.dirname(__file__))
|
||||
config_path = os.path.abspath(os.path.join(root_path, "../alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||
return config_path, scripts_path
|
||||
|
||||
|
||||
def _absolutize_sqlite_url(db_url: str) -> str:
|
||||
"""Make SQLite database path absolute. No-op for non-SQLite URLs."""
|
||||
try:
|
||||
u = make_url(db_url)
|
||||
except Exception:
|
||||
return db_url
|
||||
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return db_url
|
||||
|
||||
db_path: str = u.database or ""
|
||||
if isinstance(db_path, str) and db_path.startswith("file:"):
|
||||
return str(u) # Do not touch SQLite URI databases like: "file:xxx?mode=memory&cache=shared"
|
||||
if not os.path.isabs(db_path):
|
||||
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
|
||||
u = u.set(database=db_path)
|
||||
return str(u)
|
||||
|
||||
|
||||
def _normalize_sqlite_memory_url(db_url: str) -> tuple[str, bool]:
|
||||
"""
|
||||
If db_url points at an in-memory SQLite DB (":memory:" or file:... mode=memory),
|
||||
rewrite it to a *named* shared in-memory URI and ensure 'uri=true' is present.
|
||||
Returns: (normalized_url, is_memory)
|
||||
"""
|
||||
try:
|
||||
u = make_url(db_url)
|
||||
except Exception:
|
||||
return db_url, False
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return db_url, False
|
||||
|
||||
db = u.database or ""
|
||||
if db == ":memory:":
|
||||
u = u.set(database=f"file:comfyui_db_{os.getpid()}?mode=memory&cache=shared&uri=true")
|
||||
return str(u), True
|
||||
if isinstance(db, str) and db.startswith("file:") and "mode=memory" in db:
|
||||
if "uri=true" not in db:
|
||||
u = u.set(database=(db + ("&" if "?" in db else "?") + "uri=true"))
|
||||
return str(u), True
|
||||
return str(u), False
|
||||
|
||||
|
||||
def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
|
||||
"""Return the on-disk path for a SQLite URL, else None."""
|
||||
try:
|
||||
u = make_url(sync_url)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return None
|
||||
db_path = u.database
|
||||
if isinstance(db_path, str) and db_path.startswith("file:"):
|
||||
return None # Not a real file if it is a URI like "file:...?"
|
||||
return db_path
|
||||
|
||||
|
||||
def _get_alembic_config(sync_url: str) -> Config:
|
||||
"""Prepare Alembic Config with script location and DB URL."""
|
||||
config_path, scripts_path = _root_paths()
|
||||
cfg = Config(config_path)
|
||||
cfg.set_main_option("script_location", scripts_path)
|
||||
cfg.set_main_option("sqlalchemy.url", sync_url)
|
||||
return cfg
|
||||
|
||||
|
||||
async def init_db_engine() -> None:
|
||||
"""Initialize async engine + sessionmaker and run migrations to head.
|
||||
|
||||
This must be called once on application startup before any DB usage.
|
||||
"""
|
||||
global ENGINE, SESSION
|
||||
|
||||
if ENGINE is not None:
|
||||
return
|
||||
|
||||
raw_url = args.database_url
|
||||
if not raw_url:
|
||||
raise RuntimeError("Database URL is not configured.")
|
||||
|
||||
db_url, is_mem = _normalize_sqlite_memory_url(raw_url)
|
||||
db_url = _absolutize_sqlite_url(db_url)
|
||||
|
||||
# Prepare async engine
|
||||
connect_args = {}
|
||||
if db_url.startswith("sqlite"):
|
||||
connect_args = {
|
||||
"check_same_thread": False,
|
||||
"timeout": 12,
|
||||
}
|
||||
if is_mem:
|
||||
connect_args["uri"] = True
|
||||
|
||||
ENGINE = create_async_engine(
|
||||
db_url,
|
||||
connect_args=connect_args,
|
||||
pool_pre_ping=True,
|
||||
future=True,
|
||||
)
|
||||
|
||||
# Enforce SQLite pragmas on the async engine
|
||||
if db_url.startswith("sqlite"):
|
||||
async with ENGINE.begin() as conn:
|
||||
if not is_mem:
|
||||
# WAL for concurrency and durability, Foreign Keys for referential integrity
|
||||
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
|
||||
if str(current_mode).lower() != "wal":
|
||||
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
|
||||
if str(new_mode).lower() != "wal":
|
||||
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
|
||||
LOGGER.info("SQLite journal mode set to WAL.")
|
||||
|
||||
await conn.execute(text("PRAGMA foreign_keys = ON;"))
|
||||
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
|
||||
|
||||
await _run_migrations(database_url=db_url, connect_args=connect_args)
|
||||
|
||||
SESSION = async_sessionmaker(
|
||||
bind=ENGINE,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
|
||||
async def _run_migrations(database_url: str, connect_args: dict) -> None:
|
||||
if database_url.find("postgresql+psycopg") == -1:
|
||||
"""SQLite: Convert an async SQLAlchemy URL to a sync URL for Alembic."""
|
||||
u = make_url(database_url)
|
||||
driver = u.drivername
|
||||
if not driver.startswith("sqlite+aiosqlite"):
|
||||
raise ValueError(f"Unsupported DB driver: {driver}")
|
||||
database_url, is_mem = _normalize_sqlite_memory_url(str(u.set(drivername="sqlite")))
|
||||
database_url = _absolutize_sqlite_url(database_url)
|
||||
|
||||
cfg = _get_alembic_config(database_url)
|
||||
engine = create_engine(database_url, future=True, connect_args=connect_args)
|
||||
with engine.connect() as conn:
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
script = ScriptDirectory.from_config(cfg)
|
||||
target_rev = script.get_current_head()
|
||||
|
||||
if target_rev is None:
|
||||
LOGGER.warning("Alembic: no target revision found.")
|
||||
return
|
||||
|
||||
if current_rev == target_rev:
|
||||
LOGGER.debug("Alembic: database already at head %s", target_rev)
|
||||
return
|
||||
|
||||
LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev)
|
||||
|
||||
# Optional backup for SQLite file DBs
|
||||
backup_path = None
|
||||
sqlite_path = _get_sqlite_file_path(database_url)
|
||||
if sqlite_path and os.path.exists(sqlite_path):
|
||||
backup_path = sqlite_path + ".bkp"
|
||||
try:
|
||||
shutil.copy(sqlite_path, backup_path)
|
||||
except Exception as exc:
|
||||
LOGGER.warning("Failed to create SQLite backup before migration: %s", exc)
|
||||
|
||||
try:
|
||||
command.upgrade(cfg, target_rev)
|
||||
except Exception:
|
||||
if backup_path and os.path.exists(backup_path):
|
||||
LOGGER.exception("Error upgrading database, attempting restore from backup.")
|
||||
try:
|
||||
shutil.copy(backup_path, sqlite_path) # restore
|
||||
os.remove(backup_path)
|
||||
except Exception as re:
|
||||
LOGGER.error("Failed to restore SQLite backup: %s", re)
|
||||
else:
|
||||
LOGGER.exception("Error upgrading database, backup is not available.")
|
||||
raise
|
||||
|
||||
|
||||
def get_engine():
|
||||
"""Return the global async engine (initialized after init_db_engine())."""
|
||||
if ENGINE is None:
|
||||
raise RuntimeError("Engine is not initialized. Call init_db_engine() first.")
|
||||
return ENGINE
|
||||
|
||||
|
||||
def get_session_maker():
|
||||
"""Return the global async_sessionmaker (initialized after init_db_engine())."""
|
||||
if SESSION is None:
|
||||
raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.")
|
||||
return SESSION
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def session_scope():
|
||||
"""Async context manager for a unit of work:
|
||||
|
||||
async with session_scope() as sess:
|
||||
... use sess ...
|
||||
"""
|
||||
maker = get_session_maker()
|
||||
async with maker() as sess:
|
||||
try:
|
||||
yield sess
|
||||
await sess.commit()
|
||||
except Exception:
|
||||
await sess.rollback()
|
||||
raise
|
||||
|
||||
|
||||
async def create_session():
|
||||
"""Convenience helper to acquire a single AsyncSession instance.
|
||||
|
||||
Typical usage:
|
||||
async with (await create_session()) as sess:
|
||||
...
|
||||
"""
|
||||
maker = get_session_maker()
|
||||
return maker()
|
||||
@ -42,7 +42,6 @@ def get_installed_frontend_version():
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
return frontend_version_str
|
||||
|
||||
|
||||
def get_required_frontend_version():
|
||||
"""Get the required frontend version from requirements.txt."""
|
||||
try:
|
||||
@ -64,7 +63,6 @@ def get_required_frontend_version():
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def check_frontend_version():
|
||||
"""Check if the frontend version is up to date."""
|
||||
|
||||
@ -198,6 +196,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||
|
||||
|
||||
class FrontendManager:
|
||||
"""
|
||||
A class to manage ComfyUI frontend versions and installations.
|
||||
|
||||
This class handles the initialization and management of different frontend versions,
|
||||
including the default frontend from the pip package and custom frontend versions
|
||||
from GitHub repositories.
|
||||
|
||||
Attributes:
|
||||
CUSTOM_FRONTENDS_ROOT (str): The root directory where custom frontend versions are stored.
|
||||
"""
|
||||
|
||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||
|
||||
@classmethod
|
||||
@ -205,39 +214,17 @@ class FrontendManager:
|
||||
"""Get the required frontend package version."""
|
||||
return get_required_frontend_version()
|
||||
|
||||
@classmethod
|
||||
def get_installed_templates_version(cls) -> str:
|
||||
"""Get the currently installed workflow templates package version."""
|
||||
try:
|
||||
templates_version_str = version("comfyui-workflow-templates")
|
||||
return templates_version_str
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_required_templates_version(cls) -> str:
|
||||
"""Get the required workflow templates version from requirements.txt."""
|
||||
try:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("comfyui-workflow-templates=="):
|
||||
version_str = line.split("==")[-1]
|
||||
if not is_valid_version(version_str):
|
||||
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
|
||||
return None
|
||||
return version_str
|
||||
logging.error("comfyui-workflow-templates not found in requirements.txt")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("requirements.txt not found. Cannot determine required templates version.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
"""
|
||||
Get the path to the default frontend installation from the pip package.
|
||||
|
||||
Returns:
|
||||
str: The path to the default frontend static files.
|
||||
|
||||
Raises:
|
||||
SystemExit: If the comfyui-frontend-package is not installed.
|
||||
"""
|
||||
try:
|
||||
import comfyui_frontend_package
|
||||
|
||||
@ -258,6 +245,15 @@ comfyui-frontend-package is not installed.
|
||||
|
||||
@classmethod
|
||||
def templates_path(cls) -> str:
|
||||
"""
|
||||
Get the path to the workflow templates.
|
||||
|
||||
Returns:
|
||||
str: The path to the workflow templates directory.
|
||||
|
||||
Raises:
|
||||
SystemExit: If the comfyui-workflow-templates package is not installed.
|
||||
"""
|
||||
try:
|
||||
import comfyui_workflow_templates
|
||||
|
||||
@ -293,11 +289,16 @@ comfyui-workflow-templates is not installed.
|
||||
@classmethod
|
||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
Parse a version string into its components.
|
||||
|
||||
The version string should be in the format: 'owner/repo@version'
|
||||
where version can be either a semantic version (v1.2.3) or 'latest'.
|
||||
|
||||
Args:
|
||||
value (str): The version string to parse.
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: A tuple containing provider name and version.
|
||||
tuple[str, str, str]: A tuple containing (owner, repo, version).
|
||||
|
||||
Raises:
|
||||
argparse.ArgumentTypeError: If the version string is invalid.
|
||||
@ -314,18 +315,22 @@ comfyui-workflow-templates is not installed.
|
||||
cls, version_string: str, provider: Optional[FrontEndProvider] = None
|
||||
) -> str:
|
||||
"""
|
||||
Initializes the frontend for the specified version.
|
||||
Initialize a frontend version without error handling.
|
||||
|
||||
This method attempts to initialize a specific frontend version, either from
|
||||
the default pip package or from a custom GitHub repository. It will download
|
||||
and extract the frontend files if necessary.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string.
|
||||
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
||||
version_string (str): The version string specifying which frontend to use.
|
||||
provider (FrontEndProvider, optional): The provider to use for custom frontends.
|
||||
|
||||
Returns:
|
||||
str: The path to the initialized frontend.
|
||||
|
||||
Raises:
|
||||
Exception: If there is an error during the initialization process.
|
||||
main error source might be request timeout or invalid URL.
|
||||
Exception: If there is an error during initialization (e.g., network timeout,
|
||||
invalid URL, or missing assets).
|
||||
"""
|
||||
if version_string == DEFAULT_VERSION_STRING:
|
||||
check_frontend_version()
|
||||
@ -377,13 +382,17 @@ comfyui-workflow-templates is not installed.
|
||||
@classmethod
|
||||
def init_frontend(cls, version_string: str) -> str:
|
||||
"""
|
||||
Initializes the frontend with the specified version string.
|
||||
Initialize a frontend version with error handling.
|
||||
|
||||
This is the main method to initialize a frontend version. It wraps init_frontend_unsafe
|
||||
with error handling, falling back to the default frontend if initialization fails.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string to initialize the frontend with.
|
||||
version_string (str): The version string specifying which frontend to use.
|
||||
|
||||
Returns:
|
||||
str: The path of the initialized frontend.
|
||||
str: The path to the initialized frontend. If initialization fails,
|
||||
returns the path to the default frontend.
|
||||
"""
|
||||
try:
|
||||
return cls.init_frontend_unsafe(version_string)
|
||||
|
||||
@ -212,7 +212,8 @@ parser.add_argument(
|
||||
database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
|
||||
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -37,10 +37,7 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
|
||||
@ -237,7 +237,6 @@ class WanAttentionBlock(nn.Module):
|
||||
freqs, transformer_options=transformer_options)
|
||||
|
||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||
del y
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
|
||||
@ -50,10 +50,16 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
||||
else:
|
||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||
|
||||
def is_html_file(file_path):
|
||||
with open(file_path, "rb") as f:
|
||||
content = f.read(100)
|
||||
return b"<!DOCTYPE html>" in content or b"<html" in content
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
metadata = None
|
||||
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
try:
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
@ -66,6 +72,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if return_metadata:
|
||||
metadata = f.metadata()
|
||||
except Exception as e:
|
||||
if is_html_file(ckpt):
|
||||
raise ValueError("{}\n\nFile path: {}\n\nThe requested file is an HTML document not a safetensors file. Please re-download the file, not the web page.".format(e, ckpt))
|
||||
if len(e.args) > 0:
|
||||
message = e.args[0]
|
||||
if "HeaderTooLarge" in message:
|
||||
@ -93,6 +101,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
sd = pl_sd
|
||||
else:
|
||||
sd = pl_sd
|
||||
|
||||
# populate_db_with_asset(ckpt) # surprise tool that can help us later - performs hashing on model file
|
||||
return (sd, metadata) if return_metadata else sd
|
||||
|
||||
def save_torch_file(sd, ckpt, metadata=None):
|
||||
|
||||
@ -392,6 +392,20 @@ class MultiCombo(ComfyTypeI):
|
||||
})
|
||||
return to_return
|
||||
|
||||
@comfytype(io_type="ASSET")
|
||||
class Asset(ComfyTypeI):
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, query_tags: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: str=None, socketless: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
|
||||
self.query_tags = query_tags
|
||||
|
||||
def as_dict(self):
|
||||
to_return = super().as_dict() | prune_dict({
|
||||
"query_tags": self.query_tags
|
||||
})
|
||||
return to_return
|
||||
|
||||
@comfytype(io_type="IMAGE")
|
||||
class Image(ComfyTypeIO):
|
||||
Type = torch.Tensor
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis.luma_api import (
|
||||
LumaImageModel,
|
||||
@ -52,186 +51,174 @@ def image_result_url_extractor(response: LumaGeneration):
|
||||
def video_result_url_extractor(response: LumaGeneration):
|
||||
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
|
||||
|
||||
class LumaReferenceNode(comfy_io.ComfyNode):
|
||||
class LumaReferenceNode(ComfyNodeABC):
|
||||
"""
|
||||
Holds an image and weight for use with Luma Generate Image node.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
node_id="LumaReferenceNode",
|
||||
display_name="Luma Reference",
|
||||
category="api node/image/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
tooltip="Image to use as reference.",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
"weight",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Weight of image reference.",
|
||||
),
|
||||
comfy_io.Custom(LumaIO.LUMA_REF).Input(
|
||||
"luma_ref",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
RETURN_TYPES = (LumaIO.LUMA_REF,)
|
||||
RETURN_NAMES = ("luma_ref",)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "create_luma_reference"
|
||||
CATEGORY = "api node/image/Luma"
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
|
||||
) -> comfy_io.NodeOutput:
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "Image to use as reference.",
|
||||
},
|
||||
),
|
||||
"weight": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
"tooltip": "Weight of image reference.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {"luma_ref": (LumaIO.LUMA_REF,)},
|
||||
}
|
||||
|
||||
def create_luma_reference(
|
||||
self, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
|
||||
):
|
||||
if luma_ref is not None:
|
||||
luma_ref = luma_ref.clone()
|
||||
else:
|
||||
luma_ref = LumaReferenceChain()
|
||||
luma_ref.add(LumaReference(image=image, weight=round(weight, 2)))
|
||||
return comfy_io.NodeOutput(luma_ref)
|
||||
return (luma_ref,)
|
||||
|
||||
|
||||
class LumaConceptsNode(comfy_io.ComfyNode):
|
||||
class LumaConceptsNode(ComfyNodeABC):
|
||||
"""
|
||||
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
node_id="LumaConceptsNode",
|
||||
display_name="Luma Concepts",
|
||||
category="api node/video/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
"concept1",
|
||||
options=get_luma_concepts(include_none=True),
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"concept2",
|
||||
options=get_luma_concepts(include_none=True),
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"concept3",
|
||||
options=get_luma_concepts(include_none=True),
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"concept4",
|
||||
options=get_luma_concepts(include_none=True),
|
||||
),
|
||||
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||
"luma_concepts",
|
||||
tooltip="Optional Camera Concepts to add to the ones chosen here.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
RETURN_TYPES = (LumaIO.LUMA_CONCEPTS,)
|
||||
RETURN_NAMES = ("luma_concepts",)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "create_concepts"
|
||||
CATEGORY = "api node/video/Luma"
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"concept1": (get_luma_concepts(include_none=True),),
|
||||
"concept2": (get_luma_concepts(include_none=True),),
|
||||
"concept3": (get_luma_concepts(include_none=True),),
|
||||
"concept4": (get_luma_concepts(include_none=True),),
|
||||
},
|
||||
"optional": {
|
||||
"luma_concepts": (
|
||||
LumaIO.LUMA_CONCEPTS,
|
||||
{
|
||||
"tooltip": "Optional Camera Concepts to add to the ones chosen here."
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
def create_concepts(
|
||||
self,
|
||||
concept1: str,
|
||||
concept2: str,
|
||||
concept3: str,
|
||||
concept4: str,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
):
|
||||
chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4])
|
||||
if luma_concepts is not None:
|
||||
chain = luma_concepts.clone_and_merge(chain)
|
||||
return comfy_io.NodeOutput(chain)
|
||||
return (chain,)
|
||||
|
||||
|
||||
class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
class LumaImageGenerationNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
node_id="LumaImageNode",
|
||||
display_name="Luma Text to Image",
|
||||
category="api node/image/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in LumaImageModel],
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[ratio.value for ratio in LumaAspectRatio],
|
||||
default=LumaAspectRatio.ratio_16_9,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
"style_image_weight",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Weight of style image. Ignored if no style_image provided.",
|
||||
),
|
||||
comfy_io.Custom(LumaIO.LUMA_REF).Input(
|
||||
"image_luma_ref",
|
||||
tooltip="Luma Reference node connection to influence generation with input images; up to 4 images can be considered.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"style_image",
|
||||
tooltip="Style reference image; only 1 image will be used.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"character_image",
|
||||
tooltip="Character reference images; can be a batch of multiple, up to 4 images can be considered.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Image.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Luma"
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"model": ([model.value for model in LumaImageModel],),
|
||||
"aspect_ratio": (
|
||||
[ratio.value for ratio in LumaAspectRatio],
|
||||
{
|
||||
"default": LumaAspectRatio.ratio_16_9,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
},
|
||||
),
|
||||
"style_image_weight": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
"tooltip": "Weight of style image. Ignored if no style_image provided.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image_luma_ref": (
|
||||
LumaIO.LUMA_REF,
|
||||
{
|
||||
"tooltip": "Luma Reference node connection to influence generation with input images; up to 4 images can be considered."
|
||||
},
|
||||
),
|
||||
"style_image": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Style reference image; only 1 image will be used."},
|
||||
),
|
||||
"character_image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "Character reference images; can be a batch of multiple, up to 4 images can be considered."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
aspect_ratio: str,
|
||||
@ -240,29 +227,27 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
image_luma_ref: LumaReferenceChain = None,
|
||||
style_image: torch.Tensor = None,
|
||||
character_image: torch.Tensor = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
unique_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=True, min_length=3)
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
# handle image_luma_ref
|
||||
api_image_ref = None
|
||||
if image_luma_ref is not None:
|
||||
api_image_ref = await cls._convert_luma_refs(
|
||||
image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
|
||||
api_image_ref = await self._convert_luma_refs(
|
||||
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
|
||||
)
|
||||
# handle style_luma_ref
|
||||
api_style_ref = None
|
||||
if style_image is not None:
|
||||
api_style_ref = await cls._convert_style_image(
|
||||
style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
|
||||
api_style_ref = await self._convert_style_image(
|
||||
style_image, weight=style_image_weight, auth_kwargs=kwargs,
|
||||
)
|
||||
# handle character_ref images
|
||||
character_ref = None
|
||||
if character_image is not None:
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
character_image, max_images=4, auth_kwargs=auth_kwargs,
|
||||
character_image, max_images=4, auth_kwargs=kwargs,
|
||||
)
|
||||
character_ref = LumaCharacterRef(
|
||||
identity0=LumaImageIdentity(images=download_urls)
|
||||
@ -283,7 +268,7 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
style_ref=api_style_ref,
|
||||
character_ref=character_ref,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
@ -298,19 +283,18 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=image_result_url_extractor,
|
||||
node_id=cls.hidden.unique_id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.image) as img_response:
|
||||
img = process_image_response(await img_response.content.read())
|
||||
return comfy_io.NodeOutput(img)
|
||||
return (img,)
|
||||
|
||||
@classmethod
|
||||
async def _convert_luma_refs(
|
||||
cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
|
||||
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
luma_urls = []
|
||||
ref_count = 0
|
||||
@ -324,84 +308,82 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
break
|
||||
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
||||
|
||||
@classmethod
|
||||
async def _convert_style_image(
|
||||
cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
|
||||
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
chain = LumaReferenceChain(
|
||||
first_ref=LumaReference(image=style_image, weight=weight)
|
||||
)
|
||||
return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
||||
return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
||||
|
||||
|
||||
class LumaImageModifyNode(comfy_io.ComfyNode):
|
||||
class LumaImageModifyNode(ComfyNodeABC):
|
||||
"""
|
||||
Modifies images synchronously based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
node_id="LumaImageModifyNode",
|
||||
display_name="Luma Image to Image",
|
||||
category="api node/image/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
"image_weight",
|
||||
default=0.1,
|
||||
min=0.0,
|
||||
max=0.98,
|
||||
step=0.01,
|
||||
tooltip="Weight of the image; the closer to 1.0, the less the image will be modified.",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in LumaImageModel],
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Image.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Luma"
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"image_weight": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.1,
|
||||
"min": 0.0,
|
||||
"max": 0.98,
|
||||
"step": 0.01,
|
||||
"tooltip": "Weight of the image; the closer to 1.0, the less the image will be modified.",
|
||||
},
|
||||
),
|
||||
"model": ([model.value for model in LumaImageModel],),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
image_weight: float,
|
||||
seed,
|
||||
) -> comfy_io.NodeOutput:
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
unique_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
# first, upload image
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
image, max_images=1, auth_kwargs=kwargs,
|
||||
)
|
||||
image_url = download_urls[0]
|
||||
# next, make Luma call with download url provided
|
||||
@ -419,7 +401,7 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
|
||||
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
@ -434,84 +416,88 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=image_result_url_extractor,
|
||||
node_id=cls.hidden.unique_id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.image) as img_response:
|
||||
img = process_image_response(await img_response.content.read())
|
||||
return comfy_io.NodeOutput(img)
|
||||
return (img,)
|
||||
|
||||
|
||||
class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos synchronously based on prompt and output_size.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
node_id="LumaVideoNode",
|
||||
display_name="Luma Text to Video",
|
||||
category="api node/video/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the video generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in LumaVideoModel],
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[ratio.value for ratio in LumaAspectRatio],
|
||||
default=LumaAspectRatio.ratio_16_9,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=[resolution.value for resolution in LumaVideoOutputResolution],
|
||||
default=LumaVideoOutputResolution.res_540p,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"duration",
|
||||
options=[dur.value for dur in LumaVideoModelOutputDuration],
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"loop",
|
||||
default=False,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||
"luma_concepts",
|
||||
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
|
||||
optional=True,
|
||||
)
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/video/Luma"
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the video generation",
|
||||
},
|
||||
),
|
||||
"model": ([model.value for model in LumaVideoModel],),
|
||||
"aspect_ratio": (
|
||||
[ratio.value for ratio in LumaAspectRatio],
|
||||
{
|
||||
"default": LumaAspectRatio.ratio_16_9,
|
||||
},
|
||||
),
|
||||
"resolution": (
|
||||
[resolution.value for resolution in LumaVideoOutputResolution],
|
||||
{
|
||||
"default": LumaVideoOutputResolution.res_540p,
|
||||
},
|
||||
),
|
||||
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
|
||||
"loop": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"luma_concepts": (
|
||||
LumaIO.LUMA_CONCEPTS,
|
||||
{
|
||||
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
aspect_ratio: str,
|
||||
@ -520,15 +506,13 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
loop: bool,
|
||||
seed,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
unique_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False, min_length=3)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations",
|
||||
@ -545,12 +529,12 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
loop=loop,
|
||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
if cls.hidden.unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
@ -563,94 +547,90 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=video_result_url_extractor,
|
||||
node_id=cls.hidden.unique_id,
|
||||
node_id=unique_id,
|
||||
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
|
||||
auth_kwargs=auth_kwargs,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.video) as vid_response:
|
||||
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
|
||||
class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos synchronously based on prompt, input images, and output_size.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
node_id="LumaImageToVideoNode",
|
||||
display_name="Luma Image to Video",
|
||||
category="api node/video/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the video generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in LumaVideoModel],
|
||||
),
|
||||
# comfy_io.Combo.Input(
|
||||
# "aspect_ratio",
|
||||
# options=[ratio.value for ratio in LumaAspectRatio],
|
||||
# default=LumaAspectRatio.ratio_16_9,
|
||||
# ),
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=[resolution.value for resolution in LumaVideoOutputResolution],
|
||||
default=LumaVideoOutputResolution.res_540p,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"duration",
|
||||
options=[dur.value for dur in LumaVideoModelOutputDuration],
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"loop",
|
||||
default=False,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"first_image",
|
||||
tooltip="First frame of generated video.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"last_image",
|
||||
tooltip="Last frame of generated video.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||
"luma_concepts",
|
||||
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
|
||||
optional=True,
|
||||
)
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/video/Luma"
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the video generation",
|
||||
},
|
||||
),
|
||||
"model": ([model.value for model in LumaVideoModel],),
|
||||
# "aspect_ratio": ([ratio.value for ratio in LumaAspectRatio], {
|
||||
# "default": LumaAspectRatio.ratio_16_9,
|
||||
# }),
|
||||
"resolution": (
|
||||
[resolution.value for resolution in LumaVideoOutputResolution],
|
||||
{
|
||||
"default": LumaVideoOutputResolution.res_540p,
|
||||
},
|
||||
),
|
||||
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
|
||||
"loop": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"first_image": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "First frame of generated video."},
|
||||
),
|
||||
"last_image": (IO.IMAGE, {"tooltip": "Last frame of generated video."}),
|
||||
"luma_concepts": (
|
||||
LumaIO.LUMA_CONCEPTS,
|
||||
{
|
||||
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
resolution: str,
|
||||
@ -660,16 +640,14 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
unique_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
if first_image is None and last_image is None:
|
||||
raise Exception(
|
||||
"At least one of first_image and last_image requires an input."
|
||||
)
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
|
||||
keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
@ -690,12 +668,12 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
keyframes=keyframes,
|
||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
if cls.hidden.unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
@ -708,19 +686,18 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=video_result_url_extractor,
|
||||
node_id=cls.hidden.unique_id,
|
||||
node_id=unique_id,
|
||||
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
|
||||
auth_kwargs=auth_kwargs,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.video) as vid_response:
|
||||
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
@classmethod
|
||||
async def _convert_to_keyframes(
|
||||
cls,
|
||||
self,
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
@ -742,18 +719,23 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
return LumaKeyframes(frame0=frame0, frame1=frame1)
|
||||
|
||||
|
||||
class LumaExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
return [
|
||||
LumaImageGenerationNode,
|
||||
LumaImageModifyNode,
|
||||
LumaTextToVideoGenerationNode,
|
||||
LumaImageToVideoGenerationNode,
|
||||
LumaReferenceNode,
|
||||
LumaConceptsNode,
|
||||
]
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LumaImageNode": LumaImageGenerationNode,
|
||||
"LumaImageModifyNode": LumaImageModifyNode,
|
||||
"LumaVideoNode": LumaTextToVideoGenerationNode,
|
||||
"LumaImageToVideoNode": LumaImageToVideoGenerationNode,
|
||||
"LumaReferenceNode": LumaReferenceNode,
|
||||
"LumaConceptsNode": LumaConceptsNode,
|
||||
}
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> LumaExtension:
|
||||
return LumaExtension()
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LumaImageNode": "Luma Text to Image",
|
||||
"LumaImageModifyNode": "Luma Image to Image",
|
||||
"LumaVideoNode": "Luma Text to Video",
|
||||
"LumaImageToVideoNode": "Luma Image to Video",
|
||||
"LumaReferenceNode": "Luma Reference",
|
||||
"LumaConceptsNode": "Luma Concepts",
|
||||
}
|
||||
|
||||
@ -11,6 +11,7 @@ from comfy.comfy_types.node_typing import IO
|
||||
import folder_paths as comfy_paths
|
||||
import aiohttp
|
||||
import os
|
||||
import datetime
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
@ -242,8 +243,8 @@ class Rodin3DAPI:
|
||||
|
||||
return mesh_mode, quality_override
|
||||
|
||||
async def download_files(self, url_list, task_uuid):
|
||||
save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}")
|
||||
async def download_files(self, url_list):
|
||||
save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
model_file_path = None
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@ -319,7 +320,7 @@ class Rodin3D_Regular(Rodin3DAPI):
|
||||
**kwargs)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list, task_uuid)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
@ -365,7 +366,7 @@ class Rodin3D_Detail(Rodin3DAPI):
|
||||
**kwargs)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list, task_uuid)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
@ -411,7 +412,7 @@ class Rodin3D_Smooth(Rodin3DAPI):
|
||||
**kwargs)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list, task_uuid)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
@ -466,7 +467,7 @@ class Rodin3D_Sketch(Rodin3DAPI):
|
||||
)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list, task_uuid)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
|
||||
@ -28,12 +28,6 @@ class Text2ImageInputField(BaseModel):
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class Image2ImageInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
images: list[str] = Field(..., min_length=1, max_length=2)
|
||||
|
||||
|
||||
class Text2VideoInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
@ -55,13 +49,6 @@ class Txt2ImageParametersField(BaseModel):
|
||||
watermark: bool = Field(True)
|
||||
|
||||
|
||||
class Image2ImageParametersField(BaseModel):
|
||||
size: Optional[str] = Field(None)
|
||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
watermark: bool = Field(True)
|
||||
|
||||
|
||||
class Text2VideoParametersField(BaseModel):
|
||||
size: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
@ -86,12 +73,6 @@ class Text2ImageTaskCreationRequest(BaseModel):
|
||||
parameters: Txt2ImageParametersField = Field(...)
|
||||
|
||||
|
||||
class Image2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Image2ImageInputField = Field(...)
|
||||
parameters: Image2ImageParametersField = Field(...)
|
||||
|
||||
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Text2VideoInputField = Field(...)
|
||||
@ -154,12 +135,7 @@ async def process_task(
|
||||
url: str,
|
||||
request_model: Type[T],
|
||||
response_model: Type[R],
|
||||
payload: Union[
|
||||
Text2ImageTaskCreationRequest,
|
||||
Image2ImageTaskCreationRequest,
|
||||
Text2VideoTaskCreationRequest,
|
||||
Image2VideoTaskCreationRequest,
|
||||
],
|
||||
payload: Union[Text2ImageTaskCreationRequest, Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
|
||||
node_id: str,
|
||||
estimated_duration: int,
|
||||
poll_interval: int,
|
||||
@ -312,128 +288,6 @@ class WanTextToImageApi(comfy_io.ComfyNode):
|
||||
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
|
||||
|
||||
|
||||
class WanImageToImageApi(comfy_io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="WanImageToImageApi",
|
||||
display_name="Wan Image to Image",
|
||||
category="api node/image/Wan",
|
||||
description="Generates an image from one or two input images and a text prompt. "
|
||||
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=["wan2.5-i2i-preview"],
|
||||
default="wan2.5-i2i-preview",
|
||||
tooltip="Model to use.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
tooltip="Single-image editing or multi-image fusion, maximum 2 images.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid.",
|
||||
optional=True,
|
||||
),
|
||||
# redo this later as an optional combo of recommended resolutions
|
||||
# comfy_io.Int.Input(
|
||||
# "width",
|
||||
# default=1280,
|
||||
# min=384,
|
||||
# max=1440,
|
||||
# step=16,
|
||||
# optional=True,
|
||||
# ),
|
||||
# comfy_io.Int.Input(
|
||||
# "height",
|
||||
# default=1280,
|
||||
# min=384,
|
||||
# max=1440,
|
||||
# step=16,
|
||||
# optional=True,
|
||||
# ),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to use for generation.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
# width: int = 1024,
|
||||
# height: int = 1024,
|
||||
seed: int = 0,
|
||||
watermark: bool = True,
|
||||
):
|
||||
n_images = get_number_of_images(image)
|
||||
if n_images not in (1, 2):
|
||||
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
|
||||
images = []
|
||||
for i in image:
|
||||
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096))
|
||||
payload = Image2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
|
||||
parameters=Image2ImageParametersField(
|
||||
# size=f"{width}*{height}",
|
||||
seed=seed,
|
||||
watermark=watermark,
|
||||
),
|
||||
)
|
||||
response = await process_task(
|
||||
{
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
"/proxy/wan/api/v1/services/aigc/image2image/image-synthesis",
|
||||
request_model=Image2ImageTaskCreationRequest,
|
||||
response_model=ImageTaskStatusResponse,
|
||||
payload=payload,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=42,
|
||||
poll_interval=3,
|
||||
)
|
||||
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
|
||||
|
||||
|
||||
class WanTextToVideoApi(comfy_io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -739,7 +593,6 @@ class WanApiExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
return [
|
||||
WanTextToImageApi,
|
||||
WanImageToImageApi,
|
||||
WanTextToVideoApi,
|
||||
WanImageToVideoApi,
|
||||
]
|
||||
|
||||
@ -1,73 +1,55 @@
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
import comfy.sd
|
||||
import comfy.model_management
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class QuadrupleCLIPLoader(io.ComfyNode):
|
||||
class QuadrupleCLIPLoader:
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="QuadrupleCLIPLoader",
|
||||
category="advanced/loaders",
|
||||
description="[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct",
|
||||
inputs=[
|
||||
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
|
||||
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
|
||||
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
|
||||
io.Combo.Input("clip_name4", options=folder_paths.get_filename_list("text_encoders")),
|
||||
],
|
||||
outputs=[
|
||||
io.Clip.Output(),
|
||||
]
|
||||
)
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name3": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name4": (folder_paths.get_filename_list("text_encoders"), )
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip_name1, clip_name2, clip_name3, clip_name4):
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4):
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
|
||||
clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return io.NodeOutput(clip)
|
||||
return (clip,)
|
||||
|
||||
class CLIPTextEncodeHiDream(io.ComfyNode):
|
||||
class CLIPTextEncodeHiDream:
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeHiDream",
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("llama", multiline=True, dynamic_prompts=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
]
|
||||
)
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"llama": ("STRING", {"multiline": True, "dynamicPrompts": True})
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def encode(self, clip, clip_l, clip_g, t5xxl, llama):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, clip_l, clip_g, t5xxl, llama):
|
||||
tokens = clip.tokenize(clip_g)
|
||||
tokens["l"] = clip.tokenize(clip_l)["l"]
|
||||
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||
tokens["llama"] = clip.tokenize(llama)["llama"]
|
||||
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
|
||||
class HiDreamExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
QuadrupleCLIPLoader,
|
||||
CLIPTextEncodeHiDream,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> HiDreamExtension:
|
||||
return HiDreamExtension()
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
|
||||
"CLIPTextEncodeHiDream": CLIPTextEncodeHiDream,
|
||||
}
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
#Taken from: https://github.com/tfernd/HyperTile/
|
||||
|
||||
import math
|
||||
from typing_extensions import override
|
||||
from einops import rearrange
|
||||
# Use torch rng for consistency across generations
|
||||
from torch import randint
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
||||
min_value = min(min_value, value)
|
||||
@ -22,31 +20,25 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
||||
|
||||
return ns[idx]
|
||||
|
||||
class HyperTile(io.ComfyNode):
|
||||
class HyperTile:
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HyperTile",
|
||||
category="model_patches/unet",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Int.Input("tile_size", default=256, min=1, max=2048),
|
||||
io.Int.Input("swap_size", default=2, min=1, max=128),
|
||||
io.Int.Input("max_depth", default=0, min=0, max=10),
|
||||
io.Boolean.Input("scale_depth", default=False),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}),
|
||||
"swap_size": ("INT", {"default": 2, "min": 1, "max": 128}),
|
||||
"max_depth": ("INT", {"default": 0, "min": 0, "max": 10}),
|
||||
"scale_depth": ("BOOLEAN", {"default": False}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, tile_size, swap_size, max_depth, scale_depth) -> io.NodeOutput:
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
|
||||
latent_tile_size = max(32, tile_size) // 8
|
||||
temp = None
|
||||
self.temp = None
|
||||
|
||||
def hypertile_in(q, k, v, extra_options):
|
||||
nonlocal temp
|
||||
model_chans = q.shape[-2]
|
||||
orig_shape = extra_options['original_shape']
|
||||
apply_to = []
|
||||
@ -66,15 +58,14 @@ class HyperTile(io.ComfyNode):
|
||||
|
||||
if nh * nw > 1:
|
||||
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
||||
temp = (nh, nw, h, w)
|
||||
self.temp = (nh, nw, h, w)
|
||||
return q, k, v
|
||||
|
||||
return q, k, v
|
||||
def hypertile_out(out, extra_options):
|
||||
nonlocal temp
|
||||
if temp is not None:
|
||||
nh, nw, h, w = temp
|
||||
temp = None
|
||||
if self.temp is not None:
|
||||
nh, nw, h, w = self.temp
|
||||
self.temp = None
|
||||
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
|
||||
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
|
||||
return out
|
||||
@ -85,14 +76,6 @@ class HyperTile(io.ComfyNode):
|
||||
m.set_model_attn1_output_patch(hypertile_out)
|
||||
return (m, )
|
||||
|
||||
|
||||
class HyperTileExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
HyperTile,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> HyperTileExtension:
|
||||
return HyperTileExtension()
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"HyperTile": HyperTile,
|
||||
}
|
||||
|
||||
@ -1,22 +1,20 @@
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
import comfy.model_management as mm
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class LotusConditioning(io.ComfyNode):
|
||||
class LotusConditioning:
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LotusConditioning",
|
||||
category="conditioning/lotus",
|
||||
inputs=[],
|
||||
outputs=[io.Conditioning.Output(display_name="conditioning")],
|
||||
)
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def execute(cls) -> io.NodeOutput:
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
RETURN_NAMES = ("conditioning",)
|
||||
FUNCTION = "conditioning"
|
||||
CATEGORY = "conditioning/lotus"
|
||||
|
||||
def conditioning(self):
|
||||
device = mm.get_torch_device()
|
||||
#lotus uses a frozen encoder and null conditioning, i'm just inlining the results of that operation since it doesn't change
|
||||
#and getting parity with the reference implementation would otherwise require inference and 800mb of tensors
|
||||
@ -24,16 +22,8 @@ class LotusConditioning(io.ComfyNode):
|
||||
|
||||
cond = [[prompt_embeds, {}]]
|
||||
|
||||
return io.NodeOutput(cond)
|
||||
return (cond,)
|
||||
|
||||
|
||||
class LotusExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
LotusConditioning,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> LotusExtension:
|
||||
return LotusExtension()
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LotusConditioning" : LotusConditioning,
|
||||
}
|
||||
|
||||
@ -1,27 +1,20 @@
|
||||
from typing_extensions import override
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
import torch
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class RenormCFG(io.ComfyNode):
|
||||
class RenormCFG:
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RenormCFG",
|
||||
category="advanced/model",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("cfg_trunc", default=100, min=0.0, max=100.0, step=0.01),
|
||||
io.Float.Input("renorm_cfg", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"cfg_trunc": ("FLOAT", {"default": 100, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||
"renorm_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, cfg_trunc, renorm_cfg) -> io.NodeOutput:
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, cfg_trunc, renorm_cfg):
|
||||
def renorm_cfg_func(args):
|
||||
cond_denoised = args["cond_denoised"]
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
@ -60,10 +53,10 @@ class RenormCFG(io.ComfyNode):
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(renorm_cfg_func)
|
||||
return io.NodeOutput(m)
|
||||
return (m, )
|
||||
|
||||
|
||||
class CLIPTextEncodeLumina2(io.ComfyNode):
|
||||
class CLIPTextEncodeLumina2(ComfyNodeABC):
|
||||
SYSTEM_PROMPT = {
|
||||
"superior": "You are an assistant designed to generate superior images with the superior "\
|
||||
"degree of image-text alignment based on textual prompts or user prompts.",
|
||||
@ -76,52 +69,36 @@ class CLIPTextEncodeLumina2(io.ComfyNode):
|
||||
"Alignment: You are an assistant designed to generate high-quality images with the highest "\
|
||||
"degree of image-text alignment based on textual prompts."
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeLumina2",
|
||||
display_name="CLIP Text Encode for Lumina2",
|
||||
category="conditioning",
|
||||
description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "
|
||||
"that can be used to guide the diffusion model towards generating specific images.",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"system_prompt",
|
||||
options=list(cls.SYSTEM_PROMPT.keys()),
|
||||
tooltip=cls.SYSTEM_PROMPT_TIP,
|
||||
),
|
||||
io.String.Input(
|
||||
"user_prompt",
|
||||
multiline=True,
|
||||
dynamic_prompts=True,
|
||||
tooltip="The text to be encoded.",
|
||||
),
|
||||
io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(
|
||||
tooltip="A conditioning containing the embedded text used to guide the diffusion model.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"system_prompt": (list(CLIPTextEncodeLumina2.SYSTEM_PROMPT.keys()), {"tooltip": CLIPTextEncodeLumina2.SYSTEM_PROMPT_TIP}),
|
||||
"user_prompt": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
|
||||
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = (IO.CONDITIONING,)
|
||||
OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, user_prompt, system_prompt) -> io.NodeOutput:
|
||||
CATEGORY = "conditioning"
|
||||
DESCRIPTION = "Encodes a system prompt and a user prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
|
||||
|
||||
def encode(self, clip, user_prompt, system_prompt):
|
||||
if clip is None:
|
||||
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
|
||||
system_prompt = cls.SYSTEM_PROMPT[system_prompt]
|
||||
system_prompt = CLIPTextEncodeLumina2.SYSTEM_PROMPT[system_prompt]
|
||||
prompt = f'{system_prompt} <Prompt Start> {user_prompt}'
|
||||
tokens = clip.tokenize(prompt)
|
||||
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
|
||||
class Lumina2Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
CLIPTextEncodeLumina2,
|
||||
RenormCFG,
|
||||
]
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeLumina2": CLIPTextEncodeLumina2,
|
||||
"RenormCFG": RenormCFG
|
||||
}
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> Lumina2Extension:
|
||||
return Lumina2Extension()
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CLIPTextEncodeLumina2": "CLIP Text Encode for Lumina2",
|
||||
}
|
||||
|
||||
@ -1,29 +1,17 @@
|
||||
from typing_extensions import override
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class Mahiro(io.ComfyNode):
|
||||
class Mahiro:
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Mahiro",
|
||||
display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
|
||||
category="_for_testing",
|
||||
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="patched_model"),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model) -> io.NodeOutput:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL",),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("patched_model",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "_for_testing"
|
||||
DESCRIPTION = "Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt."
|
||||
def patch(self, model):
|
||||
m = model.clone()
|
||||
def mahiro_normd(args):
|
||||
scale: float = args['cond_scale']
|
||||
@ -42,16 +30,12 @@ class Mahiro(io.ComfyNode):
|
||||
wm = (simsc*cfg + (4-simsc)*leap) / 4
|
||||
return wm
|
||||
m.set_model_sampler_post_cfg_function(mahiro_normd)
|
||||
return io.NodeOutput(m)
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Mahiro": Mahiro
|
||||
}
|
||||
|
||||
class MahiroExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
Mahiro,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> MahiroExtension:
|
||||
return MahiroExtension()
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Mahiro": "Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
|
||||
}
|
||||
|
||||
@ -1,40 +1,23 @@
|
||||
from typing_extensions import override
|
||||
import nodes
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import nodes
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class EmptyMochiLatentVideo(io.ComfyNode):
|
||||
class EmptyMochiLatentVideo:
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyMochiLatentVideo",
|
||||
category="latent/video",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=25, min=7, max=nodes.MAX_RESOLUTION, step=6),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 25, "min": 7, "max": nodes.MAX_RESOLUTION, "step": 6}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
|
||||
CATEGORY = "latent/video"
|
||||
|
||||
def generate(self, width, height, length, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent})
|
||||
return ({"samples":latent}, )
|
||||
|
||||
|
||||
class MochiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
EmptyMochiLatentVideo,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> MochiExtension:
|
||||
return MochiExtension()
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyMochiLatentVideo": EmptyMochiLatentVideo,
|
||||
}
|
||||
|
||||
@ -5,9 +5,6 @@ import comfy.samplers
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
import math
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale):
|
||||
pos = noise_pred_pos - noise_pred_nocond
|
||||
@ -19,27 +16,20 @@ def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, co
|
||||
return cfg_result
|
||||
|
||||
#TODO: This node should be removed, it has been replaced with PerpNegGuider
|
||||
class PerpNeg(io.ComfyNode):
|
||||
class PerpNeg:
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PerpNeg",
|
||||
display_name="Perp-Neg (DEPRECATED by PerpNegGuider)",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Conditioning.Input("empty_conditioning"),
|
||||
io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
is_experimental=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL", ),
|
||||
"empty_conditioning": ("CONDITIONING", ),
|
||||
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, empty_conditioning, neg_scale) -> io.NodeOutput:
|
||||
CATEGORY = "_for_testing"
|
||||
DEPRECATED = True
|
||||
|
||||
def patch(self, model, empty_conditioning, neg_scale):
|
||||
m = model.clone()
|
||||
nocond = comfy.sampler_helpers.convert_cond(empty_conditioning)
|
||||
|
||||
@ -60,7 +50,7 @@ class PerpNeg(io.ComfyNode):
|
||||
|
||||
m.set_model_sampler_cfg_function(cfg_function)
|
||||
|
||||
return io.NodeOutput(m)
|
||||
return (m, )
|
||||
|
||||
|
||||
class Guider_PerpNeg(comfy.samplers.CFGGuider):
|
||||
@ -122,42 +112,35 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
|
||||
|
||||
return cfg_result
|
||||
|
||||
class PerpNegGuider(io.ComfyNode):
|
||||
class PerpNegGuider:
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PerpNegGuider",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Conditioning.Input("empty_conditioning"),
|
||||
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
||||
io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Guider.Output(),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"empty_conditioning": ("CONDITIONING", ),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, positive, negative, empty_conditioning, cfg, neg_scale) -> io.NodeOutput:
|
||||
RETURN_TYPES = ("GUIDER",)
|
||||
|
||||
FUNCTION = "get_guider"
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def get_guider(self, model, positive, negative, empty_conditioning, cfg, neg_scale):
|
||||
guider = Guider_PerpNeg(model)
|
||||
guider.set_conds(positive, negative, empty_conditioning)
|
||||
guider.set_cfg(cfg, neg_scale)
|
||||
return io.NodeOutput(guider)
|
||||
return (guider,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PerpNeg": PerpNeg,
|
||||
"PerpNegGuider": PerpNegGuider,
|
||||
}
|
||||
|
||||
class PerpNegExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
PerpNeg,
|
||||
PerpNegGuider,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> PerpNegExtension:
|
||||
return PerpNegExtension()
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PerpNeg": "Perp-Neg (DEPRECATED by PerpNegGuider)",
|
||||
}
|
||||
|
||||
@ -4,8 +4,6 @@ import folder_paths
|
||||
import comfy.clip_model
|
||||
import comfy.clip_vision
|
||||
import comfy.ops
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
# code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0
|
||||
VISION_CONFIG_DICT = {
|
||||
@ -118,52 +116,41 @@ class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection):
|
||||
return updated_prompt_embeds
|
||||
|
||||
|
||||
class PhotoMakerLoader(io.ComfyNode):
|
||||
class PhotoMakerLoader:
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PhotoMakerLoader",
|
||||
category="_for_testing/photomaker",
|
||||
inputs=[
|
||||
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
|
||||
],
|
||||
outputs=[
|
||||
io.Photomaker.Output(),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}}
|
||||
|
||||
@classmethod
|
||||
def execute(cls, photomaker_model_name):
|
||||
RETURN_TYPES = ("PHOTOMAKER",)
|
||||
FUNCTION = "load_photomaker_model"
|
||||
|
||||
CATEGORY = "_for_testing/photomaker"
|
||||
|
||||
def load_photomaker_model(self, photomaker_model_name):
|
||||
photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
|
||||
photomaker_model = PhotoMakerIDEncoder()
|
||||
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
|
||||
if "id_encoder" in data:
|
||||
data = data["id_encoder"]
|
||||
photomaker_model.load_state_dict(data)
|
||||
return io.NodeOutput(photomaker_model)
|
||||
return (photomaker_model,)
|
||||
|
||||
|
||||
class PhotoMakerEncode(io.ComfyNode):
|
||||
class PhotoMakerEncode:
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PhotoMakerEncode",
|
||||
category="_for_testing/photomaker",
|
||||
inputs=[
|
||||
io.Photomaker.Input("photomaker"),
|
||||
io.Image.Input("image"),
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("text", multiline=True, dynamic_prompts=True, default="photograph of photomaker"),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "photomaker": ("PHOTOMAKER",),
|
||||
"image": ("IMAGE",),
|
||||
"clip": ("CLIP", ),
|
||||
"text": ("STRING", {"multiline": True, "dynamicPrompts": True, "default": "photograph of photomaker"}),
|
||||
}}
|
||||
|
||||
@classmethod
|
||||
def execute(cls, photomaker, image, clip, text):
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "apply_photomaker"
|
||||
|
||||
CATEGORY = "_for_testing/photomaker"
|
||||
|
||||
def apply_photomaker(self, photomaker, image, clip, text):
|
||||
special_token = "photomaker"
|
||||
pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float()
|
||||
try:
|
||||
@ -191,16 +178,11 @@ class PhotoMakerEncode(io.ComfyNode):
|
||||
else:
|
||||
out = cond
|
||||
|
||||
return io.NodeOutput([[out, {"pooled_output": pooled}]])
|
||||
return ([[out, {"pooled_output": pooled}]], )
|
||||
|
||||
|
||||
class PhotomakerExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
PhotoMakerLoader,
|
||||
PhotoMakerEncode,
|
||||
]
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PhotoMakerLoader": PhotoMakerLoader,
|
||||
"PhotoMakerEncode": PhotoMakerEncode,
|
||||
}
|
||||
|
||||
async def comfy_entrypoint() -> PhotomakerExtension:
|
||||
return PhotomakerExtension()
|
||||
|
||||
@ -1,38 +1,24 @@
|
||||
from typing_extensions import override
|
||||
import nodes
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
class CLIPTextEncodePixArtAlpha(io.ComfyNode):
|
||||
class CLIPTextEncodePixArtAlpha:
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodePixArtAlpha",
|
||||
category="advanced/conditioning",
|
||||
description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
|
||||
# "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
io.String.Input("text", multiline=True, dynamic_prompts=True),
|
||||
io.Clip.Input("clip"),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
# "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
|
||||
}}
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, width, height, text):
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "advanced/conditioning"
|
||||
DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma."
|
||||
|
||||
def encode(self, clip, width, height, text):
|
||||
tokens = clip.tokenize(text)
|
||||
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}))
|
||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),)
|
||||
|
||||
|
||||
class PixArtExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
CLIPTextEncodePixArtAlpha,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> PixArtExtension:
|
||||
return PixArtExtension()
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha,
|
||||
}
|
||||
|
||||
@ -1,29 +1,24 @@
|
||||
import node_helpers
|
||||
import comfy.utils
|
||||
import math
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class TextEncodeQwenImageEdit(io.ComfyNode):
|
||||
class TextEncodeQwenImageEdit:
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeQwenImageEdit",
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
|
||||
io.Vae.Input("vae", optional=True),
|
||||
io.Image.Input("image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
},
|
||||
"optional": {"vae": ("VAE", ),
|
||||
"image": ("IMAGE", ),}}
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, vae=None, image=None) -> io.NodeOutput:
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def encode(self, clip, prompt, vae=None, image=None):
|
||||
ref_latent = None
|
||||
if image is None:
|
||||
images = []
|
||||
@ -45,30 +40,28 @@ class TextEncodeQwenImageEdit(io.ComfyNode):
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
if ref_latent is not None:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True)
|
||||
return io.NodeOutput(conditioning)
|
||||
return (conditioning, )
|
||||
|
||||
|
||||
class TextEncodeQwenImageEditPlus(io.ComfyNode):
|
||||
class TextEncodeQwenImageEditPlus:
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeQwenImageEditPlus",
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
|
||||
io.Vae.Input("vae", optional=True),
|
||||
io.Image.Input("image1", optional=True),
|
||||
io.Image.Input("image2", optional=True),
|
||||
io.Image.Input("image3", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
},
|
||||
"optional": {"vae": ("VAE", ),
|
||||
"image1": ("IMAGE", ),
|
||||
"image2": ("IMAGE", ),
|
||||
"image3": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput:
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def encode(self, clip, prompt, vae=None, image1=None, image2=None, image3=None):
|
||||
ref_latents = []
|
||||
images = [image1, image2, image3]
|
||||
images_vl = []
|
||||
@ -101,17 +94,10 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode):
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
if len(ref_latents) > 0:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
|
||||
return io.NodeOutput(conditioning)
|
||||
return (conditioning, )
|
||||
|
||||
|
||||
class QwenExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeQwenImageEdit,
|
||||
TextEncodeQwenImageEditPlus,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> QwenExtension:
|
||||
return QwenExtension()
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TextEncodeQwenImageEdit": TextEncodeQwenImageEdit,
|
||||
"TextEncodeQwenImageEditPlus": TextEncodeQwenImageEditPlus,
|
||||
}
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.61"
|
||||
__version__ = "0.3.60"
|
||||
|
||||
@ -279,10 +279,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
|
||||
|
||||
|
||||
|
||||
def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||
"""
|
||||
Get the full path of a file in a folder, has to be a file
|
||||
"""
|
||||
def get_full_path(folder_name: str, filename: str, allow_missing: bool = False) -> str | None:
|
||||
global folder_names_and_paths
|
||||
folder_name = map_legacy(folder_name)
|
||||
if folder_name not in folder_names_and_paths:
|
||||
@ -295,6 +292,8 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||
return full_path
|
||||
elif os.path.islink(full_path):
|
||||
logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path))
|
||||
elif allow_missing:
|
||||
return full_path
|
||||
|
||||
return None
|
||||
|
||||
@ -309,6 +308,27 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||
return full_path
|
||||
|
||||
|
||||
def get_relative_path(full_path: str) -> tuple[str, str] | None:
|
||||
"""Convert a full path back to a type-relative path.
|
||||
|
||||
Args:
|
||||
full_path: The full path to the file
|
||||
|
||||
Returns:
|
||||
tuple[str, str] | None: A tuple of (model_type, relative_path) if found, None otherwise
|
||||
"""
|
||||
global folder_names_and_paths
|
||||
full_path = os.path.normpath(full_path)
|
||||
|
||||
for model_type, (paths, _) in folder_names_and_paths.items():
|
||||
for base_path in paths:
|
||||
base_path = os.path.normpath(base_path)
|
||||
if full_path.startswith(base_path):
|
||||
relative_path = os.path.relpath(full_path, base_path)
|
||||
return model_type, relative_path
|
||||
|
||||
return None
|
||||
|
||||
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
||||
folder_name = map_legacy(folder_name)
|
||||
global folder_names_and_paths
|
||||
|
||||
18
main.py
18
main.py
@ -127,7 +127,6 @@ if __name__ == "__main__":
|
||||
if args.cuda_device is not None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device)
|
||||
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
||||
|
||||
if args.oneapi_device_selector is not None:
|
||||
@ -165,7 +164,6 @@ def cuda_malloc_warning():
|
||||
if cuda_malloc_warning:
|
||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
@ -280,14 +278,13 @@ def cleanup_temp():
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
async def setup_database():
|
||||
from app.assets import sync_seed_assets
|
||||
from app.db import init_db_engine
|
||||
|
||||
def setup_database():
|
||||
try:
|
||||
from app.database.db import init_db, dependencies_available
|
||||
if dependencies_available():
|
||||
init_db()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||
await init_db_engine()
|
||||
if not args.disable_assets_autoscan:
|
||||
await sync_seed_assets(["models"])
|
||||
|
||||
|
||||
def start_comfyui(asyncio_loop=None):
|
||||
@ -313,6 +310,8 @@ def start_comfyui(asyncio_loop=None):
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
prompt_server = server.PromptServer(asyncio_loop)
|
||||
|
||||
asyncio_loop.run_until_complete(setup_database())
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||
@ -321,7 +320,6 @@ def start_comfyui(asyncio_loop=None):
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
cuda_malloc_warning()
|
||||
setup_database()
|
||||
|
||||
prompt_server.add_routes()
|
||||
hijack_progress(prompt_server)
|
||||
|
||||
@ -26,12 +26,11 @@ async def cache_control(
|
||||
"""Cache control middleware that sets appropriate cache headers based on file type and response status"""
|
||||
response: web.Response = await handler(request)
|
||||
|
||||
path_filename = request.path.rsplit("/", 1)[-1]
|
||||
is_entry_point = path_filename.startswith("index") and path_filename.endswith(
|
||||
".json"
|
||||
)
|
||||
|
||||
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
|
||||
if (
|
||||
request.path.endswith(".js")
|
||||
or request.path.endswith(".css")
|
||||
or request.path.endswith("index.json")
|
||||
):
|
||||
response.headers.setdefault("Cache-Control", "no-cache")
|
||||
return response
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.61"
|
||||
version = "0.3.60"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.26.13
|
||||
comfyui-workflow-templates==0.1.91
|
||||
comfyui-workflow-templates==0.1.86
|
||||
comfyui-embedded-docs==0.2.6
|
||||
torch
|
||||
torchsde
|
||||
@ -20,7 +20,9 @@ tqdm
|
||||
psutil
|
||||
alembic
|
||||
SQLAlchemy
|
||||
aiosqlite
|
||||
av>=14.2.0
|
||||
blake3
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
|
||||
@ -37,6 +37,7 @@ from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from app.assets import sync_seed_assets, register_assets_system
|
||||
from protocol import BinaryEventTypes
|
||||
|
||||
# Import cache control middleware
|
||||
@ -178,6 +179,7 @@ class PromptServer():
|
||||
else args.front_end_root
|
||||
)
|
||||
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||||
register_assets_system(self.app, self.user_manager)
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
@ -550,8 +552,6 @@ class PromptServer():
|
||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||
required_frontend_version = FrontendManager.get_required_frontend_version()
|
||||
installed_templates_version = FrontendManager.get_installed_templates_version()
|
||||
required_templates_version = FrontendManager.get_required_templates_version()
|
||||
|
||||
system_stats = {
|
||||
"system": {
|
||||
@ -560,8 +560,6 @@ class PromptServer():
|
||||
"ram_free": ram_free,
|
||||
"comfyui_version": __version__,
|
||||
"required_frontend_version": required_frontend_version,
|
||||
"installed_templates_version": installed_templates_version,
|
||||
"required_templates_version": required_templates_version,
|
||||
"python_version": sys.version,
|
||||
"pytorch_version": comfy.model_management.torch_version,
|
||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||
@ -626,6 +624,7 @@ class PromptServer():
|
||||
|
||||
@routes.get("/object_info")
|
||||
async def get_object_info(request):
|
||||
await sync_seed_assets(["models"])
|
||||
with folder_paths.cache_helper:
|
||||
out = {}
|
||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||
|
||||
307
tests-assets/conftest.py
Normal file
307
tests-assets/conftest.py
Normal file
@ -0,0 +1,307 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable, Optional
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
"""
|
||||
Allow overriding the database URL used by the spawned ComfyUI process.
|
||||
Priority:
|
||||
1) --db-url command line option
|
||||
2) ASSETS_TEST_DB_URL environment variable (used by CI)
|
||||
3) default: sqlite in-memory
|
||||
"""
|
||||
parser.addoption(
|
||||
"--db-url",
|
||||
action="store",
|
||||
default=os.environ.get("ASSETS_TEST_DB_URL", "sqlite+aiosqlite:///:memory:"),
|
||||
help="Async SQLAlchemy DB URL (e.g. sqlite+aiosqlite:///:memory: or postgresql+psycopg://user:pass@host/db)",
|
||||
)
|
||||
|
||||
|
||||
def _free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _make_base_dirs(root: Path) -> None:
|
||||
for sub in ("models", "custom_nodes", "input", "output", "temp", "user"):
|
||||
(root / sub).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
async def _wait_http_ready(base: str, session: aiohttp.ClientSession, timeout: float = 90.0) -> None:
|
||||
start = time.time()
|
||||
last_err = None
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
async with session.get(base + "/api/assets") as r:
|
||||
if r.status in (200, 400):
|
||||
return
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
await asyncio.sleep(0.25)
|
||||
raise RuntimeError(f"ComfyUI HTTP did not become ready: {last_err}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def comfy_tmp_base_dir() -> Path:
|
||||
env_base = os.environ.get("ASSETS_TEST_BASE_DIR")
|
||||
created_by_fixture = False
|
||||
if env_base:
|
||||
tmp = Path(env_base)
|
||||
tmp.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-"))
|
||||
created_by_fixture = True
|
||||
_make_base_dirs(tmp)
|
||||
yield tmp
|
||||
if created_by_fixture:
|
||||
with contextlib.suppress(Exception):
|
||||
for p in sorted(tmp.rglob("*"), reverse=True):
|
||||
if p.is_file() or p.is_symlink():
|
||||
p.unlink(missing_ok=True)
|
||||
for p in sorted(tmp.glob("**/*"), reverse=True):
|
||||
with contextlib.suppress(Exception):
|
||||
p.rmdir()
|
||||
tmp.rmdir()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest):
|
||||
"""
|
||||
Boot ComfyUI subprocess with:
|
||||
- sandbox base dir
|
||||
- sqlite memory DB (default)
|
||||
- autoscan disabled
|
||||
Returns (base_url, process, port)
|
||||
"""
|
||||
port = _free_port()
|
||||
db_url = request.config.getoption("--db-url")
|
||||
|
||||
logs_dir = comfy_tmp_base_dir / "logs"
|
||||
logs_dir.mkdir(exist_ok=True)
|
||||
out_log = open(logs_dir / "stdout.log", "w", buffering=1)
|
||||
err_log = open(logs_dir / "stderr.log", "w", buffering=1)
|
||||
|
||||
comfy_root = Path(__file__).resolve().parent.parent
|
||||
if not (comfy_root / "main.py").is_file():
|
||||
raise FileNotFoundError(f"main.py not found under {comfy_root}")
|
||||
|
||||
proc = subprocess.Popen(
|
||||
args=[
|
||||
sys.executable,
|
||||
"main.py",
|
||||
f"--base-directory={str(comfy_tmp_base_dir)}",
|
||||
f"--database-url={db_url}",
|
||||
"--disable-assets-autoscan",
|
||||
"--listen",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
str(port),
|
||||
"--cpu",
|
||||
],
|
||||
stdout=out_log,
|
||||
stderr=err_log,
|
||||
cwd=str(comfy_root),
|
||||
env={**os.environ},
|
||||
)
|
||||
|
||||
for _ in range(50):
|
||||
if proc.poll() is not None:
|
||||
out_log.flush()
|
||||
err_log.flush()
|
||||
raise RuntimeError(f"ComfyUI exited early with code {proc.returncode}")
|
||||
time.sleep(0.1)
|
||||
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
try:
|
||||
async def _probe():
|
||||
async with aiohttp.ClientSession() as s:
|
||||
await _wait_http_ready(base_url, s, timeout=90.0)
|
||||
|
||||
asyncio.run(_probe())
|
||||
yield base_url, proc, port
|
||||
except Exception as e:
|
||||
with contextlib.suppress(Exception):
|
||||
proc.terminate()
|
||||
proc.wait(timeout=10)
|
||||
with contextlib.suppress(Exception):
|
||||
out_log.flush()
|
||||
err_log.flush()
|
||||
raise RuntimeError(f"ComfyUI did not become ready: {e}")
|
||||
|
||||
if proc and proc.poll() is None:
|
||||
with contextlib.suppress(Exception):
|
||||
proc.terminate()
|
||||
proc.wait(timeout=15)
|
||||
out_log.close()
|
||||
err_log.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def http() -> AsyncIterator[aiohttp.ClientSession]:
|
||||
timeout = aiohttp.ClientTimeout(total=120)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as s:
|
||||
yield s
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_base(comfy_url_and_proc) -> str:
|
||||
base_url, _proc, _port = comfy_url_and_proc
|
||||
return base_url
|
||||
|
||||
|
||||
async def _post_multipart_asset(
|
||||
session: aiohttp.ClientSession,
|
||||
base: str,
|
||||
*,
|
||||
name: str,
|
||||
tags: list[str],
|
||||
meta: dict,
|
||||
data: bytes,
|
||||
extra_fields: Optional[dict] = None,
|
||||
) -> tuple[int, dict]:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", data, filename=name, content_type="application/octet-stream")
|
||||
form.add_field("tags", json.dumps(tags))
|
||||
form.add_field("name", name)
|
||||
form.add_field("user_metadata", json.dumps(meta))
|
||||
if extra_fields:
|
||||
for k, v in extra_fields.items():
|
||||
form.add_field(k, v)
|
||||
async with session.post(base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
return r.status, body
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_asset_bytes() -> Callable[[str, int], bytes]:
|
||||
def _make(name: str, size: int = 8192) -> bytes:
|
||||
seed = sum(ord(c) for c in name) % 251
|
||||
return bytes((i * 31 + seed) % 256 for i in range(size))
|
||||
return _make
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def asset_factory(http: aiohttp.ClientSession, api_base: str):
|
||||
"""
|
||||
Returns create(name, tags, meta, data) -> response dict
|
||||
Tracks created ids and deletes them after the test.
|
||||
"""
|
||||
created: list[str] = []
|
||||
|
||||
async def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict:
|
||||
status, body = await _post_multipart_asset(http, api_base, name=name, tags=tags, meta=meta, data=data)
|
||||
assert status in (200, 201), body
|
||||
created.append(body["id"])
|
||||
return body
|
||||
|
||||
yield create
|
||||
|
||||
# cleanup by id
|
||||
for aid in created:
|
||||
with contextlib.suppress(Exception):
|
||||
async with http.delete(f"{api_base}/api/assets/{aid}") as r:
|
||||
await r.read()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def seeded_asset(request: pytest.FixtureRequest, http: aiohttp.ClientSession, api_base: str) -> dict:
|
||||
"""
|
||||
Upload one asset with ".safetensors" extension into models/checkpoints/unit-tests/<name>.
|
||||
Returns response dict with id, asset_hash, tags, etc.
|
||||
"""
|
||||
name = "unit_1_example.safetensors"
|
||||
p = getattr(request, "param", {}) or {}
|
||||
tags: Optional[list[str]] = p.get("tags")
|
||||
if tags is None:
|
||||
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||
meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None}
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"A" * 4096, filename=name, content_type="application/octet-stream")
|
||||
form.add_field("tags", json.dumps(tags))
|
||||
form.add_field("name", name)
|
||||
form.add_field("user_metadata", json.dumps(meta))
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 201, body
|
||||
return body
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def autoclean_unit_test_assets(http: aiohttp.ClientSession, api_base: str):
|
||||
"""Ensure isolation by removing all AssetInfo rows tagged with 'unit-tests' after each test."""
|
||||
yield
|
||||
|
||||
while True:
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests", "limit": "500", "sort": "name"},
|
||||
) as r:
|
||||
body = await r.json()
|
||||
if r.status != 200:
|
||||
break
|
||||
ids = [a["id"] for a in body.get("assets", [])]
|
||||
if not ids:
|
||||
break
|
||||
for aid in ids:
|
||||
with contextlib.suppress(Exception):
|
||||
async with http.delete(f"{api_base}/api/assets/{aid}") as dr:
|
||||
await dr.read()
|
||||
|
||||
|
||||
async def trigger_sync_seed_assets(session: aiohttp.ClientSession, base_url: str) -> None:
|
||||
"""Force a fast sync/seed pass by calling the ComfyUI '/object_info' endpoint."""
|
||||
async with session.post(base_url + "/api/assets/scan/seed", json={"roots": ["models", "input", "output"]}) as r:
|
||||
await r.read()
|
||||
await asyncio.sleep(0.1) # tiny yield to the event loop to let any final DB commits flush
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def run_scan_and_wait(http: aiohttp.ClientSession, api_base: str):
|
||||
"""Schedule an asset scan for a given root and wait until it finishes."""
|
||||
async def _run(root: str, timeout: float = 120.0):
|
||||
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r:
|
||||
# we ignore body; scheduling returns 202 with a status payload
|
||||
await r.read()
|
||||
|
||||
start = time.time()
|
||||
while True:
|
||||
async with http.get(api_base + "/api/assets/scan", params={"root": root}) as st:
|
||||
body = await st.json()
|
||||
scans = (body or {}).get("scans", [])
|
||||
status = None
|
||||
if scans:
|
||||
status = scans[-1].get("status")
|
||||
if status in {"completed", "failed", "cancelled"}:
|
||||
if status != "completed":
|
||||
raise RuntimeError(f"Scan for root={root} finished with status={status}")
|
||||
return
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError(f"Timed out waiting for scan of root={root}")
|
||||
await asyncio.sleep(0.1)
|
||||
return _run
|
||||
|
||||
|
||||
def get_asset_filename(asset_hash: str, extension: str) -> str:
|
||||
return asset_hash.removeprefix("blake3:") + extension
|
||||
347
tests-assets/test_assets_missing_sync.py
Normal file
347
tests-assets/test_assets_missing_sync.py
Normal file
@ -0,0 +1,347 @@
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_seed_asset_removed_when_file_is_deleted(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
"""Asset without hash (seed) whose file disappears:
|
||||
after triggering sync_seed_assets, Asset + AssetInfo disappear.
|
||||
"""
|
||||
# Create a file directly under input/unit-tests/<case> so tags include "unit-tests"
|
||||
case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed"
|
||||
case_dir.mkdir(parents=True, exist_ok=True)
|
||||
name = f"seed_{uuid.uuid4().hex[:8]}.bin"
|
||||
fp = case_dir / name
|
||||
fp.write_bytes(b"Z" * 2048)
|
||||
|
||||
# Trigger a seed sync so DB sees this path (seed asset => hash is NULL)
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Verify it is visible via API and carries no hash (seed)
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||
) as r1:
|
||||
body1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
# there should be exactly one with that name
|
||||
matches = [a for a in body1.get("assets", []) if a.get("name") == name]
|
||||
assert matches
|
||||
assert matches[0].get("asset_hash") is None
|
||||
asset_info_id = matches[0]["id"]
|
||||
|
||||
# Remove the underlying file and sync again
|
||||
if fp.exists():
|
||||
fp.unlink()
|
||||
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# It should disappear (AssetInfo and seed Asset gone)
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||
) as r2:
|
||||
body2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
matches2 = [a for a in body2.get("assets", []) if a.get("name") == name]
|
||||
assert not matches2, f"Seed asset {asset_info_id} should be gone after sync"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hashed_asset_missing_tag_added_then_removed_after_scan(
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""Hashed asset with a single cache_state:
|
||||
1. delete its file -> sync adds 'missing'
|
||||
2. restore file -> scan removes 'missing'
|
||||
"""
|
||||
name = "missing_tag_test.png"
|
||||
tags = ["input", "unit-tests", "msync2"]
|
||||
data = make_asset_bytes(name, 4096)
|
||||
a = await asset_factory(name, tags, {}, data)
|
||||
|
||||
# Compute its on-disk path and remove it
|
||||
dest = comfy_tmp_base_dir / "input" / "unit-tests" / "msync2" / get_asset_filename(a["asset_hash"], ".png")
|
||||
assert dest.exists(), f"Expected asset file at {dest}"
|
||||
dest.unlink()
|
||||
|
||||
# Fast sync should add 'missing' to the AssetInfo
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{a['id']}") as g1:
|
||||
d1 = await g1.json()
|
||||
assert g1.status == 200, d1
|
||||
assert "missing" in set(d1.get("tags", [])), "Expected 'missing' tag after deletion"
|
||||
|
||||
# Restore the file with the exact same content and re-hash/verify via scan
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_bytes(data)
|
||||
|
||||
await run_scan_and_wait("input")
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{a['id']}") as g2:
|
||||
d2 = await g2.json()
|
||||
assert g2.status == 200, d2
|
||||
assert "missing" not in set(d2.get("tags", [])), "Missing tag should be cleared after verify"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hashed_asset_two_asset_infos_both_get_missing(
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
):
|
||||
"""Hashed asset with a single cache_state, but two AssetInfo rows:
|
||||
deleting the single file then syncing should add 'missing' to both infos.
|
||||
"""
|
||||
# Upload one hashed asset
|
||||
name = "two_infos_one_path.png"
|
||||
base_tags = ["input", "unit-tests", "multiinfo"]
|
||||
created = await asset_factory(name, base_tags, {}, b"A" * 2048)
|
||||
|
||||
# Create second AssetInfo for the same Asset via from-hash
|
||||
payload = {
|
||||
"hash": created["asset_hash"],
|
||||
"name": "two_infos_one_path_copy.png",
|
||||
"tags": base_tags, # keep it in our unit-tests scope for cleanup
|
||||
"user_metadata": {"k": "v"},
|
||||
}
|
||||
async with http.post(api_base + "/api/assets/from-hash", json=payload) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 201, b2
|
||||
second_id = b2["id"]
|
||||
|
||||
# Remove the single underlying file
|
||||
p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png")
|
||||
assert p.exists()
|
||||
p.unlink()
|
||||
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r0:
|
||||
tags0 = await r0.json()
|
||||
assert r0.status == 200, tags0
|
||||
byname0 = {t["name"]: t for t in tags0.get("tags", [])}
|
||||
old_missing = int(byname0.get("missing", {}).get("count", 0))
|
||||
|
||||
# Sync -> both AssetInfos for this asset must receive 'missing'
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{created['id']}") as ga:
|
||||
da = await ga.json()
|
||||
assert ga.status == 200, da
|
||||
assert "missing" in set(da.get("tags", []))
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{second_id}") as gb:
|
||||
db = await gb.json()
|
||||
assert gb.status == 200, db
|
||||
assert "missing" in set(db.get("tags", []))
|
||||
|
||||
# Tag usage for 'missing' increased by exactly 2 (two AssetInfos)
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r1:
|
||||
tags1 = await r1.json()
|
||||
assert r1.status == 200, tags1
|
||||
byname1 = {t["name"]: t for t in tags1.get("tags", [])}
|
||||
new_missing = int(byname1.get("missing", {}).get("count", 0))
|
||||
assert new_missing == old_missing + 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hashed_asset_two_cache_states_partial_delete_then_full_delete(
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""Hashed asset with two cache_state rows:
|
||||
1. delete one file -> sync should NOT add 'missing'
|
||||
2. delete second file -> sync should add 'missing'
|
||||
"""
|
||||
name = "two_cache_states_partial_delete.png"
|
||||
tags = ["input", "unit-tests", "dual"]
|
||||
data = make_asset_bytes(name, 3072)
|
||||
|
||||
created = await asset_factory(name, tags, {}, data)
|
||||
path1 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual" / get_asset_filename(created["asset_hash"], ".png")
|
||||
assert path1.exists()
|
||||
|
||||
# Create a second on-disk copy under the same root but different subfolder
|
||||
path2 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual_copy" / name
|
||||
path2.parent.mkdir(parents=True, exist_ok=True)
|
||||
path2.write_bytes(data)
|
||||
|
||||
# Fast seed so the second path appears (as a seed initially)
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Deduplication of AssetInfo-s will not happen as first AssetInfo has owner='default' and second has empty owner.
|
||||
await run_scan_and_wait("input")
|
||||
|
||||
# Remove only one file and sync -> asset should still be healthy (no 'missing')
|
||||
path1.unlink()
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{created['id']}") as g1:
|
||||
d1 = await g1.json()
|
||||
assert g1.status == 200, d1
|
||||
assert "missing" not in set(d1.get("tags", [])), "Should not be missing while one valid path remains"
|
||||
|
||||
# Baseline 'missing' usage count just before last file removal
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r0:
|
||||
tags0 = await r0.json()
|
||||
assert r0.status == 200, tags0
|
||||
old_missing = int({t["name"]: t for t in tags0.get("tags", [])}.get("missing", {}).get("count", 0))
|
||||
|
||||
# Remove the second (last) file and sync -> now we expect 'missing' on this AssetInfo
|
||||
path2.unlink()
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{created['id']}") as g2:
|
||||
d2 = await g2.json()
|
||||
assert g2.status == 200, d2
|
||||
assert "missing" in set(d2.get("tags", [])), "Missing must be set once no valid paths remain"
|
||||
|
||||
# Tag usage for 'missing' increased by exactly 2 (two AssetInfo for one Asset)
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r1:
|
||||
tags1 = await r1.json()
|
||||
assert r1.status == 200, tags1
|
||||
new_missing = int({t["name"]: t for t in tags1.get("tags", [])}.get("missing", {}).get("count", 0))
|
||||
assert new_missing == old_missing + 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""
|
||||
Fast pass alone clears 'missing' when size and mtime match exactly:
|
||||
1) upload (hashed), record original mtime_ns
|
||||
2) delete -> fast pass adds 'missing'
|
||||
3) restore same bytes and set mtime back to the original value
|
||||
4) run fast pass again -> 'missing' is removed (no slow scan)
|
||||
"""
|
||||
scope = f"fastclear-{uuid.uuid4().hex[:6]}"
|
||||
name = "fastpass_clear.bin"
|
||||
data = make_asset_bytes(name, 3072)
|
||||
|
||||
a = await asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = a["id"]
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
p = base / get_asset_filename(a["asset_hash"], ".bin")
|
||||
st0 = p.stat()
|
||||
orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000))
|
||||
|
||||
# Delete -> fast pass adds 'missing'
|
||||
p.unlink()
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as g1:
|
||||
d1 = await g1.json()
|
||||
assert g1.status == 200, d1
|
||||
assert "missing" in set(d1.get("tags", []))
|
||||
|
||||
# Restore same bytes and revert mtime to the original value
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
p.write_bytes(data)
|
||||
# set both atime and mtime in ns to ensure exact match
|
||||
os.utime(p, ns=(orig_mtime_ns, orig_mtime_ns))
|
||||
|
||||
# Fast pass should clear 'missing' without a scan
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as g2:
|
||||
d2 = await g2.json()
|
||||
assert g2.status == 200, d2
|
||||
assert "missing" not in set(d2.get("tags", [])), "Fast pass should clear 'missing' when size+mtime match"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_fastpass_removes_stale_state_row_no_missing(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
Hashed asset with two states:
|
||||
- delete one file
|
||||
- run fast pass only
|
||||
Expect:
|
||||
- asset stays healthy (no 'missing')
|
||||
- stale AssetCacheState row for the deleted path is removed.
|
||||
We verify this behaviorally by recreating the deleted path and running fast pass again:
|
||||
a new *seed* AssetInfo is created, which proves the old state row was not reused.
|
||||
"""
|
||||
scope = f"stale-{uuid.uuid4().hex[:6]}"
|
||||
name = "two_states.bin"
|
||||
data = make_asset_bytes(name, 2048)
|
||||
|
||||
# Upload hashed asset at path1
|
||||
a = await asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
a1_filename = get_asset_filename(a["asset_hash"], ".bin")
|
||||
p1 = base / a1_filename
|
||||
assert p1.exists()
|
||||
|
||||
aid = a["id"]
|
||||
h = a["asset_hash"]
|
||||
|
||||
# Create second state path2, seed+scan to dedupe into the same Asset
|
||||
p2 = base / "copy" / name
|
||||
p2.parent.mkdir(parents=True, exist_ok=True)
|
||||
p2.write_bytes(data)
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
await run_scan_and_wait(root)
|
||||
|
||||
# Delete path1 and run fast pass -> no 'missing' and stale state row should be removed
|
||||
p1.unlink()
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as g1:
|
||||
d1 = await g1.json()
|
||||
assert g1.status == 200, d1
|
||||
assert "missing" not in set(d1.get("tags", []))
|
||||
|
||||
# Recreate path1 and run fast pass again.
|
||||
# If the stale state row was removed, a NEW seed AssetInfo will appear for this path.
|
||||
p1.write_bytes(data)
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}"},
|
||||
) as rl:
|
||||
bl = await rl.json()
|
||||
assert rl.status == 200, bl
|
||||
items = bl.get("assets", [])
|
||||
# one hashed AssetInfo (asset_hash == h) + one seed AssetInfo (asset_hash == null)
|
||||
hashes = [it.get("asset_hash") for it in items if it.get("name") in (name, a1_filename)]
|
||||
assert h in hashes
|
||||
assert any(x is None for x in hashes), "Expected a new seed AssetInfo for the recreated path"
|
||||
|
||||
# Asset identity still healthy
|
||||
async with http.head(f"{api_base}/api/assets/hash/{h}") as rh:
|
||||
assert rh.status == 200
|
||||
316
tests-assets/test_crud.py
Normal file
316
tests-assets/test_crud.py
Normal file
@ -0,0 +1,316 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_from_hash_success(
|
||||
http: aiohttp.ClientSession, api_base: str, seeded_asset: dict
|
||||
):
|
||||
h = seeded_asset["asset_hash"]
|
||||
payload = {
|
||||
"hash": h,
|
||||
"name": "from_hash_ok.safetensors",
|
||||
"tags": ["models", "checkpoints", "unit-tests", "from-hash"],
|
||||
"user_metadata": {"k": "v"},
|
||||
}
|
||||
async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 201, b1
|
||||
assert b1["asset_hash"] == h
|
||||
assert b1["created_new"] is False
|
||||
aid = b1["id"]
|
||||
|
||||
# Calling again with the same name should return the same AssetInfo id
|
||||
async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 201, b2
|
||||
assert b2["id"] == aid
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_and_delete_asset(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# GET detail
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
||||
detail = await rg.json()
|
||||
assert rg.status == 200, detail
|
||||
assert detail["id"] == aid
|
||||
assert "user_metadata" in detail
|
||||
assert "filename" in detail["user_metadata"]
|
||||
|
||||
# DELETE
|
||||
async with http.delete(f"{api_base}/api/assets/{aid}") as rd:
|
||||
assert rd.status == 204
|
||||
|
||||
# GET again -> 404
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg2:
|
||||
body = await rg2.json()
|
||||
assert rg2.status == 404
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_upon_reference_count(
|
||||
http: aiohttp.ClientSession, api_base: str, seeded_asset: dict
|
||||
):
|
||||
# Create a second reference to the same asset via from-hash
|
||||
src_hash = seeded_asset["asset_hash"]
|
||||
payload = {
|
||||
"hash": src_hash,
|
||||
"name": "unit_ref_copy.safetensors",
|
||||
"tags": ["models", "checkpoints", "unit-tests", "del-flow"],
|
||||
"user_metadata": {"note": "copy"},
|
||||
}
|
||||
async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r2:
|
||||
copy = await r2.json()
|
||||
assert r2.status == 201, copy
|
||||
assert copy["asset_hash"] == src_hash
|
||||
assert copy["created_new"] is False
|
||||
|
||||
# Delete original reference -> asset identity must remain
|
||||
aid1 = seeded_asset["id"]
|
||||
async with http.delete(f"{api_base}/api/assets/{aid1}") as rd1:
|
||||
assert rd1.status == 204
|
||||
|
||||
async with http.head(f"{api_base}/api/assets/hash/{src_hash}") as rh1:
|
||||
assert rh1.status == 200 # identity still present
|
||||
|
||||
# Delete the last reference with default semantics -> identity and cached files removed
|
||||
aid2 = copy["id"]
|
||||
async with http.delete(f"{api_base}/api/assets/{aid2}") as rd2:
|
||||
assert rd2.status == 204
|
||||
|
||||
async with http.head(f"{api_base}/api/assets/hash/{src_hash}") as rh2:
|
||||
assert rh2.status == 404 # orphan content removed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_asset_fields(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
payload = {
|
||||
"name": "unit_1_renamed.safetensors",
|
||||
"tags": ["models", "checkpoints", "unit-tests", "beta"],
|
||||
"user_metadata": {"purpose": "updated", "epoch": 2},
|
||||
}
|
||||
async with http.put(f"{api_base}/api/assets/{aid}", json=payload) as ru:
|
||||
body = await ru.json()
|
||||
assert ru.status == 200, body
|
||||
assert body["name"] == payload["name"]
|
||||
assert "beta" in body["tags"]
|
||||
assert body["user_metadata"]["purpose"] == "updated"
|
||||
# filename should still be present and normalized by server
|
||||
assert "filename" in body["user_metadata"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_head_asset_by_hash(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
h = seeded_asset["asset_hash"]
|
||||
|
||||
# Existing
|
||||
async with http.head(f"{api_base}/api/assets/hash/{h}") as rh1:
|
||||
assert rh1.status == 200
|
||||
|
||||
# Non-existent
|
||||
async with http.head(f"{api_base}/api/assets/hash/blake3:{'0'*64}") as rh2:
|
||||
assert rh2.status == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_head_asset_bad_hash_returns_400_and_no_body(http: aiohttp.ClientSession, api_base: str):
|
||||
# Invalid format; handler returns a JSON error, but HEAD responses must not carry a payload.
|
||||
# aiohttp exposes an empty body for HEAD, so validate status and that there is no payload.
|
||||
async with http.head(f"{api_base}/api/assets/hash/not_a_hash") as rh:
|
||||
assert rh.status == 400
|
||||
body = await rh.read()
|
||||
assert body == b""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_returns_404(http: aiohttp.ClientSession, api_base: str):
|
||||
bogus = str(uuid.uuid4())
|
||||
async with http.delete(f"{api_base}/api/assets/{bogus}") as r:
|
||||
body = await r.json()
|
||||
assert r.status == 404
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_from_hash_invalids(http: aiohttp.ClientSession, api_base: str):
|
||||
# Bad hash algorithm
|
||||
bad = {
|
||||
"hash": "sha256:" + "0" * 64,
|
||||
"name": "x.bin",
|
||||
"tags": ["models", "checkpoints", "unit-tests"],
|
||||
}
|
||||
async with http.post(f"{api_base}/api/assets/from-hash", json=bad) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 400
|
||||
assert b1["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
# Invalid JSON body
|
||||
async with http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}") as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 400
|
||||
assert b2["error"]["code"] == "INVALID_JSON"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_update_download_bad_ids(http: aiohttp.ClientSession, api_base: str):
|
||||
# All endpoints should be not found, as we UUID regex directly in the route definition.
|
||||
bad_id = "not-a-uuid"
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{bad_id}") as r1:
|
||||
assert r1.status == 404
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{bad_id}/content") as r3:
|
||||
assert r3.status == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_requires_at_least_one_field(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
async with http.put(f"{api_base}/api/assets/{aid}", json={}) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_concurrent_delete_same_asset_info_single_204(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""
|
||||
Many concurrent DELETE for the same AssetInfo should result in:
|
||||
- exactly one 204 No Content (the one that actually deleted)
|
||||
- all others 404 Not Found (row already gone)
|
||||
"""
|
||||
scope = f"conc-del-{uuid.uuid4().hex[:6]}"
|
||||
name = "to_delete.bin"
|
||||
data = make_asset_bytes(name, 1536)
|
||||
|
||||
created = await asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = created["id"]
|
||||
|
||||
# Hit the same endpoint N times in parallel.
|
||||
n_tests = 4
|
||||
url = f"{api_base}/api/assets/{aid}?delete_content=false"
|
||||
tasks = [asyncio.create_task(http.delete(url)) for _ in range(n_tests)]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
statuses = [r.status for r in responses]
|
||||
# Drain bodies to close connections (optional but avoids warnings).
|
||||
await asyncio.gather(*[r.read() for r in responses])
|
||||
|
||||
# Exactly one actual delete, the rest must be 404
|
||||
assert statuses.count(204) == 1, f"Expected exactly one 204; got: {statuses}"
|
||||
assert statuses.count(404) == n_tests - 1, f"Expected {n_tests-1} 404; got: {statuses}"
|
||||
|
||||
# The resource must be gone.
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
||||
assert rg.status == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_metadata_filename_is_set_for_seed_asset_without_hash(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
"""Seed ingest (no hash yet) must compute user_metadata['filename'] immediately."""
|
||||
scope = f"seedmeta-{uuid.uuid4().hex[:6]}"
|
||||
name = "seed_filename.bin"
|
||||
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope / "a" / "b"
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
fp = base / name
|
||||
fp.write_bytes(b"Z" * 2048)
|
||||
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
|
||||
) as r1:
|
||||
body = await r1.json()
|
||||
assert r1.status == 200, body
|
||||
matches = [a for a in body.get("assets", []) if a.get("name") == name]
|
||||
assert matches, "Seed asset should be visible after sync"
|
||||
assert matches[0].get("asset_hash") is None # still a seed
|
||||
aid = matches[0]["id"]
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as r2:
|
||||
detail = await r2.json()
|
||||
assert r2.status == 200, detail
|
||||
filename = (detail.get("user_metadata") or {}).get("filename")
|
||||
expected = str(fp.relative_to(comfy_tmp_base_dir / root)).replace("\\", "/")
|
||||
assert filename == expected, f"expected filename={expected}, got {filename!r}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_metadata_filename_computed_and_updated_on_retarget(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
1) Ingest under {root}/unit-tests/<scope>/a/b/<name> -> filename reflects relative path.
|
||||
2) Retarget by copying to {root}/unit-tests/<scope>/x/<new_name>, remove old file,
|
||||
run fast pass + scan -> filename updates to new relative path.
|
||||
"""
|
||||
scope = f"meta-fn-{uuid.uuid4().hex[:6]}"
|
||||
name1 = "compute_metadata_filename.png"
|
||||
name2 = "compute_changed_metadata_filename.png"
|
||||
data = make_asset_bytes(name1, 2100)
|
||||
|
||||
# Upload into nested path a/b
|
||||
a = await asset_factory(name1, [root, "unit-tests", scope, "a", "b"], {}, data)
|
||||
aid = a["id"]
|
||||
|
||||
root_base = comfy_tmp_base_dir / root
|
||||
p1 = (root_base / "unit-tests" / scope / "a" / "b" / get_asset_filename(a["asset_hash"], ".png"))
|
||||
assert p1.exists()
|
||||
|
||||
# filename at ingest should be the path relative to root
|
||||
rel1 = str(p1.relative_to(root_base)).replace("\\", "/")
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as g1:
|
||||
d1 = await g1.json()
|
||||
assert g1.status == 200, d1
|
||||
fn1 = d1["user_metadata"].get("filename")
|
||||
assert fn1 == rel1
|
||||
|
||||
# Retarget: copy to x/<name2>, remove old, then sync+scan
|
||||
p2 = root_base / "unit-tests" / scope / "x" / name2
|
||||
p2.parent.mkdir(parents=True, exist_ok=True)
|
||||
p2.write_bytes(data)
|
||||
if p1.exists():
|
||||
p1.unlink()
|
||||
|
||||
await trigger_sync_seed_assets(http, api_base) # seed the new path
|
||||
await run_scan_and_wait(root) # verify/hash and reconcile
|
||||
|
||||
# filename should now point at x/<name2>
|
||||
rel2 = str(p2.relative_to(root_base)).replace("\\", "/")
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as g2:
|
||||
d2 = await g2.json()
|
||||
assert g2.status == 200, d2
|
||||
fn2 = d2["user_metadata"].get("filename")
|
||||
assert fn2 == rel2
|
||||
168
tests-assets/test_downloads.py
Normal file
168
tests-assets/test_downloads.py
Normal file
@ -0,0 +1,168 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_attachment_and_inline(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# default attachment
|
||||
async with http.get(f"{api_base}/api/assets/{aid}/content") as r1:
|
||||
data = await r1.read()
|
||||
assert r1.status == 200
|
||||
cd = r1.headers.get("Content-Disposition", "")
|
||||
assert "attachment" in cd
|
||||
assert data and len(data) == 4096
|
||||
|
||||
# inline requested
|
||||
async with http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline") as r2:
|
||||
await r2.read()
|
||||
assert r2.status == 200
|
||||
cd2 = r2.headers.get("Content-Disposition", "")
|
||||
assert "inline" in cd2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_download_chooses_existing_state_and_updates_access_time(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
Hashed asset with two state paths: if the first one disappears,
|
||||
GET /content still serves from the remaining path and bumps last_access_time.
|
||||
"""
|
||||
scope = f"dl-first-{uuid.uuid4().hex[:6]}"
|
||||
name = "first_existing_state.bin"
|
||||
data = make_asset_bytes(name, 3072)
|
||||
|
||||
# Upload -> path1
|
||||
a = await asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = a["id"]
|
||||
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
path1 = base / get_asset_filename(a["asset_hash"], ".bin")
|
||||
assert path1.exists()
|
||||
|
||||
# Seed path2 by copying, then scan to dedupe into a second state
|
||||
path2 = base / "alt" / name
|
||||
path2.parent.mkdir(parents=True, exist_ok=True)
|
||||
path2.write_bytes(data)
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
await run_scan_and_wait(root)
|
||||
|
||||
# Remove path1 so server must fall back to path2
|
||||
path1.unlink()
|
||||
|
||||
# last_access_time before
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg0:
|
||||
d0 = await rg0.json()
|
||||
assert rg0.status == 200, d0
|
||||
ts0 = d0.get("last_access_time")
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
async with http.get(f"{api_base}/api/assets/{aid}/content") as r:
|
||||
blob = await r.read()
|
||||
assert r.status == 200
|
||||
assert blob == data # must serve from the surviving state (same bytes)
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg1:
|
||||
d1 = await rg1.json()
|
||||
assert rg1.status == 200, d1
|
||||
ts1 = d1.get("last_access_time")
|
||||
|
||||
def _parse_iso8601(s: Optional[str]) -> Optional[float]:
|
||||
if not s:
|
||||
return None
|
||||
s = s[:-1] if s.endswith("Z") else s
|
||||
return datetime.fromisoformat(s).timestamp()
|
||||
|
||||
t0 = _parse_iso8601(ts0)
|
||||
t1 = _parse_iso8601(ts1)
|
||||
assert t1 is not None
|
||||
if t0 is not None:
|
||||
assert t1 > t0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True)
|
||||
async def test_download_missing_file_returns_404(
|
||||
http: aiohttp.ClientSession, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict
|
||||
):
|
||||
# Remove the underlying file then attempt download.
|
||||
# We initialize fixture without additional tags to know exactly the asset file path.
|
||||
try:
|
||||
aid = seeded_asset["id"]
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
||||
detail = await rg.json()
|
||||
assert rg.status == 200
|
||||
asset_filename = get_asset_filename(detail["asset_hash"], ".safetensors")
|
||||
abs_path = comfy_tmp_base_dir / "models" / "checkpoints" / asset_filename
|
||||
assert abs_path.exists()
|
||||
abs_path.unlink()
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{aid}/content") as r2:
|
||||
assert r2.status == 404
|
||||
body = await r2.json()
|
||||
assert body["error"]["code"] == "FILE_NOT_FOUND"
|
||||
finally:
|
||||
# We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually.
|
||||
async with http.delete(f"{api_base}/api/assets/{aid}") as dr:
|
||||
await dr.read()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_download_404_if_all_states_missing(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""Multi-state asset: after the last remaining on-disk file is removed, download must return 404."""
|
||||
scope = f"dl-404-{uuid.uuid4().hex[:6]}"
|
||||
name = "missing_all_states.bin"
|
||||
data = make_asset_bytes(name, 2048)
|
||||
|
||||
# Upload -> path1
|
||||
a = await asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = a["id"]
|
||||
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
p1 = base / get_asset_filename(a["asset_hash"], ".bin")
|
||||
assert p1.exists()
|
||||
|
||||
# Seed a second state and dedupe
|
||||
p2 = base / "copy" / name
|
||||
p2.parent.mkdir(parents=True, exist_ok=True)
|
||||
p2.write_bytes(data)
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
await run_scan_and_wait(root)
|
||||
|
||||
# Remove first file -> download should still work via the second state
|
||||
p1.unlink()
|
||||
async with http.get(f"{api_base}/api/assets/{aid}/content") as ok1:
|
||||
b1 = await ok1.read()
|
||||
assert ok1.status == 200 and b1 == data
|
||||
|
||||
# Remove the last file -> download must 404
|
||||
p2.unlink()
|
||||
async with http.get(f"{api_base}/api/assets/{aid}/content") as r2:
|
||||
body = await r2.json()
|
||||
assert r2.status == 404
|
||||
assert body["error"]["code"] == "FILE_NOT_FOUND"
|
||||
337
tests-assets/test_list_filter.py
Normal file
337
tests-assets/test_list_filter.py
Normal file
@ -0,0 +1,337 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_paging_and_sort(http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes):
|
||||
names = ["a1_u.safetensors", "a2_u.safetensors", "a3_u.safetensors"]
|
||||
for n in names:
|
||||
await asset_factory(
|
||||
n,
|
||||
["models", "checkpoints", "unit-tests", "paging"],
|
||||
{"epoch": 1},
|
||||
make_asset_bytes(n, size=2048),
|
||||
)
|
||||
|
||||
# name ascending for stable order
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "0"},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
got1 = [a["name"] for a in b1["assets"]]
|
||||
assert got1 == sorted(names)[:2]
|
||||
assert b1["has_more"] is True
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "2"},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
got2 = [a["name"] for a in b2["assets"]]
|
||||
assert got2 == sorted(names)[2:]
|
||||
assert b2["has_more"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_include_exclude_and_name_contains(http: aiohttp.ClientSession, api_base: str, asset_factory):
|
||||
a = await asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024)
|
||||
b = await asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024)
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,alpha", "exclude_tags": "beta", "limit": "50"},
|
||||
) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 200
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert a["name"] in names
|
||||
assert b["name"] not in names
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests", "name_contains": "inc_"},
|
||||
) as r2:
|
||||
body2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
names2 = [x["name"] for x in body2["assets"]]
|
||||
assert a["name"] in names2
|
||||
assert b["name"] in names2
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "non-existing-tag"},
|
||||
) as r2:
|
||||
body3 = await r2.json()
|
||||
assert r2.status == 200
|
||||
assert not body3["assets"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-size"]
|
||||
n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors"
|
||||
await asset_factory(n1, t, {}, make_asset_bytes(n1, 1024))
|
||||
await asset_factory(n2, t, {}, make_asset_bytes(n2, 2048))
|
||||
await asset_factory(n3, t, {}, make_asset_bytes(n3, 3072))
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "asc"},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
names = [a["name"] for a in b1["assets"]]
|
||||
assert names[:3] == [n1, n2, n3]
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "desc"},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
names2 = [a["name"] for a in b2["assets"]]
|
||||
assert names2[:3] == [n3, n2, n1]
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-upd"]
|
||||
a1 = await asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200))
|
||||
a2 = await asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200))
|
||||
|
||||
# Rename the second asset to bump updated_at
|
||||
async with http.put(f"{api_base}/api/assets/{a2['id']}", json={"name": "upd_b_renamed.safetensors"}) as rp:
|
||||
upd = await rp.json()
|
||||
assert rp.status == 200, upd
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-upd", "sort": "updated_at", "order": "desc"},
|
||||
) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 200
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert names[0] == "upd_b_renamed.safetensors"
|
||||
assert a1["name"] in names
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-access"]
|
||||
await asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100))
|
||||
await asyncio.sleep(0.02)
|
||||
a2 = await asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100))
|
||||
|
||||
# Touch last_access_time of b by downloading its content
|
||||
await asyncio.sleep(0.02)
|
||||
async with http.get(f"{api_base}/api/assets/{a2['id']}/content") as dl:
|
||||
assert dl.status == 200
|
||||
await dl.read()
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-access", "sort": "last_access_time", "order": "desc"},
|
||||
) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 200
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert names[0] == a2["name"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-include"]
|
||||
a = await asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva"))
|
||||
await asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb"))
|
||||
|
||||
# CSV + case-insensitive
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names1 = [x["name"] for x in b1["assets"]]
|
||||
assert a["name"] in names1
|
||||
assert not any("beta" in x for x in names1)
|
||||
|
||||
# Repeated query params for include_tags
|
||||
params_multi = [
|
||||
("include_tags", "unit-tests"),
|
||||
("include_tags", "lf-include"),
|
||||
("include_tags", "alpha"),
|
||||
]
|
||||
async with http.get(api_base + "/api/assets", params=params_multi) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
names2 = [x["name"] for x in b2["assets"]]
|
||||
assert a["name"] in names2
|
||||
assert not any("beta" in x for x in names2)
|
||||
|
||||
# Duplicates and spaces in CSV
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": " unit-tests , lf-include , alpha , alpha "},
|
||||
) as r3:
|
||||
b3 = await r3.json()
|
||||
assert r3.status == 200
|
||||
names3 = [x["name"] for x in b3["assets"]]
|
||||
assert a["name"] in names3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-exclude"]
|
||||
a = await asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900))
|
||||
await asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900))
|
||||
|
||||
# Exclude uppercase should work
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names1 = [x["name"] for x in b1["assets"]]
|
||||
assert a["name"] in names1
|
||||
# Repeated excludes with duplicates
|
||||
params_multi = [
|
||||
("include_tags", "unit-tests"),
|
||||
("include_tags", "lf-exclude"),
|
||||
("exclude_tags", "beta"),
|
||||
("exclude_tags", "beta"),
|
||||
]
|
||||
async with http.get(api_base + "/api/assets", params=params_multi) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
names2 = [x["name"] for x in b2["assets"]]
|
||||
assert all("beta" not in x for x in names2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-name"]
|
||||
a1 = await asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800))
|
||||
a2 = await asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800))
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-name", "name_contains": "casemix"},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names1 = [x["name"] for x in b1["assets"]]
|
||||
assert a1["name"] in names1
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-name", "name_contains": ".SAFE"},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
names2 = [x["name"] for x in b2["assets"]]
|
||||
assert a1["name"] in names2
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-name", "name_contains": "case-"},
|
||||
) as r3:
|
||||
b3 = await r3.json()
|
||||
assert r3.status == 200
|
||||
names3 = [x["name"] for x in b3["assets"]]
|
||||
assert a2["name"] in names3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"]
|
||||
await asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600))
|
||||
await asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600))
|
||||
await asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600))
|
||||
|
||||
# Offset far beyond total
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-pagelimits", "limit": "2", "offset": "10"},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
assert not b1["assets"]
|
||||
assert b1["has_more"] is False
|
||||
|
||||
# Boundary large limit (<=500 is valid)
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-pagelimits", "limit": "500"},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
assert len(b2["assets"]) == 3
|
||||
assert b2["has_more"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base):
|
||||
async with http.get(api_base + "/api/assets", params={"offset": "-1"}) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 400
|
||||
assert b1["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
async with http.get(api_base + "/api/assets", params={"limit": "abc"}) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 400
|
||||
assert b2["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_invalid_query_rejected(http: aiohttp.ClientSession, api_base: str):
|
||||
# limit too small
|
||||
async with http.get(api_base + "/api/assets", params={"limit": "0"}) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 400
|
||||
assert b1["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
# bad metadata JSON
|
||||
async with http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 400
|
||||
assert b2["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_name_contains_literal_underscore(
|
||||
http,
|
||||
api_base,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""'name_contains' must treat '_' literally, not as a SQL wildcard.
|
||||
We create:
|
||||
- foo_bar.safetensors (should match)
|
||||
- fooxbar.safetensors (must NOT match if '_' is escaped)
|
||||
- foobar.safetensors (must NOT match)
|
||||
"""
|
||||
scope = f"lf-underscore-{uuid.uuid4().hex[:6]}"
|
||||
tags = ["models", "checkpoints", "unit-tests", scope]
|
||||
|
||||
a = await asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700))
|
||||
b = await asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700))
|
||||
c = await asset_factory("foobar.safetensors", tags, {}, make_asset_bytes("c", 700))
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": "foo_bar"},
|
||||
) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 200, body
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert a["name"] in names, f"Expected literal underscore match to include {a['name']}"
|
||||
assert b["name"] not in names, "Underscore must be escaped — should not match 'fooxbar'"
|
||||
assert c["name"] not in names, "Underscore must be escaped — should not match 'foobar'"
|
||||
assert body["total"] == 1
|
||||
387
tests-assets/test_metadata_filters.py
Normal file
387
tests-assets/test_metadata_filters.py
Normal file
@ -0,0 +1,387 @@
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_and_across_keys_and_types(
|
||||
http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes
|
||||
):
|
||||
name = "mf_and_mix.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-and"]
|
||||
meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
|
||||
await asset_factory(name, tags, meta, make_asset_bytes(name, 4096))
|
||||
|
||||
# All keys must match (AND semantics)
|
||||
f_ok = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-and",
|
||||
"metadata_filter": json.dumps(f_ok),
|
||||
},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names = [a["name"] for a in b1["assets"]]
|
||||
assert name in names
|
||||
|
||||
# One key mismatched -> no result
|
||||
f_bad = {"purpose": "mix", "epoch": 2, "active": True}
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-and",
|
||||
"metadata_filter": json.dumps(f_bad),
|
||||
},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
assert not b2["assets"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_types.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-types"]
|
||||
meta = {"epoch": 1, "active": True}
|
||||
await asset_factory(name, tags, meta, make_asset_bytes(name))
|
||||
|
||||
# int filter matches numeric
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-types",
|
||||
"metadata_filter": json.dumps({"epoch": 1}),
|
||||
},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# string "1" must NOT match numeric 1
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-types",
|
||||
"metadata_filter": json.dumps({"epoch": "1"}),
|
||||
},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200 and not b2["assets"]
|
||||
|
||||
# bool True matches, string "true" must NOT match
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-types",
|
||||
"metadata_filter": json.dumps({"active": True}),
|
||||
},
|
||||
) as r3:
|
||||
b3 = await r3.json()
|
||||
assert r3.status == 200 and any(a["name"] == name for a in b3["assets"])
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-types",
|
||||
"metadata_filter": json.dumps({"active": "true"}),
|
||||
},
|
||||
) as r4:
|
||||
b4 = await r4.json()
|
||||
assert r4.status == 200 and not b4["assets"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_list_scalars.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-list"]
|
||||
meta = {"flags": ["red", "green"]}
|
||||
await asset_factory(name, tags, meta, make_asset_bytes(name, 3000))
|
||||
|
||||
# Any-of should match because "green" is present
|
||||
filt_ok = {"flags": ["blue", "green"]}
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_ok)},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# None of provided flags present -> no match
|
||||
filt_miss = {"flags": ["blue", "yellow"]}
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_miss)},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200 and not b2["assets"]
|
||||
|
||||
# Duplicates in list should not break matching
|
||||
filt_dup = {"flags": ["green", "green", "green"]}
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_dup)},
|
||||
) as r3:
|
||||
b3 = await r3.json()
|
||||
assert r3.status == 200 and any(a["name"] == name for a in b3["assets"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_none_semantics_missing_or_null_and_any_of_with_none(
|
||||
http, api_base, asset_factory, make_asset_bytes
|
||||
):
|
||||
# a1: key missing; a2: explicit null; a3: concrete value
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-none"]
|
||||
a1 = await asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1"))
|
||||
a2 = await asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2"))
|
||||
a3 = await asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3"))
|
||||
|
||||
# Filter {maybe: None} must match a1 and a2, not a3
|
||||
filt = {"maybe": None}
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt), "sort": "name"},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
got = [a["name"] for a in b1["assets"]]
|
||||
assert a1["name"] in got and a2["name"] in got and a3["name"] not in got
|
||||
|
||||
# Any-of with None should include missing/null plus value matches
|
||||
filt_any = {"maybe": [None, "x"]}
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt_any), "sort": "name"},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
got2 = [a["name"] for a in b2["assets"]]
|
||||
assert a1["name"] in got2 and a2["name"] in got2 and a3["name"] in got2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_nested_json.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-nested"]
|
||||
cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}}
|
||||
await asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200))
|
||||
|
||||
# Exact JSON object equality (same structure)
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-nested",
|
||||
"metadata_filter": json.dumps({"config": cfg}),
|
||||
},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# Different JSON object should not match
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-nested",
|
||||
"metadata_filter": json.dumps({"config": {"optimizer": "sgd"}}),
|
||||
},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200 and not b2["assets"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_list_objects.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-objlist"]
|
||||
transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}]
|
||||
await asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048))
|
||||
|
||||
# Any-of for list of objects should match when one element equals the filter object
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-objlist",
|
||||
"metadata_filter": json.dumps({"transforms": {"type": "flip", "p": 0.5}}),
|
||||
},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# Non-matching object -> no match
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-objlist",
|
||||
"metadata_filter": json.dumps({"transforms": {"type": "rotate", "deg": 90}}),
|
||||
},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200 and not b2["assets"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_keys_unicode.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-keys"]
|
||||
meta = {
|
||||
"weird.key": "v1",
|
||||
"path/like": 7,
|
||||
"with:colon": True,
|
||||
"ключ": "значение",
|
||||
"emoji": "🐍",
|
||||
}
|
||||
await asset_factory(name, tags, meta, make_asset_bytes(name, 1500))
|
||||
|
||||
# Match all the special keys
|
||||
filt = {"weird.key": "v1", "path/like": 7, "with:colon": True, "emoji": "🐍"}
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps(filt)},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# Unicode key match
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps({"ключ": "значение"})},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200 and any(a["name"] == name for a in b2["assets"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"]
|
||||
a0 = await asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025))
|
||||
a1 = await asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026))
|
||||
|
||||
# count == 0 must match only a0
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"count": 0})},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names1 = [a["name"] for a in b1["assets"]]
|
||||
assert a0["name"] in names1 and a1["name"] not in names1
|
||||
|
||||
# Any-of list of booleans: True matches second asset
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"choices": True})},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200 and any(a["name"] == a1["name"] for a in b2["assets"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_mixed_list.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-mixed"]
|
||||
meta = {"mix": ["1", 1, True, None]}
|
||||
await asset_factory(name, tags, meta, make_asset_bytes(name, 1999))
|
||||
|
||||
# Should match because 1 is present
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": [2, 1]})},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# Should NOT match for False
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": False})},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200 and not b2["assets"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes):
|
||||
# Use a unique scope tag to avoid interference
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"]
|
||||
x = await asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua"))
|
||||
y = await asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub"))
|
||||
|
||||
# Filtering by unknown key with None should return both (missing key OR null)
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": None})},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names = {a["name"] for a in b1["assets"]}
|
||||
assert x["name"] in names and y["name"] in names
|
||||
|
||||
# Filtering by unknown key with concrete value should return none
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": "x"})},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200 and not b2["assets"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_factory, make_asset_bytes):
|
||||
# alpha matches epoch=1; beta has epoch=2
|
||||
a = await asset_factory(
|
||||
"mf_tag_alpha.safetensors",
|
||||
["models", "checkpoints", "unit-tests", "mf-tag", "alpha"],
|
||||
{"epoch": 1},
|
||||
make_asset_bytes("alpha"),
|
||||
)
|
||||
b = await asset_factory(
|
||||
"mf_tag_beta.safetensors",
|
||||
["models", "checkpoints", "unit-tests", "mf-tag", "beta"],
|
||||
{"epoch": 2},
|
||||
make_asset_bytes("beta"),
|
||||
)
|
||||
|
||||
params = {
|
||||
"include_tags": "unit-tests,mf-tag,alpha",
|
||||
"exclude_tags": "beta",
|
||||
"name_contains": "mf_tag_",
|
||||
"metadata_filter": json.dumps({"epoch": 1}),
|
||||
}
|
||||
async with http.get(api_base + "/api/assets", params=params) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 200
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert a["name"] in names
|
||||
assert b["name"] not in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes):
|
||||
# Three assets in same scope with different sizes and a common filter key
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-sort"]
|
||||
n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors"
|
||||
await asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024))
|
||||
await asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048))
|
||||
await asset_factory(n3, t, {"group": "g"}, make_asset_bytes(n3, 3072))
|
||||
|
||||
# Sort by size ascending with paging
|
||||
q = {
|
||||
"include_tags": "unit-tests,mf-sort",
|
||||
"metadata_filter": json.dumps({"group": "g"}),
|
||||
"sort": "size", "order": "asc", "limit": "2",
|
||||
}
|
||||
async with http.get(api_base + "/api/assets", params=q) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
got1 = [a["name"] for a in b1["assets"]]
|
||||
assert got1 == [n1, n2]
|
||||
assert b1["has_more"] is True
|
||||
|
||||
q2 = {**q, "offset": "2"}
|
||||
async with http.get(api_base + "/api/assets", params=q2) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
got2 = [a["name"] for a in b2["assets"]]
|
||||
assert got2 == [n3]
|
||||
assert b2["has_more"] is False
|
||||
510
tests-assets/test_scans.py
Normal file
510
tests-assets/test_scans.py
Normal file
@ -0,0 +1,510 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
def _base_for(root: str, comfy_tmp_base_dir: Path) -> Path:
|
||||
assert root in ("input", "output")
|
||||
return comfy_tmp_base_dir / root
|
||||
|
||||
|
||||
def _mkbytes(label: str, size: int) -> bytes:
|
||||
seed = sum(label.encode("utf-8")) % 251
|
||||
return bytes((i * 31 + seed) % 256 for i in range(size))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_scan_schedule_idempotent_while_running(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""Idempotent schedule while running."""
|
||||
scope = f"idem-{uuid.uuid4().hex[:6]}"
|
||||
base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create several seed files (non-zero) to ensure the scan runs long enough
|
||||
for i in range(8):
|
||||
(base / f"f{i}.bin").write_bytes(_mkbytes(f"{scope}-{i}", 2 * 1024 * 1024)) # ~2 MiB each
|
||||
|
||||
# Seed -> states with hash=NULL
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Schedule once
|
||||
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 202, b1
|
||||
scans1 = {s["root"]: s for s in b1.get("scans", [])}
|
||||
s1 = scans1.get(root)
|
||||
assert s1 and s1["status"] in {"scheduled", "running"}
|
||||
sid1 = s1["scan_id"]
|
||||
|
||||
# Schedule again immediately — must return the same scan entry (no new worker)
|
||||
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 202, b2
|
||||
scans2 = {s["root"]: s for s in b2.get("scans", [])}
|
||||
s2 = scans2.get(root)
|
||||
assert s2 and s2["scan_id"] == sid1
|
||||
|
||||
# Filtered GET must show exactly one scan for this root
|
||||
async with http.get(api_base + "/api/assets/scan", params={"root": root}) as gs:
|
||||
bs = await gs.json()
|
||||
assert gs.status == 200, bs
|
||||
scans = bs.get("scans", [])
|
||||
assert len(scans) == 1 and scans[0]["scan_id"] == sid1
|
||||
|
||||
# Let it finish to avoid cross-test interference
|
||||
await run_scan_and_wait(root)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_status_filter_by_root_and_file_errors(
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
run_scan_and_wait,
|
||||
asset_factory,
|
||||
):
|
||||
"""Filtering get scan status by root (schedule for both input and output) + file_errors presence."""
|
||||
# Create one hashed asset in input under a dir we will chmod to 000 to force PermissionError in reconcile stage
|
||||
in_scope = f"filter-in-{uuid.uuid4().hex[:6]}"
|
||||
protected_dir = _base_for("input", comfy_tmp_base_dir) / "unit-tests" / in_scope / "deny"
|
||||
protected_dir.mkdir(parents=True, exist_ok=True)
|
||||
name_in = "protected.bin"
|
||||
|
||||
data = b"A" * 4096
|
||||
await asset_factory(name_in, ["input", "unit-tests", in_scope, "deny"], {}, data)
|
||||
try:
|
||||
os.chmod(protected_dir, 0o000)
|
||||
|
||||
# Also schedule a scan for output root (no errors there)
|
||||
out_scope = f"filter-out-{uuid.uuid4().hex[:6]}"
|
||||
out_dir = _base_for("output", comfy_tmp_base_dir) / "unit-tests" / out_scope
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
(out_dir / "ok.bin").write_bytes(b"B" * 1024)
|
||||
await trigger_sync_seed_assets(http, api_base) # seed output file
|
||||
|
||||
# Schedule both roots
|
||||
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": ["input"]}) as r_in:
|
||||
assert r_in.status == 202
|
||||
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": ["output"]}) as r_out:
|
||||
assert r_out.status == 202
|
||||
|
||||
# Wait both to complete, input last (we want its errors)
|
||||
await run_scan_and_wait("output")
|
||||
await run_scan_and_wait("input")
|
||||
|
||||
# Filter by root=input: only input scan listed and must have file_errors
|
||||
async with http.get(api_base + "/api/assets/scan", params={"root": "input"}) as gs:
|
||||
body = await gs.json()
|
||||
assert gs.status == 200, body
|
||||
scans = body.get("scans", [])
|
||||
assert len(scans) == 1
|
||||
errs = scans[0].get("file_errors", [])
|
||||
# Must contain at least one error with a message
|
||||
assert errs and any(e.get("message") for e in errs)
|
||||
finally:
|
||||
# Restore perms so cleanup can remove files/dirs
|
||||
try:
|
||||
os.chmod(protected_dir, 0o755)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
@pytest.mark.skipif(os.name == "nt", reason="Permission-based file_errors are unreliable on Windows")
|
||||
async def test_scan_records_file_errors_permission_denied(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""file_errors recording (permission denied) for input/output"""
|
||||
scope = f"errs-{uuid.uuid4().hex[:6]}"
|
||||
deny_dir = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / "deny"
|
||||
deny_dir.mkdir(parents=True, exist_ok=True)
|
||||
name = "deny.bin"
|
||||
|
||||
a1 = await asset_factory(name, [root, "unit-tests", scope, "deny"], {}, b"X" * 2048)
|
||||
asset_filename = get_asset_filename(a1["asset_hash"], ".bin")
|
||||
try:
|
||||
os.chmod(deny_dir, 0o000)
|
||||
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r:
|
||||
assert r.status == 202
|
||||
await run_scan_and_wait(root)
|
||||
|
||||
async with http.get(api_base + "/api/assets/scan", params={"root": root}) as gs:
|
||||
body = await gs.json()
|
||||
assert gs.status == 200, body
|
||||
scans = body.get("scans", [])
|
||||
assert len(scans) == 1
|
||||
errs = scans[0].get("file_errors", [])
|
||||
# Should contain at least one PermissionError-like record
|
||||
assert errs
|
||||
assert any(e.get("path", "").endswith(asset_filename) and e.get("message") for e in errs)
|
||||
finally:
|
||||
try:
|
||||
os.chmod(deny_dir, 0o755)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_missing_tag_created_and_visible_in_tags(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
):
|
||||
"""Missing tag appears in tags list and increments count (input/output)"""
|
||||
# Baseline count of 'missing' tag (may be absent)
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "1000"}) as r0:
|
||||
t0 = await r0.json()
|
||||
assert r0.status == 200, t0
|
||||
byname = {t["name"]: t for t in t0.get("tags", [])}
|
||||
old_count = int(byname.get("missing", {}).get("count", 0))
|
||||
|
||||
scope = f"miss-{uuid.uuid4().hex[:6]}"
|
||||
name = "missing_me.bin"
|
||||
created = await asset_factory(name, [root, "unit-tests", scope], {}, b"Y" * 4096)
|
||||
|
||||
# Remove the only file and trigger fast pass
|
||||
p = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(created["asset_hash"], ".bin")
|
||||
assert p.exists()
|
||||
p.unlink()
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Asset has 'missing' tag
|
||||
async with http.get(f"{api_base}/api/assets/{created['id']}") as g1:
|
||||
d1 = await g1.json()
|
||||
assert g1.status == 200, d1
|
||||
assert "missing" in set(d1.get("tags", []))
|
||||
|
||||
# Tag list now contains 'missing' with increased count
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r1:
|
||||
t1 = await r1.json()
|
||||
assert r1.status == 200, t1
|
||||
byname1 = {t["name"]: t for t in t1.get("tags", [])}
|
||||
assert "missing" in byname1
|
||||
assert int(byname1["missing"]["count"]) >= old_count + 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_missing_reapplies_after_manual_removal(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
):
|
||||
"""Manual removal of 'missing' does not block automatic re-apply (input/output)"""
|
||||
scope = f"reapply-{uuid.uuid4().hex[:6]}"
|
||||
name = "reapply.bin"
|
||||
created = await asset_factory(name, [root, "unit-tests", scope], {}, b"Z" * 1024)
|
||||
|
||||
# Make it missing
|
||||
p = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(created["asset_hash"], ".bin")
|
||||
p.unlink()
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Remove the 'missing' tag manually
|
||||
async with http.delete(f"{api_base}/api/assets/{created['id']}/tags", json={"tags": ["missing"]}) as rdel:
|
||||
b = await rdel.json()
|
||||
assert rdel.status == 200, b
|
||||
assert "missing" in set(b.get("removed", []))
|
||||
|
||||
# Next sync must re-add it
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
async with http.get(f"{api_base}/api/assets/{created['id']}") as g2:
|
||||
d2 = await g2.json()
|
||||
assert g2.status == 200, d2
|
||||
assert "missing" in set(d2.get("tags", []))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_delete_one_asset_info_of_missing_asset_keeps_identity(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
):
|
||||
"""Delete one AssetInfo of a missing asset while another exists (input/output)"""
|
||||
scope = f"twoinfos-{uuid.uuid4().hex[:6]}"
|
||||
name = "twoinfos.bin"
|
||||
a1 = await asset_factory(name, [root, "unit-tests", scope], {}, b"W" * 2048)
|
||||
|
||||
# Second AssetInfo for the same content under same root (different name to avoid collision)
|
||||
a2 = await asset_factory("copy_" + name, [root, "unit-tests", scope], {}, b"W" * 2048)
|
||||
|
||||
# Remove file of the first (both point to the same Asset, but we know on-disk path name for a1)
|
||||
p1 = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(a1["asset_hash"], ".bin")
|
||||
p1.unlink()
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Both infos should be marked missing
|
||||
async with http.get(f"{api_base}/api/assets/{a1['id']}") as g1:
|
||||
d1 = await g1.json()
|
||||
assert "missing" in set(d1.get("tags", []))
|
||||
async with http.get(f"{api_base}/api/assets/{a2['id']}") as g2:
|
||||
d2 = await g2.json()
|
||||
assert "missing" in set(d2.get("tags", []))
|
||||
|
||||
# Delete one info
|
||||
async with http.delete(f"{api_base}/api/assets/{a1['id']}") as rd:
|
||||
assert rd.status == 204
|
||||
|
||||
# Asset identity still exists (by hash)
|
||||
h = a1["asset_hash"]
|
||||
async with http.head(f"{api_base}/api/assets/hash/{h}") as rh:
|
||||
assert rh.status == 200
|
||||
|
||||
# Remaining info still reflects 'missing'
|
||||
async with http.get(f"{api_base}/api/assets/{a2['id']}") as g3:
|
||||
d3 = await g3.json()
|
||||
assert g3.status == 200 and "missing" in set(d3.get("tags", []))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("keep_root", ["input", "output"])
|
||||
async def test_delete_last_asset_info_false_keeps_asset_and_states_multiroot(
|
||||
keep_root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
make_asset_bytes,
|
||||
asset_factory,
|
||||
):
|
||||
"""Delete last AssetInfo with delete_content_if_orphan=false keeps asset and the underlying on-disk content."""
|
||||
other_root = "output" if keep_root == "input" else "input"
|
||||
scope = f"delfalse-{uuid.uuid4().hex[:6]}"
|
||||
data = make_asset_bytes(scope, 3072)
|
||||
|
||||
# First upload creates the physical file
|
||||
a1 = await asset_factory("keep1.bin", [keep_root, "unit-tests", scope], {}, data)
|
||||
# Second upload (other root) is deduped to the same content; no new file on disk
|
||||
a2 = await asset_factory("keep2.bin", [other_root, "unit-tests", scope], {}, data)
|
||||
|
||||
h = a1["asset_hash"]
|
||||
p1 = _base_for(keep_root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(h, ".bin")
|
||||
|
||||
# De-dup semantics: only the first physical file exists
|
||||
assert p1.exists(), "Expected the first physical file to exist"
|
||||
|
||||
# Delete both AssetInfos; keep content on the very last delete
|
||||
async with http.delete(f"{api_base}/api/assets/{a2['id']}") as rfirst:
|
||||
assert rfirst.status == 204
|
||||
async with http.delete(f"{api_base}/api/assets/{a1['id']}?delete_content=false") as rlast:
|
||||
assert rlast.status == 204
|
||||
|
||||
# Asset identity remains and physical content is still present
|
||||
async with http.head(f"{api_base}/api/assets/hash/{h}") as rh:
|
||||
assert rh.status == 200
|
||||
assert p1.exists(), "Content file should remain after keep-content delete"
|
||||
|
||||
# Cleanup: re-create a reference by hash and then delete to purge content
|
||||
payload = {
|
||||
"hash": h,
|
||||
"name": "cleanup.bin",
|
||||
"tags": [keep_root, "unit-tests", scope, "cleanup"],
|
||||
"user_metadata": {},
|
||||
}
|
||||
async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as rfh:
|
||||
ref = await rfh.json()
|
||||
assert rfh.status == 201, ref
|
||||
cid = ref["id"]
|
||||
async with http.delete(f"{api_base}/api/assets/{cid}") as rdel:
|
||||
assert rdel.status == 204
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_sync_seed_ignores_zero_byte_files(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
scope = f"zero-{uuid.uuid4().hex[:6]}"
|
||||
base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
z = base / "empty.dat"
|
||||
z.write_bytes(b"") # zero bytes
|
||||
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# No AssetInfo created for this zero-byte file
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests," + scope, "name_contains": "empty.dat"},
|
||||
) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 200, body
|
||||
assert not [a for a in body.get("assets", []) if a.get("name") == "empty.dat"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_sync_seed_idempotency(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
scope = f"idemseed-{uuid.uuid4().hex[:6]}"
|
||||
base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
files = [f"f{i}.dat" for i in range(3)]
|
||||
for i, n in enumerate(files):
|
||||
(base / n).write_bytes(_mkbytes(n, 1500 + i * 10))
|
||||
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
async with http.get(api_base + "/api/assets", params={"include_tags": "unit-tests," + scope}) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200, b1
|
||||
c1 = len(b1.get("assets", []))
|
||||
|
||||
# Seed again -> count must stay the same
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
async with http.get(api_base + "/api/assets", params={"include_tags": "unit-tests," + scope}) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200, b2
|
||||
c2 = len(b2.get("assets", []))
|
||||
assert c1 == c2 == len(files)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_sync_seed_nested_dirs_produce_parent_tags(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
scope = f"nest-{uuid.uuid4().hex[:6]}"
|
||||
# nested: unit-tests / scope / a / b / c / deep.txt
|
||||
deep_dir = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / "a" / "b" / "c"
|
||||
deep_dir.mkdir(parents=True, exist_ok=True)
|
||||
(deep_dir / "deep.txt").write_bytes(scope.encode())
|
||||
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": "deep.txt"},
|
||||
) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 200, body
|
||||
assets = body.get("assets", [])
|
||||
assert assets, "seeded asset not found"
|
||||
tags = set(assets[0].get("tags", []))
|
||||
# Must include all parent parts as tags + the root
|
||||
for must in {root, "unit-tests", scope, "a", "b", "c"}:
|
||||
assert must in tags
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_concurrent_seed_hashing_same_file_no_dupes(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
Create a single seed file, then schedule two scans back-to-back.
|
||||
Expect: no duplicate AssetInfos, a single hashed asset, and no scan failure.
|
||||
"""
|
||||
scope = f"conc-seed-{uuid.uuid4().hex[:6]}"
|
||||
name = "seed_concurrent.bin"
|
||||
|
||||
base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
(base / name).write_bytes(b"Z" * 2048)
|
||||
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
s1, s2 = await asyncio.gather(
|
||||
http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}),
|
||||
http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}),
|
||||
)
|
||||
await s1.read()
|
||||
await s2.read()
|
||||
assert s1.status in (200, 202)
|
||||
assert s2.status in (200, 202)
|
||||
|
||||
await run_scan_and_wait(root)
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
|
||||
) as r:
|
||||
b = await r.json()
|
||||
assert r.status == 200, b
|
||||
matches = [a for a in b.get("assets", []) if a.get("name") == name]
|
||||
assert len(matches) == 1
|
||||
assert matches[0].get("asset_hash"), "Seed should have been hashed into an Asset"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_cache_state_retarget_on_content_change_asset_info_stays(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
Start with hashed H1 (AssetInfo A1). Replace file bytes on disk to become H2.
|
||||
After scan: AssetCacheState points to H2; A1 still references H1; downloading A1 -> 404.
|
||||
"""
|
||||
scope = f"retarget-{uuid.uuid4().hex[:6]}"
|
||||
name = "content_change.bin"
|
||||
d1 = make_asset_bytes("v1-" + scope, 2048)
|
||||
|
||||
a1 = await asset_factory(name, [root, "unit-tests", scope], {}, d1)
|
||||
aid = a1["id"]
|
||||
h1 = a1["asset_hash"]
|
||||
|
||||
p = comfy_tmp_base_dir / root / "unit-tests" / scope / get_asset_filename(a1["asset_hash"], ".bin")
|
||||
assert p.exists()
|
||||
|
||||
# Change the bytes in place to force a new content hash (H2)
|
||||
d2 = make_asset_bytes("v2-" + scope, 3072)
|
||||
p.write_bytes(d2)
|
||||
|
||||
# Scan to verify and retarget the state; reconcilers run after scan
|
||||
await run_scan_and_wait(root)
|
||||
|
||||
# AssetInfo still on the old content identity (H1)
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
||||
g = await rg.json()
|
||||
assert rg.status == 200, g
|
||||
assert g.get("asset_hash") == h1
|
||||
|
||||
# Download must fail until a state exists for H1 (we removed the only one by retarget)
|
||||
async with http.get(f"{api_base}/api/assets/{aid}/content") as dl:
|
||||
body = await dl.json()
|
||||
assert dl.status == 404, body
|
||||
assert body["error"]["code"] == "FILE_NOT_FOUND"
|
||||
228
tests-assets/test_tags.py
Normal file
228
tests-assets/test_tags.py
Normal file
@ -0,0 +1,228 @@
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_present(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
# Include zero-usage tags by default
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1:
|
||||
body1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names = [t["name"] for t in body1["tags"]]
|
||||
# A few system tags from migration should exist:
|
||||
assert "models" in names
|
||||
assert "checkpoints" in names
|
||||
|
||||
# Only used tags before we add anything new from this test cycle
|
||||
async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2:
|
||||
body2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
# We already seeded one asset via fixture, so used tags must be non-empty
|
||||
used_names = [t["name"] for t in body2["tags"]]
|
||||
assert "models" in used_names
|
||||
assert "checkpoints" in used_names
|
||||
|
||||
# Prefix filter should refine the list
|
||||
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}) as r3:
|
||||
b3 = await r3.json()
|
||||
assert r3.status == 200
|
||||
names3 = [t["name"] for t in b3["tags"]]
|
||||
assert "unit-tests" in names3
|
||||
assert "models" not in names3 # filtered out by prefix
|
||||
|
||||
# Order by name ascending should be stable
|
||||
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "order": "name_asc"}) as r4:
|
||||
b4 = await r4.json()
|
||||
assert r4.status == 200
|
||||
names4 = [t["name"] for t in b4["tags"]]
|
||||
assert names4 == sorted(names4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_empty_usage(http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes):
|
||||
# Baseline: system tags exist when include_zero (default) is true
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "500"}) as r1:
|
||||
body1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names = [t["name"] for t in body1["tags"]]
|
||||
assert "models" in names and "checkpoints" in names
|
||||
|
||||
# Create a short-lived asset under input with a unique custom tag
|
||||
scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}"
|
||||
custom_tag = f"temp-{uuid.uuid4().hex[:8]}"
|
||||
name = "tag_seed.bin"
|
||||
_asset = await asset_factory(
|
||||
name,
|
||||
["input", "unit-tests", scope, custom_tag],
|
||||
{},
|
||||
make_asset_bytes(name, 512),
|
||||
)
|
||||
|
||||
# While the asset exists, the custom tag must appear when include_zero=false
|
||||
async with http.get(
|
||||
api_base + "/api/tags",
|
||||
params={"include_zero": "false", "prefix": custom_tag, "limit": "50"},
|
||||
) as r2:
|
||||
body2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
used_names = [t["name"] for t in body2["tags"]]
|
||||
assert custom_tag in used_names
|
||||
|
||||
# Delete the asset so the tag usage drops to zero
|
||||
async with http.delete(f"{api_base}/api/assets/{_asset['id']}") as rd:
|
||||
assert rd.status == 204
|
||||
|
||||
# Now the custom tag must not be returned when include_zero=false
|
||||
async with http.get(
|
||||
api_base + "/api/tags",
|
||||
params={"include_zero": "false", "prefix": custom_tag, "limit": "50"},
|
||||
) as r3:
|
||||
body3 = await r3.json()
|
||||
assert r3.status == 200
|
||||
names_after = [t["name"] for t in body3["tags"]]
|
||||
assert custom_tag not in names_after
|
||||
assert not names_after # filtered view should be empty now
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_and_remove_tags(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# Add tags with duplicates and mixed case
|
||||
payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]}
|
||||
async with http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200, b1
|
||||
# normalized, deduplicated; 'unit-tests' was already present from the seed
|
||||
assert set(b1["added"]) == {"newtag", "beta"}
|
||||
assert set(b1["already_present"]) == {"unit-tests"}
|
||||
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"]
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
||||
g = await rg.json()
|
||||
assert rg.status == 200
|
||||
tags_now = set(g["tags"])
|
||||
assert {"newtag", "beta"}.issubset(tags_now)
|
||||
|
||||
# Remove a tag and a non-existent tag
|
||||
payload_del = {"tags": ["newtag", "does-not-exist"]}
|
||||
async with http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
assert set(b2["removed"]) == {"newtag"}
|
||||
assert set(b2["not_present"]) == {"does-not-exist"}
|
||||
|
||||
# Verify remaining tags after deletion
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg2:
|
||||
g2 = await rg2.json()
|
||||
assert rg2.status == 200
|
||||
tags_later = set(g2["tags"])
|
||||
assert "newtag" not in tags_later
|
||||
assert "beta" in tags_later # still present
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_list_order_and_prefix(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
h = seeded_asset["asset_hash"]
|
||||
|
||||
# Add both tags to the seeded asset (usage: orderaaa=1, orderbbb=1)
|
||||
async with http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": ["orderaaa", "orderbbb"]}) as r_add:
|
||||
add_body = await r_add.json()
|
||||
assert r_add.status == 200, add_body
|
||||
|
||||
# Create another AssetInfo from the same content but tagged ONLY with 'orderbbb'.
|
||||
payload = {
|
||||
"hash": h,
|
||||
"name": "order_only_bbb.safetensors",
|
||||
"tags": ["input", "unit-tests", "orderbbb"],
|
||||
"user_metadata": {},
|
||||
}
|
||||
async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r_copy:
|
||||
copy_body = await r_copy.json()
|
||||
assert r_copy.status == 201, copy_body
|
||||
|
||||
# 1) Default order (count_desc): 'orderbbb' should come before 'orderaaa'
|
||||
# because it has higher usage (2 vs 1).
|
||||
async with http.get(api_base + "/api/tags", params={"prefix": "order", "include_zero": "false"}) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200, b1
|
||||
names1 = [t["name"] for t in b1["tags"]]
|
||||
counts1 = {t["name"]: t["count"] for t in b1["tags"]}
|
||||
# Both must be present within the prefix subset
|
||||
assert "orderaaa" in names1 and "orderbbb" in names1
|
||||
# Usage of 'orderbbb' must be >= 'orderaaa'; in our setup it's 2 vs 1
|
||||
assert counts1["orderbbb"] >= counts1["orderaaa"]
|
||||
# And with count_desc, 'orderbbb' appears earlier than 'orderaaa'
|
||||
assert names1.index("orderbbb") < names1.index("orderaaa")
|
||||
|
||||
# 2) name_asc: lexical order should flip the relative order
|
||||
async with http.get(
|
||||
api_base + "/api/tags",
|
||||
params={"prefix": "order", "include_zero": "false", "order": "name_asc"},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200, b2
|
||||
names2 = [t["name"] for t in b2["tags"]]
|
||||
assert "orderaaa" in names2 and "orderbbb" in names2
|
||||
assert names2.index("orderaaa") < names2.index("orderbbb")
|
||||
|
||||
# 3) invalid limit rejected (existing negative case retained)
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "1001"}) as r3:
|
||||
b3 = await r3.json()
|
||||
assert r3.status == 400
|
||||
assert b3["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_endpoints_invalid_bodies(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# Add with empty list
|
||||
async with http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": []}) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 400
|
||||
assert b1["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
# Remove with wrong type
|
||||
async with http.delete(f"{api_base}/api/assets/{aid}/tags", json={"tags": [123]}) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 400
|
||||
assert b2["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
# metadata_filter provided as JSON array should be rejected (must be object)
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"metadata_filter": json.dumps([{"x": 1}])},
|
||||
) as r3:
|
||||
b3 = await r3.json()
|
||||
assert r3.status == 400
|
||||
assert b3["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_prefix_treats_underscore_literal(
|
||||
http,
|
||||
api_base,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""'prefix' for /api/tags must treat '_' literally, not as a wildcard."""
|
||||
base = f"pref_{uuid.uuid4().hex[:6]}"
|
||||
tag_ok = f"{base}_ok" # should match prefix=f"{base}_"
|
||||
tag_bad = f"{base}xok" # must NOT match if '_' is escaped
|
||||
scope = f"tags-underscore-{uuid.uuid4().hex[:6]}"
|
||||
|
||||
await asset_factory("t1.bin", ["input", "unit-tests", scope, tag_ok], {}, make_asset_bytes("t1", 512))
|
||||
await asset_factory("t2.bin", ["input", "unit-tests", scope, tag_bad], {}, make_asset_bytes("t2", 512))
|
||||
|
||||
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": f"{base}_"}) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 200, body
|
||||
names = [t["name"] for t in body["tags"]]
|
||||
assert tag_ok in names, f"Expected {tag_ok} to be returned for prefix '{base}_'"
|
||||
assert tag_bad not in names, f"'{tag_bad}' must not match — '_' is not a wildcard"
|
||||
assert body["total"] == 1
|
||||
325
tests-assets/test_uploads.py
Normal file
325
tests-assets/test_uploads.py
Normal file
@ -0,0 +1,325 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_ok_duplicate_reference(http: aiohttp.ClientSession, api_base: str, make_asset_bytes):
|
||||
name = "dup_a.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||
meta = {"purpose": "dup"}
|
||||
data = make_asset_bytes(name)
|
||||
form1 = aiohttp.FormData()
|
||||
form1.add_field("file", data, filename=name, content_type="application/octet-stream")
|
||||
form1.add_field("tags", json.dumps(tags))
|
||||
form1.add_field("name", name)
|
||||
form1.add_field("user_metadata", json.dumps(meta))
|
||||
async with http.post(api_base + "/api/assets", data=form1) as r1:
|
||||
a1 = await r1.json()
|
||||
assert r1.status == 201, a1
|
||||
assert a1["created_new"] is True
|
||||
|
||||
# Second upload with the same data and name should return created_new == False and the same asset
|
||||
form2 = aiohttp.FormData()
|
||||
form2.add_field("file", data, filename=name, content_type="application/octet-stream")
|
||||
form2.add_field("tags", json.dumps(tags))
|
||||
form2.add_field("name", name)
|
||||
form2.add_field("user_metadata", json.dumps(meta))
|
||||
async with http.post(api_base + "/api/assets", data=form2) as r2:
|
||||
a2 = await r2.json()
|
||||
assert r2.status == 200, a2
|
||||
assert a2["created_new"] is False
|
||||
assert a2["asset_hash"] == a1["asset_hash"]
|
||||
assert a2["id"] == a1["id"] # old reference
|
||||
|
||||
# Third upload with the same data but new name should return created_new == False and the new AssetReference
|
||||
form3 = aiohttp.FormData()
|
||||
form3.add_field("file", data, filename=name, content_type="application/octet-stream")
|
||||
form3.add_field("tags", json.dumps(tags))
|
||||
form3.add_field("name", name + "_d")
|
||||
form3.add_field("user_metadata", json.dumps(meta))
|
||||
async with http.post(api_base + "/api/assets", data=form3) as r2:
|
||||
a3 = await r2.json()
|
||||
assert r2.status == 200, a3
|
||||
assert a3["created_new"] is False
|
||||
assert a3["asset_hash"] == a1["asset_hash"]
|
||||
assert a3["id"] != a1["id"] # old reference
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_fastpath_from_existing_hash_no_file(http: aiohttp.ClientSession, api_base: str):
|
||||
# Seed a small file first
|
||||
name = "fastpath_seed.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests"]
|
||||
meta = {}
|
||||
form1 = aiohttp.FormData()
|
||||
form1.add_field("file", b"B" * 1024, filename=name, content_type="application/octet-stream")
|
||||
form1.add_field("tags", json.dumps(tags))
|
||||
form1.add_field("name", name)
|
||||
form1.add_field("user_metadata", json.dumps(meta))
|
||||
async with http.post(api_base + "/api/assets", data=form1) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 201, b1
|
||||
h = b1["asset_hash"]
|
||||
|
||||
# Now POST /api/assets with only hash and no file
|
||||
form2 = aiohttp.FormData(default_to_multipart=True)
|
||||
form2.add_field("hash", h)
|
||||
form2.add_field("tags", json.dumps(tags))
|
||||
form2.add_field("name", "fastpath_copy.safetensors")
|
||||
form2.add_field("user_metadata", json.dumps({"purpose": "copy"}))
|
||||
async with http.post(api_base + "/api/assets", data=form2) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200, b2 # fast path returns 200 with created_new == False
|
||||
assert b2["created_new"] is False
|
||||
assert b2["asset_hash"] == h
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_fastpath_with_known_hash_and_file(
|
||||
http: aiohttp.ClientSession, api_base: str
|
||||
):
|
||||
# Seed
|
||||
form1 = aiohttp.FormData()
|
||||
form1.add_field("file", b"C" * 128, filename="seed.safetensors", content_type="application/octet-stream")
|
||||
form1.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "fp"]))
|
||||
form1.add_field("name", "seed.safetensors")
|
||||
form1.add_field("user_metadata", json.dumps({}))
|
||||
async with http.post(api_base + "/api/assets", data=form1) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 201, b1
|
||||
h = b1["asset_hash"]
|
||||
|
||||
# Send both file and hash of existing content -> server must drain file and create from hash (200)
|
||||
form2 = aiohttp.FormData()
|
||||
form2.add_field("file", b"ignored" * 10, filename="ignored.bin", content_type="application/octet-stream")
|
||||
form2.add_field("hash", h)
|
||||
form2.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "fp"]))
|
||||
form2.add_field("name", "copy_from_hash.safetensors")
|
||||
form2.add_field("user_metadata", json.dumps({}))
|
||||
async with http.post(api_base + "/api/assets", data=form2) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200, b2
|
||||
assert b2["created_new"] is False
|
||||
assert b2["asset_hash"] == h
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_multiple_tags_fields_are_merged(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"B" * 256, filename="merge.safetensors", content_type="application/octet-stream")
|
||||
form.add_field("tags", "models,checkpoints") # CSV
|
||||
form.add_field("tags", json.dumps(["unit-tests", "alpha"])) # JSON array in second field
|
||||
form.add_field("name", "merge.safetensors")
|
||||
form.add_field("user_metadata", json.dumps({"u": 1}))
|
||||
async with http.post(api_base + "/api/assets", data=form) as r1:
|
||||
created = await r1.json()
|
||||
assert r1.status in (200, 201), created
|
||||
aid = created["id"]
|
||||
|
||||
# Verify all tags are present on the resource
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
||||
detail = await rg.json()
|
||||
assert rg.status == 200, detail
|
||||
tags = set(detail["tags"])
|
||||
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_concurrent_upload_identical_bytes_different_names(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""
|
||||
Two concurrent uploads of identical bytes but different names.
|
||||
Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True.
|
||||
"""
|
||||
scope = f"concupload-{uuid.uuid4().hex[:6]}"
|
||||
name1, name2 = "cu_a.bin", "cu_b.bin"
|
||||
data = make_asset_bytes("concurrent", 4096)
|
||||
tags = [root, "unit-tests", scope]
|
||||
|
||||
def _form(name: str) -> aiohttp.FormData:
|
||||
f = aiohttp.FormData()
|
||||
f.add_field("file", data, filename=name, content_type="application/octet-stream")
|
||||
f.add_field("tags", json.dumps(tags))
|
||||
f.add_field("name", name)
|
||||
f.add_field("user_metadata", json.dumps({}))
|
||||
return f
|
||||
|
||||
r1, r2 = await asyncio.gather(
|
||||
http.post(api_base + "/api/assets", data=_form(name1)),
|
||||
http.post(api_base + "/api/assets", data=_form(name2)),
|
||||
)
|
||||
b1, b2 = await r1.json(), await r2.json()
|
||||
assert r1.status in (200, 201), b1
|
||||
assert r2.status in (200, 201), b2
|
||||
assert b1["asset_hash"] == b2["asset_hash"]
|
||||
assert b1["id"] != b2["id"]
|
||||
|
||||
created_flags = sorted([bool(b1.get("created_new")), bool(b2.get("created_new"))])
|
||||
assert created_flags == [False, True]
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "sort": "name"},
|
||||
) as rl:
|
||||
bl = await rl.json()
|
||||
assert rl.status == 200, bl
|
||||
names = [a["name"] for a in bl.get("assets", [])]
|
||||
assert set([name1, name2]).issubset(names)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_from_hash_endpoint_404(http: aiohttp.ClientSession, api_base: str):
|
||||
payload = {
|
||||
"hash": "blake3:" + "0" * 64,
|
||||
"name": "nonexistent.bin",
|
||||
"tags": ["models", "checkpoints", "unit-tests"],
|
||||
}
|
||||
async with http.post(api_base + "/api/assets/from-hash", json=payload) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 404
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_zero_byte_rejected(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"", filename="empty.safetensors", content_type="application/octet-stream")
|
||||
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "edge"]))
|
||||
form.add_field("name", "empty.safetensors")
|
||||
form.add_field("user_metadata", json.dumps({}))
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "EMPTY_UPLOAD"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_invalid_root_tag_rejected(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"A" * 64, filename="badroot.bin", content_type="application/octet-stream")
|
||||
form.add_field("tags", json.dumps(["not-a-root", "whatever"]))
|
||||
form.add_field("name", "badroot.bin")
|
||||
form.add_field("user_metadata", json.dumps({}))
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_user_metadata_must_be_json(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"A" * 128, filename="badmeta.bin", content_type="application/octet-stream")
|
||||
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "edge"]))
|
||||
form.add_field("name", "badmeta.bin")
|
||||
form.add_field("user_metadata", "{not json}") # invalid
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_requires_multipart(http: aiohttp.ClientSession, api_base: str):
|
||||
async with http.post(api_base + "/api/assets", json={"foo": "bar"}) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 415
|
||||
assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_missing_file_and_hash(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData(default_to_multipart=True)
|
||||
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests"]))
|
||||
form.add_field("name", "x.safetensors")
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "MISSING_FILE"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_models_unknown_category(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"A" * 128, filename="m.safetensors", content_type="application/octet-stream")
|
||||
form.add_field("tags", json.dumps(["models", "no_such_category", "unit-tests"]))
|
||||
form.add_field("name", "m.safetensors")
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
assert body["error"]["message"].startswith("unknown models category")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_models_requires_category(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"A" * 64, filename="nocat.safetensors", content_type="application/octet-stream")
|
||||
form.add_field("tags", json.dumps(["models"])) # missing category
|
||||
form.add_field("name", "nocat.safetensors")
|
||||
form.add_field("user_metadata", json.dumps({}))
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_tags_traversal_guard(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"A" * 256, filename="evil.safetensors", content_type="application/octet-stream")
|
||||
# '..' should be rejected by destination resolver
|
||||
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]))
|
||||
form.add_field("name", "evil.safetensors")
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_duplicate_upload_same_display_name_does_not_clobber(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""
|
||||
Two uploads use the same tags and the same display name but different bytes.
|
||||
With hash-based filenames, they must NOT overwrite each other. Both assets
|
||||
remain accessible and serve their original content.
|
||||
"""
|
||||
scope = f"dup-path-{uuid.uuid4().hex[:6]}"
|
||||
display_name = "same_display.bin"
|
||||
|
||||
d1 = make_asset_bytes(scope + "-v1", 1536)
|
||||
d2 = make_asset_bytes(scope + "-v2", 2048)
|
||||
tags = [root, "unit-tests", scope]
|
||||
|
||||
first = await asset_factory(display_name, tags, {}, d1)
|
||||
second = await asset_factory(display_name, tags, {}, d2)
|
||||
|
||||
assert first["id"] != second["id"]
|
||||
assert first["asset_hash"] != second["asset_hash"] # different content
|
||||
assert first["name"] == second["name"] == display_name
|
||||
|
||||
# Both must be independently retrievable
|
||||
async with http.get(f"{api_base}/api/assets/{first['id']}/content") as r1:
|
||||
b1 = await r1.read()
|
||||
assert r1.status == 200
|
||||
assert b1 == d1
|
||||
async with http.get(f"{api_base}/api/assets/{second['id']}/content") as r2:
|
||||
b2 = await r2.read()
|
||||
assert r2.status == 200
|
||||
assert b2 == d2
|
||||
@ -205,74 +205,3 @@ numpy"""
|
||||
|
||||
# Assert
|
||||
assert version is None
|
||||
|
||||
|
||||
def test_get_templates_version():
|
||||
# Arrange
|
||||
expected_version = "0.1.41"
|
||||
mock_requirements_content = """torch
|
||||
torchsde
|
||||
comfyui-frontend-package==1.25.0
|
||||
comfyui-workflow-templates==0.1.41
|
||||
other-package==1.0.0
|
||||
numpy"""
|
||||
|
||||
# Act
|
||||
with patch("builtins.open", mock_open(read_data=mock_requirements_content)):
|
||||
version = FrontendManager.get_required_templates_version()
|
||||
|
||||
# Assert
|
||||
assert version == expected_version
|
||||
|
||||
|
||||
def test_get_templates_version_not_found():
|
||||
# Arrange
|
||||
mock_requirements_content = """torch
|
||||
torchsde
|
||||
comfyui-frontend-package==1.25.0
|
||||
other-package==1.0.0
|
||||
numpy"""
|
||||
|
||||
# Act
|
||||
with patch("builtins.open", mock_open(read_data=mock_requirements_content)):
|
||||
version = FrontendManager.get_required_templates_version()
|
||||
|
||||
# Assert
|
||||
assert version is None
|
||||
|
||||
|
||||
def test_get_templates_version_invalid_semver():
|
||||
# Arrange
|
||||
mock_requirements_content = """torch
|
||||
torchsde
|
||||
comfyui-workflow-templates==1.0.0.beta
|
||||
other-package==1.0.0
|
||||
numpy"""
|
||||
|
||||
# Act
|
||||
with patch("builtins.open", mock_open(read_data=mock_requirements_content)):
|
||||
version = FrontendManager.get_required_templates_version()
|
||||
|
||||
# Assert
|
||||
assert version is None
|
||||
|
||||
|
||||
def test_get_installed_templates_version():
|
||||
# Arrange
|
||||
expected_version = "0.1.40"
|
||||
|
||||
# Act
|
||||
with patch("app.frontend_management.version", return_value=expected_version):
|
||||
version = FrontendManager.get_installed_templates_version()
|
||||
|
||||
# Assert
|
||||
assert version == expected_version
|
||||
|
||||
|
||||
def test_get_installed_templates_version_not_installed():
|
||||
# Act
|
||||
with patch("app.frontend_management.version", side_effect=Exception("Package not found")):
|
||||
version = FrontendManager.get_installed_templates_version()
|
||||
|
||||
# Assert
|
||||
assert version is None
|
||||
|
||||
@ -1,62 +0,0 @@
|
||||
import pytest
|
||||
import base64
|
||||
import json
|
||||
import struct
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from aiohttp import web
|
||||
from unittest.mock import patch
|
||||
from app.model_manager import ModelFileManager
|
||||
|
||||
pytestmark = (
|
||||
pytest.mark.asyncio
|
||||
) # This applies the asyncio mark to all test functions in the module
|
||||
|
||||
@pytest.fixture
|
||||
def model_manager():
|
||||
return ModelFileManager()
|
||||
|
||||
@pytest.fixture
|
||||
def app(model_manager):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
model_manager.add_routes(routes)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
|
||||
img = Image.new('RGB', (100, 100), 'white')
|
||||
img_byte_arr = BytesIO()
|
||||
img.save(img_byte_arr, format='PNG')
|
||||
img_byte_arr.seek(0)
|
||||
img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
|
||||
|
||||
safetensors_file = tmp_path / "test_model.safetensors"
|
||||
header_bytes = json.dumps({
|
||||
"__metadata__": {
|
||||
"ssmd_cover_images": json.dumps([img_b64])
|
||||
}
|
||||
}).encode('utf-8')
|
||||
length_bytes = struct.pack('<Q', len(header_bytes))
|
||||
with open(safetensors_file, 'wb') as f:
|
||||
f.write(length_bytes)
|
||||
f.write(header_bytes)
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/0/test_model.safetensors')
|
||||
|
||||
# Verify response
|
||||
assert response.status == 200
|
||||
assert response.content_type == 'image/webp'
|
||||
|
||||
# Verify the response contains valid image data
|
||||
img_bytes = BytesIO(await response.read())
|
||||
img = Image.open(img_bytes)
|
||||
assert img.format
|
||||
assert img.format.lower() == 'webp'
|
||||
|
||||
# Clean up
|
||||
img.close()
|
||||
@ -48,13 +48,6 @@ CACHE_SCENARIOS = [
|
||||
"expected_cache": "no-cache",
|
||||
"should_have_header": True,
|
||||
},
|
||||
{
|
||||
"name": "localized_index_json_no_cache",
|
||||
"path": "/templates/index.zh.json",
|
||||
"status": 200,
|
||||
"expected_cache": "no-cache",
|
||||
"should_have_header": True,
|
||||
},
|
||||
# Non-matching files
|
||||
{
|
||||
"name": "html_no_header",
|
||||
|
||||
Reference in New Issue
Block a user