Compare commits

..

1 Commits

Author SHA1 Message Date
1d47ec38d8 Set torch version to be 2.3.1 for v0.0.3 2024-07-26 18:54:29 -07:00
598 changed files with 47827 additions and 727891 deletions

View File

@ -28,17 +28,17 @@ def pull(repo, remote_name='origin', branch='master'):
if repo.index.conflicts is not None:
for conflict in repo.index.conflicts:
print('Conflicts found in:', conflict[0].path) # noqa: T201
print('Conflicts found in:', conflict[0].path)
raise AssertionError('Conflicts, ahhhhh!!')
user = repo.default_signature
tree = repo.index.write_tree()
repo.create_commit('HEAD',
user,
user,
'Merge!',
tree,
[repo.head.target, remote_master_id])
commit = repo.create_commit('HEAD',
user,
user,
'Merge!',
tree,
[repo.head.target, remote_master_id])
# We need to do this or git CLI will think we are still merging.
repo.state_cleanup()
else:
@ -49,57 +49,26 @@ repo_path = str(sys.argv[1])
repo = pygit2.Repository(repo_path)
ident = pygit2.Signature('comfyui', 'comfy@ui')
try:
print("stashing current changes") # noqa: T201
print("stashing current changes")
repo.stash(ident)
except KeyError:
print("nothing to stash") # noqa: T201
print("nothing to stash")
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
print("creating backup branch: {}".format(backup_branch_name))
try:
repo.branches.local.create(backup_branch_name, repo.head.peel())
except:
pass
print("checking out master branch") # noqa: T201
print("checking out master branch")
branch = repo.lookup_branch('master')
if branch is None:
try:
ref = repo.lookup_reference('refs/remotes/origin/master')
except:
print("pulling.") # noqa: T201
pull(repo)
ref = repo.lookup_reference('refs/remotes/origin/master')
repo.checkout(ref)
branch = repo.lookup_branch('master')
if branch is None:
repo.create_branch('master', repo.get(ref.target))
else:
ref = repo.lookup_reference(branch.name)
repo.checkout(ref)
ref = repo.lookup_reference(branch.name)
repo.checkout(ref)
print("pulling latest changes") # noqa: T201
print("pulling latest changes")
pull(repo)
if "--stable" in sys.argv:
def latest_tag(repo):
versions = []
for k in repo.references:
try:
prefix = "refs/tags/v"
if k.startswith(prefix):
version = list(map(int, k[len(prefix):].split(".")))
versions.append((version[0] * 10000000000 + version[1] * 100000 + version[2], k))
except:
pass
versions.sort()
if len(versions) > 0:
return versions[-1][1]
return None
latest_tag = latest_tag(repo)
if latest_tag is not None:
repo.checkout(latest_tag)
print("Done!") # noqa: T201
print("Done!")
self_update = True
if len(sys.argv) > 2:
@ -139,13 +108,3 @@ if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
shutil.copy(repo_req_path, req_path)
except:
pass
stable_update_script = os.path.join(repo_path, ".ci/update_windows/update_comfyui_stable.bat")
stable_update_script_to = os.path.join(cur_path, "update_comfyui_stable.bat")
try:
if not file_size(stable_update_script_to) > 10:
shutil.copy(stable_update_script, stable_update_script_to)
except:
pass

View File

@ -1,8 +0,0 @@
@echo off
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --stable
if exist update_new.py (
move /y update_new.py update.py
echo Running updater again since it got updated.
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update --stable
)
if "%~1"=="" pause

View File

@ -1,27 +0,0 @@
As of the time of writing this you need this preview driver for best results:
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html
HOW TO RUN:
If you have a AMD gpu:
run_amd_gpu.bat
If you have memory issues you can try disabling the smart memory management by running comfyui with:
run_amd_gpu_disable_smart_memory.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors
RECOMMENDED WAY TO UPDATE:
To update the ComfyUI code: update\update_comfyui.bat
TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.

View File

@ -1,2 +0,0 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
pause

View File

@ -4,9 +4,6 @@ if you have a NVIDIA gpu:
run_nvidia_gpu.bat
if you want to enable the fast fp16 accumulation (faster for fp16 models with slightly less quality):
run_nvidia_gpu_fast_fp16_accumulation.bat
To run it in slow CPU mode:
@ -17,7 +14,7 @@ run_cpu.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
RECOMMENDED WAY TO UPDATE:

View File

@ -1,2 +0,0 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
pause

View File

@ -1,3 +0,0 @@
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@ -1,3 +0,0 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@ -1,3 +0,0 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

3
.gitattributes vendored
View File

@ -1,3 +0,0 @@
/web/assets/** linguist-generated
/web/** linguist-vendored
comfy_api_nodes/apis/__init__.py linguist-generated

View File

@ -8,23 +8,13 @@ body:
Before submitting a **Bug Report**, please ensure the following:
- **1:** You are running the latest version of ComfyUI.
- **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report.
- **2:** You have looked at the existing bug reports and made sure this isn't already reported.
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
`--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version.
`--disable-all-custom-nodes` command line argument.
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
## Very Important
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
- type: checkboxes
id: custom-nodes-test
attributes:
label: Custom Node Testing
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
options:
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
required: false
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
- type: textarea
attributes:
label: Expected Behavior

View File

@ -1,8 +1,5 @@
blank_issues_enabled: true
contact_links:
- name: ComfyUI Frontend Issues
url: https://github.com/Comfy-Org/ComfyUI_frontend/issues
about: Issues related to the ComfyUI frontend (display issues, user interaction bugs), please go to the frontend repo to file the issue
- name: ComfyUI Matrix Space
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).

View File

@ -11,14 +11,6 @@ body:
**2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics.
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
- type: checkboxes
id: custom-nodes-test
attributes:
label: Custom Node Testing
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
options:
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
required: false
- type: textarea
attributes:
label: Your question

View File

@ -1,21 +0,0 @@
<!-- API_NODE_PR_CHECKLIST: do not remove -->
## API Node PR Checklist
### Scope
- [ ] **Is API Node Change**
### Pricing & Billing
- [ ] **Need pricing update**
- [ ] **No pricing update**
If **Need pricing update**:
- [ ] Metronome rate cards updated
- [ ] Autobilling tests updated and passing
### QA
- [ ] **QA done**
- [ ] **QA not required**
### Comms
- [ ] Informed **Kosinkadink**

View File

@ -1,58 +0,0 @@
name: Append API Node PR template
on:
pull_request_target:
types: [opened, reopened, synchronize, ready_for_review]
paths:
- 'comfy_api_nodes/**' # only run if these files changed
permissions:
contents: read
pull-requests: write
jobs:
inject:
runs-on: ubuntu-latest
steps:
- name: Ensure template exists and append to PR body
uses: actions/github-script@v7
with:
script: |
const { owner, repo } = context.repo;
const number = context.payload.pull_request.number;
const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
const marker = '<!-- API_NODE_PR_CHECKLIST: do not remove -->';
const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
let templateText;
try {
const res = await github.rest.repos.getContent({
owner,
repo,
path: templatePath,
ref: pr.base.ref
});
const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
templateText = buf.toString('utf8');
} catch (e) {
core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
return;
}
// Enforce the presence of the marker inside the template (for idempotence)
if (!templateText.includes(marker)) {
core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
return;
}
// If the PR already contains the marker, do not append again.
const body = pr.body || '';
if (body.includes(marker)) {
core.info('Template already present in PR body; nothing to inject.');
return;
}
const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
core.notice('API Node template appended to PR description.');

View File

@ -1,40 +0,0 @@
name: Check for Windows Line Endings
on:
pull_request:
branches: ['*'] # Trigger on all pull requests to any branch
jobs:
check-line-endings:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch all history to compare changes
- name: Check for Windows line endings (CRLF)
run: |
# Get the list of changed files in the PR
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
# Flag to track if CRLF is found
CRLF_FOUND=false
# Loop through each changed file
for FILE in $CHANGED_FILES; do
# Check if the file exists and is a text file
if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then
# Check for CRLF line endings
if grep -UP '\r$' "$FILE"; then
echo "Error: Windows line endings (CRLF) detected in $FILE"
CRLF_FOUND=true
fi
fi
done
# Exit with error if CRLF was found
if [ "$CRLF_FOUND" = true ]; then
exit 1
fi

View File

@ -1,53 +0,0 @@
# This is the GitHub Workflow that drives full-GPU-enabled tests of pull requests to ComfyUI, when the 'Run-CI-Test' label is added
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
name: Pull Request CI Workflow Runs
on:
pull_request_target:
types: [labeled]
jobs:
pr-test-stable:
if: ${{ github.event.label.name == 'Run-CI-Test' }}
strategy:
fail-fast: false
matrix:
os: [macos, linux, windows]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, Windows]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
use_prior_commit: 'true'
comment:
if: ${{ github.event.label.name == 'Run-CI-Test' }}
runs-on: ubuntu-latest
permissions:
pull-requests: write
steps:
- uses: actions/github-script@v6
with:
script: |
github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge'
})

23
.github/workflows/pylint.yml vendored Normal file
View File

@ -0,0 +1,23 @@
name: Python Linting
on: [push, pull_request]
jobs:
pylint:
name: Run Pylint
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.x
- name: Install Pylint
run: pip install pylint
- name: Run Pylint
run: pylint --rcfile=.pylintrc $(find . -type f -name "*.py")

View File

@ -1,78 +0,0 @@
name: "Release Stable All Portable Versions"
on:
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
jobs:
release_nvidia_default:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA Default (cu130)"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu130"
python_minor: "13"
python_patch: "9"
rel_name: "nvidia"
rel_extra_name: ""
test_release: true
secrets: inherit
release_nvidia_cu128:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA cu128"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu128"
python_minor: "12"
python_patch: "10"
rel_name: "nvidia"
rel_extra_name: "_cu128"
test_release: true
secrets: inherit
release_nvidia_cu126:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA cu126"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu126"
python_minor: "12"
python_patch: "10"
rel_name: "nvidia"
rel_extra_name: "_cu126"
test_release: true
secrets: inherit
release_amd_rocm:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release AMD ROCm 6.4.4"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "rocm644"
python_minor: "12"
python_patch: "10"
rel_name: "amd"
rel_extra_name: ""
test_release: false
secrets: inherit

View File

@ -1,108 +0,0 @@
name: Release Webhook
on:
release:
types: [published]
jobs:
send-webhook:
runs-on: ubuntu-latest
steps:
- name: Send release webhook
env:
WEBHOOK_URL: ${{ secrets.RELEASE_GITHUB_WEBHOOK_URL }}
WEBHOOK_SECRET: ${{ secrets.RELEASE_GITHUB_WEBHOOK_SECRET }}
run: |
# Generate UUID for delivery ID
DELIVERY_ID=$(uuidgen)
HOOK_ID="release-webhook-$(date +%s)"
# Create webhook payload matching GitHub release webhook format
PAYLOAD=$(cat <<EOF
{
"action": "published",
"release": {
"id": ${{ github.event.release.id }},
"node_id": "${{ github.event.release.node_id }}",
"url": "${{ github.event.release.url }}",
"html_url": "${{ github.event.release.html_url }}",
"assets_url": "${{ github.event.release.assets_url }}",
"upload_url": "${{ github.event.release.upload_url }}",
"tag_name": "${{ github.event.release.tag_name }}",
"target_commitish": "${{ github.event.release.target_commitish }}",
"name": ${{ toJSON(github.event.release.name) }},
"body": ${{ toJSON(github.event.release.body) }},
"draft": ${{ github.event.release.draft }},
"prerelease": ${{ github.event.release.prerelease }},
"created_at": "${{ github.event.release.created_at }}",
"published_at": "${{ github.event.release.published_at }}",
"author": {
"login": "${{ github.event.release.author.login }}",
"id": ${{ github.event.release.author.id }},
"node_id": "${{ github.event.release.author.node_id }}",
"avatar_url": "${{ github.event.release.author.avatar_url }}",
"url": "${{ github.event.release.author.url }}",
"html_url": "${{ github.event.release.author.html_url }}",
"type": "${{ github.event.release.author.type }}",
"site_admin": ${{ github.event.release.author.site_admin }}
},
"tarball_url": "${{ github.event.release.tarball_url }}",
"zipball_url": "${{ github.event.release.zipball_url }}",
"assets": ${{ toJSON(github.event.release.assets) }}
},
"repository": {
"id": ${{ github.event.repository.id }},
"node_id": "${{ github.event.repository.node_id }}",
"name": "${{ github.event.repository.name }}",
"full_name": "${{ github.event.repository.full_name }}",
"private": ${{ github.event.repository.private }},
"owner": {
"login": "${{ github.event.repository.owner.login }}",
"id": ${{ github.event.repository.owner.id }},
"node_id": "${{ github.event.repository.owner.node_id }}",
"avatar_url": "${{ github.event.repository.owner.avatar_url }}",
"url": "${{ github.event.repository.owner.url }}",
"html_url": "${{ github.event.repository.owner.html_url }}",
"type": "${{ github.event.repository.owner.type }}",
"site_admin": ${{ github.event.repository.owner.site_admin }}
},
"html_url": "${{ github.event.repository.html_url }}",
"clone_url": "${{ github.event.repository.clone_url }}",
"git_url": "${{ github.event.repository.git_url }}",
"ssh_url": "${{ github.event.repository.ssh_url }}",
"url": "${{ github.event.repository.url }}",
"created_at": "${{ github.event.repository.created_at }}",
"updated_at": "${{ github.event.repository.updated_at }}",
"pushed_at": "${{ github.event.repository.pushed_at }}",
"default_branch": "${{ github.event.repository.default_branch }}",
"fork": ${{ github.event.repository.fork }}
},
"sender": {
"login": "${{ github.event.sender.login }}",
"id": ${{ github.event.sender.id }},
"node_id": "${{ github.event.sender.node_id }}",
"avatar_url": "${{ github.event.sender.avatar_url }}",
"url": "${{ github.event.sender.url }}",
"html_url": "${{ github.event.sender.html_url }}",
"type": "${{ github.event.sender.type }}",
"site_admin": ${{ github.event.sender.site_admin }}
}
}
EOF
)
# Generate HMAC-SHA256 signature
SIGNATURE=$(echo -n "$PAYLOAD" | openssl dgst -sha256 -hmac "$WEBHOOK_SECRET" -hex | cut -d' ' -f2)
# Send webhook with required headers
curl -X POST "$WEBHOOK_URL" \
-H "Content-Type: application/json" \
-H "X-GitHub-Event: release" \
-H "X-GitHub-Delivery: $DELIVERY_ID" \
-H "X-GitHub-Hook-ID: $HOOK_ID" \
-H "X-Hub-Signature-256: sha256=$SIGNATURE" \
-H "User-Agent: GitHub-Actions-Webhook/1.0" \
-d "$PAYLOAD" \
--fail --silent --show-error
echo "✅ Release webhook sent successfully"

View File

@ -1,48 +0,0 @@
name: Python Linting
on: [push, pull_request]
jobs:
ruff:
name: Run Ruff
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.x
- name: Install Ruff
run: pip install ruff
- name: Run Ruff
run: ruff check .
pylint:
name: Run Pylint
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Install Pylint
run: pip install pylint
- name: Run Pylint
run: pylint comfy_api_nodes

View File

@ -2,78 +2,9 @@
name: "Release Stable Version"
on:
workflow_call:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu129"
python_minor:
description: 'Python minor version'
required: true
type: string
default: "13"
python_patch:
description: 'Python patch version'
required: true
type: string
default: "6"
rel_name:
description: 'Release name'
required: true
type: string
default: "nvidia"
rel_extra_name:
description: 'Release extra name'
required: false
type: string
default: ""
test_release:
description: 'Test Release'
required: true
type: boolean
default: true
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu129"
python_minor:
description: 'Python minor version'
required: true
type: string
default: "13"
python_patch:
description: 'Python patch version'
required: true
type: string
default: "6"
rel_name:
description: 'Release name'
required: true
type: string
default: "nvidia"
rel_extra_name:
description: 'Release extra name'
required: false
type: string
default: ""
test_release:
description: 'Test Release'
required: true
type: boolean
default: true
push:
tags:
- 'v*'
jobs:
package_comfy_windows:
@ -82,57 +13,70 @@ jobs:
packages: "write"
pull-requests: "read"
runs-on: windows-latest
strategy:
matrix:
python_version: [3.11.8]
cuda_version: [121]
steps:
- name: Calculate Minor Version
shell: bash
run: |
# Extract the minor version from the Python version
MINOR_VERSION=$(echo "${{ matrix.python_version }}" | cut -d'.' -f2)
echo "MINOR_VERSION=$MINOR_VERSION" >> $GITHUB_ENV
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
- uses: actions/checkout@v4
with:
ref: ${{ inputs.git_tag }}
fetch-depth: 150
fetch-depth: 0
persist-credentials: false
- uses: actions/cache/restore@v4
id: cache
with:
path: |
${{ inputs.cache_tag }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
- shell: bash
run: |
mv ${{ inputs.cache_tag }}_python_deps.tar ../
echo "@echo off
call update_comfyui.bat nopause
echo -
echo This will try to update pytorch and all python dependencies.
echo -
echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo -
pause
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
python -m pip wheel --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
mv temp_wheel_dir cu${{ matrix.cuda_version }}_python_deps
mv cu${{ matrix.cuda_version }}_python_deps ../
mv update_comfyui_and_python_dependencies.bat ../
cd ..
tar xf ${{ inputs.cache_tag }}_python_deps.tar
pwd
ls
- shell: bash
run: |
cd ..
cp -r ComfyUI ComfyUI_copy
curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip
curl https://www.python.org/ftp/python/${{ matrix.python_version }}/python-${{ matrix.python_version }}-embed-amd64.zip -o python_embeded.zip
unzip python_embeded.zip -d python_embeded
cd python_embeded
echo ${{ env.MINOR_VERSION }}
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
echo 'import site' >> ./python3${{ env.MINOR_VERSION }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
./python.exe -s -m pip install -r requirements_comfyui.txt
rm requirements_comfyui.txt
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
fi
./python.exe --version
echo "Pip version:"
./python.exe -m pip --version
set PATH=$PWD/Scripts:$PATH
echo $PATH
./python.exe -s -m pip install ../cu${{ matrix.cuda_version }}_python_deps/*
sed -i '1i../ComfyUI' ./python3${{ env.MINOR_VERSION }}._pth
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
git clone https://github.com/comfyanonymous/taesd
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable
mv python_embeded ComfyUI_windows_portable
@ -142,29 +86,24 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
cp -r ComfyUI/.ci/windows_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
- shell: bash
if: ${{ inputs.test_release }}
run: |
cd ..
cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
python_embeded/python.exe -s ./update/update.py ComfyUI/
ls
- name: Upload binaries to release
uses: softprops/action-gh-release@v2
uses: svenstaro/upload-release-action@v2
with:
files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
tag_name: ${{ inputs.git_tag }}
draft: true
overwrite_files: true
repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ComfyUI_windows_portable_nvidia.7z
tag: ${{ github.ref }}
overwrite: true

View File

@ -1,21 +0,0 @@
name: 'Close stale issues'
on:
schedule:
# Run daily at 430 am PT
- cron: '30 11 * * *'
permissions:
issues: write
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v9
with:
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
days-before-stale: 30
days-before-close: 7
stale-issue-label: 'Stale'
only-labels: 'User Support'
exempt-all-assignees: true
exempt-all-milestones: true

76
.github/workflows/test-browser.yml vendored Normal file
View File

@ -0,0 +1,76 @@
# This is a temporary action during frontend TS migration.
# This file should be removed after TS migration is completed.
# The browser test is here to ensure TS repo is working the same way as the
# current JS code.
# If you are adding UI feature, please sync your changes to the TS repo:
# huchenlei/ComfyUI_frontend and update test expectation files accordingly.
name: Playwright Browser Tests CI
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout ComfyUI
uses: actions/checkout@v4
with:
repository: "comfyanonymous/ComfyUI"
path: "ComfyUI"
- name: Checkout ComfyUI_frontend
uses: actions/checkout@v4
with:
repository: "huchenlei/ComfyUI_frontend"
path: "ComfyUI_frontend"
ref: "fcc54d803e5b6a9b08a462a1d94899318c96dcbb"
- uses: actions/setup-node@v3
with:
node-version: lts/*
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install wait-for-it
working-directory: ComfyUI
- name: Start ComfyUI server
run: |
python main.py --cpu 2>&1 | tee console_output.log &
wait-for-it --service 127.0.0.1:8188 -t 600
working-directory: ComfyUI
- name: Install ComfyUI_frontend dependencies
run: |
npm ci
working-directory: ComfyUI_frontend
- name: Install Playwright Browsers
run: npx playwright install --with-deps
working-directory: ComfyUI_frontend
- name: Run Playwright tests
run: npx playwright test
working-directory: ComfyUI_frontend
- name: Check for unhandled exceptions in server log
run: |
if grep -qE "Exception|Error" console_output.log; then
echo "Unhandled exception/error found in server log."
exit 1
fi
working-directory: ComfyUI
- uses: actions/upload-artifact@v4
if: always()
with:
name: playwright-report
path: ComfyUI_frontend/playwright-report/
retention-days: 30
- uses: actions/upload-artifact@v4
if: always()
with:
name: console-output
path: ComfyUI/console_output.log
retention-days: 30

View File

@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
@ -28,4 +28,4 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements.txt

View File

@ -1,98 +0,0 @@
# This is the GitHub Workflow that drives automatic full-GPU-enabled tests of all new commits to the master branch of ComfyUI
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
name: Full Comfy CI Workflow Runs
on:
push:
branches:
- master
paths-ignore:
- 'app/**'
- 'input/**'
- 'output/**'
- 'notebooks/**'
- 'script_examples/**'
- '.github/**'
- 'web/**'
workflow_dispatch:
jobs:
test-stable:
strategy:
fail-fast: false
matrix:
# os: [macos, linux, windows]
# os: [macos, linux]
os: [linux]
python_version: ["3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
# - os: macos
# runner_label: [self-hosted, macOS]
# flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
# - os: windows
# runner_label: [self-hosted, Windows]
# flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
# test-win-nightly:
# strategy:
# fail-fast: true
# matrix:
# os: [windows]
# python_version: ["3.9", "3.10", "3.11", "3.12"]
# cuda_version: ["12.1"]
# torch_version: ["nightly"]
# include:
# - os: windows
# runner_label: [self-hosted, Windows]
# flags: ""
# runs-on: ${{ matrix.runner_label }}
# steps:
# - name: Test Workflows
# uses: comfy-org/comfy-action@main
# with:
# os: ${{ matrix.os }}
# python_version: ${{ matrix.python_version }}
# torch_version: ${{ matrix.torch_version }}
# google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
# comfyui_flags: ${{ matrix.flags }}
test-unix-nightly:
strategy:
fail-fast: false
matrix:
# os: [macos, linux]
os: [linux]
python_version: ["3.11"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
# - os: macos
# runner_label: [self-hosted, macOS]
# flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}

View File

@ -1,30 +0,0 @@
name: Execution Tests
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
runs-on: ${{ matrix.os }}
continue-on-error: true
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install -r tests-unit/requirements.txt
- name: Run Execution Tests
run: |
python -m pytest tests/execution -v --skip-timing-checks

View File

@ -1,45 +0,0 @@
name: Test server launches without errors
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout ComfyUI
uses: actions/checkout@v4
with:
repository: "comfyanonymous/ComfyUI"
path: "ComfyUI"
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install wait-for-it
working-directory: ComfyUI
- name: Start ComfyUI server
run: |
python main.py --cpu 2>&1 | tee console_output.log &
wait-for-it --service 127.0.0.1:8188 -t 30
working-directory: ComfyUI
- name: Check for unhandled exceptions in server log
run: |
if grep -qE "Exception|Error" console_output.log; then
echo "Unhandled exception/error found in server log."
exit 1
fi
working-directory: ComfyUI
- uses: actions/upload-artifact@v4
if: always()
with:
name: console-output
path: ComfyUI/console_output.log
retention-days: 30

View File

@ -1,29 +1,29 @@
name: Unit Tests
name: Tests CI
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
on: [push, pull_request]
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, windows-2022, macos-latest]
runs-on: ${{ matrix.os }}
continue-on-error: true
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
- uses: actions/setup-node@v3
with:
python-version: '3.12'
node-version: 18
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Run Tests
run: |
npm ci
npm run test:generate
npm test -- --verbose
working-directory: ./tests-ui
- name: Run Unit Tests
run: |
pip install -r tests-unit/requirements.txt

View File

@ -1,56 +0,0 @@
name: Generate Pydantic Stubs from api.comfy.org
on:
schedule:
- cron: '0 0 * * 1'
workflow_dispatch:
jobs:
generate-models:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install 'datamodel-code-generator[http]'
npm install @redocly/cli
- name: Download OpenAPI spec
run: |
curl -o openapi.yaml https://api.comfy.org/openapi
- name: Filter OpenAPI spec with Redocly
run: |
npx @redocly/cli bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
- name: Generate API models
run: |
datamodel-codegen --use-subclass-enum --input filtered-openapi.yaml --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel
- name: Check for changes
id: git-check
run: |
git diff --exit-code comfy_api_nodes/apis || echo "changes=true" >> $GITHUB_OUTPUT
- name: Create Pull Request
if: steps.git-check.outputs.changes == 'true'
uses: peter-evans/create-pull-request@v5
with:
commit-message: 'chore: update API models from OpenAPI spec'
title: 'Update API models from api.comfy.org'
body: |
This PR updates the API models based on the latest api.comfy.org OpenAPI specification.
Generated automatically by the a Github workflow.
branch: update-api-stubs
delete-branch: true
base: master

View File

@ -1,58 +0,0 @@
name: Update Version File
on:
pull_request:
paths:
- "pyproject.toml"
branches:
- master
jobs:
update-version:
runs-on: ubuntu-latest
# Don't run on fork PRs
if: github.event.pull_request.head.repo.full_name == github.repository
permissions:
pull-requests: write
contents: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- name: Update comfyui_version.py
run: |
# Read version from pyproject.toml and update comfyui_version.py
python -c '
import tomllib
# Read version from pyproject.toml
with open("pyproject.toml", "rb") as f:
config = tomllib.load(f)
version = config["project"]["version"]
# Write version to comfyui_version.py
with open("comfyui_version.py", "w") as f:
f.write("# This file is automatically generated by the build process when version is\n")
f.write("# updated in pyproject.toml.\n")
f.write(f"__version__ = \"{version}\"\n")
'
- name: Commit changes
run: |
git config --local user.name "github-actions"
git config --local user.email "github-actions@github.com"
git fetch origin ${{ github.head_ref }}
git checkout -B ${{ github.head_ref }} origin/${{ github.head_ref }}
git add comfyui_version.py
git diff --quiet && git diff --staged --quiet || git commit -m "chore: Update comfyui_version.py to match pyproject.toml"
git push origin HEAD:${{ github.head_ref }}

View File

@ -8,28 +8,23 @@ on:
required: false
type: string
default: ""
extra_dependencies:
description: 'extra dependencies'
required: false
type: string
default: ""
cu:
description: 'cuda version'
required: true
type: string
default: "130"
default: "121"
python_minor:
description: 'python minor version'
required: true
type: string
default: "13"
default: "11"
python_patch:
description: 'python patch version'
required: true
type: string
default: "9"
default: "8"
# push:
# branches:
# - master
@ -56,8 +51,7 @@ jobs:
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir

View File

@ -1,64 +0,0 @@
name: "Windows Release dependencies Manual"
on:
workflow_dispatch:
inputs:
torch_dependencies:
description: 'torch dependencies'
required: false
type: string
default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128"
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu128"
python_minor:
description: 'python minor version'
required: true
type: string
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "10"
jobs:
build_dependencies:
runs-on: windows-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
- shell: bash
run: |
echo "@echo off
call update_comfyui.bat nopause
echo -
echo This will try to update pytorch and all python dependencies.
echo -
echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo -
pause
..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps
tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps
- uses: actions/cache/save@v4
with:
path: |
${{ inputs.cache_tag }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}

View File

@ -7,19 +7,19 @@ on:
description: 'cuda version'
required: true
type: string
default: "129"
default: "124"
python_minor:
description: 'python minor version'
required: true
type: string
default: "13"
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "5"
default: "3"
# push:
# branches:
# - master
@ -34,7 +34,7 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 30
fetch-depth: 0
persist-credentials: false
- uses: actions/setup-python@v5
with:
@ -49,16 +49,14 @@ jobs:
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
python -m pip wheel torch torchvision torchaudio mpmath==1.3.0 numpy==1.26.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
git clone https://github.com/comfyanonymous/taesd
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable_nightly_pytorch
mv python_embeded ComfyUI_windows_portable_nightly_pytorch
@ -68,15 +66,14 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
cp -r ComfyUI/.ci/windows_base_files/* ./
echo "call update_comfyui.bat nopause
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > ./update/update_comfyui_and_python_dependencies.bat
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
cd ComfyUI_windows_portable_nightly_pytorch

View File

@ -7,19 +7,19 @@ on:
description: 'cuda version'
required: true
type: string
default: "129"
default: "121"
python_minor:
description: 'python minor version'
required: true
type: string
default: "13"
default: "11"
python_patch:
description: 'python patch version'
required: true
type: string
default: "6"
default: "8"
# push:
# branches:
# - master
@ -50,7 +50,7 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 150
fetch-depth: 0
persist-credentials: false
- shell: bash
run: |
@ -64,14 +64,10 @@ jobs:
./python.exe get-pip.py
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
git clone https://github.com/comfyanonymous/taesd
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable
mv python_embeded ComfyUI_windows_portable
@ -81,19 +77,17 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp -r ComfyUI/.ci/windows_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
python_embeded/python.exe -s ./update/update.py ComfyUI/
ls
- name: Upload binaries to release

7
.gitignore vendored
View File

@ -12,15 +12,10 @@ extra_model_paths.yaml
.vscode/
.idea/
venv/
.venv/
/web/extensions/*
!/web/extensions/logging.js.example
!/web/extensions/core/
/tests-ui/data/object_info.json
/user/
*.log
web_custom_versions/
.DS_Store
openapi.yaml
filtered-openapi.yaml
uv.lock
web_custom_versions/

3
.pylintrc Normal file
View File

@ -0,0 +1,3 @@
[MESSAGES CONTROL]
disable=all
enable=eval-used

View File

@ -1,3 +1 @@
# Admins
* @comfyanonymous
* @kosinkadink
* @comfyanonymous

View File

@ -1,168 +0,0 @@
# The Comfy guide to Quantization
## How does quantization work?
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
```
absmax = max(abs(tensor))
scale = amax / max_dynamic_range_low_precision
# Quantization
tensor_q = (tensor / scale).to(low_precision_dtype)
# De-Quantization
tensor_dq = tensor_q.to(fp16) * scale
tensor_dq ~ tensor
```
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
## Quantization in Comfy
```
QuantizedTensor (torch.Tensor subclass)
↓ __torch_dispatch__
Two-Level Registry (generic + layout handlers)
MixedPrecisionOps + Metadata Detection
```
### Representation
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
A `Layout` class defines how a specific quantization format behaves:
- Required parameters
- Quantize method
- De-Quantize method
```python
from comfy.quant_ops import QuantizedLayout
class MyLayout(QuantizedLayout):
@classmethod
def quantize(cls, tensor, **kwargs):
# Convert to quantized format
qdata = ...
params = {'scale': ..., 'orig_dtype': tensor.dtype}
return qdata, params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
return qdata.to(orig_dtype) * scale
```
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
```python
from comfy.quant_ops import register_layout_op
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
def my_linear(func, args, kwargs):
# Extract tensors, call optimized kernel
...
```
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
### Mixed Precision
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
**Architecture:**
```python
class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {} # Maps layer names to quantization configs
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
```
**Key mechanism:**
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
- If the layer name **is** in `_layer_quant_config`:
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
- Load associated quantization parameters (scales, block_size, etc.)
**Why it's needed:**
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
## Checkpoint Format
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
The quantized checkpoint will contain the same layers as the original checkpoint but:
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
### Scaling Parameters details
We define 4 possible scaling parameters that should cover most recipes in the near-future:
- **weight_scale**: quantization scalers for the weights
- **weight_scale_2**: global scalers in the context of double scaling
- **pre_quant_scale**: scalers used for smoothing salient weights
- **input_scale**: quantization scalers for the activations
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|--------|---------------|--------------|----------------|-----------------|-------------|
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
### Quantization Metadata
The metadata stored alongside the checkpoint contains:
- **format_version**: String to define a version of the standard
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
Example:
```json
{
"_quantization_metadata": {
"format_version": "1.0",
"layers": {
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
}
}
}
```
## Creating Quantized Checkpoints
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
### Weight Quantization
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
### Calibration (for Activation Quantization)
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
1. **Collect statistics**: Run inference on N representative samples
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
3. **Compute scales**: Derive `input_scale` from collected statistics
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.

340
README.md
View File

@ -1,95 +1,22 @@
<div align="center">
ComfyUI
=======
The most powerful and modular stable diffusion GUI and backend.
-----------
![ComfyUI Screenshot](comfyui_screenshot.png)
# ComfyUI
**The most powerful and modular visual AI engine and application.**
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
[![Website][website-shield]][website-url]
[![Dynamic JSON Badge][discord-shield]][discord-url]
[![Twitter][twitter-shield]][twitter-url]
[![Matrix][matrix-shield]][matrix-url]
<br>
[![][github-release-shield]][github-release-link]
[![][github-release-date-shield]][github-release-link]
[![][github-downloads-shield]][github-downloads-link]
[![][github-downloads-latest-shield]][github-downloads-link]
[matrix-shield]: https://img.shields.io/badge/Matrix-000000?style=flat&logo=matrix&logoColor=white
[matrix-url]: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
[website-shield]: https://img.shields.io/badge/ComfyOrg-4285F4?style=flat
[website-url]: https://www.comfy.org/
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
[discord-url]: https://www.comfy.org/discord
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
[twitter-url]: https://x.com/ComfyUI
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
[github-release-date-shield]: https://img.shields.io/github/release-date/comfyanonymous/ComfyUI?style=flat
[github-downloads-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/total?style=flat
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe)
</div>
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
## Get Started
#### [Desktop Application](https://www.comfy.org/download)
- The easiest way to get started.
- Available on Windows & macOS.
#### [Windows Portable Package](#installing)
- Get the latest commits and completely portable.
- Available on Windows.
#### [Manual Install](#manual-install-windows-linux)
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
### [Installing ComfyUI](#installing)
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Image Models
- SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
- Pixart Alpha and Sigma
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
- [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
- Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- 3D Models
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run large models on GPUs with as low as 1GB vram with smart offloading.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
- Works even if you don't have a GPU with: ```--cpu``` (slow)
- Can load ckpt and safetensors: All in one checkpoints or standalone diffusion models, VAEs and CLIP models.
- Safe loading of ckpt, pt, pth, etc.. files.
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
- Embeddings/Textual inversion
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
@ -100,114 +27,72 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Works fully offline: core will never download anything unless you want to.
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
- Starts up very fast.
- Works fully offline: will never download anything.
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
## Release Process
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0) roughly every week.
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
- Serves as the foundation for the desktop release
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
- Builds a new release using the latest stable core version
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
- Weekly frontend updates are merged into the core repository
- Features are frozen for the upcoming core release
- Development continues for the next release cycle
## Shortcuts
| Keybind | Explanation |
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
| `Ctrl` + `Enter` | Queue up current graph for generation |
| `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
| `Ctrl` + `Alt` + `Enter` | Cancel current generation |
| `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
| `Ctrl` + `S` | Save workflow |
| `Ctrl` + `O` | Load workflow |
| `Ctrl` + `A` | Select all nodes |
| `Alt `+ `C` | Collapse/uncollapse selected nodes |
| `Ctrl` + `M` | Mute/unmute selected nodes |
| `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| `Delete`/`Backspace` | Delete selected nodes |
| `Ctrl` + `Backspace` | Delete the current graph |
| `Space` | Move the canvas around when held and moving the cursor |
| `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
| `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| `Shift` + `Drag` | Move multiple selected nodes at the same time |
| `Ctrl` + `D` | Load default graph |
| `Alt` + `+` | Canvas Zoom in |
| `Alt` + `-` | Canvas Zoom out |
| `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
| `P` | Pin/Unpin selected nodes |
| `Ctrl` + `G` | Group selected nodes |
| `Q` | Toggle visibility of the queue |
| `H` | Toggle visibility of history |
| `R` | Refresh graph |
| `F` | Show/Hide menu |
| `.` | Fit view to selection (Whole graph when nothing is selected) |
| Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + Z/Ctrl + Y | Undo/Redo |
| Ctrl + S | Save workflow |
| Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes |
| Alt + C | Collapse/uncollapse selected nodes |
| Ctrl + M | Mute/unmute selected nodes |
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| Delete/Backspace | Delete selected nodes |
| Ctrl + Backspace | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor |
| Ctrl/Shift + Click | Add clicked node to selection |
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| Shift + Drag | Move multiple selected nodes at the same time |
| Ctrl + D | Load default graph |
| Alt + `+` | Canvas Zoom in |
| Alt + `-` | Canvas Zoom out |
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
| Q | Toggle visibility of the queue |
| H | Toggle visibility of history |
| R | Refresh graph |
| Double-Click LMB | Open node quick search palette |
| `Shift` + Drag | Move multiple wires at once |
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
`Ctrl` can also be replaced with `Cmd` instead for macOS users
Ctrl can also be replaced with Cmd instead for macOS users
# Installing
## Windows Portable
## Windows
There is a portable standalone build for Windows that should work for running on Nvidia GPUs or for running on your CPU only on the [releases page](https://github.com/comfyanonymous/ComfyUI/releases).
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu121_or_cpu.7z)
Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
If you have trouble extracting it, right click the file -> properties -> unblock
Update your Nvidia drivers if it doesn't start.
#### Alternative Downloads:
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI?
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
## Jupyter Notebook
## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started)
You can install and start ComfyUI using comfy-cli:
```bash
pip install comfy-cli
comfy install
```
To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb)
## Manual Install (Windows, Linux)
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
### Instructions:
Git clone this repo.
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
@ -215,54 +100,24 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae
### AMD GPUs (Linux)
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware.
RDNA 3 (RX 7000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/```
RDNA 4 (RX 9000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/```
### Intel GPUs (Windows and Linux)
Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
1. To install PyTorch xpu, use the following command:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu```
This is the command to install the Pytorch xpu nightly which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
### NVIDIA
Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu130```
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
This is the command to install pytorch nightly instead which might have performance improvements.
This is the command to install pytorch nightly instead which might have performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124```
#### Troubleshooting
@ -282,6 +137,17 @@ After this you should have everything installed and can proceed to running Comfy
### Others:
#### Intel GPUs
Intel GPU support is available for all Intel GPUs supported by Intel's Extension for Pytorch (IPEX) with the support requirements listed in the [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) page. Choose your platform and method of install and follow the instructions. The steps are as follows:
1. Start by installing the drivers or kernel listed or newer in the Installation page of IPEX linked above for Windows and Linux if needed.
1. Follow the instructions to install [Intel's oneAPI Basekit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html) for your platform.
1. Install the packages for IPEX using the instructions provided in the Installation page for your platform.
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux and run ComfyUI normally as described above after everything is installed.
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
#### Apple Mac silicon
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
@ -293,29 +159,23 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
#### Ascend NPUs
#### DirectML (AMD Cards on Windows)
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
1. Begin by installing the recommended or newer kernel version for Linux as specified in the Installation page of torch-npu, if necessary.
2. Proceed with the installation of Ascend Basekit, which includes the driver, firmware, and CANN, following the instructions provided for your specific platform.
3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page.
4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier.
### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies?
#### Cambricon MLUs
You don't. If you have another UI installed and working with its own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it:
For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a step-by-step guide tailored to your platform and installation method:
```source path_to_other_sd_gui/venv/bin/activate```
1. Install the Cambricon CNToolkit by adhering to the platform-specific instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cntoolkit_3.7.2/cntoolkit_install_3.7.2/index.html)
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
3. Launch ComfyUI by running `python main.py`
or on Windows:
#### Iluvatar Corex
With Powershell: ```"path_to_other_sd_gui\venv\Scripts\Activate.ps1"```
For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method:
With cmd.exe: ```"path_to_other_sd_gui\venv\Scripts\activate.bat"```
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
2. Launch ComfyUI by running `python main.py`
And then you can use that terminal to run ComfyUI without installing any dependencies. Note that the venv folder might be called something else depending on the SD UI.
# Running
@ -329,14 +189,6 @@ For 6700, 6600 and maybe other RDNA2 or older: ```HSA_OVERRIDE_GFX_VERSION=10.3.
For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 python main.py```
### AMD ROCm Tips
You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default.
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
# Notes
Only parts of the graph that have an output with all the correct inputs will be executed.
@ -360,67 +212,25 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
Use ```--preview-method auto``` to enable previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
## How to use TLS/SSL?
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`.
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
## Support and dev channel
[Discord](https://comfy.org/discord): Try the #help or #feedback channels.
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).
See also: [https://www.comfy.org/](https://www.comfy.org/)
## Frontend Development
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
### Reporting Issues and Requesting Features
For any bugs, issues, or feature requests related to the frontend, please use the [ComfyUI Frontend repository](https://github.com/Comfy-Org/ComfyUI_frontend). This will help us manage and address frontend-specific concerns more efficiently.
### Using the Latest Frontend
The new frontend is now the default for ComfyUI. However, please note:
1. The frontend in the main ComfyUI repository is updated fortnightly.
2. Daily releases are available in the separate frontend repository.
To use the most up-to-date frontend version:
1. For the latest daily release, launch ComfyUI with this command line argument:
```
--front-end-version Comfy-Org/ComfyUI_frontend@latest
```
2. For a specific version, replace `latest` with the desired version number:
```
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
```
This approach allows you to easily switch between the stable fortnightly release and the cutting-edge daily updates, or even specific versions for testing purposes.
### Accessing the Legacy Frontend
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
```
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
```
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
# QA
### Which GPU should I buy for this?
[See this page for some recommendations](https://github.com/comfyanonymous/ComfyUI/wiki/Which-GPU-should-I-buy-for-ComfyUI)

View File

@ -1,84 +0,0 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = alembic_db
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic_db/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = sqlite:///user/comfyui.db
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME

View File

@ -1,4 +0,0 @@
## Generate new revision
1. Update models in `/app/database/models.py`
2. Run `alembic revision --autogenerate -m "{your message}"`

View File

@ -1,64 +0,0 @@
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
from app.database.models import Base
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@ -1,28 +0,0 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

View File

@ -1,3 +0,0 @@
# ComfyUI Internal Routes
All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.

View File

@ -1,73 +0,0 @@
from aiohttp import web
from typing import Optional
from folder_paths import folder_names_and_paths, get_directory_by_type
from api_server.services.terminal_service import TerminalService
import app.logger
import os
class InternalRoutes:
'''
The top level web router for internal routes: /internal/*
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
Check README.md for more information.
'''
def __init__(self, prompt_server):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
self.prompt_server = prompt_server
self.terminal_service = TerminalService(prompt_server)
def setup_routes(self):
@self.routes.get('/logs')
async def get_logs(request):
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
@self.routes.get('/logs/raw')
async def get_raw_logs(request):
self.terminal_service.update_size()
return web.json_response({
"entries": list(app.logger.get_logs()),
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
})
@self.routes.patch('/logs/subscribe')
async def subscribe_logs(request):
json_data = await request.json()
client_id = json_data["clientId"]
enabled = json_data["enabled"]
if enabled:
self.terminal_service.subscribe(client_id)
else:
self.terminal_service.unsubscribe(client_id)
return web.Response(status=200)
@self.routes.get('/folder_paths')
async def get_folder_paths(request):
response = {}
for key in folder_names_and_paths:
response[key] = folder_names_and_paths[key][0]
return web.json_response(response)
@self.routes.get('/files/{directory_type}')
async def get_files(request: web.Request) -> web.Response:
directory_type = request.match_info['directory_type']
if directory_type not in ("output", "input", "temp"):
return web.json_response({"error": "Invalid directory type"}, status=400)
directory = get_directory_by_type(directory_type)
sorted_files = sorted(
(entry for entry in os.scandir(directory) if entry.is_file()),
key=lambda entry: -entry.stat().st_mtime
)
return web.json_response([entry.name for entry in sorted_files], status=200)
def get_app(self):
if self._app is None:
self._app = web.Application()
self.setup_routes()
self._app.add_routes(self.routes)
return self._app

View File

@ -1,60 +0,0 @@
from app.logger import on_flush
import os
import shutil
class TerminalService:
def __init__(self, server):
self.server = server
self.cols = None
self.rows = None
self.subscriptions = set()
on_flush(self.send_messages)
def get_terminal_size(self):
try:
size = os.get_terminal_size()
return (size.columns, size.lines)
except OSError:
try:
size = shutil.get_terminal_size()
return (size.columns, size.lines)
except OSError:
return (80, 24) # fallback to 80x24
def update_size(self):
columns, lines = self.get_terminal_size()
changed = False
if columns != self.cols:
self.cols = columns
changed = True
if lines != self.rows:
self.rows = lines
changed = True
if changed:
return {"cols": self.cols, "rows": self.rows}
return None
def subscribe(self, client_id):
self.subscriptions.add(client_id)
def unsubscribe(self, client_id):
self.subscriptions.discard(client_id)
def send_messages(self, entries):
if not len(entries) or not len(self.subscriptions):
return
new_size = self.update_size()
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
if client_id not in self.server.sockets:
# Automatically unsub if the socket has disconnected
self.unsubscribe(client_id)
continue
self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)

View File

@ -1,42 +0,0 @@
import os
from typing import List, Union, TypedDict, Literal
from typing_extensions import TypeGuard
class FileInfo(TypedDict):
name: str
path: str
type: Literal["file"]
size: int
class DirectoryInfo(TypedDict):
name: str
path: str
type: Literal["directory"]
FileSystemItem = Union[FileInfo, DirectoryInfo]
def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
return item["type"] == "file"
class FileSystemOperations:
@staticmethod
def walk_directory(directory: str) -> List[FileSystemItem]:
file_list: List[FileSystemItem] = []
for root, dirs, files in os.walk(directory):
for name in files:
file_path = os.path.join(root, name)
relative_path = os.path.relpath(file_path, directory)
file_list.append({
"name": name,
"path": relative_path,
"type": "file",
"size": os.path.getsize(file_path)
})
for name in dirs:
dir_path = os.path.join(root, name)
relative_path = os.path.relpath(dir_path, directory)
file_list.append({
"name": name,
"path": relative_path,
"type": "directory"
})
return file_list

View File

@ -1,7 +1,6 @@
import os
import json
from aiohttp import web
import logging
class AppSettings():
@ -9,21 +8,11 @@ class AppSettings():
self.user_manager = user_manager
def get_settings(self, request):
try:
file = self.user_manager.get_request_user_filepath(
request,
"comfy.settings.json"
)
except KeyError as e:
logging.error("User settings not found.")
raise web.HTTPUnauthorized() from e
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
if os.path.isfile(file):
try:
with open(file) as f:
return json.load(f)
except:
logging.error(f"The user settings file is corrupted: {file}")
return {}
with open(file) as f:
return json.load(f)
else:
return {}
@ -62,4 +51,4 @@ class AppSettings():
settings = self.get_settings(request)
settings[setting_id] = await request.json()
self.save_settings(request, settings)
return web.Response(status=200)
return web.Response(status=200)

View File

@ -1,145 +0,0 @@
from __future__ import annotations
import os
import folder_paths
import glob
from aiohttp import web
import json
import logging
from functools import lru_cache
from utils.json_util import merge_json_recursive
# Extra locale files to load into main.json
EXTRA_LOCALE_FILES = [
"nodeDefs.json",
"commands.json",
"settings.json",
]
def safe_load_json_file(file_path: str) -> dict:
if not os.path.exists(file_path):
return {}
try:
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError:
logging.error(f"Error loading {file_path}")
return {}
class CustomNodeManager:
@lru_cache(maxsize=1)
def build_translations(self):
"""Load all custom nodes translations during initialization. Translations are
expected to be loaded from `locales/` folder.
The folder structure is expected to be the following:
- custom_nodes/
- custom_node_1/
- locales/
- en/
- main.json
- commands.json
- settings.json
returned translations are expected to be in the following format:
{
"en": {
"nodeDefs": {...},
"commands": {...},
"settings": {...},
...{other main.json keys}
}
}
"""
translations = {}
for folder in folder_paths.get_folder_paths("custom_nodes"):
# Sort glob results for deterministic ordering
for custom_node_dir in sorted(glob.glob(os.path.join(folder, "*/"))):
locales_dir = os.path.join(custom_node_dir, "locales")
if not os.path.exists(locales_dir):
continue
for lang_dir in glob.glob(os.path.join(locales_dir, "*/")):
lang_code = os.path.basename(os.path.dirname(lang_dir))
if lang_code not in translations:
translations[lang_code] = {}
# Load main.json
main_file = os.path.join(lang_dir, "main.json")
node_translations = safe_load_json_file(main_file)
# Load extra locale files
for extra_file in EXTRA_LOCALE_FILES:
extra_file_path = os.path.join(lang_dir, extra_file)
key = extra_file.split(".")[0]
json_data = safe_load_json_file(extra_file_path)
if json_data:
node_translations[key] = json_data
if node_translations:
translations[lang_code] = merge_json_recursive(
translations[lang_code], node_translations
)
return translations
def add_routes(self, routes, webapp, loadedModules):
example_workflow_folder_names = ["example_workflows", "example", "examples", "workflow", "workflows"]
@routes.get("/workflow_templates")
async def get_workflow_templates(request):
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
files = []
for folder in folder_paths.get_folder_paths("custom_nodes"):
for folder_name in example_workflow_folder_names:
pattern = os.path.join(folder, f"*/{folder_name}/*.json")
matched_files = glob.glob(pattern)
files.extend(matched_files)
workflow_templates_dict = (
{}
) # custom_nodes folder name -> example workflow names
for file in files:
custom_nodes_name = os.path.basename(
os.path.dirname(os.path.dirname(file))
)
workflow_name = os.path.splitext(os.path.basename(file))[0]
workflow_templates_dict.setdefault(custom_nodes_name, []).append(
workflow_name
)
return web.json_response(workflow_templates_dict)
# Serve workflow templates from custom nodes.
for module_name, module_dir in loadedModules:
for folder_name in example_workflow_folder_names:
workflows_dir = os.path.join(module_dir, folder_name)
if os.path.exists(workflows_dir):
if folder_name != "example_workflows":
logging.debug(
"Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'",
folder_name, module_name)
webapp.add_routes(
[
web.static(
"/api/workflow_templates/" + module_name, workflows_dir
)
]
)
@routes.get("/i18n")
async def get_i18n(request):
"""Returns translations from all custom nodes' locales folders."""
return web.json_response(self.build_translations())

View File

@ -1,112 +0,0 @@
import logging
import os
import shutil
from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message
from comfy.cli_args import args
_DB_AVAILABLE = False
Session = None
try:
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
_DB_AVAILABLE = True
except ImportError as e:
log_startup_warning(
f"""
------------------------------------------------------------------------
Error importing dependencies: {e}
{get_missing_requirements_message()}
This error is happening because ComfyUI now uses a local sqlite database.
------------------------------------------------------------------------
""".strip()
)
def dependencies_available():
"""
Temporary function to check if the dependencies are available
"""
return _DB_AVAILABLE
def can_create_session():
"""
Temporary function to check if the database is available to create a session
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
"""
return dependencies_available() and Session is not None
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
config.set_main_option("sqlalchemy.url", args.database_url)
return config
def get_db_path():
url = args.database_url
if url.startswith("sqlite:///"):
return url.split("///")[1]
else:
raise ValueError(f"Unsupported database URL '{url}'.")
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
db_path = get_db_path()
db_exists = os.path.exists(db_path)
config = get_alembic_config()
# Check if we need to upgrade
engine = create_engine(db_url)
conn = engine.connect()
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(config)
target_rev = script.get_current_head()
if target_rev is None:
logging.warning("No target revision found.")
elif current_rev != target_rev:
# Backup the database pre upgrade
backup_path = db_path + ".bkp"
if db_exists:
shutil.copy(db_path, backup_path)
else:
backup_path = None
try:
command.upgrade(config, target_rev)
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
except Exception as e:
if backup_path:
# Restore the database from backup if upgrade fails
shutil.copy(backup_path, db_path)
os.remove(backup_path)
logging.exception("Error upgrading database: ")
raise e
global Session
Session = sessionmaker(bind=engine)
def create_session():
return Session()

View File

@ -1,14 +0,0 @@
from sqlalchemy.orm import declarative_base
Base = declarative_base()
def to_dict(obj):
fields = obj.__table__.columns.keys()
return {
field: (val.to_dict() if hasattr(val, "to_dict") else val)
for field in fields
if (val := getattr(obj, field))
}
# TODO: Define models here

View File

@ -3,93 +3,16 @@ import argparse
import logging
import os
import re
import sys
import tempfile
import zipfile
import importlib
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import Dict, TypedDict, Optional
from aiohttp import web
from importlib.metadata import version
from typing import TypedDict
import requests
from typing_extensions import NotRequired
from utils.install_util import get_missing_requirements_message, requirements_path
from comfy.cli_args import DEFAULT_VERSION_STRING
import app.logger
def frontend_install_warning_message():
return f"""
{get_missing_requirements_message()}
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
""".strip()
def parse_version(version: str) -> tuple[int, int, int]:
return tuple(map(int, version.split(".")))
def is_valid_version(version: str) -> bool:
"""Validate if a string is a valid semantic version (X.Y.Z format)."""
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
return bool(re.match(pattern, version))
def get_installed_frontend_version():
"""Get the currently installed frontend package version."""
frontend_version_str = version("comfyui-frontend-package")
return frontend_version_str
def get_required_frontend_version():
"""Get the required frontend version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-frontend-package=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-frontend-package not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required frontend version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
def check_frontend_version():
"""Check if the frontend version is up to date."""
try:
frontend_version_str = get_installed_frontend_version()
frontend_version = parse_version(frontend_version_str)
required_frontend_str = get_required_frontend_version()
required_frontend = parse_version(required_frontend_str)
if frontend_version < required_frontend:
app.logger.log_startup_warning(
f"""
________________________________________________________________________
WARNING WARNING WARNING WARNING WARNING
Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}.
{frontend_install_warning_message()}
________________________________________________________________________
""".strip()
)
else:
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
except Exception as e:
logging.error(f"Failed to check frontend version: {e}")
REQUEST_TIMEOUT = 10 # seconds
@ -145,22 +68,9 @@ class FrontEndProvider:
response.raise_for_status() # Raises an HTTPError if the response was an error
return response.json()
@cached_property
def latest_prerelease(self) -> Release:
"""Get the latest pre-release version - even if it's older than the latest release"""
release = [release for release in self.all_releases if release["prerelease"]]
if not release:
raise ValueError("No pre-releases found")
# GitHub returns releases in reverse chronological order, so first is latest
return release[0]
def get_release(self, version: str) -> Release:
if version == "latest":
return self.latest_release
elif version == "prerelease":
return self.latest_prerelease
else:
for release in self.all_releases:
if release["tag_name"] in [version, f"v{version}"]:
@ -199,146 +109,9 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager:
DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
def get_required_frontend_version(cls) -> str:
"""Get the required frontend package version."""
return get_required_frontend_version()
@classmethod
def get_installed_templates_version(cls) -> str:
"""Get the currently installed workflow templates package version."""
try:
templates_version_str = version("comfyui-workflow-templates")
return templates_version_str
except Exception:
return None
@classmethod
def get_required_templates_version(cls) -> str:
"""Get the required workflow templates version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-workflow-templates=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-workflow-templates not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required templates version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
@classmethod
def default_frontend_path(cls) -> str:
try:
import comfyui_frontend_package
return str(importlib.resources.files(comfyui_frontend_package) / "static")
except ImportError:
logging.error(
f"""
********** ERROR ***********
comfyui-frontend-package is not installed.
{frontend_install_warning_message()}
********** ERROR ***********
""".strip()
)
sys.exit(-1)
@classmethod
def template_asset_map(cls) -> Optional[Dict[str, str]]:
"""Return a mapping of template asset names to their absolute paths."""
try:
from comfyui_workflow_templates import (
get_asset_path,
iter_templates,
)
except ImportError:
logging.error(
f"""
********** ERROR ***********
comfyui-workflow-templates is not installed.
{frontend_install_warning_message()}
********** ERROR ***********
""".strip()
)
return None
try:
template_entries = list(iter_templates())
except Exception as exc:
logging.error(f"Failed to enumerate workflow templates: {exc}")
return None
asset_map: Dict[str, str] = {}
try:
for entry in template_entries:
for asset in entry.assets:
asset_map[asset.filename] = get_asset_path(
entry.template_id, asset.filename
)
except Exception as exc:
logging.error(f"Failed to resolve template asset paths: {exc}")
return None
if not asset_map:
logging.error("No workflow template assets found. Did the packages install correctly?")
return None
return asset_map
@classmethod
def legacy_templates_path(cls) -> Optional[str]:
"""Return the legacy templates directory shipped inside the meta package."""
try:
import comfyui_workflow_templates
return str(
importlib.resources.files(comfyui_workflow_templates) / "templates"
)
except ImportError:
logging.error(
f"""
********** ERROR ***********
comfyui-workflow-templates is not installed.
{frontend_install_warning_message()}
********** ERROR ***********
""".strip()
)
return None
@classmethod
def embedded_docs_path(cls) -> str:
"""Get the path to embedded documentation"""
try:
import comfyui_embedded_docs
return str(
importlib.resources.files(comfyui_embedded_docs) / "docs"
)
except ImportError:
logging.info("comfyui-embedded-docs package not found")
return None
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
@ -351,7 +124,7 @@ comfyui-workflow-templates is not installed.
Raises:
argparse.ArgumentTypeError: If the version string is invalid.
"""
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+[-._a-zA-Z0-9]*|latest|prerelease)$"
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
match_result = re.match(VERSION_PATTERN, value)
if match_result is None:
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
@ -359,15 +132,12 @@ comfyui-workflow-templates is not installed.
return match_result.group(1), match_result.group(2), match_result.group(3)
@classmethod
def init_frontend_unsafe(
cls, version_string: str, provider: Optional[FrontEndProvider] = None
) -> str:
def init_frontend_unsafe(cls, version_string: str) -> str:
"""
Initializes the frontend for the specified version.
Args:
version_string (str): The version string.
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
Returns:
str: The path to the initialized frontend.
@ -377,28 +147,10 @@ comfyui-workflow-templates is not installed.
main error source might be request timeout or invalid URL.
"""
if version_string == DEFAULT_VERSION_STRING:
check_frontend_version()
return cls.default_frontend_path()
return cls.DEFAULT_FRONTEND_PATH
repo_owner, repo_name, version = cls.parse_version_string(version_string)
if version.startswith("v"):
expected_path = str(
Path(cls.CUSTOM_FRONTENDS_ROOT)
/ f"{repo_owner}_{repo_name}"
/ version.lstrip("v")
)
if os.path.exists(expected_path):
logging.info(
f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}"
)
return expected_path
logging.info(
f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..."
)
provider = provider or FrontEndProvider(repo_owner, repo_name)
provider = FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version)
semantic_version = release["tag_name"].lstrip("v")
@ -406,21 +158,15 @@ comfyui-workflow-templates is not installed.
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
)
if not os.path.exists(web_root):
try:
os.makedirs(web_root, exist_ok=True)
logging.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
finally:
# Clean up the directory if it is empty, i.e. the download failed
if not os.listdir(web_root):
os.rmdir(web_root)
os.makedirs(web_root, exist_ok=True)
logging.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
return web_root
@classmethod
@ -439,19 +185,4 @@ comfyui-workflow-templates is not installed.
except Exception as e:
logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.")
check_frontend_version()
return cls.default_frontend_path()
@classmethod
def template_asset_handler(cls):
assets = cls.template_asset_map()
if not assets:
return None
async def serve_template(request: web.Request) -> web.StreamResponse:
rel_path = request.match_info.get("path", "")
target = assets.get(rel_path)
if target is None:
raise web.HTTPNotFound()
return web.FileResponse(target)
return serve_template
return cls.DEFAULT_FRONTEND_PATH

View File

@ -1,98 +0,0 @@
from collections import deque
from datetime import datetime
import io
import logging
import sys
import threading
logs = None
stdout_interceptor = None
stderr_interceptor = None
class LogInterceptor(io.TextIOWrapper):
def __init__(self, stream, *args, **kwargs):
buffer = stream.buffer
encoding = stream.encoding
super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
self._lock = threading.Lock()
self._flush_callbacks = []
self._logs_since_flush = []
def write(self, data):
entry = {"t": datetime.now().isoformat(), "m": data}
with self._lock:
self._logs_since_flush.append(entry)
# Simple handling for cr to overwrite the last output if it isnt a full line
# else logs just get full of progress messages
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
logs.pop()
logs.append(entry)
super().write(data)
def flush(self):
super().flush()
for cb in self._flush_callbacks:
cb(self._logs_since_flush)
self._logs_since_flush = []
def on_flush(self, callback):
self._flush_callbacks.append(callback)
def get_logs():
return logs
def on_flush(callback):
if stdout_interceptor is not None:
stdout_interceptor.on_flush(callback)
if stderr_interceptor is not None:
stderr_interceptor.on_flush(callback)
def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False):
global logs
if logs:
return
# Override output streams and log to buffer
logs = deque(maxlen=capacity)
global stdout_interceptor
global stderr_interceptor
stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
# Setup default global logger
logger = logging.getLogger()
logger.setLevel(log_level)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(message)s"))
if use_stdout:
# Only errors and critical to stderr
stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR)
# Lesser to stdout
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setFormatter(logging.Formatter("%(message)s"))
stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
logger.addHandler(stdout_handler)
logger.addHandler(stream_handler)
STARTUP_WARNINGS = []
def log_startup_warning(msg):
logging.warning(msg)
STARTUP_WARNINGS.append(msg)
def print_startup_warnings():
for s in STARTUP_WARNINGS:
logging.warning(s)
STARTUP_WARNINGS.clear()

View File

@ -1,195 +0,0 @@
from __future__ import annotations
import os
import base64
import json
import time
import logging
import folder_paths
import glob
import comfy.utils
from aiohttp import web
from PIL import Image
from io import BytesIO
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
class ModelFileManager:
def __init__(self) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
return self.cache.get(key, default)
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
self.cache[key] = value
def clear_cache(self):
self.cache.clear()
def add_routes(self, routes):
# NOTE: This is an experiment to replace `/models`
@routes.get("/experiment/models")
async def get_model_folders(request):
model_types = list(folder_paths.folder_names_and_paths.keys())
folder_black_list = ["configs", "custom_nodes"]
output_folders: list[dict] = []
for folder in model_types:
if folder in folder_black_list:
continue
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
return web.json_response(output_folders)
# NOTE: This is an experiment to replace `/models/{folder}`
@routes.get("/experiment/models/{folder}")
async def get_all_models(request):
folder = request.match_info.get("folder", None)
if not folder in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = self.get_model_file_list(folder)
return web.json_response(files)
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
async def get_model_preview(request):
folder_name = request.match_info.get("folder", None)
path_index = int(request.match_info.get("path_index", None))
filename = request.match_info.get("filename", None)
if not folder_name in folder_paths.folder_names_and_paths:
return web.Response(status=404)
folders = folder_paths.folder_names_and_paths[folder_name]
folder = folders[0][path_index]
full_filename = os.path.join(folder, filename)
previews = self.get_model_previews(full_filename)
default_preview = previews[0] if len(previews) > 0 else None
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404)
try:
with Image.open(default_preview) as img:
img_bytes = BytesIO()
img.save(img_bytes, format="WEBP")
img_bytes.seek(0)
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
except:
return web.Response(status=404)
def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name]
output_list: list[dict] = []
for index, folder in enumerate(folders[0]):
if not os.path.isdir(folder):
continue
out = self.cache_model_file_list_(folder)
if out is None:
out = self.recursive_search_models_(folder, index)
self.set_cache(folder, out)
output_list.extend(out[0])
return output_list
def cache_model_file_list_(self, folder: str):
model_file_list_cache = self.get_cache(folder)
if model_file_list_cache is None:
return None
if not os.path.isdir(folder):
return None
if os.path.getmtime(folder) != model_file_list_cache[1]:
return None
for x in model_file_list_cache[1]:
time_modified = model_file_list_cache[1][x]
folder = x
if os.path.getmtime(folder) != time_modified:
return None
return model_file_list_cache
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
if not os.path.isdir(directory):
return [], {}, time.perf_counter()
excluded_dir_names = [".git"]
# TODO use settings
include_hidden_files = False
result: list[str] = []
dirs: dict[str, float] = {}
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
if not include_hidden_files:
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
filenames = [f for f in filenames if not f.startswith(".")]
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
for file_name in filenames:
try:
full_path = os.path.join(dirpath, file_name)
relative_path = os.path.relpath(full_path, directory)
# Get file metadata
file_info = {
"name": relative_path,
"pathIndex": pathIndex,
"modified": os.path.getmtime(full_path), # Add modification time
"created": os.path.getctime(full_path), # Add creation time
"size": os.path.getsize(full_path) # Add file size
}
result.append(file_info)
except Exception as e:
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
continue
for d in subdirs:
path: str = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
return result, dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
return []
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
safetensors_metadata = {}
result: list[str | BytesIO] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
def __exit__(self, exc_type, exc_value, traceback):
self.clear_cache()

View File

@ -1,112 +0,0 @@
from __future__ import annotations
from typing import TypedDict
import os
import folder_paths
import glob
from aiohttp import web
import hashlib
class Source:
custom_node = "custom_node"
class SubgraphEntry(TypedDict):
source: str
"""
Source of subgraph - custom_nodes vs templates.
"""
path: str
"""
Relative path of the subgraph file.
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
"""
name: str
"""
Name of subgraph file.
"""
info: CustomNodeSubgraphEntryInfo
"""
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
"""
data: str
class CustomNodeSubgraphEntryInfo(TypedDict):
node_pack: str
"""Node pack name."""
class SubgraphManager:
def __init__(self):
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r') as f:
entry['data'] = f.read()
return entry
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
if entry is None:
return None
entry = entry.copy()
entry.pop('path', None)
if remove_data:
entry.pop('data', None)
return entry
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
entries = entries.copy()
for key in list(entries.keys()):
entries[key] = await self.sanitize_entry(entries[key], remove_data)
return entries
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
# if not forced to reload and cached, return cache
if not force_reload and self.cached_custom_node_subgraphs is not None:
return self.cached_custom_node_subgraphs
# Load subgraphs from custom nodes
subfolder = "subgraphs"
subgraphs_dict: dict[SubgraphEntry] = {}
for folder in folder_paths.get_folder_paths("custom_nodes"):
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
matched_files = glob.glob(pattern)
for file in matched_files:
# replace backslashes with forward slashes
file = file.replace('\\', '/')
info: CustomNodeSubgraphEntryInfo = {
"node_pack": "custom_nodes." + file.split('/')[-3]
}
source = Source.custom_node
# hash source + path to make sure id will be as unique as possible, but
# reproducible across backend reloads
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
entry: SubgraphEntry = {
"source": Source.custom_node,
"name": os.path.splitext(os.path.basename(file))[0],
"path": file,
"info": info,
}
subgraphs_dict[id] = entry
self.cached_custom_node_subgraphs = subgraphs_dict
return subgraphs_dict
async def get_custom_node_subgraph(self, id: str, loadedModules):
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
entry: SubgraphEntry = subgraphs.get(id, None)
if entry is not None and entry.get('data', None) is None:
await self.load_entry_data(entry)
return entry
def add_routes(self, routes, loadedModules):
@routes.get("/global_subgraphs")
async def get_global_subgraphs(request):
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
# that's the reasoning for the current implementation
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
@routes.get("/global_subgraphs/{id}")
async def get_global_subgraph(request):
id = request.match_info.get("id", None)
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
return web.json_response(await self.sanitize_entry(subgraph))

View File

@ -1,60 +1,38 @@
from __future__ import annotations
import json
import os
import re
import uuid
import glob
import shutil
import logging
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
import folder_paths
from folder_paths import user_directory
from .app_settings import AppSettings
from typing import TypedDict
default_user = "default"
class FileInfo(TypedDict):
path: str
size: int
modified: int
created: int
def get_file_info(path: str, relative_to: str) -> FileInfo:
return {
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
"size": os.path.getsize(path),
"modified": os.path.getmtime(path),
"created": os.path.getctime(path)
}
users_file = os.path.join(user_directory, "users.json")
class UserManager():
def __init__(self):
user_directory = folder_paths.get_user_directory()
global user_directory
self.settings = AppSettings(self)
if not os.path.exists(user_directory):
os.makedirs(user_directory, exist_ok=True)
os.mkdir(user_directory)
if not args.multi_user:
logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user:
if os.path.isfile(self.get_users_file()):
with open(self.get_users_file()) as f:
if os.path.isfile(users_file):
with open(users_file) as f:
self.users = json.load(f)
else:
self.users = {}
else:
self.users = {"default": "default"}
def get_users_file(self):
return os.path.join(folder_paths.get_user_directory(), "users.json")
def get_request_user_id(self, request):
user = "default"
if args.multi_user and "comfy-user" in request.headers:
@ -66,7 +44,7 @@ class UserManager():
return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
user_directory = folder_paths.get_user_directory()
global user_directory
if type == "userdata":
root_dir = user_directory
@ -81,10 +59,6 @@ class UserManager():
return None
if file is not None:
# Check if filename is url encoded
if "%" in file:
file = parse.unquote(file)
# prevent leaving /{type}/{user}
path = os.path.abspath(os.path.join(user_root, file))
if os.path.commonpath((user_root, path)) != user_root:
@ -106,7 +80,8 @@ class UserManager():
self.users[user_id] = name
with open(self.get_users_file(), "w") as f:
global users_file
with open(users_file, "w") as f:
json.dump(self.users, f)
return user_id
@ -137,171 +112,25 @@ class UserManager():
@routes.get("/userdata")
async def listuserdata(request):
"""
List user data files in a specified directory.
This endpoint allows listing files in a user's data directory, with options for recursion,
full file information, and path splitting.
Query Parameters:
- dir (required): The directory to list files from.
- recurse (optional): If "true", recursively list files in subdirectories.
- full_info (optional): If "true", return detailed file information (path, size, modified time).
- split (optional): If "true", split file paths into components (only applies when full_info is false).
Returns:
- 400: If 'dir' parameter is missing.
- 403: If the requested path is not allowed.
- 404: If the requested directory does not exist.
- 200: JSON response with the list of files or file information.
The response format depends on the query parameters:
- Default: List of relative file paths.
- full_info=true: List of dictionaries with file details.
- split=true (and full_info=false): List of lists, each containing path components.
"""
directory = request.rel_url.query.get('dir', '')
if not directory:
return web.Response(status=400, text="Directory not provided")
return web.Response(status=400)
path = self.get_request_user_filepath(request, directory)
if not path:
return web.Response(status=403, text="Invalid directory")
return web.Response(status=403)
if not os.path.exists(path):
return web.Response(status=404, text="Directory not found")
return web.Response(status=404)
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
results = glob.glob(os.path.join(
glob.escape(path), '**/*'), recursive=recurse)
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
split_path = request.rel_url.query.get('split', '').lower() == "true"
# Use different patterns based on whether we're recursing or not
if recurse:
pattern = os.path.join(glob.escape(path), '**', '*')
else:
pattern = os.path.join(glob.escape(path), '*')
def process_full_path(full_path: str) -> FileInfo | str | list[str]:
if full_info:
return get_file_info(full_path, path)
rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
if split_path:
return [rel_path] + rel_path.split('/')
return rel_path
results = [
process_full_path(full_path)
for full_path in glob.glob(pattern, recursive=recurse)
if os.path.isfile(full_path)
]
return web.json_response(results)
@routes.get("/v2/userdata")
async def list_userdata_v2(request):
"""
List files and directories in a user's data directory.
This endpoint provides a structured listing of contents within a specified
subdirectory of the user's data storage.
Query Parameters:
- path (optional): The relative path within the user's data directory
to list. Defaults to the root ('').
Returns:
- 400: If the requested path is invalid, outside the user's data directory, or is not a directory.
- 404: If the requested path does not exist.
- 403: If the user is invalid.
- 500: If there is an error reading the directory contents.
- 200: JSON response containing a list of file and directory objects.
Each object includes:
- name: The name of the file or directory.
- type: 'file' or 'directory'.
- path: The relative path from the user's data root.
- size (for files): The size in bytes.
- modified (for files): The last modified timestamp (Unix epoch).
"""
requested_rel_path = request.rel_url.query.get('path', '')
# URL-decode the path parameter
try:
requested_rel_path = parse.unquote(requested_rel_path)
except Exception as e:
logging.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}")
return web.Response(status=400, text="Invalid characters in path parameter")
# Check user validity and get the absolute path for the requested directory
try:
base_user_path = self.get_request_user_filepath(request, None, create_dir=False)
if requested_rel_path:
target_abs_path = self.get_request_user_filepath(request, requested_rel_path, create_dir=False)
else:
target_abs_path = base_user_path
except KeyError as e:
# Invalid user detected by get_request_user_id inside get_request_user_filepath
logging.warning(f"Access denied for user: {e}")
return web.Response(status=403, text="Invalid user specified in request")
if not target_abs_path:
# Path traversal or other issue detected by get_request_user_filepath
return web.Response(status=400, text="Invalid path requested")
# Handle cases where the user directory or target path doesn't exist
if not os.path.exists(target_abs_path):
# Check if it's the base user directory that's missing (new user case)
if target_abs_path == base_user_path:
# It's okay if the base user directory doesn't exist yet, return empty list
return web.json_response([])
else:
# A specific subdirectory was requested but doesn't exist
return web.Response(status=404, text="Requested path not found")
if not os.path.isdir(target_abs_path):
return web.Response(status=400, text="Requested path is not a directory")
results = []
try:
for root, dirs, files in os.walk(target_abs_path, topdown=True):
# Process directories
for dir_name in dirs:
dir_path = os.path.join(root, dir_name)
rel_path = os.path.relpath(dir_path, base_user_path).replace(os.sep, '/')
results.append({
"name": dir_name,
"path": rel_path,
"type": "directory"
})
# Process files
for file_name in files:
file_path = os.path.join(root, file_name)
rel_path = os.path.relpath(file_path, base_user_path).replace(os.sep, '/')
entry_info = {
"name": file_name,
"path": rel_path,
"type": "file"
}
try:
stats = os.stat(file_path) # Use os.stat for potentially better performance with os.walk
entry_info["size"] = stats.st_size
entry_info["modified"] = stats.st_mtime
except OSError as stat_error:
logging.warning(f"Could not stat file {file_path}: {stat_error}")
pass # Include file with available info
results.append(entry_info)
except OSError as e:
logging.error(f"Error listing directory {target_abs_path}: {e}")
return web.Response(status=500, text="Error reading directory contents")
# Sort results alphabetically, directories first then files
results.sort(key=lambda x: (x['type'] != 'directory', x['name'].lower()))
if split_path:
results = [[x] + x.split(os.sep) for x in results]
return web.json_response(results)
@ -309,14 +138,14 @@ class UserManager():
file = request.match_info.get(param, None)
if not file:
return web.Response(status=400)
path = self.get_request_user_filepath(request, file)
if not path:
return web.Response(status=403)
if check_exists and not os.path.exists(path):
return web.Response(status=404)
return path
@routes.get("/userdata/{file}")
@ -324,63 +153,25 @@ class UserManager():
path = get_user_data_path(request, check_exists=True)
if not isinstance(path, str):
return path
return web.FileResponse(path)
@routes.post("/userdata/{file}")
async def post_userdata(request):
"""
Upload or update a user data file.
This endpoint handles file uploads to a user's data directory, with options for
controlling overwrite behavior and response format.
Query Parameters:
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
If "false", returns only the relative file path.
Path Parameters:
- file: The target file path (URL encoded if necessary).
Returns:
- 400: If 'file' parameter is missing.
- 403: If the requested path is not allowed.
- 409: If overwrite=false and the file already exists.
- 200: JSON response with either:
- Full file information (if full_info=true)
- Relative file path (if full_info=false)
The request body should contain the raw file content to be written.
"""
path = get_user_data_path(request)
if not isinstance(path, str):
return path
overwrite = request.query.get("overwrite", 'true') != "false"
full_info = request.query.get('full_info', 'false').lower() == "true"
overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(path):
return web.Response(status=409, text="File already exists")
return web.Response(status=409)
try:
body = await request.read()
with open(path, "wb") as f:
f.write(body)
except OSError as e:
logging.warning(f"Error saving file '{path}': {e}")
return web.Response(
status=400,
reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
)
user_path = self.get_request_user_filepath(request, None)
if full_info:
resp = get_file_info(path, user_path)
else:
resp = os.path.relpath(path, user_path)
body = await request.read()
with open(path, "wb") as f:
f.write(body)
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
return web.json_response(resp)
@routes.delete("/userdata/{file}")
@ -390,56 +181,25 @@ class UserManager():
return path
os.remove(path)
return web.Response(status=204)
@routes.post("/userdata/{file}/move/{dest}")
async def move_userdata(request):
"""
Move or rename a user data file.
This endpoint handles moving or renaming files within a user's data directory, with options for
controlling overwrite behavior and response format.
Path Parameters:
- file: The source file path (URL encoded if necessary)
- dest: The destination file path (URL encoded if necessary)
Query Parameters:
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
If "false", returns only the relative file path.
Returns:
- 400: If either 'file' or 'dest' parameter is missing
- 403: If either requested path is not allowed
- 404: If the source file does not exist
- 409: If overwrite=false and the destination file already exists
- 200: JSON response with either:
- Full file information (if full_info=true)
- Relative file path (if full_info=false)
"""
source = get_user_data_path(request, check_exists=True)
if not isinstance(source, str):
return source
dest = get_user_data_path(request, check_exists=False, param="dest")
if not isinstance(source, str):
return dest
overwrite = request.query.get("overwrite", 'true') != "false"
full_info = request.query.get('full_info', 'false').lower() == "true"
overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(dest):
return web.Response(status=409, text="File already exists")
return web.Response(status=409)
logging.info(f"moving '{source}' -> '{dest}'")
print(f"moving '{source}' -> '{dest}'")
shutil.move(source, dest)
user_path = self.get_request_user_filepath(request, None)
if full_info:
resp = get_file_info(dest, user_path)
else:
resp = os.path.relpath(dest, user_path)
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
return web.json_response(resp)

View File

@ -1,91 +0,0 @@
from .wav2vec2 import Wav2Vec2Model
from .whisper import WhisperLargeV3
import comfy.model_management
import comfy.ops
import comfy.utils
import logging
import torchaudio
class AudioEncoderModel():
def __init__(self, config):
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
model_type = config.pop("model_type")
model_config = dict(config)
model_config.update({
"dtype": self.dtype,
"device": offload_device,
"operations": comfy.ops.manual_cast
})
if model_type == "wav2vec2":
self.model = Wav2Vec2Model(**model_config)
elif model_type == "whisper3":
self.model = WhisperLargeV3(**model_config)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
def get_sd(self):
return self.model.state_dict()
def encode_audio(self, audio, sample_rate):
comfy.model_management.load_model_gpu(self.patcher)
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
out, all_layers = self.model(audio.to(self.load_device))
outputs = {}
outputs["encoded_audio"] = out
outputs["encoded_audio_all_layers"] = all_layers
outputs["audio_samples"] = audio.shape[2]
return outputs
def load_audio_encoder_from_sd(sd, prefix=""):
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
if "encoder.layer_norm.bias" in sd: #wav2vec2
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
if embed_dim == 1024:# large
config = {
"model_type": "wav2vec2",
"embed_dim": 1024,
"num_heads": 16,
"num_layers": 24,
"conv_norm": True,
"conv_bias": True,
"do_normalize": True,
"do_stable_layer_norm": True
}
elif embed_dim == 768: # base
config = {
"model_type": "wav2vec2",
"embed_dim": 768,
"num_heads": 12,
"num_layers": 12,
"conv_norm": False,
"conv_bias": False,
"do_normalize": False, # chinese-wav2vec2-base has this False
"do_stable_layer_norm": False
}
else:
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
elif "model.encoder.embed_positions.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
config = {
"model_type": "whisper3",
}
else:
raise RuntimeError("ERROR: audio encoder not supported.")
audio_encoder = AudioEncoderModel(config)
m, u = audio_encoder.load_sd(sd)
if len(m) > 0:
logging.warning("missing audio encoder: {}".format(m))
if len(u) > 0:
logging.warning("unexpected audio encoder: {}".format(u))
return audio_encoder

View File

@ -1,252 +0,0 @@
import torch
import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention_masked
class LayerNormConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
super().__init__()
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype)
def forward(self, x):
x = self.conv(x)
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
class LayerGroupNormConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
super().__init__()
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype)
def forward(self, x):
x = self.conv(x)
return torch.nn.functional.gelu(self.layer_norm(x))
class ConvNoNorm(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
super().__init__()
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
def forward(self, x):
x = self.conv(x)
return torch.nn.functional.gelu(x)
class ConvFeatureEncoder(nn.Module):
def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None):
super().__init__()
if conv_norm:
self.conv_layers = nn.ModuleList([
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
])
else:
self.conv_layers = nn.ModuleList([
LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations),
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
])
def forward(self, x):
x = x.unsqueeze(1)
for conv in self.conv_layers:
x = conv(x)
return x.transpose(1, 2)
class FeatureProjection(nn.Module):
def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None):
super().__init__()
self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype)
self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype)
def forward(self, x):
x = self.layer_norm(x)
x = self.projection(x)
return x
class PositionalConvEmbedding(nn.Module):
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
super().__init__()
self.conv = nn.Conv1d(
embed_dim,
embed_dim,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=groups,
)
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
self.activation = nn.GELU()
def forward(self, x):
x = x.transpose(1, 2)
x = self.conv(x)[:, :, :-1]
x = self.activation(x)
x = x.transpose(1, 2)
return x
class TransformerEncoder(nn.Module):
def __init__(
self,
embed_dim=768,
num_heads=12,
num_layers=12,
mlp_ratio=4.0,
do_stable_layer_norm=True,
dtype=None, device=None, operations=None
):
super().__init__()
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
self.layers = nn.ModuleList([
TransformerEncoderLayer(
embed_dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
do_stable_layer_norm=do_stable_layer_norm,
device=device, dtype=dtype, operations=operations
)
for _ in range(num_layers)
])
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
self.do_stable_layer_norm = do_stable_layer_norm
def forward(self, x, mask=None):
x = x + self.pos_conv_embed(x)
all_x = ()
if not self.do_stable_layer_norm:
x = self.layer_norm(x)
for layer in self.layers:
all_x += (x,)
x = layer(x, mask)
if self.do_stable_layer_norm:
x = self.layer_norm(x)
all_x += (x,)
return x, all_x
class Attention(nn.Module):
def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
def forward(self, x, mask=None):
assert (mask is None) # TODO?
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
out = optimized_attention_masked(q, k, v, self.num_heads)
return self.out_proj(out)
class FeedForward(nn.Module):
def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None):
super().__init__()
self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype)
self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype)
def forward(self, x):
x = self.intermediate_dense(x)
x = torch.nn.functional.gelu(x)
x = self.output_dense(x)
return x
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
embed_dim=768,
num_heads=12,
mlp_ratio=4.0,
do_stable_layer_norm=True,
dtype=None, device=None, operations=None
):
super().__init__()
self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations)
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
self.do_stable_layer_norm = do_stable_layer_norm
def forward(self, x, mask=None):
residual = x
if self.do_stable_layer_norm:
x = self.layer_norm(x)
x = self.attention(x, mask=mask)
x = residual + x
if not self.do_stable_layer_norm:
x = self.layer_norm(x)
return self.final_layer_norm(x + self.feed_forward(x))
else:
return x + self.feed_forward(self.final_layer_norm(x))
class Wav2Vec2Model(nn.Module):
"""Complete Wav2Vec 2.0 model."""
def __init__(
self,
embed_dim=1024,
final_dim=256,
num_heads=16,
num_layers=24,
conv_norm=True,
conv_bias=True,
do_normalize=True,
do_stable_layer_norm=True,
dtype=None, device=None, operations=None
):
super().__init__()
conv_dim = 512
self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations)
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
self.do_normalize = do_normalize
self.encoder = TransformerEncoder(
embed_dim=embed_dim,
num_heads=num_heads,
num_layers=num_layers,
do_stable_layer_norm=do_stable_layer_norm,
device=device, dtype=dtype, operations=operations
)
def forward(self, x, mask_time_indices=None, return_dict=False):
x = torch.mean(x, dim=1)
if self.do_normalize:
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
features = self.feature_extractor(x)
features = self.feature_projection(features)
batch_size, seq_len, _ = features.shape
x, all_x = self.encoder(features)
return x, all_x

View File

@ -1,186 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from typing import Optional
from comfy.ldm.modules.attention import optimized_attention_masked
import comfy.ops
class WhisperFeatureExtractor(nn.Module):
def __init__(self, n_mels=128, device=None):
super().__init__()
self.sample_rate = 16000
self.n_fft = 400
self.hop_length = 160
self.n_mels = n_mels
self.chunk_length = 30
self.n_samples = 480000
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=self.sample_rate,
n_fft=self.n_fft,
hop_length=self.hop_length,
n_mels=self.n_mels,
f_min=0,
f_max=8000,
norm="slaney",
mel_scale="slaney",
).to(device)
def __call__(self, audio):
audio = torch.mean(audio, dim=1)
batch_size = audio.shape[0]
processed_audio = []
for i in range(batch_size):
aud = audio[i]
if aud.shape[0] > self.n_samples:
aud = aud[:self.n_samples]
elif aud.shape[0] < self.n_samples:
aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
processed_audio.append(aud)
audio = torch.stack(processed_audio)
mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)
log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
log_mel_spec = (log_mel_spec + 4.0) / 4.0
return log_mel_spec
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seq_len, _ = query.shape
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
attn_output = self.out_proj(attn_output)
return attn_output
class EncoderLayer(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
residual = x
x = self.self_attn_layer_norm(x)
x = self.self_attn(x, x, x, attention_mask)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.fc1(x)
x = F.gelu(x)
x = self.fc2(x)
x = residual + x
return x
class AudioEncoder(nn.Module):
def __init__(
self,
n_mels: int = 128,
n_ctx: int = 1500,
n_state: int = 1280,
n_head: int = 20,
n_layer: int = 32,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)
self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)
self.layers = nn.ModuleList([
EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
for _ in range(n_layer)
])
self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.transpose(1, 2)
x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
all_x = ()
for layer in self.layers:
all_x += (x,)
x = layer(x)
x = self.layer_norm(x)
all_x += (x,)
return x, all_x
class WhisperLargeV3(nn.Module):
def __init__(
self,
n_mels: int = 128,
n_audio_ctx: int = 1500,
n_audio_state: int = 1280,
n_audio_head: int = 20,
n_audio_layer: int = 32,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)
self.encoder = AudioEncoder(
n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
dtype=dtype, device=device, operations=operations
)
def forward(self, audio):
mel = self.feature_extractor(audio)
x, all_x = self.encoder(mel)
return x, all_x

View File

@ -2,9 +2,11 @@
#and modified
import torch
import torch as th
import torch.nn as nn
from ..ldm.modules.diffusionmodules.util import (
zero_module,
timestep_embedding,
)
@ -160,6 +162,7 @@ class ControlNet(nn.Module):
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
@ -412,6 +415,7 @@ class ControlNet(nn.Module):
out_output = []
out_middle = []
hs = []
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)

View File

@ -1,120 +0,0 @@
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
class ControlNetEmbedder(nn.Module):
def __init__(
self,
img_size: int,
patch_size: int,
in_chans: int,
attention_head_dim: int,
num_attention_heads: int,
adm_in_channels: int,
num_layers: int,
main_model_double: int,
double_y_emb: bool,
device: torch.device,
dtype: torch.dtype,
pos_embed_max_size: Optional[int] = None,
operations = None,
):
super().__init__()
self.main_model_double = main_model_double
self.dtype = dtype
self.hidden_size = num_attention_heads * attention_head_dim
self.patch_size = patch_size
self.x_embedder = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=pos_embed_max_size is None,
device=device,
dtype=dtype,
operations=operations,
)
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
self.double_y_emb = double_y_emb
if self.double_y_emb:
self.orig_y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.y_embedder = VectorEmbedder(
self.hidden_size, self.hidden_size, dtype, device, operations=operations
)
else:
self.y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.transformer_blocks = nn.ModuleList(
DismantledBlock(
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(num_layers)
)
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
# TODO double check this logic when 8b
self.use_y_embedder = True
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=False,
device=device,
dtype=dtype,
operations=operations,
)
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
hint = None,
) -> Tuple[Tensor, List[Tensor]]:
x_shape = list(x.shape)
x = self.x_embedder(x)
if not self.double_y_emb:
h = (x_shape[-2] + 1) // self.patch_size
w = (x_shape[-1] + 1) // self.patch_size
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
c = self.t_embedder(timesteps, dtype=x.dtype)
if y is not None and self.y_embedder is not None:
if self.double_y_emb:
y = self.orig_y_embedder(y)
y = self.y_embedder(y)
c = c + y
x = x + self.pos_embed_input(hint)
block_out = ()
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
for i in range(len(self.transformer_blocks)):
out = self.transformer_blocks[i](x, c)
if not self.double_y_emb:
x = out
block_out += (self.controlnet_blocks[i](out),) * repeat
return {"output": block_out}

View File

@ -1,12 +1,11 @@
import torch
from typing import Optional
from typing import Dict, Optional
import comfy.ldm.modules.diffusionmodules.mmdit
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
def __init__(
self,
num_blocks = None,
control_latent_channels = None,
dtype = None,
device = None,
operations = None,
@ -18,13 +17,10 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
for _ in range(len(self.joint_blocks)):
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
if control_latent_channels is None:
control_latent_channels = self.in_channels
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
None,
self.patch_size,
control_latent_channels,
self.in_channels,
self.hidden_size,
bias=True,
strict_img_size=False,

View File

@ -1,6 +1,7 @@
import argparse
import enum
import os
from typing import Optional
import comfy.options
@ -35,22 +36,20 @@ class EnumAction(argparse.Action):
parser = argparse.ArgumentParser()
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.")
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
@ -61,13 +60,10 @@ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.")
fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
@ -81,15 +77,12 @@ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Stor
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
@ -99,20 +92,10 @@ class LatentPreviewMethod(enum.Enum):
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
@ -129,43 +112,21 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
AutoTune = "autotune"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes. Also prevents the frontend from communicating with the internet.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
@ -184,14 +145,13 @@ parser.add_argument(
""",
)
def is_valid_directory(path: str) -> str:
"""Validate if the given path is a directory, and check permissions."""
if not os.path.exists(path):
raise argparse.ArgumentTypeError(f"The path '{path}' does not exist.")
def is_valid_directory(path: Optional[str]) -> Optional[str]:
"""Validate if the given path is a directory."""
if path is None:
return None
if not os.path.isdir(path):
raise argparse.ArgumentTypeError(f"'{path}' is not a directory.")
if not os.access(path, os.R_OK):
raise argparse.ArgumentTypeError(f"You do not have read permissions for '{path}'.")
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
return path
parser.add_argument(
@ -201,22 +161,6 @@ parser.add_argument(
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
)
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
parser.add_argument(
"--comfy-api-base",
type=str,
default="https://api.comfy.org",
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
)
database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
if comfy.options.args_parsing:
args = parser.parse_args()
else:
@ -228,16 +172,9 @@ if args.windows_standalone_build:
if args.disable_auto_launch:
args.auto_launch = False
if args.force_fp16:
args.fp16_unet = True
import logging
logging_level = logging.INFO
if args.verbose:
logging_level = logging.DEBUG
# '--fast' is not provided, use an empty set
if args.fast is None:
args.fast = set()
# '--fast' is provided with an empty list, enable all optimizations
elif args.fast == []:
args.fast = set(PerformanceFeature)
# '--fast' is provided with a list of performance features, use that list
else:
args.fast = set(args.fast)
logging.basicConfig(format="%(message)s", level=logging_level)

View File

@ -5,7 +5,7 @@
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 49407,
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_size": 1280,
"initializer_factor": 1.0,

View File

@ -1,6 +1,5 @@
import torch
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ops
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
@ -23,7 +22,6 @@ class CLIPAttention(torch.nn.Module):
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu,
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
}
class CLIPMLP(torch.nn.Module):
@ -61,12 +59,8 @@ class CLIPEncoder(torch.nn.Module):
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
all_intermediate = None
if intermediate_output is not None:
if intermediate_output == "all":
all_intermediate = []
intermediate_output = None
elif intermediate_output < 0:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
@ -74,22 +68,16 @@ class CLIPEncoder(torch.nn.Module):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1)
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, operations=None):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
super().__init__()
self.token_embedding = operations.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, dtype=torch.float32):
return self.token_embedding(input_tokens, out_dtype=dtype) + comfy.ops.cast_to(self.position_embedding.weight, dtype=dtype, device=input_tokens.device)
def forward(self, input_tokens):
return self.token_embedding(input_tokens) + self.position_embedding.weight
class CLIPTextModel_(torch.nn.Module):
@ -99,27 +87,20 @@ class CLIPTextModel_(torch.nn.Module):
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
num_positions = config_dict["max_position_embeddings"]
self.eos_token_id = config_dict["eos_token_id"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
if embeds is not None:
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
else:
x = self.embeddings(input_tokens, dtype=dtype)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
x = self.embeddings(input_tokens)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1)
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
@ -130,10 +111,7 @@ class CLIPTextModel_(torch.nn.Module):
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
if num_tokens is not None:
pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
else:
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
return x, i, pooled_output
class CLIPTextModel(torch.nn.Module):
@ -143,6 +121,7 @@ class CLIPTextModel(torch.nn.Module):
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
embed_dim = config_dict["hidden_size"]
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
def get_input_embeddings(self):
@ -158,35 +137,27 @@ class CLIPTextModel(torch.nn.Module):
class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
super().__init__()
num_patches = (image_size // patch_size) ** 2
if model_type == "siglip_vision_model":
self.class_embedding = None
patch_bias = True
else:
num_patches = num_patches + 1
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
patch_bias = False
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
self.patch_embedding = operations.Conv2d(
in_channels=num_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=patch_bias,
bias=False,
dtype=dtype,
device=device
)
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
num_patches = (image_size // patch_size) ** 2
num_positions = num_patches + 1
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
if self.class_embedding is not None:
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
class CLIPVision(torch.nn.Module):
@ -197,15 +168,9 @@ class CLIPVision(torch.nn.Module):
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
model_type = config_dict["model_type"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
if model_type == "siglip_vision_model":
self.pre_layrnorm = lambda a: a
self.output_layernorm = True
else:
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.output_layernorm = False
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations)
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.post_layernorm = operations.LayerNorm(embed_dim)
@ -214,41 +179,16 @@ class CLIPVision(torch.nn.Module):
x = self.pre_layrnorm(x)
#TODO: attention_mask?
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
if self.output_layernorm:
x = self.post_layernorm(x)
pooled_output = x
else:
pooled_output = self.post_layernorm(x[:, 0, :])
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output
class LlavaProjector(torch.nn.Module):
def __init__(self, in_dim, out_dim, dtype, device, operations):
super().__init__()
self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
def forward(self, x):
return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
if "projection_dim" in config_dict:
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
else:
self.visual_projection = lambda a: a
if "llava3" == config_dict.get("projector_type", None):
self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
else:
self.multi_modal_projector = None
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs)
out = self.visual_projection(x[2])
projected = None
if self.multi_modal_projector is not None:
projected = self.multi_modal_projector(x[1])
return (x[0], x[1], out, projected)
return (x[0], x[1], out)

View File

@ -9,7 +9,6 @@ import comfy.model_patcher
import comfy.model_management
import comfy.utils
import comfy.clip_model
import comfy.image_encoders.dino2
class Output:
def __getitem__(self, key):
@ -17,50 +16,29 @@ class Output:
def __setitem__(self, key, item):
setattr(self, key, item)
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
if crop:
scale = (size / min(image.shape[2], image.shape[3]))
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
scale = (size / min(image.shape[2], image.shape[3]))
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
IMAGE_ENCODERS = {
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
}
class ClipVisionModel():
def __init__(self, json_config):
with open(json_config) as f:
config = json.load(f)
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
model_type = config.get("model_type", "clip_vision_model")
model_class = IMAGE_ENCODERS.get(model_type)
if model_type == "siglip_vision_model":
self.return_all_hidden_states = True
else:
self.return_all_hidden_states = False
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
@ -71,22 +49,15 @@ class ClipVisionModel():
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image, crop=True):
def encode_image(self, image):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output()
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
if self.return_all_hidden_states:
all_hs = out[1].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = all_hs[:, -2]
outputs["all_hidden_states"] = all_hs
else:
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
outputs["mm_projected"] = out[3]
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
return outputs
def convert_to_transformers(sd, prefix):
@ -123,25 +94,10 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
if embed_shape == 729:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif embed_shape == 1024:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
elif embed_shape == 577:
if "multi_modal_projector.linear_1.bias" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
# Dinov2
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
else:
return None
@ -153,7 +109,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
keys = list(sd.keys())
for k in keys:
if k not in u:
sd.pop(k)
t = sd.pop(k)
del t
return clip
def load(ckpt_path):

View File

@ -1,19 +0,0 @@
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"image_size": 336,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-5,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"projector_type": "llava3",
"torch_dtype": "float32"
}

View File

@ -1,13 +0,0 @@
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 384,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}

View File

@ -1,13 +0,0 @@
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 512,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 16,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}

View File

@ -1,43 +0,0 @@
# Comfy Typing
## Type hinting for ComfyUI Node development
This module provides type hinting and concrete convenience types for node developers.
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
```python
from comfy.comfy_types import IO, ComfyNodeABC, CheckLazyMixin
class ExampleNode(ComfyNodeABC):
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {"required": {}}
```
Full example is in [examples/example_nodes.py](examples/example_nodes.py).
# Types
A few primary types are documented below. More complete information is available via the docstrings on each type.
## `IO`
A string enum of built-in and a few custom data types. Includes the following special types and their requisite plumbing:
- `ANY`: `"*"`
- `NUMBER`: `"FLOAT,INT"`
- `PRIMITIVE`: `"STRING,FLOAT,INT,BOOLEAN"`
## `ComfyNodeABC`
An abstract base class for nodes, offering type-hinting / autocomplete, and somewhat-alright docstrings.
### Type hinting for `INPUT_TYPES`
![INPUT_TYPES auto-completion in Visual Studio Code](examples/input_types.png)
### `INPUT_TYPES` return dict
![INPUT_TYPES return value type hinting in Visual Studio Code](examples/required_hint.png)
### Options for individual inputs
![INPUT_TYPES return value option auto-completion in Visual Studio Code](examples/input_options.png)

View File

@ -1,28 +0,0 @@
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
from inspect import cleandoc
class ExampleNode(ComfyNodeABC):
"""An example node that just adds 1 to an input integer.
* Requires a modern IDE to provide any benefit (detail: an IDE configured with analysis paths etc).
* This node is intended as an example for developers only.
"""
DESCRIPTION = cleandoc(__doc__)
CATEGORY = "examples"
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"input_int": (IO.INT, {"defaultInput": True}),
}
}
RETURN_TYPES = (IO.INT,)
RETURN_NAMES = ("input_plus_one",)
FUNCTION = "execute"
def execute(self, input_int: int):
return (input_int + 1,)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 19 KiB

View File

@ -1,350 +0,0 @@
"""Comfy-specific type hinting"""
from __future__ import annotations
from typing import Literal, TypedDict, Optional
from typing_extensions import NotRequired
from abc import ABC, abstractmethod
from enum import Enum
class StrEnum(str, Enum):
"""Base class for string enums. Python's StrEnum is not available until 3.11."""
def __str__(self) -> str:
return self.value
class IO(StrEnum):
"""Node input/output data types.
Includes functionality for ``"*"`` (`ANY`) and ``"MULTI,TYPES"``.
"""
STRING = "STRING"
IMAGE = "IMAGE"
MASK = "MASK"
LATENT = "LATENT"
BOOLEAN = "BOOLEAN"
INT = "INT"
FLOAT = "FLOAT"
COMBO = "COMBO"
CONDITIONING = "CONDITIONING"
SAMPLER = "SAMPLER"
SIGMAS = "SIGMAS"
GUIDER = "GUIDER"
NOISE = "NOISE"
CLIP = "CLIP"
CONTROL_NET = "CONTROL_NET"
VAE = "VAE"
MODEL = "MODEL"
LORA_MODEL = "LORA_MODEL"
LOSS_MAP = "LOSS_MAP"
CLIP_VISION = "CLIP_VISION"
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
STYLE_MODEL = "STYLE_MODEL"
GLIGEN = "GLIGEN"
UPSCALE_MODEL = "UPSCALE_MODEL"
AUDIO = "AUDIO"
WEBCAM = "WEBCAM"
POINT = "POINT"
FACE_ANALYSIS = "FACE_ANALYSIS"
BBOX = "BBOX"
SEGS = "SEGS"
VIDEO = "VIDEO"
ANY = "*"
"""Always matches any type, but at a price.
Causes some functionality issues (e.g. reroutes, link types), and should be avoided whenever possible.
"""
NUMBER = "FLOAT,INT"
"""A float or an int - could be either"""
PRIMITIVE = "STRING,FLOAT,INT,BOOLEAN"
"""Could be any of: string, float, int, or bool"""
def __ne__(self, value: object) -> bool:
if self == "*" or value == "*":
return False
if not isinstance(value, str):
return True
a = frozenset(self.split(","))
b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b))
class RemoteInputOptions(TypedDict):
route: str
"""The route to the remote source."""
refresh_button: bool
"""Specifies whether to show a refresh button in the UI below the widget."""
control_after_refresh: Literal["first", "last"]
"""Specifies the control after the refresh button is clicked. If "first", the first item will be automatically selected, and so on."""
timeout: int
"""The maximum amount of time to wait for a response from the remote source in milliseconds."""
max_retries: int
"""The maximum number of retries before aborting the request."""
refresh: int
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
class MultiSelectOptions(TypedDict):
placeholder: NotRequired[str]
"""The placeholder text to display in the multi-select widget when no items are selected."""
chip: NotRequired[bool]
"""Specifies whether to use chips instead of comma separated values for the multi-select widget."""
class InputTypeOptions(TypedDict):
"""Provides type hinting for the return type of the INPUT_TYPES node function.
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes
"""
default: NotRequired[bool | str | float | int | list | tuple]
"""The default value of the widget"""
defaultInput: NotRequired[bool]
"""@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
- defaultInput on required inputs should be dropped.
- defaultInput on optional inputs should be replaced with forceInput.
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
"""
forceInput: NotRequired[bool]
"""Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
lazy: NotRequired[bool]
"""Declares that this input uses lazy evaluation"""
rawLink: NotRequired[bool]
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
tooltip: NotRequired[str]
"""Tooltip for the input (or widget), shown on pointer hover"""
socketless: NotRequired[bool]
"""All inputs (including widgets) have an input socket to connect links. When ``true``, if there is a widget for this input, no socket will be created.
Available from frontend v1.17.5
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548
"""
widgetType: NotRequired[str]
"""Specifies a type to be used for widget initialization if different from the input type.
Available from frontend v1.18.0
https://github.com/Comfy-Org/ComfyUI_frontend/pull/3550"""
# class InputTypeNumber(InputTypeOptions):
# default: float | int
min: NotRequired[float]
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
max: NotRequired[float]
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
step: NotRequired[float]
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
round: NotRequired[float]
"""Floats are rounded by this value (``FLOAT``)"""
# class InputTypeBoolean(InputTypeOptions):
# default: bool
label_on: NotRequired[str]
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
label_off: NotRequired[str]
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
# class InputTypeString(InputTypeOptions):
# default: str
multiline: NotRequired[bool]
"""Use a multiline text box (``STRING``)"""
placeholder: NotRequired[str]
"""Placeholder text to display in the UI when empty (``STRING``)"""
# Deprecated:
# defaultVal: str
dynamicPrompts: NotRequired[bool]
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
# class InputTypeCombo(InputTypeOptions):
image_upload: NotRequired[bool]
"""Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`."""
image_folder: NotRequired[Literal["input", "output", "temp"]]
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
"""
remote: NotRequired[RemoteInputOptions]
"""Specifies the configuration for a remote input.
Available after ComfyUI frontend v1.9.7
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
control_after_generate: NotRequired[bool]
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
options: NotRequired[list[str | int | float]]
"""COMBO type only. Specifies the selectable options for the combo widget.
Prefer:
["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
Over:
[["Option 1", "Option 2", "Option 3"]]
"""
multi_select: NotRequired[MultiSelectOptions]
"""COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
class HiddenInputTypeDict(TypedDict):
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
node_id: NotRequired[Literal["UNIQUE_ID"]]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
unique_id: NotRequired[Literal["UNIQUE_ID"]]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
prompt: NotRequired[Literal["PROMPT"]]
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
extra_pnginfo: NotRequired[Literal["EXTRA_PNGINFO"]]
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
dynprompt: NotRequired[Literal["DYNPROMPT"]]
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
class InputTypeDict(TypedDict):
"""Provides type hinting for node INPUT_TYPES.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs
"""
required: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
"""Describes all inputs that must be connected for the node to execute."""
optional: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
"""Describes inputs which do not need to be connected."""
hidden: NotRequired[HiddenInputTypeDict]
"""Offers advanced functionality and server-client communication.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
"""
class ComfyNodeABC(ABC):
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview
"""
DESCRIPTION: str
"""Node description, shown as a tooltip when hovering over the node.
Usage::
# Explicitly define the description
DESCRIPTION = "Example description here."
# Use the docstring of the node class.
DESCRIPTION = cleandoc(__doc__)
"""
CATEGORY: str
"""The category of the node, as per the "Add Node" menu.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#category
"""
EXPERIMENTAL: bool
"""Flags a node as experimental, informing users that it may change or not work as expected."""
DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
API_NODE: Optional[bool]
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
@classmethod
@abstractmethod
def INPUT_TYPES(s) -> InputTypeDict:
"""Defines node inputs.
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
* The ``optional`` key can be added to describe inputs which do not need to be connected.
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#input-types
"""
return {"required": {}}
OUTPUT_NODE: bool
"""Flags this node as an output node, causing any inputs it requires to be executed.
If a node is not connected to any output nodes, that node will not be executed. Usage::
OUTPUT_NODE = True
From the docs:
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
"""
INPUT_IS_LIST: bool
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
From the docs:
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
"""
OUTPUT_IS_LIST: tuple[bool, ...]
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
A ``tuple[bool]``, where the items match those in `RETURN_TYPES`::
RETURN_TYPES = (IO.INT, IO.INT, IO.STRING)
OUTPUT_IS_LIST = (True, True, False) # The string output will be handled normally
From the docs:
In order to tell Comfy that the list being returned should not be wrapped, but treated as a series of data for sequential processing,
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
specifying which outputs which should be so treated.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
"""
RETURN_TYPES: tuple[IO, ...]
"""A tuple representing the outputs of this node.
Usage::
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types
"""
RETURN_NAMES: tuple[str, ...]
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names
"""
OUTPUT_TOOLTIPS: tuple[str, ...]
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
FUNCTION: str
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#function
"""
class CheckLazyMixin:
"""Provides a basic check_lazy_status implementation and type hinting for nodes that use lazy inputs."""
def check_lazy_status(self, **kwargs) -> list[str]:
"""Returns a list of input names that should be evaluated.
This basic mixin impl. requires all inputs.
:kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \
When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``.
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status
"""
need = [name for name in kwargs if kwargs[name] is None]
return need
class FileLocator(TypedDict):
"""Provides type hinting for the file location"""
filename: str
"""The filename of the file."""
subfolder: str
"""The subfolder of the file."""
type: Literal["input", "output", "temp"]
"""The root folder of the file."""

View File

@ -1,9 +1,11 @@
import torch
import math
import comfy.utils
import logging
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
class CONDRegular:
def __init__(self, cond):
self.cond = cond
@ -11,15 +13,12 @@ class CONDRegular:
def _copy_with(self, cond):
return self.__class__(cond)
def process_cond(self, batch_size, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device, skipping concat.")
return False
return True
def concat(self, others):
@ -28,19 +27,15 @@ class CONDRegular:
conds.append(x.cond)
return torch.cat(conds)
def size(self):
return list(self.cond.size())
class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, area, **kwargs):
def process_cond(self, batch_size, device, area, **kwargs):
data = self.cond
if area is not None:
dims = len(area) // 2
for i in range(dims):
data = data.narrow(i + 2, area[i + dims], area[i])
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
class CONDCrossAttn(CONDRegular):
@ -51,13 +46,10 @@ class CONDCrossAttn(CONDRegular):
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
mult_min = math.lcm(s1[1], s2[1])
mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device: skipping concat.")
return False
return True
def concat(self, others):
@ -65,7 +57,7 @@ class CONDCrossAttn(CONDRegular):
crossattn_max_len = self.cond.shape[1]
for x in others:
c = x.cond
crossattn_max_len = math.lcm(crossattn_max_len, c.shape[1])
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
conds.append(c)
out = []
@ -75,12 +67,11 @@ class CONDCrossAttn(CONDRegular):
out.append(c)
return torch.cat(out)
class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, **kwargs):
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(self.cond)
def can_concat(self, other):
@ -90,48 +81,3 @@ class CONDConstant(CONDRegular):
def concat(self, others):
return self.cond
def size(self):
return [1]
class CONDList(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, **kwargs):
out = []
for c in self.cond:
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
return self._copy_with(out)
def can_concat(self, other):
if len(self.cond) != len(other.cond):
return False
for i in range(len(self.cond)):
if self.cond[i].shape != other.cond[i].shape:
return False
return True
def concat(self, others):
out = []
for i in range(len(self.cond)):
o = [self.cond[i]]
for x in others:
o.append(x.cond[i])
out.append(torch.cat(o))
return out
def size(self): # hackish implementation to make the mem estimation work
o = 0
c = 1
for c in self.cond:
size = c.size()
o += math.prod(size)
if len(size) > 1:
c = size[1]
return [1, c, o // c]

View File

@ -1,540 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable
import torch
import numpy as np
import collections
from dataclasses import dataclass
from abc import ABC, abstractmethod
import logging
import comfy.model_management
import comfy.patcher_extension
if TYPE_CHECKING:
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.controlnet import ControlBase
class ContextWindowABC(ABC):
def __init__(self):
...
@abstractmethod
def get_tensor(self, full: torch.Tensor) -> torch.Tensor:
"""
Get torch.Tensor applicable to current window.
"""
raise NotImplementedError("Not implemented.")
@abstractmethod
def add_window(self, full: torch.Tensor, to_add: torch.Tensor) -> torch.Tensor:
"""
Apply torch.Tensor of window to the full tensor, in place. Returns reference to updated full tensor, not a copy.
"""
raise NotImplementedError("Not implemented.")
class ContextHandlerABC(ABC):
def __init__(self):
...
@abstractmethod
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
raise NotImplementedError("Not implemented.")
@abstractmethod
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: ContextWindowABC, device=None) -> list:
raise NotImplementedError("Not implemented.")
@abstractmethod
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
raise NotImplementedError("Not implemented.")
class IndexListContextWindow(ContextWindowABC):
def __init__(self, index_list: list[int], dim: int=0):
self.index_list = index_list
self.context_length = len(index_list)
self.dim = dim
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
if dim is None:
dim = self.dim
if dim == 0 and full.shape[dim] == 1:
return full
idx = [slice(None)] * dim + [self.index_list]
return full[idx].to(device)
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
if dim is None:
dim = self.dim
idx = [slice(None)] * dim + [self.index_list]
full[idx] += to_add
return full
class IndexListCallbacks:
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
EXECUTE_START = "execute_start"
EXECUTE_CLEANUP = "execute_cleanup"
def init_callbacks(self):
return {}
@dataclass
class ContextSchedule:
name: str
func: Callable
@dataclass
class ContextFuseMethod:
name: str
func: Callable
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
self.context_schedule = context_schedule
self.fuse_method = fuse_method
self.context_length = context_length
self.context_overlap = context_overlap
self.context_stride = context_stride
self.closed_loop = closed_loop
self.dim = dim
self._step = 0
self.callbacks = {}
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
if x_in.size(self.dim) > self.context_length:
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
return True
return False
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
if control.previous_controlnet is not None:
self.prepare_control_objects(control.previous_controlnet, device)
return control
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: IndexListContextWindow, device=None) -> list:
if cond_in is None:
return None
# reuse or resize cond items to match context requirements
resized_cond = []
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in:
resized_actual_cond = actual_cond.copy()
# now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary
for key in actual_cond:
try:
cond_item = actual_cond[key]
if isinstance(cond_item, torch.Tensor):
# check that tensor is the expected length - x.size(0)
if self.dim < cond_item.ndim and cond_item.size(self.dim) == x_in.size(self.dim):
# if so, it's subsetting time - tell controls the expected indeces so they can handle them
actual_cond_item = window.get_tensor(cond_item)
resized_actual_cond[key] = actual_cond_item.to(device)
else:
resized_actual_cond[key] = cond_item.to(device)
# look for control
elif key == "control":
resized_actual_cond[key] = self.prepare_control_objects(cond_item, device)
elif isinstance(cond_item, dict):
new_cond_item = cond_item.copy()
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items():
if isinstance(cond_value, torch.Tensor):
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
# if has cond that is a Tensor, check if needs to be subset
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
elif cond_key == "num_video_frames": # for SVD
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
new_cond_item[cond_key].cond = window.context_length
resized_actual_cond[key] = new_cond_item
else:
resized_actual_cond[key] = cond_item
finally:
del cond_item # just in case to prevent VRAM issues
resized_cond.append(resized_actual_cond)
return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
self._step = int(matches[0].item())
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model
context_windows = self.context_schedule.func(full_length, self, model_options)
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
self.set_step(timestep, model_options)
context_windows = self.get_context_windows(model, x_in, model_options)
enumerated_context_windows = list(enumerate(context_windows))
conds_final = [torch.zeros_like(x_in) for _ in conds]
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
else:
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options)
for enum_window in enumerated_context_windows:
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
for result in results:
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
conds_final, counts_final, biases_final)
try:
# finalize conds
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
# relative is already normalized, so return as is
del counts_final
return conds_final
else:
# normalize conds via division by context usage counts
for i in range(len(conds_final)):
conds_final[i] /= counts_final[i]
del counts_final
return conds_final
finally:
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options)
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
model_options, device=None, first_device=None):
results: list[ContextResults] = []
for window_idx, window in enumerated_context_windows:
# allow processing to end between context window executions for faster Cancel
comfy.model_management.throw_exception_if_processing_interrupted()
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
# update exposed params
model_options["transformer_options"]["context_window"] = window
# get subsections of x, timestep, conds
sub_x = window.get_tensor(x_in, device)
sub_timestep = window.get_tensor(timestep, device, dim=0)
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
if device is not None:
for i in range(len(sub_conds_out)):
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
return results
def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor,
conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
for pos, idx in enumerate(window.index_list):
# bias is the influence of a specific index in relation to the whole context window
bias = 1 - abs(idx - (window.index_list[0] + window.index_list[-1]) / 2) / ((window.index_list[-1] - window.index_list[0] + 1e-2) / 2)
bias = max(1e-2, bias)
# take weighted average relative to total bias of current idx
for i in range(len(sub_conds_out)):
bias_total = biases_final[i][idx]
prev_weight = (bias_total / (bias_total + bias))
new_weight = (bias / (bias_total + bias))
# account for dims of tensors
idx_window = [slice(None)] * self.dim + [idx]
pos_window = [slice(None)] * self.dim + [pos]
# apply new values
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
biases_final[i][idx] = bias_total + bias
else:
# add conds and counts based on weights of fuse method
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
for i in range(len(sub_conds_out)):
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
window.add_window(counts_final[i], weights_tensor)
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
# limit noise_shape length to context_length for more accurate vram use estimation
model_options = kwargs.get("model_options", None)
if model_options is None:
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
handler: IndexListContextHandler = model_options.get("context_handler", None)
if handler is not None:
noise_shape = list(noise_shape)
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
return executor(model, noise_shape, *args, **kwargs)
def create_prepare_sampling_wrapper(model: ModelPatcher):
model.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING,
"ContextWindows_prepare_sampling",
_prepare_sampling_wrapper
)
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
total_dims = len(x_in.shape)
weights_tensor = torch.Tensor(weights).to(device=device)
for _ in range(dim):
weights_tensor = weights_tensor.unsqueeze(0)
for _ in range(total_dims - dim - 1):
weights_tensor = weights_tensor.unsqueeze(-1)
return weights_tensor
def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
total_dims = len(x_in.shape)
shape = []
for _ in range(dim):
shape.append(1)
shape.append(x_in.shape[dim])
for _ in range(total_dims - dim - 1):
shape.append(1)
return shape
class ContextSchedules:
UNIFORM_LOOPED = "looped_uniform"
UNIFORM_STANDARD = "standard_uniform"
STATIC_STANDARD = "standard_static"
BATCHED = "batched"
# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py
def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
windows = []
if num_frames < handler.context_length:
windows.append(list(range(num_frames)))
return windows
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
# obtain uniform windows as normal, looping and all
for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(handler._step)))
for j in range(
int(ordered_halving(handler._step) * context_step) + pad,
num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
(handler.context_length * context_step - handler.context_overlap),
):
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
return windows
def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
# unlike looped, uniform_straight does NOT allow windows that loop back to the beginning;
# instead, they get shifted to the corresponding end of the frames.
# in the case that a window (shifted or not) is identical to the previous one, it gets skipped.
windows = []
if num_frames <= handler.context_length:
windows.append(list(range(num_frames)))
return windows
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
# first, obtain uniform windows as normal, looping and all
for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(handler._step)))
for j in range(
int(ordered_halving(handler._step) * context_step) + pad,
num_frames + pad + (-handler.context_overlap),
(handler.context_length * context_step - handler.context_overlap),
):
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
# now that windows are created, shift any windows that loop, and delete duplicate windows
delete_idxs = []
win_i = 0
while win_i < len(windows):
# if window is rolls over itself, need to shift it
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
if is_roll:
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
shift_window_to_end(windows[win_i], num_frames=num_frames)
# check if next window (cyclical) is missing roll_val
if roll_val not in windows[(win_i+1) % len(windows)]:
# need to insert new window here - just insert window starting at roll_val
windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length)))
# delete window if it's not unique
for pre_i in range(0, win_i):
if windows[win_i] == windows[pre_i]:
delete_idxs.append(win_i)
break
win_i += 1
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
delete_idxs.reverse()
for i in delete_idxs:
windows.pop(i)
return windows
def create_windows_static_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
windows = []
if num_frames <= handler.context_length:
windows.append(list(range(num_frames)))
return windows
# always return the same set of windows
delta = handler.context_length - handler.context_overlap
for start_idx in range(0, num_frames, delta):
# if past the end of frames, move start_idx back to allow same context_length
ending = start_idx + handler.context_length
if ending >= num_frames:
final_delta = ending - num_frames
final_start_idx = start_idx - final_delta
windows.append(list(range(final_start_idx, final_start_idx + handler.context_length)))
break
windows.append(list(range(start_idx, start_idx + handler.context_length)))
return windows
def create_windows_batched(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
windows = []
if num_frames <= handler.context_length:
windows.append(list(range(num_frames)))
return windows
# always return the same set of windows;
# no overlap, just cut up based on context_length;
# last window size will be different if num_frames % opts.context_length != 0
for start_idx in range(0, num_frames, handler.context_length):
windows.append(list(range(start_idx, min(start_idx + handler.context_length, num_frames))))
return windows
def create_windows_default(num_frames: int, handler: IndexListContextHandler):
return [list(range(num_frames))]
CONTEXT_MAPPING = {
ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped,
ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard,
ContextSchedules.STATIC_STANDARD: create_windows_static_standard,
ContextSchedules.BATCHED: create_windows_batched,
}
def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
func = CONTEXT_MAPPING.get(context_schedule, None)
if func is None:
raise ValueError(f"Unknown context_schedule '{context_schedule}'.")
return ContextSchedule(context_schedule, func)
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
def create_weights_flat(length: int, **kwargs) -> list[float]:
# weight is the same for all
return [1.0] * length
def create_weights_pyramid(length: int, **kwargs) -> list[float]:
# weight is based on the distance away from the edge of the context window;
# based on weighted average concept in FreeNoise paper
if length % 2 == 0:
max_weight = length // 2
weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))
else:
max_weight = (length + 1) // 2
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
return weight_sequence
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
# only expected overlap is given different weights
weights_torch = torch.ones((length))
# blend left-side on all except first window
if min(idxs) > 0:
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
weights_torch[:handler.context_overlap] = ramp_up
# blend right-side on all except last window
if max(idxs) < full_length-1:
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
weights_torch[-handler.context_overlap:] = ramp_down
return weights_torch
class ContextFuseMethods:
FLAT = "flat"
PYRAMID = "pyramid"
RELATIVE = "relative"
OVERLAP_LINEAR = "overlap-linear"
LIST = [PYRAMID, FLAT, OVERLAP_LINEAR]
LIST_STATIC = [PYRAMID, RELATIVE, FLAT, OVERLAP_LINEAR]
FUSE_MAPPING = {
ContextFuseMethods.FLAT: create_weights_flat,
ContextFuseMethods.PYRAMID: create_weights_pyramid,
ContextFuseMethods.RELATIVE: create_weights_pyramid,
ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear,
}
def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod:
func = FUSE_MAPPING.get(fuse_method, None)
if func is None:
raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
return ContextFuseMethod(fuse_method, func)
# Returns fraction that has denominator that is a power of 2
def ordered_halving(val):
# get binary value, padded with 0s for 64 bits
bin_str = f"{val:064b}"
# flip binary value, padding included
bin_flip = bin_str[::-1]
# convert binary to int
as_int = int(bin_flip, 2)
# divide by 1 << 64, equivalent to 2**64, or 18446744073709551616,
# or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's)
return as_int / (1 << 64)
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
all_indexes = list(range(num_frames))
for w in windows:
for val in w:
try:
all_indexes.remove(val)
except ValueError:
pass
return all_indexes
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
prev_val = -1
for i, val in enumerate(window):
val = val % num_frames
if val < prev_val:
return True, i
prev_val = val
return False, -1
def shift_window_to_start(window: list[int], num_frames: int):
start_val = window[0]
for i in range(len(window)):
# 1) subtract each element by start_val to move vals relative to the start of all frames
# 2) add num_frames and take modulus to get adjusted vals
window[i] = ((window[i] - start_val) + num_frames) % num_frames
def shift_window_to_end(window: list[int], num_frames: int):
# 1) shift window to start
shift_window_to_start(window, num_frames)
end_val = window[-1]
end_delta = num_frames - end_val - 1
for i in range(len(window)):
# 2) add end_delta to each val to slide windows to end
window[i] = window[i] + end_delta

View File

@ -1,24 +1,4 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
from enum import Enum
import math
import os
import logging
@ -28,23 +8,16 @@ import comfy.model_detection
import comfy.model_patcher
import comfy.ops
import comfy.latent_formats
import comfy.model_base
import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
import comfy.ldm.qwen_image.controlnet
import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.hooks import HookGroup
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
#print(current_batch_size, target_batch_size)
if current_batch_size == 1:
return tensor
@ -60,12 +33,8 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
else:
return torch.cat([tensor] * batched_number, dim=0)
class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
class ControlBase:
def __init__(self):
def __init__(self, device=None):
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
@ -77,26 +46,18 @@ class ControlBase:
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}
self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
self.extra_concat_orig = []
self.extra_concat = None
self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
if device is None:
device = comfy.model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
self.cond_hint_original = cond_hint
self.strength = strength
self.timestep_percent_range = timestep_percent_range
if self.latent_format is not None:
if vae is None:
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
self.vae = vae
self.extra_concat_orig = extra_concat.copy()
if self.concat_mask and len(self.extra_concat_orig) == 0:
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
return self
def pre_run(self, model, percent_to_timestep_function):
@ -111,9 +72,9 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
self.cond_hint = None
self.extra_concat = None
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.timestep_range = None
def get_models(self):
@ -122,14 +83,6 @@ class ControlBase:
out += self.previous_controlnet.get_models()
return out
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:
out.append(self.extra_hooks)
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_extra_hooks()
return out
def copy_to(self, c):
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
@ -140,12 +93,6 @@ class ControlBase:
c.latent_format = self.latent_format
c.extra_args = self.extra_args.copy()
c.vae = self.vae
c.extra_conds = self.extra_conds.copy()
c.strength_type = self.strength_type
c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy()
c.extra_hooks = self.extra_hooks.clone() if self.extra_hooks else None
c.preprocess_image = self.preprocess_image
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
@ -166,12 +113,9 @@ class ControlBase:
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
applied_to.add(x)
if self.strength_type == StrengthType.CONSTANT:
x *= self.strength
elif self.strength_type == StrengthType.LINEAR_UP:
x *= (self.strength ** float(len(control_output) - i))
x *= self.strength
if output_dtype is not None and x.dtype != output_dtype:
if x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].append(x)
@ -198,8 +142,8 @@ class ControlBase:
class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
super().__init__()
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device)
self.control_model = control_model
self.load_device = load_device
if control_model is not None:
@ -210,15 +154,11 @@ class ControlNet(ControlBase):
self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype
self.latent_format = latent_format
self.extra_conds += extra_conds
self.strength_type = strength_type
self.concat_mask = concat_mask
self.preprocess_image = preprocess_image
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
@ -231,51 +171,34 @@ class ControlNet(ControlBase):
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
compression_ratio = self.compression_ratio
if self.vae is not None:
compression_ratio *= self.vae.spacial_compression_encode()
else:
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center")
self.cond_hint = self.preprocess_image(self.cond_hint)
compression_ratio *= self.vae.downscale_ratio
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
comfy.model_management.load_models_gpu(loaded_models)
if self.latent_format is not None:
self.cond_hint = self.latent_format.process_in(self.cond_hint)
if len(self.extra_concat_orig) > 0:
to_concat = []
for c in self.extra_concat_orig:
c = c.to(self.cond_hint.device)
c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center")
if c.ndim < self.cond_hint.ndim:
c = c.unsqueeze(2)
c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2)
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
extra = self.extra_args.copy()
for c in self.extra_conds:
temp = cond.get(c, None)
if temp is not None:
extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device)
y = cond.get('y', None)
if y is not None:
y = y.to(dtype)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra)
return self.control_merge(control, control_prev, output_dtype=None)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
return self.control_merge(control, control_prev, output_dtype)
def copy(self):
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
@ -301,6 +224,7 @@ class ControlLoraOps:
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
@ -310,13 +234,11 @@ class ControlLoraOps:
self.bias = None
def forward(self, input):
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
weight, bias = comfy.ops.cast_bias_weight(self, input)
if self.up is not None:
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
x = torch.nn.functional.linear(input, weight, bias)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
return torch.nn.functional.linear(input, weight, bias)
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(
@ -352,20 +274,18 @@ class ControlLoraOps:
def forward(self, input):
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
weight, bias = comfy.ops.cast_bias_weight(self, input)
if self.up is not None:
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
ControlBase.__init__(self)
def __init__(self, control_weights, global_average_pooling=False, device=None):
ControlBase.__init__(self, device)
self.control_weights = control_weights
self.global_average_pooling = global_average_pooling
self.extra_conds += ["y"]
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
@ -388,6 +308,7 @@ class ControlLora(ControlNet):
self.control_model.to(comfy.model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
cm = self.control_model.state_dict()
for k in sd:
weight = sd[k]
@ -397,9 +318,8 @@ class ControlLora(ControlNet):
pass
for k in self.control_weights:
if (k not in {"lora_controlnet"}):
if (k.endswith(".up") or k.endswith(".down") or k.endswith(".weight") or k.endswith(".bias")) and ("__" not in k):
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
if k not in {"lora_controlnet"}:
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
@ -418,204 +338,43 @@ class ControlLora(ControlNet):
def inference_memory_requirements(self, dtype):
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def controlnet_config(sd, model_options={}):
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
def load_controlnet_mmdit(sd):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = comfy.utils.weight_dtype(sd)
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
supported_inference_dtypes = model_config.supported_inference_dtypes
controlnet_config = model_config.unet_config
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
offload_device = comfy.model_management.unet_offload_device()
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model
def load_controlnet_mmdit(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
concat_mask = False
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
if control_latent_channels == 17: #inpaint controlnet
concat_mask = True
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
class ControlNetSD35(ControlNet):
def pre_run(self, model, percent_to_timestep_function):
if self.control_model.double_y_emb:
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
else:
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
super().pre_run(model, percent_to_timestep_function)
def copy(self):
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c
def load_controlnet_sd35(sd, model_options={}):
control_type = -1
if "control_type" in sd:
control_type = round(sd.pop("control_type").item())
# blur_cnet = control_type == 0
canny_cnet = control_type == 1
depth_cnet = control_type == 2
new_sd = {}
for k in comfy.utils.MMDIT_MAP_BASIC:
if k[1] in sd:
new_sd[k[0]] = sd.pop(k[1])
for k in sd:
new_sd[k] = sd[k]
sd = new_sd
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
depth = y_emb_shape[0] // 64
hidden_size = 64 * depth
num_heads = depth
head_dim = hidden_size // num_heads
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
patch_size=2,
in_chans=16,
num_layers=num_blocks,
main_model_double=depth,
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
attention_head_dim=head_dim,
num_attention_heads=num_heads,
adm_in_channels=2048,
device=offload_device,
dtype=unet_dtype,
operations=operations)
control_model = controlnet_load_state_dict(control_model, sd)
latent_format = comfy.latent_formats.SD3()
preprocess_image = lambda a: a
if canny_cnet:
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
elif depth_cnet:
preprocess_image = lambda a: 1.0 - a
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
return control
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data)
latent_format = comfy.latent_formats.SDXL()
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
return control
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet_flux_instantx(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
for k in sd:
new_sd[k] = sd[k]
num_union_modes = 0
union_cnet = "controlnet_mode_embedder.weight"
if union_cnet in new_sd:
num_union_modes = new_sd[union_cnet].shape[0]
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
concat_mask = False
if control_latent_channels == 17:
concat_mask = True
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.Flux()
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet_qwen_instantx(sd, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1]
extra_condition_channels = 0
concat_mask = False
if control_latent_channels == 68: #inpaint controlnet
extra_condition_channels = control_latent_channels - 64
concat_mask = True
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
latent_format = comfy.latent_formats.Wan21()
extra_conds = []
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
controlnet_data = state_dict
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
def load_controlnet(ckpt_path, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data, model_options=model_options)
return ControlLora(controlnet_data)
controlnet_config = None
supported_inference_dtypes = None
@ -670,21 +429,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
if len(leftover_keys) > 0:
logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data:
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
elif "pos_embed_input.proj.weight" in controlnet_data:
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
else:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data:
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
return load_controlnet_mmdit(controlnet_data)
pth_key = 'control_model.zero_convs.0.0.weight'
pth = False
@ -696,35 +442,26 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
elif key in controlnet_data:
prefix = ""
else:
net = load_t2i_adapter(controlnet_data, model_options=model_options)
net = load_t2i_adapter(controlnet_data)
if net is None:
logging.error("error could not detect control model type.")
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
return net
if controlnet_config is None:
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
supported_inference_dtypes = model_config.supported_inference_dtypes
controlnet_config = model_config.unet_config
unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
if supported_inference_dtypes is None:
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
load_device = comfy.model_management.get_torch_device()
if supported_inference_dtypes is None:
unet_dtype = comfy.model_management.unet_dtype()
else:
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
controlnet_config["operations"] = operations
if manual_cast_dtype is not None:
controlnet_config["operations"] = comfy.ops.manual_cast
controlnet_config["dtype"] = unet_dtype
controlnet_config["device"] = comfy.model_management.unet_offload_device()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
@ -758,33 +495,22 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
global_average_pooling = model_options.get("global_average_pooling", False)
global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
def load_controlnet(ckpt_path, model=None, model_options={}):
model_options = model_options.copy()
if "global_average_pooling" not in model_options:
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
model_options["global_average_pooling"] = True
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
if cnet is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
return cnet
class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
super().__init__()
super().__init__(device)
self.t2i_model = t2i_model
self.channels_in = channels_in
self.control_input = None
self.compression_ratio = compression_ratio
self.upscale_algorithm = upscale_algorithm
if device is None:
device = comfy.model_management.get_torch_device()
self.device = device
def scale_image_to(self, width, height):
unshuffle_amount = self.t2i_model.unshuffle_amount
@ -792,10 +518,10 @@ class T2IAdapter(ControlBase):
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
return width, height
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
@ -832,7 +558,7 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
def load_t2i_adapter(t2i_data):
compression_ratio = 8
upscale_algorithm = 'nearest-exact'
@ -843,7 +569,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
for i in range(4):
for j in range(2):
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
prefix_replace["adapter.body.{}.".format(i, )] = "body.{}.".format(i * 2)
prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
prefix_replace["adapter."] = ""
t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
keys = t2i_data.keys()

View File

@ -4,6 +4,105 @@ import logging
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
# =================#
# UNet Conversion #
# =================#
unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("out.0.weight", "conv_norm_out.weight"),
("out.0.bias", "conv_norm_out.bias"),
("out.2.weight", "conv_out.weight"),
("out.2.bias", "conv_out.bias"),
]
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0", "norm1"),
("in_layers.2", "conv1"),
("out_layers.0", "norm2"),
("out_layers.3", "conv2"),
("emb_layers.1", "time_emb_proj"),
("skip_connection", "conv_shortcut"),
]
unet_conversion_map_layer = []
# hardcoded number of downblocks and resnets/attentions...
# would need smarter logic for other networks.
for i in range(4):
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
def convert_unet_state_dict(unet_state_dict):
# buyer beware: this is a *brittle* function,
# and correct output requires that all of these pieces interact in
# the exact order in which I have arranged them.
mapping = {k: k for k in unet_state_dict.keys()}
for sd_name, hf_name in unet_conversion_map:
mapping[hf_name] = sd_name
for k, v in mapping.items():
if "resnets" in k:
for sd_part, hf_part in unet_conversion_map_resnet:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
for sd_part, hf_part in unet_conversion_map_layer:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
# ================#
# VAE Conversion #
# ================#
@ -58,23 +157,16 @@ vae_conversion_map_attn = [
]
def reshape_weight_for_sd(w, conv3d=False):
def reshape_weight_for_sd(w):
# convert HF linear weights to SD conv2d weights
if conv3d:
return w.reshape(*w.shape, 1, 1, 1)
else:
return w.reshape(*w.shape, 1, 1)
return w.reshape(*w.shape, 1, 1)
def convert_vae_state_dict(vae_state_dict):
mapping = {k: k for k in vae_state_dict.keys()}
conv3d = False
for k, v in mapping.items():
for sd_part, hf_part in vae_conversion_map:
v = v.replace(hf_part, sd_part)
if v.endswith(".conv.weight"):
if not conv3d and vae_state_dict[k].ndim == 5:
conv3d = True
mapping[k] = v
for k, v in mapping.items():
if "attentions" in k:
@ -87,7 +179,7 @@ def convert_vae_state_dict(vae_state_dict):
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
logging.debug(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v, conv3d=conv3d)
new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict
@ -114,7 +206,6 @@ textenc_pattern = re.compile("|".join(protected.keys()))
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx = {"q": 0, "k": 1, "v": 2}
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
def cat_tensors(tensors):
x = 0
@ -131,7 +222,6 @@ def cat_tensors(tensors):
return out
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
new_state_dict = {}
capture_qkv_weight = {}
@ -187,3 +277,5 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
def convert_text_enc_state_dict(text_enc_dict):
return text_enc_dict

View File

@ -22,7 +22,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
if text_encoder2_path is not None:
text_encoder_paths.append(text_encoder2_path)
unet = comfy.sd.load_diffusion_model(unet_path)
unet = comfy.sd.load_unet(unet_path)
clip = None
if output_clip:

View File

@ -1,10 +1,10 @@
#code taken from: https://github.com/wl-zhao/UniPC and modified
import torch
import torch.nn.functional as F
import math
import logging
from tqdm.auto import trange
from tqdm.auto import trange, tqdm
class NoiseScheduleVP:
@ -16,7 +16,7 @@ class NoiseScheduleVP:
continuous_beta_0=0.1,
continuous_beta_1=20.,
):
r"""Create a wrapper class for the forward SDE (VP type).
"""Create a wrapper class for the forward SDE (VP type).
***
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
@ -80,7 +80,7 @@ class NoiseScheduleVP:
'linear' or 'cosine' for continuous-time DPMs.
Returns:
A wrapper object of the forward SDE (VP type).
===============================================================
Example:
@ -208,7 +208,7 @@ def model_wrapper(
arXiv preprint arXiv:2202.00512 (2022).
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
arXiv preprint arXiv:2210.02303 (2022).
4. "score": marginal score function. (Trained by denoising score matching).
Note that the score function and the noise prediction model follows a simple relationship:
```
@ -226,7 +226,7 @@ def model_wrapper(
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
``
The input `classifier_fn` has the following format:
``
@ -240,12 +240,12 @@ def model_wrapper(
The input `model` has the following format:
``
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
``
``
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
arXiv preprint arXiv:2207.12598 (2022).
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
or continuous-time labels (i.e. epsilon to T).
@ -254,7 +254,7 @@ def model_wrapper(
``
def model_fn(x, t_continuous) -> noise:
t_input = get_model_input_time(t_continuous)
return noise_pred(model, x, t_input, **model_kwargs)
return noise_pred(model, x, t_input, **model_kwargs)
``
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
@ -359,7 +359,7 @@ class UniPC:
max_val=1.,
variant='bh1',
):
"""Construct a UniPC.
"""Construct a UniPC.
We support both data_prediction and noise_prediction.
"""
@ -372,7 +372,7 @@ class UniPC:
def dynamic_thresholding_fn(self, x0, t=None):
"""
The dynamic thresholding method.
The dynamic thresholding method.
"""
dims = x0.dim()
p = self.dynamic_thresholding_ratio
@ -404,7 +404,7 @@ class UniPC:
def model_fn(self, x, t):
"""
Convert the model to the noise prediction model or the data prediction model.
Convert the model to the noise prediction model or the data prediction model.
"""
if self.predict_x0:
return self.data_prediction_fn(x, t)
@ -461,7 +461,7 @@ class UniPC:
def denoise_to_zero_fn(self, x, s):
"""
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
"""
return self.data_prediction_fn(x, s)
@ -475,7 +475,7 @@ class UniPC:
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
logging.info(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
ns = self.noise_schedule
assert order <= len(model_prev_list)
@ -510,7 +510,7 @@ class UniPC:
col = torch.ones_like(rks)
for k in range(1, K + 1):
C.append(col)
col = col * rks / (k + 1)
col = col * rks / (k + 1)
C = torch.stack(C, dim=1)
if len(D1s) > 0:
@ -519,6 +519,7 @@ class UniPC:
A_p = C_inv_p
if use_corrector:
print('using corrector')
C_inv = torch.linalg.inv(C)
A_c = C_inv
@ -621,12 +622,12 @@ class UniPC:
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= (i + 1)
h_phi_k = h_phi_k / hh - 1 / factorial_i
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=x.device)
@ -661,7 +662,7 @@ class UniPC:
if x_t is None:
if use_predictor:
pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_p, D1s)
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
@ -669,7 +670,7 @@ class UniPC:
if use_corrector:
model_t = self.model_fn(x_t, t)
if D1s is not None:
corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = (model_t - model_prev_0)
@ -703,6 +704,7 @@ class UniPC:
):
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
# t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
steps = len(timesteps) - 1
if method == 'multistep':
assert steps >= order
@ -870,4 +872,4 @@ def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=F
return x
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')

View File

@ -1,67 +0,0 @@
import torch
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
mantissa_scaled = torch.where(
normal_mask,
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
)
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
#Not 100% sure about this
def manual_stochastic_round_to_float8(x, dtype, generator=None):
if dtype == torch.float8_e4m3fn:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
elif dtype == torch.float8_e5m2:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
else:
raise ValueError("Unsupported dtype")
x = x.half()
sign = torch.sign(x)
abs_x = x.abs()
sign = torch.where(abs_x == 0, 0, sign)
# Combine exponent calculation and clamping
exponent = torch.clamp(
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
0, 2**EXPONENT_BITS - 1
)
# Combine mantissa calculation and rounding
normal_mask = ~(exponent == 0)
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
sign *= torch.where(
normal_mask,
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
)
inf = torch.finfo(dtype)
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
return sign
def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.float32:
return value.to(dtype=torch.float32)
if dtype == torch.float16:
return value.to(dtype=torch.float16)
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))
for i in range(0, value.shape[0], slice_size):
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
return output
return value.to(dtype=dtype)

View File

@ -1,10 +1,54 @@
import math
import torch
from torch import nn
from .ldm.modules.attention import CrossAttention, FeedForward
from .ldm.modules.attention import CrossAttention
from inspect import isfunction
import comfy.ops
ops = comfy.ops.manual_cast
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = ops.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * torch.nn.functional.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
ops.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
ops.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class GatedCrossAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):

View File

@ -1,785 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable
import enum
import math
import torch
import numpy as np
import itertools
import logging
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher, PatcherInjection
from comfy.model_base import BaseModel
from comfy.sd import CLIP
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from node_helpers import conditioning_set_values
# #######################################################################################################
# Hooks explanation
# -------------------
# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to
# make explicit special cases like it does for ControlNet and GLIGEN.
#
# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those
# that should run special code when a 'marked' cond is used in sampling.
# #######################################################################################################
class EnumHookMode(enum.Enum):
'''
Priority of hook memory optimization vs. speed, mostly related to WeightHooks.
MinVram: No caching will occur for any operations related to hooks.
MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups.
'''
MinVram = "minvram"
MaxSpeed = "maxspeed"
class EnumHookType(enum.Enum):
'''
Hook types, each of which has different expected behavior.
'''
Weight = "weight"
ObjectPatch = "object_patch"
AdditionalModels = "add_models"
TransformerOptions = "transformer_options"
Injections = "add_injections"
class EnumWeightTarget(enum.Enum):
Model = "model"
Clip = "clip"
class EnumHookScope(enum.Enum):
'''
Determines if hook should be limited in its influence over sampling.
AllConditioning: hook will affect all conds used in sampling.
HookedOnly: hook will only affect the conds it was attached to.
'''
AllConditioning = "all_conditioning"
HookedOnly = "hooked_only"
class _HookRef:
pass
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
'''Example for how custom_should_register function can look like.'''
return True
def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]:
'''Creates base dictionary for use with Hooks' target param.'''
d = {}
if target is not None:
d['target'] = target
d.update(kwargs)
return d
class Hook:
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
self.hook_type = hook_type
'''Enum identifying the general class of this hook.'''
self.hook_ref = hook_ref if hook_ref else _HookRef()
'''Reference shared between hook clones that have the same value. Should NOT be modified.'''
self.hook_id = hook_id
'''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.'''
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
'''Keyframe storage that can be referenced to get strength for current sampling step.'''
self.hook_scope = hook_scope
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
self.custom_should_register = default_should_register
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
@property
def strength(self):
return self.hook_keyframe.strength
def initialize_timesteps(self, model: BaseModel):
self.reset()
self.hook_keyframe.initialize_timesteps(model)
def reset(self):
self.hook_keyframe.reset()
def clone(self):
c: Hook = self.__class__()
c.hook_type = self.hook_type
c.hook_ref = self.hook_ref
c.hook_id = self.hook_id
c.hook_keyframe = self.hook_keyframe
c.hook_scope = self.hook_scope
c.custom_should_register = self.custom_should_register
return c
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
return self.custom_should_register(self, model, model_options, target_dict, registered)
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
def __eq__(self, other: Hook):
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
def __hash__(self):
return hash(self.hook_ref)
class WeightHook(Hook):
'''
Hook responsible for tracking weights to be applied to some model/clip.
Note, value of hook_scope is ignored and is treated as HookedOnly.
'''
def __init__(self, strength_model=1.0, strength_clip=1.0):
super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly)
self.weights: dict = None
self.weights_clip: dict = None
self.need_weight_init = True
self._strength_model = strength_model
self._strength_clip = strength_clip
self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs
@property
def strength_model(self):
return self._strength_model * self.strength
@property
def strength_clip(self):
return self._strength_clip * self.strength
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
weights = None
target = target_dict.get('target', None)
if target == EnumWeightTarget.Clip:
strength = self._strength_clip
else:
strength = self._strength_model
if self.need_weight_init:
key_map = {}
if target == EnumWeightTarget.Clip:
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
else:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
else:
if target == EnumWeightTarget.Clip:
weights = self.weights_clip
else:
weights = self.weights
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
registered.add(self)
return True
# TODO: add logs about any keys that were not applied
def clone(self):
c: WeightHook = super().clone()
c.weights = self.weights
c.weights_clip = self.weights_clip
c.need_weight_init = self.need_weight_init
c._strength_model = self._strength_model
c._strength_clip = self._strength_clip
return c
class ObjectPatchHook(Hook):
def __init__(self, object_patches: dict[str]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.ObjectPatch)
self.object_patches = object_patches
self.hook_scope = hook_scope
def clone(self):
c: ObjectPatchHook = super().clone()
c.object_patches = self.object_patches
return c
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.")
class AdditionalModelsHook(Hook):
'''
Hook responsible for telling model management any additional models that should be loaded.
Note, value of hook_scope is ignored and is treated as AllConditioning.
'''
def __init__(self, models: list[ModelPatcher]=None, key: str=None):
super().__init__(hook_type=EnumHookType.AdditionalModels)
self.models = models
self.key = key
def clone(self):
c: AdditionalModelsHook = super().clone()
c.models = self.models.copy() if self.models else self.models
c.key = self.key
return c
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
registered.add(self)
return True
class TransformerOptionsHook(Hook):
'''
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
'''
def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.TransformerOptions)
self.transformers_dict = transformers_dict
self.hook_scope = hook_scope
self._skip_adding = False
'''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.'''
def clone(self):
c: TransformerOptionsHook = super().clone()
c.transformers_dict = self.transformers_dict
c._skip_adding = self._skip_adding
return c
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
self._skip_adding = False
if self.hook_scope == EnumHookScope.AllConditioning:
add_model_options = {"transformer_options": self.transformers_dict,
"to_load_options": self.transformers_dict}
# skip_adding if included in AllConditioning to avoid double loading
self._skip_adding = True
else:
add_model_options = {"to_load_options": self.transformers_dict}
registered.add(self)
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
return True
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
if not self._skip_adding:
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
WrapperHook = TransformerOptionsHook
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
class InjectionsHook(Hook):
def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.Injections)
self.key = key
self.injections = injections
self.hook_scope = hook_scope
def clone(self):
c: InjectionsHook = super().clone()
c.key = self.key
c.injections = self.injections.copy() if self.injections else self.injections
return c
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.")
class HookGroup:
'''
Stores groups of hooks, and allows them to be queried by type.
To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly;
always use the provided functions on HookGroup.
'''
def __init__(self):
self.hooks: list[Hook] = []
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
def __len__(self):
return len(self.hooks)
def add(self, hook: Hook):
if hook not in self.hooks:
self.hooks.append(hook)
self._hook_dict.setdefault(hook.hook_type, []).append(hook)
def remove(self, hook: Hook):
if hook in self.hooks:
self.hooks.remove(hook)
self._hook_dict[hook.hook_type].remove(hook)
def get_type(self, hook_type: EnumHookType):
return self._hook_dict.get(hook_type, [])
def contains(self, hook: Hook):
return hook in self.hooks
def is_subset_of(self, other: HookGroup):
self_hooks = set(self.hooks)
other_hooks = set(other.hooks)
return self_hooks.issubset(other_hooks)
def new_with_common_hooks(self, other: HookGroup):
c = HookGroup()
for hook in self.hooks:
if other.contains(hook):
c.add(hook.clone())
return c
def clone(self):
c = HookGroup()
for hook in self.hooks:
c.add(hook.clone())
return c
def clone_and_combine(self, other: HookGroup):
c = self.clone()
if other is not None:
for hook in other.hooks:
c.add(hook.clone())
return c
def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup):
if hook_kf is None:
hook_kf = HookKeyframeGroup()
else:
hook_kf = hook_kf.clone()
for hook in self.hooks:
hook.hook_keyframe = hook_kf
def get_hooks_for_clip_schedule(self):
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
# only care about WeightHooks, for now
for hook in self.get_type(EnumHookType.Weight):
hook: WeightHook
hook_schedule = []
# if no hook keyframes, assign default value
if len(hook.hook_keyframe.keyframes) == 0:
hook_schedule.append(((0.0, 1.0), None))
scheduled_hooks[hook] = hook_schedule
continue
# find ranges of values
prev_keyframe = hook.hook_keyframe.keyframes[0]
for keyframe in hook.hook_keyframe.keyframes:
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
prev_keyframe = keyframe
elif keyframe.start_percent == prev_keyframe.start_percent:
prev_keyframe = keyframe
# create final range, assuming last start_percent was not 1.0
if not math.isclose(prev_keyframe.start_percent, 1.0):
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
scheduled_hooks[hook] = hook_schedule
# hooks should not have their schedules in a list of tuples
all_ranges: list[tuple[float, float]] = []
for range_kfs in scheduled_hooks.values():
for t_range, keyframe in range_kfs:
all_ranges.append(t_range)
# turn list of ranges into boundaries
boundaries_set = set(itertools.chain.from_iterable(all_ranges))
boundaries_set.add(0.0)
boundaries = sorted(boundaries_set)
real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)]
# with real ranges defined, give appropriate hooks w/ keyframes for each range
scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = []
for t_range in real_ranges:
hooks_schedule = []
for hook, val in scheduled_hooks.items():
keyframe = None
# check if is a keyframe that works for the current t_range
for stored_range, stored_kf in val:
# if stored start is less than current end, then fits - give it assigned keyframe
if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]:
keyframe = stored_kf
break
hooks_schedule.append((hook, keyframe))
scheduled_keyframes.append((t_range, hooks_schedule))
return scheduled_keyframes
def reset(self):
for hook in self.hooks:
hook.reset()
@staticmethod
def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup:
actual: list[HookGroup] = []
for group in hooks_list:
if group is not None:
actual.append(group)
if len(actual) < require_count:
raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.")
# if no hooks, then return None
if len(actual) == 0:
return None
# if only 1 hook, just return itself without cloning
elif len(actual) == 1:
return actual[0]
final_hook: HookGroup = None
for hook in actual:
if final_hook is None:
final_hook = hook.clone()
else:
final_hook = final_hook.clone_and_combine(hook)
return final_hook
class HookKeyframe:
def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
self.strength = strength
# scheduling
self.start_percent = float(start_percent)
self.start_t = 999999999.9
self.guarantee_steps = guarantee_steps
def get_effective_guarantee_steps(self, max_sigma: torch.Tensor):
'''If keyframe starts before current sampling range (max_sigma), treat as 0.'''
if self.start_t > max_sigma:
return 0
return self.guarantee_steps
def clone(self):
c = HookKeyframe(strength=self.strength,
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
c.start_t = self.start_t
return c
class HookKeyframeGroup:
def __init__(self):
self.keyframes: list[HookKeyframe] = []
self._current_keyframe: HookKeyframe = None
self._current_used_steps = 0
self._current_index = 0
self._current_strength = None
self._curr_t = -1.
# properties shadow those of HookWeightsKeyframe
@property
def strength(self):
if self._current_keyframe is not None:
return self._current_keyframe.strength
return 1.0
def reset(self):
self._current_keyframe = None
self._current_used_steps = 0
self._current_index = 0
self._current_strength = None
self.curr_t = -1.
self._set_first_as_current()
def add(self, keyframe: HookKeyframe):
# add to end of list, then sort
self.keyframes.append(keyframe)
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
self._set_first_as_current()
def _set_first_as_current(self):
if len(self.keyframes) > 0:
self._current_keyframe = self.keyframes[0]
else:
self._current_keyframe = None
def has_guarantee_steps(self):
for kf in self.keyframes:
if kf.guarantee_steps > 0:
return True
return False
def has_index(self, index: int):
return index >= 0 and index < len(self.keyframes)
def is_empty(self):
return len(self.keyframes) == 0
def clone(self):
c = HookKeyframeGroup()
for keyframe in self.keyframes:
c.keyframes.append(keyframe.clone())
c._set_first_as_current()
return c
def initialize_timesteps(self, model: BaseModel):
for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool:
if self.is_empty():
return False
if curr_t == self._curr_t:
return False
max_sigma = torch.max(transformer_options["sample_sigmas"])
prev_index = self._current_index
prev_strength = self._current_strength
# if met guaranteed steps, look for next keyframe in case need to switch
if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma):
# if has next index, loop through and see if need to switch
if self.has_index(self._current_index+1):
for i in range(self._current_index+1, len(self.keyframes)):
eval_c = self.keyframes[i]
# check if start_t is greater or equal to curr_t
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
if eval_c.start_t >= curr_t:
self._current_index = i
self._current_strength = eval_c.strength
self._current_keyframe = eval_c
self._current_used_steps = 0
# if guarantee_steps greater than zero, stop searching for other keyframes
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
break
# if eval_c is outside the percent range, stop looking further
else: break
# update steps current context is used
self._current_used_steps += 1
# update current timestep this was performed on
self._curr_t = curr_t
# return True if keyframe changed, False if no change
return prev_index != self._current_index and prev_strength != self._current_strength
class InterpolationMethod:
LINEAR = "linear"
EASE_IN = "ease_in"
EASE_OUT = "ease_out"
EASE_IN_OUT = "ease_in_out"
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]
@classmethod
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
diff = num_to - num_from
if method == cls.LINEAR:
weights = torch.linspace(num_from, num_to, length)
elif method == cls.EASE_IN:
index = torch.linspace(0, 1, length)
weights = diff * np.power(index, 2) + num_from
elif method == cls.EASE_OUT:
index = torch.linspace(0, 1, length)
weights = diff * (1 - np.power(1 - index, 2)) + num_from
elif method == cls.EASE_IN_OUT:
index = torch.linspace(0, 1, length)
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
else:
raise ValueError(f"Unrecognized interpolation method '{method}'.")
if reverse:
weights = weights.flip(dims=(0,))
return weights
def get_sorted_list_via_attr(objects: list, attr: str) -> list:
if not objects:
return objects
elif len(objects) <= 1:
return [x for x in objects]
# now that we know we have to sort, do it following these rules:
# a) if objects have same value of attribute, maintain their relative order
# b) perform sorting of the groups of objects with same attributes
unique_attrs = {}
for o in objects:
val_attr = getattr(o, attr)
attr_list: list = unique_attrs.get(val_attr, list())
attr_list.append(o)
if val_attr not in unique_attrs:
unique_attrs[val_attr] = attr_list
# now that we have the unique attr values grouped together in relative order, sort them by key
sorted_attrs = dict(sorted(unique_attrs.items()))
# now flatten out the dict into a list to return
sorted_list = []
for object_list in sorted_attrs.values():
sorted_list.extend(object_list)
return sorted_list
def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None):
# if no hooks or is not a ModelPatcher for sampling, return empty dict
if hooks is None or model.is_clip:
return {}
if transformer_options is None:
transformer_options = {}
for hook in hooks.get_type(EnumHookType.TransformerOptions):
hook: TransformerOptionsHook
hook.on_apply_hooks(model, transformer_options)
return transformer_options
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
hook_group.add(hook)
hook.weights = lora
return hook_group
def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float):
hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
hook_group.add(hook)
patches_model = None
patches_clip = None
if weights_model is not None:
patches_model = {}
for key in weights_model:
patches_model[key] = ("model_as_lora", (weights_model[key],))
if weights_clip is not None:
patches_clip = {}
for key in weights_clip:
patches_clip[key] = ("model_as_lora", (weights_clip[key],))
hook.weights = patches_model
hook.weights_clip = patches_clip
hook.need_weight_init = False
return hook_group
def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True):
if model is None:
return None
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
if discard_model_sampling:
# do not include ANY model_sampling components of the model that should act as a patch
for key in list(patches_model.keys()):
if key.startswith("model_sampling"):
patches_model.pop(key, None)
return patches_model
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor],
strength_model: float, strength_clip: float):
key_map = {}
if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
hook_group = HookGroup()
hook = WeightHook()
hook_group.add(hook)
loaded: dict[str] = comfy.lora.load_lora(lora, key_map)
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model)
else:
k = ()
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
else:
k1 = ()
new_clip = None
k = set(k)
k1 = set(k1)
for x in loaded:
if (x not in k) and (x not in k1):
logging.warning(f"NOT LOADED {x}")
return (new_modelpatcher, new_clip, hook_group)
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
hooks_key = 'hooks'
# if hooks only exist in one dict, do what's needed so that it ends up in c_dict
if hooks_key not in values:
return
if hooks_key not in c_dict:
hooks_value = values.get(hooks_key, None)
if hooks_value is not None:
c_dict[hooks_key] = hooks_value
return
# otherwise, need to combine with minimum duplication via cache
hooks_tuple = (c_dict[hooks_key], values[hooks_key])
cached_hooks = cache.get(hooks_tuple, None)
if cached_hooks is None:
new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1])
cache[hooks_tuple] = new_hooks
c_dict[hooks_key] = new_hooks
else:
c_dict[hooks_key] = cache[hooks_tuple]
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True,
cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
c = []
if cache is None:
cache = {}
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
if append_hooks and k == 'hooks':
_combine_hooks_from_values(n[1], values, cache)
else:
n[1][k] = values[k]
c.append(n)
return c
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
if hooks is None:
return cond
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache)
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
if timestep_range is None:
return cond
return conditioning_set_values(cond, {"start_percent": timestep_range[0],
"end_percent": timestep_range[1]})
def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float):
if mask is None:
return cond
set_area_to_bounds = False
if set_cond_area != 'default':
set_area_to_bounds = True
if len(mask.shape) < 3:
mask = mask.unsqueeze(0)
return conditioning_set_values(cond, {'mask': mask,
'set_area_to_bounds': set_area_to_bounds,
'mask_strength': strength})
def combine_conditioning(conds: list):
combined_conds = []
for cond in conds:
combined_conds.extend(cond)
return combined_conds
def combine_with_new_conds(conds: list, new_conds: list):
combined_conds = []
for c, new_c in zip(conds, new_conds):
combined_conds.append(combine_conditioning([c, new_c]))
return combined_conds
def set_conds_props(conds: list, strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
final_conds = []
cache = {}
for c in conds:
# first, apply lora_hook to conditioning, if provided
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache)
# next, apply mask to conditioning
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
# apply timesteps, if present
c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range)
# finally, apply mask to conditioning and store
final_conds.append(c)
return final_conds
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
cache = {}
for c, masked_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache)
# next, apply mask to new conditioning, if provided
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
# apply timesteps, if present
masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range)
# finally, combine with existing conditioning and store
combined_conds.append(combine_conditioning([c, masked_c]))
return combined_conds
def set_default_conds_and_combine(conds: list, new_conds: list,
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
cache = {}
for c, new_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache)
# next, add default_cond key to cond so that during sampling, it can be identified
new_c = conditioning_set_values(new_c, {'default': True})
# apply timesteps, if present
new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range)
# finally, combine with existing conditioning and store
combined_conds.append(combine_conditioning([c, new_c]))
return combined_conds

View File

@ -1,160 +0,0 @@
import torch
from comfy.text_encoders.bert import BertAttention
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
class Dino2AttentionOutput(torch.nn.Module):
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
super().__init__()
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
def forward(self, x):
return self.dense(x)
class Dino2AttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
def forward(self, x, mask, optimized_attention):
return self.output(self.attention(x, mask, optimized_attention))
class LayerScale(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
class Dinov2MLP(torch.nn.Module):
def __init__(self, hidden_size: int, dtype, device, operations):
super().__init__()
mlp_ratio = 4
hidden_features = int(hidden_size * mlp_ratio)
self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.fc1(hidden_state)
hidden_state = torch.nn.functional.gelu(hidden_state)
hidden_state = self.fc2(hidden_state)
return hidden_state
class SwiGLUFFN(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
in_features = out_features = dim
hidden_features = int(dim * 4)
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
def forward(self, x):
x = self.weights_in(x)
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
return self.weights_out(x)
class Dino2Block(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
if use_swiglu_ffn:
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
else:
self.mlp = Dinov2MLP(dim, dtype, device, operations)
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, x, optimized_attention):
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
return x
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
super().__init__()
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
for _ in range(num_layers)])
def forward(self, x, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
for i, layer in enumerate(self.layer):
x = layer(x, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class Dino2PatchEmbeddings(torch.nn.Module):
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
super().__init__()
self.projection = operations.Conv2d(
in_channels=num_channels,
out_channels=dim,
kernel_size=patch_size,
stride=patch_size,
bias=True,
dtype=dtype,
device=device
)
def forward(self, pixel_values):
return self.projection(pixel_values).flatten(2).transpose(1, 2)
class Dino2Embeddings(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
patch_size = 14
image_size = 518
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
def forward(self, pixel_values):
x = self.patch_embeddings(pixel_values)
# TODO: mask_token?
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
return x
class Dinov2Model(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
num_layers = config_dict["num_hidden_layers"]
dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
x = self.embeddings(pixel_values)
x, i = self.encoder(x, intermediate_output=intermediate_output)
x = self.layernorm(x)
pooled_output = x[:, 0, :]
return x, i, pooled_output, None

View File

@ -1,21 +0,0 @@
{
"attention_probs_dropout_prob": 0.0,
"drop_path_rate": 0.0,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 1536,
"image_size": 518,
"initializer_range": 0.02,
"layer_norm_eps": 1e-06,
"layerscale_value": 1.0,
"mlp_ratio": 4,
"model_type": "dinov2",
"num_attention_heads": 24,
"num_channels": 3,
"num_hidden_layers": 40,
"patch_size": 14,
"qkv_bias": true,
"use_swiglu_ffn": true,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225]
}

View File

@ -1,22 +0,0 @@
{
"hidden_size": 1024,
"use_mask_token": true,
"patch_size": 14,
"image_size": 518,
"num_channels": 3,
"num_attention_heads": 16,
"initializer_range": 0.02,
"attention_probs_dropout_prob": 0.0,
"hidden_dropout_prob": 0.0,
"hidden_act": "gelu",
"mlp_ratio": 4,
"model_type": "dinov2",
"num_hidden_layers": 24,
"layer_norm_eps": 1e-6,
"qkv_bias": true,
"use_swiglu_ffn": false,
"layerscale_value": 1.0,
"drop_path_rate": 0.0,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225]
}

View File

@ -11,6 +11,7 @@ import numpy as np
# Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d

View File

@ -1,121 +0,0 @@
# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019)
# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf
# Codebase ref: https://github.com/scxue/SA-Solver
import math
from typing import Union, Callable
import torch
def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor:
"""Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts.
Integral of exp((1 + tau^2) * x) * x^p dx
= product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx,
with base case p=0 where integral equals product_terms[0].
where
product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2).
Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1).
Return coefficients used by the SA-Solver in data prediction mode.
Args:
s: Start time s.
t: End time t.
solver_order: Current order of the solver.
tau_t: Stochastic strength parameter in the SDE.
Returns:
Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order1, shape (solver_order,).
"""
tau_mul = 1 + tau_t ** 2
h = t - s
p = torch.arange(solver_order, dtype=s.dtype, device=s.device)
# product_terms after factoring out exp((1 + tau^2) * t)
# Includes (1 + tau^2) factor from outside the integral
product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp())
# Lower triangular recursive coefficient matrix
# Accumulates recursive coefficients based on p / (1 + tau^2)
recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0)
log_factorial = (p + 1).lgamma()
recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0)
if tau_t > 0:
recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul))
signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0)
recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril()
return recursive_coeff_mat @ product_terms_factored
def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor:
"""Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details)."""
tau_mul = 1 + tau_t ** 2
h = lambda_t - lambda_s
alpha_t = sigma_next * lambda_t.exp()
if is_corrector_step:
# Simplified 1-step (order-2) corrector
b_1 = alpha_t * (0.5 * tau_mul * h)
b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1
else:
# Simplified 2-step predictor
b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s)
b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2
return torch.stack([b_2, b_1])
def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor:
"""Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18).
The solver order corresponds to the number of input lambdas (half-logSNR points).
Args:
sigma_next: Sigma at end time t.
curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,).
lambda_s: Lambda at start time s.
lambda_t: Lambda at end time t.
tau_t: Stochastic strength parameter in the SDE.
simple_order_2: Whether to enable the simple order-2 scheme.
is_corrector_step: Flag for corrector step in simple order-2 mode.
Returns:
b_i coefficients for the SA-Solver, shape (N,), where N is the solver order.
"""
num_timesteps = curr_lambdas.shape[0]
if simple_order_2 and num_timesteps == 2:
return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step)
# Compute coefficients by solving a linear system from Lagrange basis interpolation
exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t)
vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T
lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs)
# (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t)
# = sigma_t * exp(lambda_t) = alpha_t
# exp((1 + tau^2) * lambda_t) is extracted from the integral
alpha_t = sigma_next * lambda_t.exp()
return alpha_t * lagrange_integrals
def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]:
"""Return a function that controls the stochasticity of SA-Solver.
When eta = 0, SA-Solver runs as ODE. The official approach uses
time t to determine the SDE interval, while here we use sigma instead.
See:
https://github.com/scxue/SA-Solver/blob/main/README.md
"""
def tau_func(sigma: Union[torch.Tensor, float]) -> float:
if eta <= 0:
return 0.0 # ODE
if isinstance(sigma, torch.Tensor):
sigma = sigma.item()
return eta if start_sigma >= sigma >= end_sigma else 0.0
return tau_func

File diff suppressed because it is too large Load Diff

View File

@ -3,9 +3,7 @@ import torch
class LatentFormat:
scale_factor = 1.0
latent_channels = 4
latent_dimensions = 2
latent_rgb_factors = None
latent_rgb_factors_bias = None
taesd_decoder_name = None
def process_in(self, latent):
@ -32,13 +30,11 @@ class SDXL(LatentFormat):
def __init__(self):
self.latent_rgb_factors = [
# R G B
[ 0.3651, 0.4232, 0.4341],
[-0.2533, -0.0042, 0.1068],
[ 0.1076, 0.1111, -0.0362],
[-0.3165, -0.2492, -0.2188]
[ 0.3920, 0.4054, 0.4549],
[-0.2634, -0.0196, 0.0653],
[ 0.0568, 0.1687, -0.0755],
[-0.3112, -0.2359, -0.2076]
]
self.latent_rgb_factors_bias = [ 0.1084, -0.0175, -0.0011]
self.taesd_decoder_name = "taesdxl_decoder"
class SDXL_Playground_2_5(LatentFormat):
@ -116,24 +112,23 @@ class SD3(LatentFormat):
self.scale_factor = 1.5305
self.shift_factor = 0.0609
self.latent_rgb_factors = [
[-0.0922, -0.0175, 0.0749],
[ 0.0311, 0.0633, 0.0954],
[ 0.1994, 0.0927, 0.0458],
[ 0.0856, 0.0339, 0.0902],
[ 0.0587, 0.0272, -0.0496],
[-0.0006, 0.1104, 0.0309],
[ 0.0978, 0.0306, 0.0427],
[-0.0042, 0.1038, 0.1358],
[-0.0194, 0.0020, 0.0669],
[-0.0488, 0.0130, -0.0268],
[ 0.0922, 0.0988, 0.0951],
[-0.0278, 0.0524, -0.0542],
[ 0.0332, 0.0456, 0.0895],
[-0.0069, -0.0030, -0.0810],
[-0.0596, -0.0465, -0.0293],
[-0.1448, -0.1463, -0.1189]
[-0.0645, 0.0177, 0.1052],
[ 0.0028, 0.0312, 0.0650],
[ 0.1848, 0.0762, 0.0360],
[ 0.0944, 0.0360, 0.0889],
[ 0.0897, 0.0506, -0.0364],
[-0.0020, 0.1203, 0.0284],
[ 0.0855, 0.0118, 0.0283],
[-0.0539, 0.0658, 0.1047],
[-0.0057, 0.0116, 0.0700],
[-0.0412, 0.0281, -0.0039],
[ 0.1106, 0.1171, 0.1220],
[-0.0248, 0.0682, -0.0481],
[ 0.0815, 0.0846, 0.1207],
[-0.0120, -0.0055, -0.0867],
[-0.0749, -0.0634, -0.0456],
[-0.1418, -0.1457, -0.1259]
]
self.latent_rgb_factors_bias = [0.2394, 0.2135, 0.1925]
self.taesd_decoder_name = "taesd3_decoder"
def process_in(self, latent):
@ -144,565 +139,3 @@ class SD3(LatentFormat):
class StableAudio1(LatentFormat):
latent_channels = 64
latent_dimensions = 1
class Flux(SD3):
latent_channels = 16
def __init__(self):
self.scale_factor = 0.3611
self.shift_factor = 0.1159
self.latent_rgb_factors =[
[-0.0346, 0.0244, 0.0681],
[ 0.0034, 0.0210, 0.0687],
[ 0.0275, -0.0668, -0.0433],
[-0.0174, 0.0160, 0.0617],
[ 0.0859, 0.0721, 0.0329],
[ 0.0004, 0.0383, 0.0115],
[ 0.0405, 0.0861, 0.0915],
[-0.0236, -0.0185, -0.0259],
[-0.0245, 0.0250, 0.1180],
[ 0.1008, 0.0755, -0.0421],
[-0.0515, 0.0201, 0.0011],
[ 0.0428, -0.0012, -0.0036],
[ 0.0817, 0.0765, 0.0749],
[-0.1264, -0.0522, -0.1103],
[-0.0280, -0.0881, -0.0499],
[-0.1262, -0.0982, -0.0778]
]
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.taesd_decoder_name = "taef1_decoder"
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
class Mochi(LatentFormat):
latent_channels = 12
latent_dimensions = 3
def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([-0.06730895953510081, -0.038011381506090416, -0.07477820912866141,
-0.05565264470995561, 0.012767231469026969, -0.04703542746246419,
0.043896967884726704, -0.09346305707025976, -0.09918314763016893,
-0.008729793427399178, -0.011931556316503654, -0.0321993391887285]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([0.9263795028493863, 0.9248894543193766, 0.9393059390890617,
0.959253732819592, 0.8244560132752793, 0.917259975397747,
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
self.latent_rgb_factors =[
[-0.0069, -0.0045, 0.0018],
[ 0.0154, -0.0692, -0.0274],
[ 0.0333, 0.0019, 0.0206],
[-0.1390, 0.0628, 0.1678],
[-0.0725, 0.0134, -0.1898],
[ 0.0074, -0.0270, -0.0209],
[-0.0176, -0.0277, -0.0221],
[ 0.5294, 0.5204, 0.3852],
[-0.0326, -0.0446, -0.0143],
[-0.0659, 0.0153, -0.0153],
[ 0.0185, -0.0217, 0.0014],
[-0.0396, -0.0495, -0.0281]
]
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
self.taesd_decoder_name = None #TODO
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std
def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class LTXV(LatentFormat):
latent_channels = 128
latent_dimensions = 3
def __init__(self):
self.latent_rgb_factors = [
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
[ 8.6031e-02, 6.5813e-02, 9.5409e-04],
[-1.2576e-02, -7.5734e-03, -4.0528e-03],
[ 9.4063e-03, -2.1688e-03, 2.6093e-03],
[ 3.7636e-03, 1.2765e-02, 9.1548e-03],
[ 2.1024e-02, -5.2973e-03, 3.4373e-03],
[-8.8896e-03, -1.9703e-02, -1.8761e-02],
[-1.3160e-02, -1.0523e-02, 1.9709e-03],
[-1.5152e-03, -6.9891e-03, -7.5810e-03],
[-1.7247e-03, 4.6560e-04, -3.3839e-03],
[ 1.3617e-02, 4.7077e-03, -2.0045e-03],
[ 1.0256e-02, 7.7318e-03, 1.3948e-02],
[-1.6108e-02, -6.2151e-03, 1.1561e-03],
[ 7.3407e-03, 1.5628e-02, 4.4865e-04],
[ 9.5357e-04, -2.9518e-03, -1.4760e-02],
[ 1.9143e-02, 1.0868e-02, 1.2264e-02],
[ 4.4575e-03, 3.6682e-05, -6.8508e-03],
[-4.5681e-04, 3.2570e-03, 7.7929e-03],
[ 3.3902e-02, 3.3405e-02, 3.7454e-02],
[-2.3001e-02, -2.4877e-03, -3.1033e-03],
[ 5.0265e-02, 3.8841e-02, 3.3539e-02],
[-4.1018e-03, -1.1095e-03, 1.5859e-03],
[-1.2689e-01, -1.3107e-01, -2.1005e-01],
[ 2.6276e-02, 1.4189e-02, -3.5963e-03],
[-4.8679e-03, 8.8486e-03, 7.8029e-03],
[-1.6610e-03, -4.8597e-03, -5.2060e-03],
[-2.1010e-03, 2.3610e-03, 9.3796e-03],
[-2.2482e-02, -2.1305e-02, -1.5087e-02],
[-1.5753e-02, -1.0646e-02, -6.5083e-03],
[-4.6975e-03, 5.0288e-03, -6.7390e-03],
[ 1.1951e-02, 2.0712e-02, 1.6191e-02],
[-6.3704e-03, -8.4827e-03, -9.5483e-03],
[ 7.2610e-03, -9.9326e-03, -2.2978e-02],
[-9.1904e-04, 6.2882e-03, 9.5720e-03],
[-3.7178e-02, -3.7123e-02, -5.6713e-02],
[-1.3373e-01, -1.0720e-01, -5.3801e-02],
[-5.3702e-03, 8.1256e-03, 8.8397e-03],
[-1.5247e-01, -2.1437e-01, -2.1843e-01],
[ 3.1441e-02, 7.0335e-03, -9.7541e-03],
[ 2.1528e-03, -8.9817e-03, -2.1023e-02],
[ 3.8461e-03, -5.8957e-03, -1.5014e-02],
[-4.3470e-03, -1.2940e-02, -1.5972e-02],
[-5.4781e-03, -1.0842e-02, -3.0204e-03],
[-6.5347e-03, 3.0806e-03, -1.0163e-02],
[-5.0414e-03, -7.1503e-03, -8.9686e-04],
[-8.5851e-03, -2.4351e-03, 1.0674e-03],
[-9.0016e-03, -9.6493e-03, 1.5692e-03],
[ 5.0914e-03, 1.2099e-02, 1.9968e-02],
[ 1.3758e-02, 1.1669e-02, 8.1958e-03],
[-1.0518e-02, -1.1575e-02, -4.1307e-03],
[-2.8410e-02, -3.1266e-02, -2.2149e-02],
[ 2.9336e-03, 3.6511e-02, 1.8717e-02],
[-1.6703e-02, -1.6696e-02, -4.4529e-03],
[ 4.8818e-02, 4.0063e-02, 8.7410e-03],
[-1.5066e-02, -5.7328e-04, 2.9785e-03],
[-1.7613e-02, -8.1034e-03, 1.3086e-02],
[-9.2633e-03, 1.0803e-02, -6.3489e-03],
[ 3.0851e-03, 4.7750e-04, 1.2347e-02],
[-2.2785e-02, -2.3043e-02, -2.6005e-02],
[-2.4787e-02, -1.5389e-02, -2.2104e-02],
[-2.3572e-02, 1.0544e-03, 1.2361e-02],
[-7.8915e-03, -1.2271e-03, -6.0968e-03],
[-1.1478e-02, -1.2543e-03, 6.2679e-03],
[-5.4229e-02, 2.6644e-02, 6.3394e-03],
[ 4.4216e-03, -7.3338e-03, -1.0464e-02],
[-4.5013e-03, 1.6082e-03, 1.4420e-02],
[ 1.3673e-02, 8.8877e-03, 4.1253e-03],
[-1.0145e-02, 9.0072e-03, 1.5695e-02],
[-5.6234e-03, 1.1847e-03, 8.1261e-03],
[-3.7171e-03, -5.3538e-03, 1.2590e-03],
[ 2.9476e-02, 2.1424e-02, 3.0424e-02],
[-3.4925e-02, -2.4340e-02, -2.5316e-02],
[-3.4127e-02, -2.2406e-02, -1.0589e-02],
[-1.7342e-02, -1.3249e-02, -1.0719e-02],
[-2.1478e-03, -8.6051e-03, -2.9878e-03],
[ 1.2089e-03, -4.2391e-03, -6.8569e-03],
[ 9.0411e-04, -6.6886e-03, -6.7547e-05],
[ 1.6048e-02, -1.0057e-02, -2.8929e-02],
[ 1.2290e-03, 1.0163e-02, 1.8861e-02],
[ 1.7264e-02, 2.7257e-04, 1.3785e-02],
[-1.3482e-02, -3.6427e-03, 6.7481e-04],
[ 4.6782e-03, -5.2423e-03, 2.4467e-03],
[-5.9113e-03, -6.2244e-03, -1.8162e-03],
[ 1.5496e-02, 1.4582e-02, 1.9514e-03],
[ 7.4958e-03, 1.5886e-03, -8.2305e-03],
[ 1.9086e-02, 1.6360e-03, -3.9674e-03],
[-5.7021e-03, -2.7307e-03, -4.1066e-03],
[ 1.7450e-03, 1.4602e-02, 2.5794e-02],
[-8.2788e-04, 2.2902e-03, 4.5161e-03],
[ 1.1632e-02, 8.9193e-03, -7.2813e-03],
[ 7.5721e-03, 2.6784e-03, 1.1393e-02],
[ 5.1939e-03, 3.6903e-03, 1.4049e-02],
[-1.8383e-02, -2.2529e-02, -2.4477e-02],
[ 5.8842e-04, -5.7874e-03, -1.4770e-02],
[-1.6125e-02, -8.6101e-03, -1.4533e-02],
[ 2.0540e-02, 2.0729e-02, 6.4338e-03],
[ 3.3587e-03, -1.1226e-02, -1.6444e-02],
[-1.4742e-03, -1.0489e-02, 1.7097e-03],
[ 2.8130e-02, 2.3546e-02, 3.2791e-02],
[-1.8532e-02, -1.2842e-02, -8.7756e-03],
[-8.0533e-03, -1.0771e-02, -1.7536e-02],
[-3.9009e-03, 1.6150e-02, 3.3359e-02],
[-7.4554e-03, -1.4154e-02, -6.1910e-03],
[ 3.4734e-03, -1.1370e-02, -1.0581e-02],
[ 1.1476e-02, 3.9281e-03, 2.8231e-03],
[ 7.1639e-03, -1.4741e-03, -3.8066e-03],
[ 2.2250e-03, -8.7552e-03, -9.5719e-03],
[ 2.4146e-02, 2.1696e-02, 2.8056e-02],
[-5.4365e-03, -2.4291e-02, -1.7802e-02],
[ 7.4263e-03, 1.0510e-02, 1.2705e-02],
[ 6.2669e-03, 6.2658e-03, 1.9211e-02],
[ 1.6378e-02, 9.4933e-03, 6.6971e-03],
[ 1.7173e-02, 2.3601e-02, 2.3296e-02],
[-1.4568e-02, -9.8279e-03, -1.1556e-02],
[ 1.4431e-02, 1.4430e-02, 6.6362e-03],
[-6.8230e-03, 1.8863e-02, 1.4555e-02],
[ 6.1156e-03, 3.4700e-03, -2.6662e-03],
[-2.6983e-03, -5.9402e-03, -9.2276e-03],
[ 1.0235e-02, 7.4173e-03, -7.6243e-03],
[-1.3255e-02, 1.9322e-02, -9.2153e-04],
[ 2.4222e-03, -4.8039e-03, -1.5759e-02],
[ 2.6244e-02, 2.5951e-02, 2.0249e-02],
[ 1.5711e-02, 1.8498e-02, 2.7407e-03],
[-2.1714e-03, 4.7214e-03, -2.2443e-02],
[-7.4747e-03, 7.4166e-03, 1.4430e-02],
[-8.3906e-03, -7.9776e-03, 9.7927e-03],
[ 3.8321e-02, 9.6622e-03, -1.9268e-02],
[-1.4605e-02, -6.7032e-03, 3.9675e-03]
]
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
class HunyuanVideo(LatentFormat):
latent_channels = 16
latent_dimensions = 3
scale_factor = 0.476986
latent_rgb_factors = [
[-0.0395, -0.0331, 0.0445],
[ 0.0696, 0.0795, 0.0518],
[ 0.0135, -0.0945, -0.0282],
[ 0.0108, -0.0250, -0.0765],
[-0.0209, 0.0032, 0.0224],
[-0.0804, -0.0254, -0.0639],
[-0.0991, 0.0271, -0.0669],
[-0.0646, -0.0422, -0.0400],
[-0.0696, -0.0595, -0.0894],
[-0.0799, -0.0208, -0.0375],
[ 0.1166, 0.1627, 0.0962],
[ 0.1165, 0.0432, 0.0407],
[-0.2315, -0.1920, -0.1355],
[-0.0270, 0.0401, -0.0821],
[-0.0616, -0.0997, -0.0727],
[ 0.0249, -0.0469, -0.1703]
]
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
class Cosmos1CV8x8x8(LatentFormat):
latent_channels = 16
latent_dimensions = 3
latent_rgb_factors = [
[ 0.1817, 0.2284, 0.2423],
[-0.0586, -0.0862, -0.3108],
[-0.4703, -0.4255, -0.3995],
[ 0.0803, 0.1963, 0.1001],
[-0.0820, -0.1050, 0.0400],
[ 0.2511, 0.3098, 0.2787],
[-0.1830, -0.2117, -0.0040],
[-0.0621, -0.2187, -0.0939],
[ 0.3619, 0.1082, 0.1455],
[ 0.3164, 0.3922, 0.2575],
[ 0.1152, 0.0231, -0.0462],
[-0.1434, -0.3609, -0.3665],
[ 0.0635, 0.1471, 0.1680],
[-0.3635, -0.1963, -0.3248],
[-0.1865, 0.0365, 0.2346],
[ 0.0447, 0.0994, 0.0881]
]
latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976]
class Wan21(LatentFormat):
latent_channels = 16
latent_dimensions = 3
latent_rgb_factors = [
[-0.1299, -0.1692, 0.2932],
[ 0.0671, 0.0406, 0.0442],
[ 0.3568, 0.2548, 0.1747],
[ 0.0372, 0.2344, 0.1420],
[ 0.0313, 0.0189, -0.0328],
[ 0.0296, -0.0956, -0.0665],
[-0.3477, -0.4059, -0.2925],
[ 0.0166, 0.1902, 0.1975],
[-0.0412, 0.0267, -0.1364],
[-0.1293, 0.0740, 0.1636],
[ 0.0680, 0.3019, 0.1128],
[ 0.0032, 0.0581, 0.0639],
[-0.1251, 0.0927, 0.1699],
[ 0.0060, -0.0633, 0.0005],
[ 0.3477, 0.2275, 0.2950],
[ 0.1984, 0.0913, 0.1861]
]
latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360]
def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]).view(1, self.latent_channels, 1, 1, 1)
self.taesd_decoder_name = None #TODO
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std
def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class Wan22(Wan21):
latent_channels = 48
latent_dimensions = 3
latent_rgb_factors = [
[ 0.0119, 0.0103, 0.0046],
[-0.1062, -0.0504, 0.0165],
[ 0.0140, 0.0409, 0.0491],
[-0.0813, -0.0677, 0.0607],
[ 0.0656, 0.0851, 0.0808],
[ 0.0264, 0.0463, 0.0912],
[ 0.0295, 0.0326, 0.0590],
[-0.0244, -0.0270, 0.0025],
[ 0.0443, -0.0102, 0.0288],
[-0.0465, -0.0090, -0.0205],
[ 0.0359, 0.0236, 0.0082],
[-0.0776, 0.0854, 0.1048],
[ 0.0564, 0.0264, 0.0561],
[ 0.0006, 0.0594, 0.0418],
[-0.0319, -0.0542, -0.0637],
[-0.0268, 0.0024, 0.0260],
[ 0.0539, 0.0265, 0.0358],
[-0.0359, -0.0312, -0.0287],
[-0.0285, -0.1032, -0.1237],
[ 0.1041, 0.0537, 0.0622],
[-0.0086, -0.0374, -0.0051],
[ 0.0390, 0.0670, 0.2863],
[ 0.0069, 0.0144, 0.0082],
[ 0.0006, -0.0167, 0.0079],
[ 0.0313, -0.0574, -0.0232],
[-0.1454, -0.0902, -0.0481],
[ 0.0714, 0.0827, 0.0447],
[-0.0304, -0.0574, -0.0196],
[ 0.0401, 0.0384, 0.0204],
[-0.0758, -0.0297, -0.0014],
[ 0.0568, 0.1307, 0.1372],
[-0.0055, -0.0310, -0.0380],
[ 0.0239, -0.0305, 0.0325],
[-0.0663, -0.0673, -0.0140],
[-0.0416, -0.0047, -0.0023],
[ 0.0166, 0.0112, -0.0093],
[-0.0211, 0.0011, 0.0331],
[ 0.1833, 0.1466, 0.2250],
[-0.0368, 0.0370, 0.0295],
[-0.3441, -0.3543, -0.2008],
[-0.0479, -0.0489, -0.0420],
[-0.0660, -0.0153, 0.0800],
[-0.0101, 0.0068, 0.0156],
[-0.0690, -0.0452, -0.0927],
[-0.0145, 0.0041, 0.0015],
[ 0.0421, 0.0451, 0.0373],
[ 0.0504, -0.0483, -0.0356],
[-0.0837, 0.0168, 0.0055]
]
latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388]
def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
]).view(1, self.latent_channels, 1, 1, 1)
class HunyuanImage21(LatentFormat):
latent_channels = 64
latent_dimensions = 2
scale_factor = 0.75289
latent_rgb_factors = [
[-0.0154, -0.0397, -0.0521],
[ 0.0005, 0.0093, 0.0006],
[-0.0805, -0.0773, -0.0586],
[-0.0494, -0.0487, -0.0498],
[-0.0212, -0.0076, -0.0261],
[-0.0179, -0.0417, -0.0505],
[ 0.0158, 0.0310, 0.0239],
[ 0.0409, 0.0516, 0.0201],
[ 0.0350, 0.0553, 0.0036],
[-0.0447, -0.0327, -0.0479],
[-0.0038, -0.0221, -0.0365],
[-0.0423, -0.0718, -0.0654],
[ 0.0039, 0.0368, 0.0104],
[ 0.0655, 0.0217, 0.0122],
[ 0.0490, 0.1638, 0.2053],
[ 0.0932, 0.0829, 0.0650],
[-0.0186, -0.0209, -0.0135],
[-0.0080, -0.0076, -0.0148],
[-0.0284, -0.0201, 0.0011],
[-0.0642, -0.0294, -0.0777],
[-0.0035, 0.0076, -0.0140],
[ 0.0519, 0.0731, 0.0887],
[-0.0102, 0.0095, 0.0704],
[ 0.0068, 0.0218, -0.0023],
[-0.0726, -0.0486, -0.0519],
[ 0.0260, 0.0295, 0.0263],
[ 0.0250, 0.0333, 0.0341],
[ 0.0168, -0.0120, -0.0174],
[ 0.0226, 0.1037, 0.0114],
[ 0.2577, 0.1906, 0.1604],
[-0.0646, -0.0137, -0.0018],
[-0.0112, 0.0309, 0.0358],
[-0.0347, 0.0146, -0.0481],
[ 0.0234, 0.0179, 0.0201],
[ 0.0157, 0.0313, 0.0225],
[ 0.0423, 0.0675, 0.0524],
[-0.0031, 0.0027, -0.0255],
[ 0.0447, 0.0555, 0.0330],
[-0.0152, 0.0103, 0.0299],
[-0.0755, -0.0489, -0.0635],
[ 0.0853, 0.0788, 0.1017],
[-0.0272, -0.0294, -0.0471],
[ 0.0440, 0.0400, -0.0137],
[ 0.0335, 0.0317, -0.0036],
[-0.0344, -0.0621, -0.0984],
[-0.0127, -0.0630, -0.0620],
[-0.0648, 0.0360, 0.0924],
[-0.0781, -0.0801, -0.0409],
[ 0.0363, 0.0613, 0.0499],
[ 0.0238, 0.0034, 0.0041],
[-0.0135, 0.0258, 0.0310],
[ 0.0614, 0.1086, 0.0589],
[ 0.0428, 0.0350, 0.0205],
[ 0.0153, 0.0173, -0.0018],
[-0.0288, -0.0455, -0.0091],
[ 0.0344, 0.0109, -0.0157],
[-0.0205, -0.0247, -0.0187],
[ 0.0487, 0.0126, 0.0064],
[-0.0220, -0.0013, 0.0074],
[-0.0203, -0.0094, -0.0048],
[-0.0719, 0.0429, -0.0442],
[ 0.1042, 0.0497, 0.0356],
[-0.0659, -0.0578, -0.0280],
[-0.0060, -0.0322, -0.0234]]
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
class HunyuanImage21Refiner(LatentFormat):
latent_channels = 64
latent_dimensions = 3
scale_factor = 1.03682
def process_in(self, latent):
out = latent * self.scale_factor
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
def process_out(self, latent):
z = latent / self.scale_factor
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
return z
class HunyuanVideo15(LatentFormat):
latent_rgb_factors = [
[ 0.0568, -0.0521, -0.0131],
[ 0.0014, 0.0735, 0.0326],
[ 0.0186, 0.0531, -0.0138],
[-0.0031, 0.0051, 0.0288],
[ 0.0110, 0.0556, 0.0432],
[-0.0041, -0.0023, -0.0485],
[ 0.0530, 0.0413, 0.0253],
[ 0.0283, 0.0251, 0.0339],
[ 0.0277, -0.0372, -0.0093],
[ 0.0393, 0.0944, 0.1131],
[ 0.0020, 0.0251, 0.0037],
[-0.0017, 0.0012, 0.0234],
[ 0.0468, 0.0436, 0.0203],
[ 0.0354, 0.0439, -0.0233],
[ 0.0090, 0.0123, 0.0346],
[ 0.0382, 0.0029, 0.0217],
[ 0.0261, -0.0300, 0.0030],
[-0.0088, -0.0220, -0.0283],
[-0.0272, -0.0121, -0.0363],
[-0.0664, -0.0622, 0.0144],
[ 0.0414, 0.0479, 0.0529],
[ 0.0355, 0.0612, -0.0247],
[ 0.0147, 0.0264, 0.0174],
[ 0.0438, 0.0038, 0.0542],
[ 0.0431, -0.0573, -0.0033],
[-0.0162, -0.0211, -0.0406],
[-0.0487, -0.0295, -0.0393],
[ 0.0005, -0.0109, 0.0253],
[ 0.0296, 0.0591, 0.0353],
[ 0.0119, 0.0181, -0.0306],
[-0.0085, -0.0362, 0.0229],
[ 0.0005, -0.0106, 0.0242]
]
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
latent_channels = 32
latent_dimensions = 3
scale_factor = 1.03682
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 0.9990943042622529
class Hunyuan3Dv2_1(LatentFormat):
scale_factor = 1.0039506158752403
latent_channels = 64
latent_dimensions = 1
class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 1.0188137142395404
class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2
class ChromaRadiance(LatentFormat):
latent_channels = 3
def __init__(self):
self.latent_rgb_factors = [
# R G B
[ 1.0, 0.0, 0.0 ],
[ 0.0, 1.0, 0.0 ],
[ 0.0, 0.0, 1.0 ]
]
def process_in(self, latent):
return latent
def process_out(self, latent):
return latent

View File

@ -1,768 +0,0 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/attention.py
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union, Optional
import torch
import torch.nn.functional as F
from torch import nn
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
kv_heads: Optional[int] = None,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
processor=None,
out_dim: int = None,
out_context_dim: int = None,
context_pre_only=None,
pre_only=False,
elementwise_affine: bool = True,
is_causal: bool = False,
dtype=None, device=None, operations=None
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.use_bias = bias
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.is_causal = is_causal
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.only_cross_attention = only_cross_attention
if self.added_kv_proj_dim is None and self.only_cross_attention:
raise ValueError(
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
)
self.group_norm = None
self.spatial_norm = None
self.norm_q = None
self.norm_k = None
self.norm_cross = None
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
else:
self.to_k = None
self.to_v = None
self.added_proj_bias = added_proj_bias
if self.added_kv_proj_dim is not None:
self.add_k_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
self.add_v_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
if self.context_pre_only is not None:
self.add_q_proj = operations.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias, dtype=dtype, device=device)
else:
self.add_q_proj = None
self.add_k_proj = None
self.add_v_proj = None
if not self.pre_only:
self.to_out = nn.ModuleList([])
self.to_out.append(operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device))
self.to_out.append(nn.Dropout(dropout))
else:
self.to_out = None
if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
else:
self.to_add_out = None
self.norm_added_q = None
self.norm_added_k = None
self.processor = processor
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
transformer_options={},
**cross_attention_kwargs,
) -> torch.Tensor:
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
transformer_options=transformer_options,
**cross_attention_kwargs,
)
class CustomLiteLAProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
def __init__(self):
self.kernel_func = nn.ReLU(inplace=False)
self.eps = 1e-15
self.pad_val = 1.0
def apply_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
hidden_states_len = hidden_states.shape[1]
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
if encoder_hidden_states is not None:
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = hidden_states.shape[0]
# `sample` projections.
dtype = hidden_states.dtype
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# `context` projections.
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# attention
if not attn.is_cross_attention:
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
else:
query = hidden_states
key = encoder_hidden_states
value = encoder_hidden_states
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
# RoPE需要 [B, H, S, D] 输入
# 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S])
# Apply query and key normalization if needed
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if rotary_freqs_cis is not None:
query = self.apply_rotary_emb(query, rotary_freqs_cis)
if not attn.is_cross_attention:
key = self.apply_rotary_emb(key, rotary_freqs_cis)
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
# 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
query = query.permute(0, 1, 3, 2) # [B, H, D, S]
if attention_mask is not None:
# attention_mask: [B, S] -> [B, 1, S, 1]
attention_mask = attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S, 1]
query = query * attention_mask.permute(0, 1, 3, 2) # [B, H, S, D] * [B, 1, S, 1]
if not attn.is_cross_attention:
key = key * attention_mask # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
value = value * attention_mask.permute(0, 1, 3, 2) # 如果 value 是 [B, h, D, S]那么需调整mask以匹配S维度
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S_enc, 1]
# 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1]
value = value * encoder_attention_mask.permute(0, 1, 3, 2) # [B, h, D, S_enc] * [B, 1, 1, S_enc]
query = self.kernel_func(query)
key = self.kernel_func(key)
query, key, value = query.float(), key.float(), value.float()
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
vk = torch.matmul(value, key)
hidden_states = torch.matmul(vk, query)
if hidden_states.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.float()
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
hidden_states = hidden_states.to(dtype)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.to(dtype)
# Split the attention outputs.
if encoder_hidden_states is not None and not attn.is_cross_attention and has_encoder_hidden_state_proj:
hidden_states, encoder_hidden_states = (
hidden_states[:, : hidden_states_len],
hidden_states[:, hidden_states_len:],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None and not attn.context_pre_only and not attn.is_cross_attention and hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if encoder_hidden_states is not None and context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if torch.get_autocast_gpu_dtype() == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return hidden_states, encoder_hidden_states
class CustomerAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def apply_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
transformer_options={},
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if rotary_freqs_cis is not None:
query = self.apply_rotary_emb(query, rotary_freqs_cis)
if not attn.is_cross_attention:
key = self.apply_rotary_emb(key, rotary_freqs_cis)
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
# attention_mask: N x S1
# encoder_attention_mask: N x S2
# cross attention 整合attention_mask和encoder_attention_mask
combined_mask = attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
attention_mask = attention_mask[:, None, :, :].expand(-1, attn.heads, -1, -1).to(query.dtype)
elif not attn.is_cross_attention and attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = optimized_attention(
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
).to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
"""Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
if isinstance(x, (list, tuple)):
return list(x)
return [x for _ in range(repeat_time)]
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
"""Return tuple with min_len by repeating element at idx_repeat."""
# convert to list first
x = val2list(x)
# repeat elements if necessary
if len(x) > 0:
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
return tuple(x)
def t2i_modulate(x, shift, scale):
return x * (1 + scale) + shift
def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
if isinstance(kernel_size, tuple):
return tuple([get_same_padding(ks) for ks in kernel_size])
else:
assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
return kernel_size // 2
class ConvLayer(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
kernel_size=3,
stride=1,
dilation=1,
groups=1,
padding: Union[int, None] = None,
use_bias=False,
norm=None,
act=None,
dtype=None, device=None, operations=None
):
super().__init__()
if padding is None:
padding = get_same_padding(kernel_size)
padding *= dilation
self.in_dim = in_dim
self.out_dim = out_dim
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.groups = groups
self.padding = padding
self.use_bias = use_bias
self.conv = operations.Conv1d(
in_dim,
out_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=use_bias,
device=device,
dtype=dtype
)
if norm is not None:
self.norm = operations.RMSNorm(out_dim, elementwise_affine=False, dtype=dtype, device=device)
else:
self.norm = None
if act is not None:
self.act = nn.SiLU(inplace=True)
else:
self.act = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
if self.norm:
x = self.norm(x)
if self.act:
x = self.act(x)
return x
class GLUMBConv(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int,
out_feature=None,
kernel_size=3,
stride=1,
padding: Union[int, None] = None,
use_bias=False,
norm=(None, None, None),
act=("silu", "silu", None),
dilation=1,
dtype=None, device=None, operations=None
):
out_feature = out_feature or in_features
super().__init__()
use_bias = val2tuple(use_bias, 3)
norm = val2tuple(norm, 3)
act = val2tuple(act, 3)
self.glu_act = nn.SiLU(inplace=False)
self.inverted_conv = ConvLayer(
in_features,
hidden_features * 2,
1,
use_bias=use_bias[0],
norm=norm[0],
act=act[0],
dtype=dtype,
device=device,
operations=operations,
)
self.depth_conv = ConvLayer(
hidden_features * 2,
hidden_features * 2,
kernel_size,
stride=stride,
groups=hidden_features * 2,
padding=padding,
use_bias=use_bias[1],
norm=norm[1],
act=None,
dilation=dilation,
dtype=dtype,
device=device,
operations=operations,
)
self.point_conv = ConvLayer(
hidden_features,
out_feature,
1,
use_bias=use_bias[2],
norm=norm[2],
act=act[2],
dtype=dtype,
device=device,
operations=operations,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.transpose(1, 2)
x = self.inverted_conv(x)
x = self.depth_conv(x)
x, gate = torch.chunk(x, 2, dim=1)
gate = self.glu_act(gate)
x = x * gate
x = self.point_conv(x)
x = x.transpose(1, 2)
return x
class LinearTransformerBlock(nn.Module):
"""
A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
use_adaln_single=True,
cross_attention_dim=None,
added_kv_proj_dim=None,
context_pre_only=False,
mlp_ratio=4.0,
add_cross_attention=False,
add_cross_attention_dim=None,
qk_norm=None,
dtype=None, device=None, operations=None
):
super().__init__()
self.norm1 = operations.RMSNorm(dim, elementwise_affine=False, eps=1e-6)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
added_kv_proj_dim=added_kv_proj_dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
qk_norm=qk_norm,
processor=CustomLiteLAProcessor2_0(),
dtype=dtype,
device=device,
operations=operations,
)
self.add_cross_attention = add_cross_attention
self.context_pre_only = context_pre_only
if add_cross_attention and add_cross_attention_dim is not None:
self.cross_attn = Attention(
query_dim=dim,
cross_attention_dim=add_cross_attention_dim,
added_kv_proj_dim=add_cross_attention_dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=context_pre_only,
bias=True,
qk_norm=qk_norm,
processor=CustomerAttnProcessor2_0(),
dtype=dtype,
device=device,
operations=operations,
)
self.norm2 = operations.RMSNorm(dim, 1e-06, elementwise_affine=False)
self.ff = GLUMBConv(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
use_bias=(True, True, False),
norm=(None, None, None),
act=("silu", "silu", None),
dtype=dtype,
device=device,
operations=operations,
)
self.use_adaln_single = use_adaln_single
if use_adaln_single:
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, dtype=dtype, device=device))
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: torch.FloatTensor = None,
encoder_attention_mask: torch.FloatTensor = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
temb: torch.FloatTensor = None,
transformer_options={},
):
N = hidden_states.shape[0]
# step 1: AdaLN single
if self.use_adaln_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
comfy.model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
if self.use_adaln_single:
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
# step 2: attention
if not self.add_cross_attention:
attn_output, encoder_hidden_states = self.attn(
hidden_states=norm_hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
transformer_options=transformer_options,
)
else:
attn_output, _ = self.attn(
hidden_states=norm_hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=None,
transformer_options=transformer_options,
)
if self.use_adaln_single:
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if self.add_cross_attention:
attn_output = self.cross_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
transformer_options=transformer_options,
)
hidden_states = attn_output + hidden_states
# step 3: add norm
norm_hidden_states = self.norm2(hidden_states)
if self.use_adaln_single:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
# step 4: feed forward
ff_output = self.ff(norm_hidden_states)
if self.use_adaln_single:
ff_output = gate_mlp * ff_output
hidden_states = hidden_states + ff_output
return hidden_states

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More