Compare commits

..

40 Commits

Author SHA1 Message Date
6c611b0b99 Change node id to reflect node name 2025-08-18 15:39:16 -07:00
cd54d502fc Make sure models_memory_reserve is considered with inference_memory as well in max func calls 2025-08-18 15:34:53 -07:00
63571c6c3d Renamed to Reserve Additional Memory 2025-08-18 15:04:49 -07:00
bae0c31a68 Added missing model.clone() call 2025-08-18 14:51:11 -07:00
34b1f51f4a Created Add Memory to Reserve node 2025-08-18 14:45:21 -07:00
bd2ab73976 fix(WAN-nodes): invalid nodeid for WanTrackToVideo (#9396) 2025-08-18 03:26:55 -04:00
da2efeaec6 Bump frontend to 1.25.9 (#9394) 2025-08-17 20:21:02 -07:00
7f3b9b16c6 Make step index detection much more robust (#9392) 2025-08-17 18:54:07 -04:00
d4e353a94e Update template to 0.1.60 (#9377) 2025-08-17 17:38:40 -04:00
ed43784b0d WIP Qwen edit model: The diffusion model part. (#9383) 2025-08-17 16:45:39 -04:00
0f2b8525bc Qwen image model refactor. (#9375) 2025-08-16 17:51:28 -04:00
20a84166d0 record audio node (#8716)
* record audio node

* sf
2025-08-16 02:07:12 -04:00
ed2e33c69a bump frontend version to 1.25.8 (#9361) 2025-08-15 23:32:58 -04:00
1702e6df16 Implement wan2.2 camera model. (#9357)
Use the old WanCameraImageToVideo node.
2025-08-15 17:29:58 -04:00
c308a8840a Add FluxKontextMultiReferenceLatentMethod node. (#9356)
This node is only useful if someone trains the kontext model to properly
use multiple reference images via the index method.

The default is the offset method which feeds the multiple images like if
they were stitched together as one. This method works with the current
flux kontext model.
2025-08-15 15:50:39 -04:00
027c63f63a fix(OpenAIGPTImage1): set correct MIME type for multipart uploads to OpenAI edits (#9348) 2025-08-15 14:57:47 -04:00
e08ecfbd8a Add warning when using old pytorch. (#9347) 2025-08-15 00:22:26 -04:00
4e5c230f6a Fix last commit not working on older pytorch. (#9346) 2025-08-14 23:44:02 -04:00
f0d5d0111f Avoid torch compile graphbreak for older pytorch versions (#9344)
Turns out torch.compile has some gaps in context manager decorator
syntax support. I've sent patches to fix that in PyTorch, but it won't
be available for all the folks running older versions of PyTorch, hence
this trivial patch.
2025-08-14 23:41:37 -04:00
ad19a069f6 Make SLG nodes work on Qwen Image model. (#9345) 2025-08-14 23:16:01 -04:00
5d65d6753b convert WAN nodes to V3 schema (#9201) 2025-08-14 21:48:41 -04:00
deebee4ff6 Update default parameters for Moonvalley video nodes (#9290)
* Update default parameters for Moonvalley video nodes

- Changed default negative prompts to a more extensive list for both BaseMoonvalleyVideoNode and MoonvalleyVideo2VideoNode.
- Updated default guidance scale values for both nodes to enhance prompt adherence.
- Set a fixed default seed value for consistency in video generation.

* no message

* ruff fix

---------

Co-authored-by: thorsten <thorsten@tripod-digital.co.nz>
2025-08-14 21:46:55 -04:00
fa570cbf59 Update CODEOWNERS (#9343) 2025-08-14 19:44:22 -04:00
644b23ac0b Make custom node testing checkbox optional in issue templates (#9342)
The checkbox for confirming custom node testing is now optional in both bug report and user support templates. This allows users to submit issues even if they haven't been able to test with custom nodes disabled, making the reporting process more accessible.
2025-08-14 17:36:53 -04:00
72fd4d22b6 av is an essential dependency. (#9341) 2025-08-14 16:03:21 -04:00
e4f7ea105f Added context window support to core sampling code (#9238)
* Added initial support for basic context windows - in progress

* Add prepare_sampling wrapper for context window to more accurately estimate latent memory requirements, fixed merging wrappers/callbacks dicts in prepare_model_patcher

* Made context windows compatible with different dimensions; works for WAN, but results are bad

* Fix comfy.patcher_extension.merge_nested_dicts calls in prepare_model_patcher in sampler_helpers.py

* Considering adding some callbacks to context window code to allow extensions of behavior without the need to rewrite code

* Made dim slicing cleaner

* Add Wan Context WIndows node for testing

* Made context schedule and fuse method functions be stored on the handler instead of needing to be registered in core code to be found

* Moved some code around between node_context_windows.py and context_windows.py

* Change manual context window nodes names/ids

* Added callbacks to IndexListContexHandler

* Adjusted default values for context_length and context_overlap, made schema.inputs definition for WAN Context Windows less annoying

* Make get_resized_cond more robust for various dim sizes

* Fix typo

* Another small fix
2025-08-13 21:33:05 -04:00
c991a5da65 Fix XPU iGPU regressions (#9322)
* Change bf16 check and switch non-blocking to off default with option to force to regain speed on certain classes of iGPUs and refactor xpu check.

* Turn non_blocking off by default for xpu.

* Update README.md for Intel GPUs.
2025-08-13 19:13:35 -04:00
9df8792d4b Make last PR not crash comfy on old pytorch. (#9324) 2025-08-13 15:12:41 -04:00
3da5a07510 SDPA backend priority (#9299) 2025-08-13 14:53:27 -04:00
afa0a45206 Reduce portable size again. (#9323)
* compress more

* test

* not needed
2025-08-13 14:42:08 -04:00
615eb52049 Put back frontend version. (#9317) 2025-08-13 03:48:06 -04:00
d5c1954d5c ComfyUI version 0.3.50 2025-08-13 03:46:38 -04:00
e400f26c8f Downgrade frontend for release. (#9316) 2025-08-13 03:44:54 -04:00
5ca8e2fac3 Update release workflow to python3.13 pytorch cu129 (#9315)
* Try to reduce size of portable even more.

* Update stable release workflow to python 3.13 cu129

* Update dependencies workflow to python3.13 cu129
2025-08-13 03:01:12 -04:00
3294782d19 Update template to 0.1.59 (#9313) 2025-08-13 02:50:50 -04:00
898d88e10e Make torchaudio exception catching less specific (#9309) 2025-08-12 23:34:58 -04:00
560d38f34c Wan2.2 fun control support. (#9292) 2025-08-12 23:26:33 -04:00
e1d4f36d8d Update test release package workflow with python 3.13 cu129. (#9306) 2025-08-12 20:13:04 -04:00
1e3ae1eed8 Update template to 0.1.58 (#9302) 2025-08-12 17:14:27 -04:00
f4231a80b1 fix(Kling Image API Node): do not pass "image_type" when no image (#9271)
* fix(Kling Image API Node): do not pass "image_type" when no image

* fix(Kling Image API Node): raise client-side error when kling_v1 is used with reference image
2025-08-11 17:15:14 -04:00
45 changed files with 1448 additions and 1146 deletions

View File

@ -22,7 +22,7 @@ body:
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
options:
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
required: true
required: false
- type: textarea
attributes:
label: Expected Behavior

View File

@ -18,7 +18,7 @@ body:
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
options:
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
required: true
required: false
- type: textarea
attributes:
label: Your question

View File

@ -12,17 +12,17 @@ on:
description: 'CUDA version'
required: true
type: string
default: "128"
default: "129"
python_minor:
description: 'Python minor version'
required: true
type: string
default: "12"
default: "13"
python_patch:
description: 'Python patch version'
required: true
type: string
default: "10"
default: "6"
jobs:
@ -66,8 +66,13 @@ jobs:
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
cd ..
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
@ -85,7 +90,7 @@ jobs:
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
cd ComfyUI_windows_portable

View File

@ -17,19 +17,19 @@ on:
description: 'cuda version'
required: true
type: string
default: "128"
default: "129"
python_minor:
description: 'python minor version'
required: true
type: string
default: "12"
default: "13"
python_patch:
description: 'python patch version'
required: true
type: string
default: "10"
default: "6"
# push:
# branches:
# - master

View File

@ -7,19 +7,19 @@ on:
description: 'cuda version'
required: true
type: string
default: "128"
default: "129"
python_minor:
description: 'python minor version'
required: true
type: string
default: "12"
default: "13"
python_patch:
description: 'python patch version'
required: true
type: string
default: "10"
default: "6"
# push:
# branches:
# - master
@ -64,6 +64,10 @@ jobs:
./python.exe get-pip.py
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd
@ -82,7 +86,7 @@ jobs:
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
cd ComfyUI_windows_portable

View File

@ -5,20 +5,21 @@
# Inlined the team members for now.
# Maintainers
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
# Python web server
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
# Node developers
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
/comfy_api_nodes/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill

View File

@ -39,7 +39,7 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
## Get Started
#### [Desktop Application](https://www.comfy.org/download)
- The easiest way to get started.
- The easiest way to get started.
- Available on Windows & macOS.
#### [Windows Portable Package](#installing)
@ -211,27 +211,19 @@ This is the command to install the nightly with ROCm 6.4 which might have some p
### Intel GPUs (Windows and Linux)
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
1. To install PyTorch nightly, use the following command:
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
1. To install PyTorch xpu, use the following command:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu```
This is the command to install the Pytorch xpu nightly which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
2. Launch ComfyUI by running `python main.py`
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below:
```
conda install libuv
pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
```
For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
### NVIDIA
@ -352,7 +344,7 @@ Generate a self-signed certificate (not appropriate for shared/production use) a
Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`.
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
## Support and dev channel

View File

@ -1,40 +0,0 @@
"""init
Revision ID: e9c714da8d57
Revises:
Create Date: 2025-05-30 20:14:33.772039
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'e9c714da8d57'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
op.create_table('model',
sa.Column('type', sa.Text(), nullable=False),
sa.Column('path', sa.Text(), nullable=False),
sa.Column('file_name', sa.Text(), nullable=True),
sa.Column('file_size', sa.Integer(), nullable=True),
sa.Column('hash', sa.Text(), nullable=True),
sa.Column('hash_algorithm', sa.Text(), nullable=True),
sa.Column('source_url', sa.Text(), nullable=True),
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
sa.PrimaryKeyConstraint('type', 'path')
)
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('model')
# ### end Alembic commands ###

View File

@ -1,11 +1,4 @@
from sqlalchemy import (
Column,
Integer,
Text,
DateTime,
)
from sqlalchemy.orm import declarative_base
from sqlalchemy.sql import func
Base = declarative_base()
@ -18,42 +11,4 @@ def to_dict(obj):
if (val := getattr(obj, field))
}
class Model(Base):
"""
sqlalchemy model representing a model file in the system.
This class defines the database schema for storing information about model files,
including their type, path, hash, and when they were added to the system.
Attributes:
type (Text): The type of the model, this is the name of the folder in the models folder (primary key)
path (Text): The file path of the model relative to the type folder (primary key)
file_name (Text): The name of the model file
file_size (Integer): The size of the model file in bytes
hash (Text): A hash of the model file
hash_algorithm (Text): The algorithm used to generate the hash
source_url (Text): The URL of the model file
date_added (DateTime): Timestamp of when the model was added to the system
"""
__tablename__ = "model"
type = Column(Text, primary_key=True)
path = Column(Text, primary_key=True)
file_name = Column(Text)
file_size = Column(Integer)
hash = Column(Text)
hash_algorithm = Column(Text)
source_url = Column(Text)
date_added = Column(DateTime, server_default=func.now())
def to_dict(self):
"""
Convert the model instance to a dictionary representation.
Returns:
dict: A dictionary containing the attributes of the model
"""
dict = to_dict(self)
return dict
# TODO: Define models here

View File

@ -196,17 +196,6 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager:
"""
A class to manage ComfyUI frontend versions and installations.
This class handles the initialization and management of different frontend versions,
including the default frontend from the pip package and custom frontend versions
from GitHub repositories.
Attributes:
CUSTOM_FRONTENDS_ROOT (str): The root directory where custom frontend versions are stored.
"""
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
@ -216,15 +205,6 @@ class FrontendManager:
@classmethod
def default_frontend_path(cls) -> str:
"""
Get the path to the default frontend installation from the pip package.
Returns:
str: The path to the default frontend static files.
Raises:
SystemExit: If the comfyui-frontend-package is not installed.
"""
try:
import comfyui_frontend_package
@ -245,15 +225,6 @@ comfyui-frontend-package is not installed.
@classmethod
def templates_path(cls) -> str:
"""
Get the path to the workflow templates.
Returns:
str: The path to the workflow templates directory.
Raises:
SystemExit: If the comfyui-workflow-templates package is not installed.
"""
try:
import comfyui_workflow_templates
@ -289,16 +260,11 @@ comfyui-workflow-templates is not installed.
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
Parse a version string into its components.
The version string should be in the format: 'owner/repo@version'
where version can be either a semantic version (v1.2.3) or 'latest'.
Args:
value (str): The version string to parse.
Returns:
tuple[str, str, str]: A tuple containing (owner, repo, version).
tuple[str, str]: A tuple containing provider name and version.
Raises:
argparse.ArgumentTypeError: If the version string is invalid.
@ -315,22 +281,18 @@ comfyui-workflow-templates is not installed.
cls, version_string: str, provider: Optional[FrontEndProvider] = None
) -> str:
"""
Initialize a frontend version without error handling.
This method attempts to initialize a specific frontend version, either from
the default pip package or from a custom GitHub repository. It will download
and extract the frontend files if necessary.
Initializes the frontend for the specified version.
Args:
version_string (str): The version string specifying which frontend to use.
provider (FrontEndProvider, optional): The provider to use for custom frontends.
version_string (str): The version string.
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
Returns:
str: The path to the initialized frontend.
Raises:
Exception: If there is an error during initialization (e.g., network timeout,
invalid URL, or missing assets).
Exception: If there is an error during the initialization process.
main error source might be request timeout or invalid URL.
"""
if version_string == DEFAULT_VERSION_STRING:
check_frontend_version()
@ -382,17 +344,13 @@ comfyui-workflow-templates is not installed.
@classmethod
def init_frontend(cls, version_string: str) -> str:
"""
Initialize a frontend version with error handling.
This is the main method to initialize a frontend version. It wraps init_frontend_unsafe
with error handling, falling back to the default frontend if initialization fails.
Initializes the frontend with the specified version string.
Args:
version_string (str): The version string specifying which frontend to use.
version_string (str): The version string to initialize the frontend with.
Returns:
str: The path to the initialized frontend. If initialization fails,
returns the path to the default frontend.
str: The path of the initialized frontend.
"""
try:
return cls.init_frontend_unsafe(version_string)

View File

@ -1,331 +0,0 @@
import os
import logging
import time
import requests
from tqdm import tqdm
from folder_paths import get_relative_path, get_full_path
from app.database.db import create_session, dependencies_available, can_create_session
import blake3
import comfy.utils
if dependencies_available():
from app.database.models import Model
class ModelProcessor:
def _validate_path(self, model_path):
try:
if not self._file_exists(model_path):
logging.error(f"Model file not found: {model_path}")
return None
result = get_relative_path(model_path)
if not result:
logging.error(
f"Model file not in a recognized model directory: {model_path}"
)
return None
return result
except Exception as e:
logging.error(f"Error validating model path {model_path}: {str(e)}")
return None
def _file_exists(self, path):
"""Check if a file exists."""
return os.path.exists(path)
def _get_file_size(self, path):
"""Get file size."""
return os.path.getsize(path)
def _get_hasher(self):
return blake3.blake3()
def _hash_file(self, model_path):
try:
hasher = self._get_hasher()
with open(model_path, "rb", buffering=0) as f:
b = bytearray(128 * 1024)
mv = memoryview(b)
while n := f.readinto(mv):
hasher.update(mv[:n])
return hasher.hexdigest()
except Exception as e:
logging.error(f"Error hashing file {model_path}: {str(e)}")
return None
def _get_existing_model(self, session, model_type, model_relative_path):
return (
session.query(Model)
.filter(Model.type == model_type)
.filter(Model.path == model_relative_path)
.first()
)
def _ensure_source_url(self, session, model, source_url):
if model.source_url is None:
model.source_url = source_url
session.commit()
def _update_database(
self,
session,
model_type,
model_path,
model_relative_path,
model_hash,
model,
source_url,
):
try:
if not model:
model = self._get_existing_model(
session, model_type, model_relative_path
)
if not model:
model = Model(
path=model_relative_path,
type=model_type,
file_name=os.path.basename(model_path),
)
session.add(model)
model.file_size = self._get_file_size(model_path)
model.hash = model_hash
if model_hash:
model.hash_algorithm = "blake3"
model.source_url = source_url
session.commit()
return model
except Exception as e:
logging.error(
f"Error updating database for {model_relative_path}: {str(e)}"
)
def process_file(self, model_path, source_url=None, model_hash=None):
"""
Process a model file and update the database with metadata.
If the file already exists and matches the database, it will not be processed again.
Returns the model object or if an error occurs, returns None.
"""
try:
if not can_create_session():
return
result = self._validate_path(model_path)
if not result:
return
model_type, model_relative_path = result
with create_session() as session:
session.expire_on_commit = False
existing_model = self._get_existing_model(
session, model_type, model_relative_path
)
if (
existing_model
and existing_model.hash
and existing_model.file_size == self._get_file_size(model_path)
):
# File exists with hash and same size, no need to process
self._ensure_source_url(session, existing_model, source_url)
return existing_model
if model_hash:
model_hash = model_hash.lower()
logging.info(f"Using provided hash: {model_hash}")
else:
start_time = time.time()
logging.info(f"Hashing model {model_relative_path}")
model_hash = self._hash_file(model_path)
if not model_hash:
return
logging.info(
f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)"
)
return self._update_database(
session,
model_type,
model_path,
model_relative_path,
model_hash,
existing_model,
source_url,
)
except Exception as e:
logging.error(f"Error processing model file {model_path}: {str(e)}")
return None
def retrieve_model_by_hash(self, model_hash, model_type=None, session=None):
"""
Retrieve a model file from the database by hash and optionally by model type.
Returns the model object or None if the model doesnt exist or an error occurs.
"""
try:
if not can_create_session():
return
dispose_session = False
if session is None:
session = create_session()
dispose_session = True
model = session.query(Model).filter(Model.hash == model_hash)
if model_type is not None:
model = model.filter(Model.type == model_type)
return model.first()
except Exception as e:
logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}")
return None
finally:
if dispose_session:
session.close()
def retrieve_hash(self, model_path, model_type=None):
"""
Retrieve the hash of a model file from the database.
Returns the hash or None if the model doesnt exist or an error occurs.
"""
try:
if not can_create_session():
return
if model_type is not None:
result = self._validate_path(model_path)
if not result:
return None
model_type, model_relative_path = result
with create_session() as session:
model = self._get_existing_model(
session, model_type, model_relative_path
)
if model and model.hash:
return model.hash
return None
except Exception as e:
logging.error(f"Error retrieving hash for {model_path}: {str(e)}")
return None
def _validate_file_extension(self, file_name):
"""Validate that the file extension is supported."""
extension = os.path.splitext(file_name)[1]
if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"):
raise ValueError(f"Unsupported unsafe file for download: {file_name}")
def _check_existing_file(self, model_type, file_name, expected_hash):
"""Check if file exists and has correct hash."""
destination_path = get_full_path(model_type, file_name, allow_missing=True)
if self._file_exists(destination_path):
model = self.process_file(destination_path)
if model and (expected_hash is None or model.hash == expected_hash):
logging.debug(
f"File {destination_path} already exists in the database and has the correct hash or no hash was provided."
)
return destination_path
else:
raise ValueError(
f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again."
)
return None
def _check_existing_file_by_hash(self, hash, type, url):
"""Check if a file with the given hash exists in the database and on disk."""
hash = hash.lower()
with create_session() as session:
model = self.retrieve_model_by_hash(hash, type, session)
if model:
existing_path = get_full_path(type, model.path)
if existing_path:
logging.debug(
f"File {model.path} already exists in the database at {existing_path}"
)
self._ensure_source_url(session, model, url)
return existing_path
else:
logging.debug(
f"File {model.path} exists in the database but not on disk"
)
return None
def _download_file(self, url, destination_path, hasher):
"""Download a file and update the hasher with its contents."""
response = requests.get(url, stream=True)
logging.info(f"Downloading {url} to {destination_path}")
with open(destination_path, "wb") as f:
total_size = int(response.headers.get("content-length", 0))
if total_size > 0:
pbar = comfy.utils.ProgressBar(total_size)
else:
pbar = None
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
for chunk in response.iter_content(chunk_size=128 * 1024):
if chunk:
f.write(chunk)
hasher.update(chunk)
progress_bar.update(len(chunk))
if pbar:
pbar.update(len(chunk))
def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path):
"""Verify that the downloaded file has the expected hash."""
if expected_hash is not None and calculated_hash != expected_hash:
self._remove_file(destination_path)
raise ValueError(
f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}"
)
def _remove_file(self, file_path):
"""Remove a file from disk."""
os.remove(file_path)
def ensure_downloaded(self, type, url, desired_file_name, hash=None):
"""
Ensure a model file is downloaded and has the correct hash.
Returns the path to the downloaded file.
"""
logging.debug(
f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'"
)
# Validate file extension
self._validate_file_extension(desired_file_name)
# Check if file exists with correct hash
if hash:
existing_path = self._check_existing_file_by_hash(hash, type, url)
if existing_path:
return existing_path
# Check if file exists locally
destination_path = get_full_path(type, desired_file_name, allow_missing=True)
existing_path = self._check_existing_file(type, desired_file_name, hash)
if existing_path:
return existing_path
# Download the file
hasher = self._get_hasher()
self._download_file(url, destination_path, hasher)
# Verify hash
calculated_hash = hasher.hexdigest()
self._verify_downloaded_hash(calculated_hash, hash, destination_path)
# Update database
self.process_file(destination_path, url, calculated_hash)
# TODO: Notify frontend to reload models
return destination_path
model_processor = ModelProcessor()

View File

@ -132,6 +132,8 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
@ -210,7 +212,6 @@ database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
if comfy.options.args_parsing:
args = parser.parse_args()

540
comfy/context_windows.py Normal file
View File

@ -0,0 +1,540 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable
import torch
import numpy as np
import collections
from dataclasses import dataclass
from abc import ABC, abstractmethod
import logging
import comfy.model_management
import comfy.patcher_extension
if TYPE_CHECKING:
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.controlnet import ControlBase
class ContextWindowABC(ABC):
def __init__(self):
...
@abstractmethod
def get_tensor(self, full: torch.Tensor) -> torch.Tensor:
"""
Get torch.Tensor applicable to current window.
"""
raise NotImplementedError("Not implemented.")
@abstractmethod
def add_window(self, full: torch.Tensor, to_add: torch.Tensor) -> torch.Tensor:
"""
Apply torch.Tensor of window to the full tensor, in place. Returns reference to updated full tensor, not a copy.
"""
raise NotImplementedError("Not implemented.")
class ContextHandlerABC(ABC):
def __init__(self):
...
@abstractmethod
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
raise NotImplementedError("Not implemented.")
@abstractmethod
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: ContextWindowABC, device=None) -> list:
raise NotImplementedError("Not implemented.")
@abstractmethod
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
raise NotImplementedError("Not implemented.")
class IndexListContextWindow(ContextWindowABC):
def __init__(self, index_list: list[int], dim: int=0):
self.index_list = index_list
self.context_length = len(index_list)
self.dim = dim
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
if dim is None:
dim = self.dim
if dim == 0 and full.shape[dim] == 1:
return full
idx = [slice(None)] * dim + [self.index_list]
return full[idx].to(device)
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
if dim is None:
dim = self.dim
idx = [slice(None)] * dim + [self.index_list]
full[idx] += to_add
return full
class IndexListCallbacks:
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
EXECUTE_START = "execute_start"
EXECUTE_CLEANUP = "execute_cleanup"
def init_callbacks(self):
return {}
@dataclass
class ContextSchedule:
name: str
func: Callable
@dataclass
class ContextFuseMethod:
name: str
func: Callable
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
self.context_schedule = context_schedule
self.fuse_method = fuse_method
self.context_length = context_length
self.context_overlap = context_overlap
self.context_stride = context_stride
self.closed_loop = closed_loop
self.dim = dim
self._step = 0
self.callbacks = {}
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
if x_in.size(self.dim) > self.context_length:
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
return True
return False
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
if control.previous_controlnet is not None:
self.prepare_control_objects(control.previous_controlnet, device)
return control
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: IndexListContextWindow, device=None) -> list:
if cond_in is None:
return None
# reuse or resize cond items to match context requirements
resized_cond = []
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in:
resized_actual_cond = actual_cond.copy()
# now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary
for key in actual_cond:
try:
cond_item = actual_cond[key]
if isinstance(cond_item, torch.Tensor):
# check that tensor is the expected length - x.size(0)
if self.dim < cond_item.ndim and cond_item.size(self.dim) == x_in.size(self.dim):
# if so, it's subsetting time - tell controls the expected indeces so they can handle them
actual_cond_item = window.get_tensor(cond_item)
resized_actual_cond[key] = actual_cond_item.to(device)
else:
resized_actual_cond[key] = cond_item.to(device)
# look for control
elif key == "control":
resized_actual_cond[key] = self.prepare_control_objects(cond_item, device)
elif isinstance(cond_item, dict):
new_cond_item = cond_item.copy()
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items():
if isinstance(cond_value, torch.Tensor):
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
# if has cond that is a Tensor, check if needs to be subset
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
elif cond_key == "num_video_frames": # for SVD
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
new_cond_item[cond_key].cond = window.context_length
resized_actual_cond[key] = new_cond_item
else:
resized_actual_cond[key] = cond_item
finally:
del cond_item # just in case to prevent VRAM issues
resized_cond.append(resized_actual_cond)
return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
self._step = int(matches[0].item())
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model
context_windows = self.context_schedule.func(full_length, self, model_options)
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
self.set_step(timestep, model_options)
context_windows = self.get_context_windows(model, x_in, model_options)
enumerated_context_windows = list(enumerate(context_windows))
conds_final = [torch.zeros_like(x_in) for _ in conds]
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
else:
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options)
for enum_window in enumerated_context_windows:
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
for result in results:
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
conds_final, counts_final, biases_final)
try:
# finalize conds
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
# relative is already normalized, so return as is
del counts_final
return conds_final
else:
# normalize conds via division by context usage counts
for i in range(len(conds_final)):
conds_final[i] /= counts_final[i]
del counts_final
return conds_final
finally:
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options)
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
model_options, device=None, first_device=None):
results: list[ContextResults] = []
for window_idx, window in enumerated_context_windows:
# allow processing to end between context window executions for faster Cancel
comfy.model_management.throw_exception_if_processing_interrupted()
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
# update exposed params
model_options["transformer_options"]["context_window"] = window
# get subsections of x, timestep, conds
sub_x = window.get_tensor(x_in, device)
sub_timestep = window.get_tensor(timestep, device, dim=0)
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
if device is not None:
for i in range(len(sub_conds_out)):
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
return results
def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor,
conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
for pos, idx in enumerate(window.index_list):
# bias is the influence of a specific index in relation to the whole context window
bias = 1 - abs(idx - (window.index_list[0] + window.index_list[-1]) / 2) / ((window.index_list[-1] - window.index_list[0] + 1e-2) / 2)
bias = max(1e-2, bias)
# take weighted average relative to total bias of current idx
for i in range(len(sub_conds_out)):
bias_total = biases_final[i][idx]
prev_weight = (bias_total / (bias_total + bias))
new_weight = (bias / (bias_total + bias))
# account for dims of tensors
idx_window = [slice(None)] * self.dim + [idx]
pos_window = [slice(None)] * self.dim + [pos]
# apply new values
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
biases_final[i][idx] = bias_total + bias
else:
# add conds and counts based on weights of fuse method
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
for i in range(len(sub_conds_out)):
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
window.add_window(counts_final[i], weights_tensor)
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
# limit noise_shape length to context_length for more accurate vram use estimation
model_options = kwargs.get("model_options", None)
if model_options is None:
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
handler: IndexListContextHandler = model_options.get("context_handler", None)
if handler is not None:
noise_shape = list(noise_shape)
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
return executor(model, noise_shape, *args, **kwargs)
def create_prepare_sampling_wrapper(model: ModelPatcher):
model.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING,
"ContextWindows_prepare_sampling",
_prepare_sampling_wrapper
)
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
total_dims = len(x_in.shape)
weights_tensor = torch.Tensor(weights).to(device=device)
for _ in range(dim):
weights_tensor = weights_tensor.unsqueeze(0)
for _ in range(total_dims - dim - 1):
weights_tensor = weights_tensor.unsqueeze(-1)
return weights_tensor
def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
total_dims = len(x_in.shape)
shape = []
for _ in range(dim):
shape.append(1)
shape.append(x_in.shape[dim])
for _ in range(total_dims - dim - 1):
shape.append(1)
return shape
class ContextSchedules:
UNIFORM_LOOPED = "looped_uniform"
UNIFORM_STANDARD = "standard_uniform"
STATIC_STANDARD = "standard_static"
BATCHED = "batched"
# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py
def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
windows = []
if num_frames < handler.context_length:
windows.append(list(range(num_frames)))
return windows
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
# obtain uniform windows as normal, looping and all
for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(handler._step)))
for j in range(
int(ordered_halving(handler._step) * context_step) + pad,
num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
(handler.context_length * context_step - handler.context_overlap),
):
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
return windows
def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
# unlike looped, uniform_straight does NOT allow windows that loop back to the beginning;
# instead, they get shifted to the corresponding end of the frames.
# in the case that a window (shifted or not) is identical to the previous one, it gets skipped.
windows = []
if num_frames <= handler.context_length:
windows.append(list(range(num_frames)))
return windows
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
# first, obtain uniform windows as normal, looping and all
for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(handler._step)))
for j in range(
int(ordered_halving(handler._step) * context_step) + pad,
num_frames + pad + (-handler.context_overlap),
(handler.context_length * context_step - handler.context_overlap),
):
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
# now that windows are created, shift any windows that loop, and delete duplicate windows
delete_idxs = []
win_i = 0
while win_i < len(windows):
# if window is rolls over itself, need to shift it
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
if is_roll:
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
shift_window_to_end(windows[win_i], num_frames=num_frames)
# check if next window (cyclical) is missing roll_val
if roll_val not in windows[(win_i+1) % len(windows)]:
# need to insert new window here - just insert window starting at roll_val
windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length)))
# delete window if it's not unique
for pre_i in range(0, win_i):
if windows[win_i] == windows[pre_i]:
delete_idxs.append(win_i)
break
win_i += 1
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
delete_idxs.reverse()
for i in delete_idxs:
windows.pop(i)
return windows
def create_windows_static_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
windows = []
if num_frames <= handler.context_length:
windows.append(list(range(num_frames)))
return windows
# always return the same set of windows
delta = handler.context_length - handler.context_overlap
for start_idx in range(0, num_frames, delta):
# if past the end of frames, move start_idx back to allow same context_length
ending = start_idx + handler.context_length
if ending >= num_frames:
final_delta = ending - num_frames
final_start_idx = start_idx - final_delta
windows.append(list(range(final_start_idx, final_start_idx + handler.context_length)))
break
windows.append(list(range(start_idx, start_idx + handler.context_length)))
return windows
def create_windows_batched(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
windows = []
if num_frames <= handler.context_length:
windows.append(list(range(num_frames)))
return windows
# always return the same set of windows;
# no overlap, just cut up based on context_length;
# last window size will be different if num_frames % opts.context_length != 0
for start_idx in range(0, num_frames, handler.context_length):
windows.append(list(range(start_idx, min(start_idx + handler.context_length, num_frames))))
return windows
def create_windows_default(num_frames: int, handler: IndexListContextHandler):
return [list(range(num_frames))]
CONTEXT_MAPPING = {
ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped,
ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard,
ContextSchedules.STATIC_STANDARD: create_windows_static_standard,
ContextSchedules.BATCHED: create_windows_batched,
}
def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
func = CONTEXT_MAPPING.get(context_schedule, None)
if func is None:
raise ValueError(f"Unknown context_schedule '{context_schedule}'.")
return ContextSchedule(context_schedule, func)
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
def create_weights_flat(length: int, **kwargs) -> list[float]:
# weight is the same for all
return [1.0] * length
def create_weights_pyramid(length: int, **kwargs) -> list[float]:
# weight is based on the distance away from the edge of the context window;
# based on weighted average concept in FreeNoise paper
if length % 2 == 0:
max_weight = length // 2
weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))
else:
max_weight = (length + 1) // 2
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
return weight_sequence
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
# only expected overlap is given different weights
weights_torch = torch.ones((length))
# blend left-side on all except first window
if min(idxs) > 0:
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
weights_torch[:handler.context_overlap] = ramp_up
# blend right-side on all except last window
if max(idxs) < full_length-1:
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
weights_torch[-handler.context_overlap:] = ramp_down
return weights_torch
class ContextFuseMethods:
FLAT = "flat"
PYRAMID = "pyramid"
RELATIVE = "relative"
OVERLAP_LINEAR = "overlap-linear"
LIST = [PYRAMID, FLAT, OVERLAP_LINEAR]
LIST_STATIC = [PYRAMID, RELATIVE, FLAT, OVERLAP_LINEAR]
FUSE_MAPPING = {
ContextFuseMethods.FLAT: create_weights_flat,
ContextFuseMethods.PYRAMID: create_weights_pyramid,
ContextFuseMethods.RELATIVE: create_weights_pyramid,
ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear,
}
def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod:
func = FUSE_MAPPING.get(fuse_method, None)
if func is None:
raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
return ContextFuseMethod(fuse_method, func)
# Returns fraction that has denominator that is a power of 2
def ordered_halving(val):
# get binary value, padded with 0s for 64 bits
bin_str = f"{val:064b}"
# flip binary value, padding included
bin_flip = bin_str[::-1]
# convert binary to int
as_int = int(bin_flip, 2)
# divide by 1 << 64, equivalent to 2**64, or 18446744073709551616,
# or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's)
return as_int / (1 << 64)
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
all_indexes = list(range(num_frames))
for w in windows:
for val in w:
try:
all_indexes.remove(val)
except ValueError:
pass
return all_indexes
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
prev_val = -1
for i, val in enumerate(window):
val = val % num_frames
if val < prev_val:
return True, i
prev_val = val
return False, -1
def shift_window_to_start(window: list[int], num_frames: int):
start_val = window[0]
for i in range(len(window)):
# 1) subtract each element by start_val to move vals relative to the start of all frames
# 2) add num_frames and take modulus to get adjusted vals
window[i] = ((window[i] - start_val) + num_frames) % num_frames
def shift_window_to_end(window: list[int], num_frames: int):
# 1) shift window to start
shift_window_to_start(window, num_frames)
end_val = window[-1]
end_delta = num_frames - end_val - 1
for i in range(len(window)):
# 2) add end_delta to each val to slide windows to end
window[i] = window[i] + end_delta

View File

@ -224,19 +224,27 @@ class Flux(nn.Module):
if ref_latents is not None:
h = 0
w = 0
index = 0
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
for ref in ref_latents:
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
if index_ref_method:
index += 1
h_offset = 0
w_offset = 0
else:
h_offset = h
index = 1
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))

View File

@ -178,7 +178,7 @@ class FourierEmbedder(nn.Module):
class CrossAttentionProcessor:
def __call__(self, attn, q, k, v):
out = F.scaled_dot_product_attention(q, k, v)
out = comfy.ops.scaled_dot_product_attention(q, k, v)
return out

View File

@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
mask = mask.unsqueeze(1)
if SDP_BATCH_LIMIT >= b:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.shape[0] > 1:
m = mask[i : i + SDP_BATCH_LIMIT]
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
out[i : i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention(
q[i : i + SDP_BATCH_LIMIT],
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],

View File

@ -285,7 +285,7 @@ def pytorch_attention(q, k, v):
)
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")

View File

@ -333,21 +333,25 @@ class QwenImageTransformer2DModel(nn.Module):
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
self.gradient_checkpointing = False
def pos_embeds(self, x, context):
def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
txt_start = round(max(h_len, w_len))
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
def forward(
self,
@ -356,19 +360,46 @@ class QwenImageTransformer2DModel(nn.Module):
context,
attention_mask=None,
guidance: torch.Tensor = None,
ref_latents=None,
transformer_options={},
**kwargs
):
timestep = timesteps
encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask
image_rotary_emb = self.pos_embeds(x, context)
hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
if ref_latents is not None:
h = 0
w = 0
index = 0
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
for ref in ref_latents:
if index_ref_method:
index += 1
h_offset = 0
w_offset = 0
else:
index = 1
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size)))
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
@ -383,18 +414,30 @@ class QwenImageTransformer2DModel(nn.Module):
else self.time_text_embed(timestep, guidance, hidden_states)
)
for block in self.transformer_blocks:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@ -391,6 +391,7 @@ class WanModel(torch.nn.Module):
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
in_dim_ref_conv=None,
image_model=None,
device=None,
dtype=None,
@ -484,6 +485,11 @@ class WanModel(torch.nn.Module):
else:
self.img_emb = None
if in_dim_ref_conv is not None:
self.ref_conv = operations.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:], device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
else:
self.ref_conv = None
def forward_orig(
self,
x,
@ -526,6 +532,13 @@ class WanModel(torch.nn.Module):
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
full_ref = None
if self.ref_conv is not None:
full_ref = kwargs.get("reference_latent", None)
if full_ref is not None:
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
x = torch.concat((full_ref, x), dim=1)
# context
context = self.text_embedding(context)
@ -552,6 +565,9 @@ class WanModel(torch.nn.Module):
# head
x = self.head(x, e)
if full_ref is not None:
x = x[:, full_ref.shape[1]:]
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
@ -570,6 +586,9 @@ class WanModel(torch.nn.Module):
x = torch.cat([x, time_dim_concat], dim=2)
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
if self.ref_conv is not None and "reference_latent" in kwargs:
t_len += 1
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
@ -749,7 +768,12 @@ class CameraWanModel(WanModel):
operations=None,
):
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
if model_type == 'camera':
model_type = 'i2v'
else:
model_type = 't2v'
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)

View File

@ -890,6 +890,10 @@ class Flux(BaseModel):
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
ref_latents_method = kwargs.get("reference_latents_method", None)
if ref_latents_method is not None:
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
return out
def extra_conds_shapes(self, **kwargs):
@ -1124,7 +1128,11 @@ class WAN21(BaseModel):
mask = mask.repeat(1, 4, 1, 1, 1)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((mask, image), dim=1)
concat_mask_index = kwargs.get("concat_mask_index", 0)
if concat_mask_index != 0:
return torch.cat((image[:, :concat_mask_index], mask, image[:, concat_mask_index:]), dim=1)
else:
return torch.cat((mask, image), dim=1)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@ -1140,6 +1148,10 @@ class WAN21(BaseModel):
if time_dim_concat is not None:
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
reference_latents = kwargs.get("reference_latents", None)
if reference_latents is not None:
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
return out
@ -1319,4 +1331,14 @@ class QwenImage(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
ref_latents_method = kwargs.get("reference_latents_method", None)
if ref_latents_method is not None:
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
return out

View File

@ -364,7 +364,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "camera"
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "camera"
else:
dit_config["model_type"] = "camera_2.2"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"
@ -373,6 +376,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
if flf_weight is not None:
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix))
if ref_conv_weight is not None:
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D

View File

@ -78,7 +78,6 @@ try:
torch_version = torch.version.__version__
temp = torch_version.split(".")
torch_version_numeric = (int(temp[0]), int(temp[1]))
xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available()
except:
pass
@ -102,10 +101,14 @@ if args.directml is not None:
try:
import intel_extension_for_pytorch as ipex # noqa: F401
_ = torch.xpu.device_count()
xpu_available = xpu_available or torch.xpu.is_available()
except:
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
pass
try:
_ = torch.xpu.device_count()
xpu_available = torch.xpu.is_available()
except:
xpu_available = False
try:
if torch.backends.mps.is_available():
@ -579,16 +582,23 @@ def free_memory(memory_required, device, keep_loaded=[]):
soft_empty_cache()
return unloaded_models
def get_models_memory_reserve(models):
total_reserve = 0
for model in models:
total_reserve += model.get_model_memory_reserve(convert_to_bytes=True)
return total_reserve
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
models_memory_reserve = get_models_memory_reserve(models)
extra_mem = max(inference_memory + models_memory_reserve, memory_required + extra_reserved_memory() + models_memory_reserve)
if minimum_memory_required is None:
minimum_memory_required = extra_mem
else:
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
minimum_memory_required = max(inference_memory + models_memory_reserve, minimum_memory_required + extra_reserved_memory() + models_memory_reserve)
models = set(models)
@ -946,10 +956,12 @@ def pick_weight_dtype(dtype, fallback_dtype, device=None):
return dtype
def device_supports_non_blocking(device):
if args.force_non_blocking:
return True
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
if is_intel_xpu():
return True
if is_intel_xpu(): #xpu does support non blocking but it is slower on iGPUs for some reason so disable by default until situation changes
return False
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
return False
if directml_enabled:
@ -1282,10 +1294,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
if is_intel_xpu():
if torch_version_numeric < (2, 6):
if torch_version_numeric < (2, 3):
return True
else:
return torch.xpu.get_device_capability(device)['has_bfloat16_conversions']
return torch.xpu.is_bf16_supported()
if is_ascend_npu():
return True

View File

@ -24,7 +24,7 @@ import inspect
import logging
import math
import uuid
from typing import Callable, Optional
from typing import Callable, Optional, Union
import torch
@ -84,6 +84,12 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
model_options["disable_cfg1_optimization"] = True
return model_options
def add_model_options_memory_reserve(model_options, memory_reserve_gb: float):
if "model_memory_reserve" not in model_options:
model_options["model_memory_reserve"] = []
model_options["model_memory_reserve"].append(memory_reserve_gb)
return model_options
def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
@ -439,6 +445,17 @@ class ModelPatcher:
self.force_cast_weights = True
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
def add_model_memory_reserve(self, memory_reserve_gb: float):
"""Adds additional expected memory usage for the model, in gigabytes."""
self.model_options = add_model_options_memory_reserve(self.model_options, memory_reserve_gb)
def get_model_memory_reserve(self, convert_to_bytes: bool = False) -> Union[float, int]:
"""Returns the total expected memory usage for the model in gigabytes, or bytes if convert_to_bytes is True."""
total_reserve = sum(self.model_options.get("model_memory_reserve", []))
if convert_to_bytes:
return total_reserve * 1024 * 1024 * 1024
return total_reserve
def add_weight_wrapper(self, name, function):
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
self.patches_uuid = uuid.uuid4()

View File

@ -24,6 +24,32 @@ import comfy.float
import comfy.rmsnorm
import contextlib
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
try:
if torch.cuda.is_available():
from torch.nn.attention import SDPBackend, sdpa_kernel
import inspect
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
SDPA_BACKEND_PRIORITY = [
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
]
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
else:
logging.warning("Torch version too old to set sdpa backend priority.")
except (ModuleNotFoundError, TypeError):
logging.warning("Could not set sdpa backend priority.")
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
def cast_to_input(weight, input, non_blocking=False, copy=True):

View File

@ -1,6 +1,7 @@
import torch
import comfy.model_management
import numbers
import logging
RMSNorm = None
@ -9,6 +10,7 @@ try:
RMSNorm = torch.nn.RMSNorm
except:
rms_norm_torch = None
logging.warning("Please update pytorch to use native RMSNorm")
def rms_norm(x, weight=None, eps=1e-6):

View File

@ -149,7 +149,7 @@ def cleanup_models(conds, models):
cleanup_additional_models(set(control_cleanup))
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
'''
Registers hooks from conds.
'''
@ -158,8 +158,8 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
for k in conds:
get_hooks_from_cond(conds[k], hooks)
# add wrappers and callbacks from ModelPatcher to transformer_options
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("wrappers", {}), model.wrappers, copy_dict1=False)
comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("callbacks", {}), model.callbacks, copy_dict1=False)
# begin registering hooks
registered = comfy.hooks.HookGroup()
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)

View File

@ -16,6 +16,7 @@ import comfy.sampler_helpers
import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
import scipy.stats
import numpy
@ -198,14 +199,20 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options: dict[str]):
handler: comfy.context_windows.ContextHandlerABC = model_options.get("context_handler", None)
if handler is None or not handler.should_use_context(model, conds, x_in, timestep, model_options):
return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options)
return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options)
def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_calc_cond_batch,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
)
return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
out_conds = []
out_counts = []
# separate conds by matching hooks

View File

@ -1046,6 +1046,18 @@ class WAN21_Camera(WAN21_T2V):
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
return out
class WAN22_Camera(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "camera_2.2",
"in_dim": 36,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
return out
class WAN21_Vace(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@ -1260,6 +1272,6 @@ class QwenImage(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]

View File

@ -50,16 +50,10 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def is_html_file(file_path):
with open(file_path, "rb") as f:
content = f.read(100)
return b"<!DOCTYPE html>" in content or b"<html" in content
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
device = torch.device("cpu")
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
@ -72,8 +66,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if return_metadata:
metadata = f.metadata()
except Exception as e:
if is_html_file(ckpt):
raise ValueError("{}\n\nFile path: {}\n\nThe requested file is an HTML document not a safetensors file. Please re-download the file, not the web page.".format(e, ckpt))
if len(e.args) > 0:
message = e.args[0]
if "HeaderTooLarge" in message:
@ -101,13 +93,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
sd = pl_sd
else:
sd = pl_sd
try:
from app.model_processor import model_processor
model_processor.process_file(ckpt)
except Exception as e:
logging.error(f"Error processing file {ckpt}: {e}")
return (sd, metadata) if return_metadata else sd
def save_torch_file(sd, ckpt, metadata=None):

View File

@ -12,7 +12,7 @@ import torch
try:
import torchaudio
TORCH_AUDIO_AVAILABLE = True
except ImportError:
except:
TORCH_AUDIO_AVAILABLE = False
from PIL import Image as PILImage
from PIL.PngImagePlugin import PngInfo

View File

@ -1690,7 +1690,11 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
):
self.validate_prompt(prompt, negative_prompt)
if image is not None:
if image is None:
image_type = None
elif model_name == KlingImageGenModelName.kling_v1:
raise ValueError(f"The model {KlingImageGenModelName.kling_v1.value} does not support reference images.")
else:
image = tensor_to_base64_string(image)
initial_operation = SynchronousOperation(

View File

@ -1,6 +1,5 @@
import logging
from typing import Any, Callable, Optional, TypeVar
import random
import torch
from comfy_api_nodes.util.validation_utils import (
get_image_dimensions,
@ -208,20 +207,29 @@ def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
def _validate_video_dimensions(width: int, height: int) -> None:
"""Validates video dimensions meet Moonvalley V2V requirements."""
supported_resolutions = {
(1920, 1080), (1080, 1920), (1152, 1152),
(1536, 1152), (1152, 1536)
(1920, 1080),
(1080, 1920),
(1152, 1152),
(1536, 1152),
(1152, 1536),
}
if (width, height) not in supported_resolutions:
supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)])
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
supported_list = ", ".join(
[f"{w}x{h}" for w, h in sorted(supported_resolutions)]
)
raise ValueError(
f"Resolution {width}x{height} not supported. Supported: {supported_list}"
)
def _validate_container_format(video: VideoInput) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
raise ValueError(
f"Only MP4 container format supported. Got: {container_format}"
)
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
@ -244,7 +252,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
return video
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
"""
Returns a new VideoInput object trimmed from the beginning to the specified duration,
@ -302,7 +309,9 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
# Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate
estimated_frames = int(duration_sec * fps)
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
target_frames = (
estimated_frames // 16
) * 16 # Round down to nearest multiple of 16
if target_frames == 0:
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
@ -424,7 +433,7 @@ class BaseMoonvalleyVideoNode:
MoonvalleyTextToVideoInferenceParams,
"negative_prompt",
multiline=True,
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts",
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
),
"resolution": (
IO.COMBO,
@ -441,12 +450,11 @@ class BaseMoonvalleyVideoNode:
"tooltip": "Resolution of the output video",
},
),
# "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}),
"prompt_adherence": model_field_to_node_input(
IO.FLOAT,
MoonvalleyTextToVideoInferenceParams,
"guidance_scale",
default=7.0,
default=10.0,
step=1,
min=1,
max=20,
@ -455,13 +463,12 @@ class BaseMoonvalleyVideoNode:
IO.INT,
MoonvalleyTextToVideoInferenceParams,
"seed",
default=random.randint(0, 2**32 - 1),
default=9,
min=0,
max=4294967295,
step=1,
display="number",
tooltip="Random seed value",
control_after_generate=True,
),
"steps": model_field_to_node_input(
IO.INT,
@ -532,9 +539,11 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
# Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png"
image_url = (await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
))[0]
image_url = (
await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
)
)[0]
request = MoonvalleyTextToVideoRequest(
image_url=image_url, prompt_text=prompt, inference_params=inference_params
@ -570,17 +579,39 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text",
multiline=True
IO.STRING,
MoonvalleyVideoToVideoRequest,
"prompt_text",
multiline=True,
),
"negative_prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyVideoToVideoInferenceParams,
"negative_prompt",
multiline=True,
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts"
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
),
"seed": model_field_to_node_input(
IO.INT,
MoonvalleyVideoToVideoInferenceParams,
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display="number",
tooltip="Random seed value",
control_after_generate=False,
),
"prompt_adherence": model_field_to_node_input(
IO.FLOAT,
MoonvalleyVideoToVideoInferenceParams,
"guidance_scale",
default=10.0,
step=1,
min=1,
max=20,
),
"seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
@ -588,7 +619,14 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
"unique_id": "UNIQUE_ID",
},
"optional": {
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}),
"video": (
IO.VIDEO,
{
"default": "",
"multiline": False,
"tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
},
),
"control_type": (
["Motion Transfer", "Pose Transfer"],
{"default": "Motion Transfer"},
@ -602,8 +640,14 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
"max": 100,
"tooltip": "Only used if control_type is 'Motion Transfer'",
},
)
}
),
"image": model_field_to_node_input(
IO.IMAGE,
MoonvalleyTextToVideoRequest,
"image_url",
tooltip="The reference image used to generate the video",
),
},
}
RETURN_TYPES = ("VIDEO",)
@ -613,6 +657,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
video = kwargs.get("video")
image = kwargs.get("image", None)
if not video:
raise MoonvalleyApiError("video is required")
@ -620,8 +665,16 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
video_url = ""
if video:
validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
video_url = await upload_video_to_comfyapi(
validated_video, auth_kwargs=kwargs
)
mime_type = "image/png"
if not image is None:
validate_input_image(image, with_frame_conditioning=True)
image_url = await upload_images_to_comfyapi(
image=image, auth_kwargs=kwargs, max_images=1, mime_type=mime_type
)
control_type = kwargs.get("control_type")
motion_intensity = kwargs.get("motion_intensity")
@ -631,12 +684,12 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
# Only include motion_intensity for Motion Transfer
control_params = {}
if control_type == "Motion Transfer" and motion_intensity is not None:
control_params['motion_intensity'] = motion_intensity
control_params["motion_intensity"] = motion_intensity
inference_params=MoonvalleyVideoToVideoInferenceParams(
inference_params = MoonvalleyVideoToVideoInferenceParams(
negative_prompt=negative_prompt,
seed=kwargs.get("seed"),
control_params=control_params
control_params=control_params,
)
control = self.parseControlParameter(control_type)
@ -647,6 +700,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
prompt_text=prompt,
inference_params=inference_params,
)
request.image_url = image_url if not image is None else None
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
@ -694,15 +748,15 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
inference_params=MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
num_frames=128,
width=width_height.get("width"),
height=width_height.get("height"),
)
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
num_frames=128,
width=width_height.get("width"),
height=width_height.get("height"),
)
request = MoonvalleyTextToVideoRequest(
prompt_text=prompt, inference_params=inference_params
)

View File

@ -464,8 +464,6 @@ class OpenAIGPTImage1(ComfyNodeABC):
path = "/proxy/openai/images/generations"
content_type = "application/json"
request_class = OpenAIImageGenerationRequest
img_binaries = []
mask_binary = None
files = []
if image is not None:
@ -484,14 +482,11 @@ class OpenAIGPTImage1(ComfyNodeABC):
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
img_binary = img_byte_arr
img_binary.name = f"image_{i}.png"
img_binaries.append(img_binary)
if batch_size == 1:
files.append(("image", img_binary))
files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
else:
files.append(("image[]", img_binary))
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
if mask is not None:
if image is None:
@ -511,9 +506,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
mask_img_byte_arr = io.BytesIO()
mask_img.save(mask_img_byte_arr, format="PNG")
mask_img_byte_arr.seek(0)
mask_binary = mask_img_byte_arr
mask_binary.name = "mask.png"
files.append(("mask", mask_binary))
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
# Build the operation
operation = SynchronousOperation(

View File

@ -346,6 +346,24 @@ class LoadAudio:
return "Invalid audio file: {}".format(audio)
return True
class RecordAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {"audio": ("AUDIO_RECORD", {})}}
CATEGORY = "audio"
RETURN_TYPES = ("AUDIO", )
FUNCTION = "load"
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )
NODE_CLASS_MAPPINGS = {
"EmptyLatentAudio": EmptyLatentAudio,
"VAEEncodeAudio": VAEEncodeAudio,
@ -356,6 +374,7 @@ NODE_CLASS_MAPPINGS = {
"LoadAudio": LoadAudio,
"PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio,
"RecordAudio": RecordAudio,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@ -367,4 +386,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"SaveAudio": "Save Audio (FLAC)",
"SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)",
"RecordAudio": "Record Audio",
}

View File

@ -0,0 +1,89 @@
from __future__ import annotations
from comfy_api.latest import ComfyExtension, io
import comfy.context_windows
import nodes
class ContextWindowsManualNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="ContextWindowsManual",
display_name="Context Windows (Manual)",
category="context",
description="Manually set context windows.",
inputs=[
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."),
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."),
io.Combo.Input("context_schedule", options=[
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
comfy.context_windows.ContextSchedules.BATCHED,
], tooltip="The stride of the context window."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
],
outputs=[
io.Model.Output(tooltip="The model with context windows applied during sampling."),
],
is_experimental=True,
)
@classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
model = model.clone()
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
fuse_method=comfy.context_windows.get_matching_fuse_method(fuse_method),
context_length=context_length,
context_overlap=context_overlap,
context_stride=context_stride,
closed_loop=closed_loop,
dim=dim)
# make memory usage calculation only take into account the context window latents
comfy.context_windows.create_prepare_sampling_wrapper(model)
return io.NodeOutput(model)
class WanContextWindowsManualNode(ContextWindowsManualNode):
@classmethod
def define_schema(cls) -> io.Schema:
schema = super().define_schema()
schema.node_id = "WanContextWindowsManual"
schema.display_name = "WAN Context Windows (Manual)"
schema.description = "Manually set context windows for WAN-like models (dim=2)."
schema.inputs = [
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window."),
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window."),
io.Combo.Input("context_schedule", options=[
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
comfy.context_windows.ContextSchedules.BATCHED,
], tooltip="The stride of the context window."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
]
return schema
@classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
class ContextWindowsExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ContextWindowsManualNode,
WanContextWindowsManualNode,
]
def comfy_entrypoint():
return ContextWindowsExtension()

View File

@ -100,9 +100,28 @@ class FluxKontextImageScale:
return (image, )
class FluxKontextMultiReferenceLatentMethod:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"reference_latents_method": (("offset", "index"), ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
EXPERIMENTAL = True
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, reference_latents_method):
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
return (c, )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
"FluxDisableGuidance": FluxDisableGuidance,
"FluxKontextImageScale": FluxKontextImageScale,
"FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod,
}

View File

@ -0,0 +1,33 @@
from comfy_api.latest import io, ComfyExtension
class MemoryReserveNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="ReserveAdditionalMemory",
display_name="Reserve Additional Memory",
description="Adds additional expected memory usage for the model, in gigabytes.",
category="advanced/debug/model",
inputs=[
io.Model.Input("model", tooltip="The model to add memory reserve to."),
io.Float.Input("memory_reserve_gb", min=0.0, default=0.0, max=2048.0, step=0.1, tooltip="The additional expected memory usage for the model, in gigabytes."),
],
outputs=[
io.Model.Output(tooltip="The model with the additional memory reserve."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, memory_reserve_gb: float) -> io.NodeOutput:
model = model.clone()
model.add_model_memory_reserve(memory_reserve_gb)
return io.NodeOutput(model)
class MemoryReserveExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
MemoryReserveNode,
]
def comfy_entrypoint():
return MemoryReserveExtension()

View File

@ -9,29 +9,35 @@ import comfy.clip_vision
import json
import numpy as np
from typing import Tuple
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class WanImageToVideo:
class WanImageToVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
}}
def define_schema(cls):
return io.Schema(
node_id="WanImageToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("start_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None):
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
@ -51,32 +57,36 @@ class WanImageToVideo:
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
return io.NodeOutput(positive, negative, out_latent)
class WanFunControlToVideo:
class WanFunControlToVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"control_video": ("IMAGE", ),
}}
def define_schema(cls):
return io.Schema(
node_id="WanFunControlToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("start_image", optional=True),
io.Image.Input("control_video", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None):
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
@ -101,32 +111,96 @@ class WanFunControlToVideo:
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
return io.NodeOutput(positive, negative, out_latent)
class WanFirstLastFrameToVideo:
class Wan22FunControlToVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ),
"clip_vision_end_image": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
def define_schema(cls):
return io.Schema(
node_id="Wan22FunControlToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("ref_image", optional=True),
io.Image.Input("control_video", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
CATEGORY = "conditioning/video_models"
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
mask[:, :, :start_image.shape[0] + 3] = 0.0
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None):
ref_latent = None
if ref_image is not None:
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
ref_latent = vae.encode(ref_image[:, :, :, :3])
if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(control_video[:, :, :, :3])
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
if ref_latent is not None:
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class WanFirstLastFrameToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanFirstLastFrameToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.ClipVisionOutput.Input("clip_vision_start_image", optional=True),
io.ClipVisionOutput.Input("clip_vision_end_image", optional=True),
io.Image.Input("start_image", optional=True),
io.Image.Input("end_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
@ -167,62 +241,70 @@ class WanFirstLastFrameToVideo:
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
return io.NodeOutput(positive, negative, out_latent)
class WanFunInpaintToVideo:
class WanFunInpaintToVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
def define_schema(cls):
return io.Schema(
node_id="WanFunInpaintToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("start_image", optional=True),
io.Image.Input("end_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None) -> io.NodeOutput:
flfv = WanFirstLastFrameToVideo()
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
return flfv.execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
class WanVaceToVideo:
class WanVaceToVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
},
"optional": {"control_video": ("IMAGE", ),
"control_masks": ("MASK", ),
"reference_image": ("IMAGE", ),
}}
def define_schema(cls):
return io.Schema(
node_id="WanVaceToVideo",
category="conditioning/video_models",
is_experimental=True,
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Float.Input("strength", default=1.0, min=0.0, max=1000.0, step=0.01),
io.Image.Input("control_video", optional=True),
io.Mask.Input("control_masks", optional=True),
io.Image.Input("reference_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
io.Int.Output(display_name="trim_latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT")
RETURN_NAMES = ("positive", "negative", "latent", "trim_latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
EXPERIMENTAL = True
def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None):
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None) -> io.NodeOutput:
latent_length = ((length - 1) // 4) + 1
if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
@ -279,52 +361,59 @@ class WanVaceToVideo:
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent, trim_latent)
return io.NodeOutput(positive, negative, out_latent, trim_latent)
class TrimVideoLatent:
class TrimVideoLatent(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}),
}}
def define_schema(cls):
return io.Schema(
node_id="TrimVideoLatent",
category="latent/video",
is_experimental=True,
inputs=[
io.Latent.Input("samples"),
io.Int.Input("trim_amount", default=0, min=0, max=99999),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/video"
EXPERIMENTAL = True
def op(self, samples, trim_amount):
@classmethod
def execute(cls, samples, trim_amount) -> io.NodeOutput:
samples_out = samples.copy()
s1 = samples["samples"]
samples_out["samples"] = s1[:, :, trim_amount:]
return (samples_out,)
return io.NodeOutput(samples_out)
class WanCameraImageToVideo:
class WanCameraImageToVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"camera_conditions": ("WAN_CAMERA_EMBEDDING", ),
}}
def define_schema(cls):
return io.Schema(
node_id="WanCameraImageToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("start_image", optional=True),
io.WanCameraEmbedding.Input("camera_conditions", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None):
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
@ -333,9 +422,12 @@ class WanCameraImageToVideo:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
mask[:, :, :start_image.shape[0] + 3] = 0.0
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask})
if camera_conditions is not None:
positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions})
@ -347,29 +439,34 @@ class WanCameraImageToVideo:
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
return io.NodeOutput(positive, negative, out_latent)
class WanPhantomSubjectToVideo:
class WanPhantomSubjectToVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"images": ("IMAGE", ),
}}
def define_schema(cls):
return io.Schema(
node_id="WanPhantomSubjectToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("images", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative_text"),
io.Conditioning.Output(display_name="negative_img_text"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, images):
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, images) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
cond2 = negative
if images is not None:
@ -385,7 +482,7 @@ class WanPhantomSubjectToVideo:
out_latent = {}
out_latent["samples"] = latent
return (positive, cond2, negative, out_latent)
return io.NodeOutput(positive, cond2, negative, out_latent)
def parse_json_tracks(tracks):
"""Parse JSON track data into a standardized format"""
@ -598,39 +695,41 @@ def patch_motion(
return out_mask_full, out_feature_full
class WanTrackToVideo:
class WanTrackToVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"tracks": ("STRING", {"multiline": True, "default": "[]"}),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
"topk": ("INT", {"default": 2, "min": 1, "max": 10}),
"start_image": ("IMAGE", ),
},
"optional": {
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
}}
def define_schema(cls):
return io.Schema(
node_id="WanTrackToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.String.Input("tracks", multiline=True, default="[]"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Float.Input("temperature", default=220.0, min=1.0, max=1000.0, step=0.1),
io.Int.Input("topk", default=2, min=1, max=10),
io.Image.Input("start_image"),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, tracks, width, height, length, batch_size,
temperature, topk, start_image=None, clip_vision_output=None):
@classmethod
def execute(cls, positive, negative, vae, tracks, width, height, length, batch_size,
temperature, topk, start_image=None, clip_vision_output=None) -> io.NodeOutput:
tracks_data = parse_json_tracks(tracks)
if not tracks_data:
return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output)
return WanImageToVideo().execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output)
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
device=comfy.model_management.intermediate_device())
@ -684,34 +783,36 @@ class WanTrackToVideo:
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
return io.NodeOutput(positive, negative, out_latent)
class Wan22ImageToVideoLatent:
class Wan22ImageToVideoLatent(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"vae": ("VAE", ),
"width": ("INT", {"default": 1280, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 704, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 49, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"start_image": ("IMAGE", ),
}}
def define_schema(cls):
return io.Schema(
node_id="Wan22ImageToVideoLatent",
category="conditioning/inpaint",
inputs=[
io.Vae.Input("vae"),
io.Int.Input("width", default=1280, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=704, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("start_image", optional=True),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, vae, width, height, length, batch_size, start_image=None):
@classmethod
def execute(cls, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
if start_image is None:
out_latent = {}
out_latent["samples"] = latent
return (out_latent,)
return io.NodeOutput(out_latent)
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
@ -726,18 +827,25 @@ class Wan22ImageToVideoLatent:
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return (out_latent,)
return io.NodeOutput(out_latent)
NODE_CLASS_MAPPINGS = {
"WanTrackToVideo": WanTrackToVideo,
"WanImageToVideo": WanImageToVideo,
"WanFunControlToVideo": WanFunControlToVideo,
"WanFunInpaintToVideo": WanFunInpaintToVideo,
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
"WanVaceToVideo": WanVaceToVideo,
"TrimVideoLatent": TrimVideoLatent,
"WanCameraImageToVideo": WanCameraImageToVideo,
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo,
"Wan22ImageToVideoLatent": Wan22ImageToVideoLatent,
}
class WanExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
WanTrackToVideo,
WanImageToVideo,
WanFunControlToVideo,
Wan22FunControlToVideo,
WanFunInpaintToVideo,
WanFirstLastFrameToVideo,
WanVaceToVideo,
TrimVideoLatent,
WanCameraImageToVideo,
WanPhantomSubjectToVideo,
Wan22ImageToVideoLatent,
]
async def comfy_entrypoint() -> WanExtension:
return WanExtension()

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.49"
__version__ = "0.3.50"

View File

@ -275,7 +275,10 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
def get_full_path(folder_name: str, filename: str, allow_missing: bool = False) -> str | None:
def get_full_path(folder_name: str, filename: str) -> str | None:
"""
Get the full path of a file in a folder, has to be a file
"""
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in folder_names_and_paths:
@ -288,8 +291,6 @@ def get_full_path(folder_name: str, filename: str, allow_missing: bool = False)
return full_path
elif os.path.islink(full_path):
logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path))
elif allow_missing:
return full_path
return None
@ -304,27 +305,6 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str:
return full_path
def get_relative_path(full_path: str) -> tuple[str, str] | None:
"""Convert a full path back to a type-relative path.
Args:
full_path: The full path to the file
Returns:
tuple[str, str] | None: A tuple of (model_type, relative_path) if found, None otherwise
"""
global folder_names_and_paths
full_path = os.path.normpath(full_path)
for model_type, (paths, _) in folder_names_and_paths.items():
for base_path in paths:
base_path = os.path.normpath(base_path)
if full_path.startswith(base_path):
relative_path = os.path.relpath(full_path, base_path)
return model_type, relative_path
return None
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name)
global folder_names_and_paths

View File

@ -164,6 +164,7 @@ def cuda_malloc_warning():
if cuda_malloc_warning:
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server_instance):
current_time: float = 0.0
cache_type = execution.CacheType.CLASSIC
@ -278,6 +279,7 @@ def cleanup_temp():
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
def setup_database():
try:
from app.database.db import init_db, dependencies_available
@ -286,6 +288,7 @@ def setup_database():
except Exception as e:
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
def start_comfyui(asyncio_loop=None):
"""
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.

View File

@ -2320,6 +2320,8 @@ async def init_builtin_extra_nodes():
"nodes_camera_trajectory.py",
"nodes_edit_model.py",
"nodes_tcfg.py",
"nodes_context_windows.py",
"nodes_memory_reserve.py",
]
import_failed = []

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.49"
version = "0.3.50"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.24.4
comfyui-workflow-templates==0.1.53
comfyui-frontend-package==1.25.9
comfyui-workflow-templates==0.1.60
comfyui-embedded-docs==0.2.6
torch
torchsde
@ -20,12 +20,11 @@ tqdm
psutil
alembic
SQLAlchemy
blake3
av>=14.2.0
#non essential dependencies:
kornia>=0.7.1
spandrel
soundfile
av>=14.2.0
pydantic~=2.0
pydantic-settings~=2.0

View File

@ -1,253 +0,0 @@
import pytest
from unittest.mock import patch, MagicMock
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.model_processor import ModelProcessor
from app.database.models import Model, Base
import os
# Test data constants
TEST_MODEL_TYPE = "checkpoints"
TEST_URL = "http://example.com/model.safetensors"
TEST_FILE_NAME = "model.safetensors"
TEST_EXPECTED_HASH = "abc123"
TEST_DESTINATION_PATH = "/path/to/model.safetensors"
def create_test_model(session, file_name, model_type, hash_value, file_size=1000, source_url=None):
"""Helper to create a test model in the database."""
model = Model(path=file_name, type=model_type, hash=hash_value, file_size=file_size, source_url=source_url)
session.add(model)
session.commit()
return model
def setup_mock_hash_calculation(model_processor, hash_value):
"""Helper to setup hash calculation mocks."""
mock_hash = MagicMock()
mock_hash.hexdigest.return_value = hash_value
return patch.object(model_processor, "_get_hasher", return_value=mock_hash)
def verify_model_in_db(session, file_name, expected_hash=None, expected_type=None):
"""Helper to verify model exists in database with correct attributes."""
db_model = session.query(Model).filter_by(path=file_name).first()
assert db_model is not None
if expected_hash:
assert db_model.hash == expected_hash
if expected_type:
assert db_model.type == expected_type
return db_model
@pytest.fixture
def db_engine():
# Configure in-memory database
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
@pytest.fixture
def db_session(db_engine):
Session = sessionmaker(bind=db_engine)
session = Session()
yield session
session.close()
@pytest.fixture
def mock_get_relative_path():
with patch("app.model_processor.get_relative_path") as mock:
mock.side_effect = lambda path: (TEST_MODEL_TYPE, os.path.basename(path))
yield mock
@pytest.fixture
def mock_get_full_path():
with patch("app.model_processor.get_full_path") as mock:
mock.return_value = TEST_DESTINATION_PATH
yield mock
@pytest.fixture
def model_processor(db_session, mock_get_relative_path, mock_get_full_path):
with patch("app.model_processor.create_session", return_value=db_session):
with patch("app.model_processor.can_create_session", return_value=True):
processor = ModelProcessor()
# Setup test state
processor.removed_files = []
processor.downloaded_files = []
processor.file_exists = {}
def mock_download_file(url, destination_path, hasher):
processor.downloaded_files.append((url, destination_path))
processor.file_exists[destination_path] = True
# Simulate writing some data to the file
test_data = b"test data"
hasher.update(test_data)
def mock_remove_file(file_path):
processor.removed_files.append(file_path)
if file_path in processor.file_exists:
del processor.file_exists[file_path]
# Setup common patches
file_exists_patch = patch.object(
processor,
"_file_exists",
side_effect=lambda path: processor.file_exists.get(path, False),
)
file_size_patch = patch.object(
processor,
"_get_file_size",
side_effect=lambda path: (
1000 if processor.file_exists.get(path, False) else 0
),
)
download_file_patch = patch.object(
processor, "_download_file", side_effect=mock_download_file
)
remove_file_patch = patch.object(
processor, "_remove_file", side_effect=mock_remove_file
)
with (
file_exists_patch,
file_size_patch,
download_file_patch,
remove_file_patch,
):
yield processor
def test_ensure_downloaded_invalid_extension(model_processor):
# Ensure that an unsupported file extension raises an error to prevent unsafe file downloads
with pytest.raises(ValueError, match="Unsupported unsafe file for download"):
model_processor.ensure_downloaded(TEST_MODEL_TYPE, TEST_URL, "model.exe")
def test_ensure_downloaded_existing_file_with_hash(model_processor, db_session):
# Ensure that a file with the same hash but from a different source is not downloaded again
SOURCE_URL = "https://example.com/other.sft"
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH, source_url=SOURCE_URL)
model_processor.file_exists[TEST_DESTINATION_PATH] = True
result = model_processor.ensure_downloaded(
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
)
assert result == TEST_DESTINATION_PATH
model = verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
assert model.source_url == SOURCE_URL # Ensure the source URL is not overwritten
def test_ensure_downloaded_existing_file_hash_mismatch(model_processor, db_session):
# Ensure that a file with a different hash raises an error
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, "different_hash")
model_processor.file_exists[TEST_DESTINATION_PATH] = True
with pytest.raises(ValueError, match="File .* exists with hash .* but expected .*"):
model_processor.ensure_downloaded(
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
)
def test_ensure_downloaded_new_file(model_processor, db_session):
# Ensure that a new file is downloaded
model_processor.file_exists[TEST_DESTINATION_PATH] = False
with setup_mock_hash_calculation(model_processor, TEST_EXPECTED_HASH):
result = model_processor.ensure_downloaded(
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
)
assert result == TEST_DESTINATION_PATH
assert len(model_processor.downloaded_files) == 1
assert model_processor.downloaded_files[0] == (TEST_URL, TEST_DESTINATION_PATH)
assert model_processor.file_exists[TEST_DESTINATION_PATH]
verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
def test_ensure_downloaded_hash_mismatch(model_processor, db_session):
# Ensure that download that results in a different hash raises an error
model_processor.file_exists[TEST_DESTINATION_PATH] = False
with setup_mock_hash_calculation(model_processor, "different_hash"):
with pytest.raises(
ValueError,
match="Downloaded file hash .* does not match expected hash .*",
):
model_processor.ensure_downloaded(
TEST_MODEL_TYPE,
TEST_URL,
TEST_FILE_NAME,
TEST_EXPECTED_HASH,
)
assert len(model_processor.removed_files) == 1
assert model_processor.removed_files[0] == TEST_DESTINATION_PATH
assert TEST_DESTINATION_PATH not in model_processor.file_exists
assert db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() is None
def test_process_file_without_hash(model_processor, db_session):
# Test processing file without provided hash
model_processor.file_exists[TEST_DESTINATION_PATH] = True
with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH):
result = model_processor.process_file(TEST_DESTINATION_PATH)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
def test_retrieve_model_by_hash(model_processor, db_session):
# Test retrieving model by hash
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
def test_retrieve_model_by_hash_and_type(model_processor, db_session):
# Test retrieving model by hash and type
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
assert result.type == TEST_MODEL_TYPE
def test_retrieve_hash(model_processor, db_session):
# Test retrieving hash for existing model
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
with patch.object(
model_processor,
"_validate_path",
return_value=(TEST_MODEL_TYPE, TEST_FILE_NAME),
):
result = model_processor.retrieve_hash(TEST_DESTINATION_PATH, TEST_MODEL_TYPE)
assert result == TEST_EXPECTED_HASH
def test_validate_file_extension_valid_extensions(model_processor):
# Test all valid file extensions
valid_extensions = [".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"]
for ext in valid_extensions:
model_processor._validate_file_extension(f"test{ext}") # Should not raise
def test_process_file_existing_without_source_url(model_processor, db_session):
# Test processing an existing file that needs its source URL updated
model_processor.file_exists[TEST_DESTINATION_PATH] = True
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
assert result.source_url == TEST_URL
db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first()
assert db_model.source_url == TEST_URL