Compare commits

..

2 Commits

Author SHA1 Message Date
b3609acf30 Merge branch 'master' into node-essentials-category 2026-02-20 23:40:42 -08:00
7bca096a30 fix: swap essentials_category from CLIPTextEncode to PrimitiveStringMultiline
Remove CLIPTextEncode from Basics essentials category and add
PrimitiveStringMultiline (String Multiline) in its place.

Amp-Thread-ID: https://ampcode.com/threads/T-019c7efb-d916-7244-8c43-77b615ba0622
2026-02-20 23:16:55 -08:00
18 changed files with 68 additions and 435 deletions

View File

@ -1,7 +1,6 @@
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
language: "en-US" language: "en-US"
early_access: false early_access: false
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
reviews: reviews:
profile: "chill" profile: "chill"
@ -36,14 +35,6 @@ reviews:
- "!**/*.bat" - "!**/*.bat"
path_instructions: path_instructions:
- path: "**"
instructions: |
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
de-indented, or reformatted without logic changes. If code appears in the diff
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
treat it as unchanged. Contributors should not feel obligated to address
pre-existing issues outside the scope of their contribution.
- path: "comfy/**" - path: "comfy/**"
instructions: | instructions: |
Core ML/diffusion engine. Focus on: Core ML/diffusion engine. Focus on:
@ -83,11 +74,7 @@ reviews:
auto_review: auto_review:
enabled: true enabled: true
auto_incremental_review: true auto_incremental_review: true
drafts: false drafts: true
ignore_title_keywords:
- "WIP"
- "DO NOT REVIEW"
- "DO NOT MERGE"
finishing_touches: finishing_touches:
docstrings: docstrings:
@ -97,7 +84,7 @@ reviews:
tools: tools:
ruff: ruff:
enabled: false enabled: true
pylint: pylint:
enabled: false enabled: false
flake8: flake8:

View File

@ -53,7 +53,7 @@ class SubgraphManager:
return entry_id, entry return entry_id, entry
async def load_entry_data(self, entry: SubgraphEntry): async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r', encoding='utf-8') as f: with open(entry['path'], 'r') as f:
entry['data'] = f.read() entry['data'] = f.read()
return entry return entry

View File

@ -9,7 +9,6 @@ from comfy.ldm.lightricks.model import (
LTXVModel, LTXVModel,
) )
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import comfy.ldm.common_dit import comfy.ldm.common_dit
class CompressedTimestep: class CompressedTimestep:
@ -451,29 +450,6 @@ class LTXAVModel(LTXVModel):
operations=self.operations, operations=self.operations,
) )
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=self.operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=self.operations,
)
def preprocess_text_embeds(self, context):
if context.shape[-1] == self.caption_channels * 2:
return context
out_vid = self.video_embeddings_connector(context)[0]
out_audio = self.audio_embeddings_connector(context)[0]
return torch.concat((out_vid, out_audio), dim=-1)
def _init_transformer_blocks(self, device, dtype, **kwargs): def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks for LTXAV.""" """Initialize transformer blocks for LTXAV."""
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(

View File

@ -234,7 +234,7 @@ class Embeddings1DConnector(nn.Module):
return indices return indices
def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None): def precompute_freqs_cis(self, indices_grid, spacing="exp"):
dim = self.inner_dim dim = self.inner_dim
n_elem = 2 # 2 because of cos and sin n_elem = 2 # 2 because of cos and sin
freqs = self.precompute_freqs(indices_grid, spacing) freqs = self.precompute_freqs(indices_grid, spacing)
@ -247,7 +247,7 @@ class Embeddings1DConnector(nn.Module):
) )
else: else:
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
def forward( def forward(
self, self,
@ -288,7 +288,7 @@ class Embeddings1DConnector(nn.Module):
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
) )
indices_grid = indices_grid[None, None, :] indices_grid = indices_grid[None, None, :]
freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype) freqs_cis = self.precompute_freqs_cis(indices_grid)
# 2. Blocks # 2. Blocks
for block_idx, block in enumerate(self.transformer_1d_blocks): for block_idx, block in enumerate(self.transformer_1d_blocks):

View File

@ -78,4 +78,4 @@ def interpret_gathered_like(tensors, gathered):
return dest_views return dest_views
aimdo_enabled = False aimdo_allocator = None

View File

@ -988,14 +988,10 @@ class LTXAV(BaseModel):
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None) attention_mask = kwargs.get("attention_mask", None)
device = kwargs["device"]
if attention_mask is not None: if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None) cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None: if cross_attn is not None:
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))

View File

@ -836,7 +836,7 @@ def unet_inital_load_device(parameters, dtype):
mem_dev = get_free_memory(torch_dev) mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev) mem_cpu = get_free_memory(cpu_dev)
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled: if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None:
return torch_dev return torch_dev
else: else:
return cpu_dev return cpu_dev
@ -1121,6 +1121,7 @@ def get_cast_buffer(offload_stream, device, size, ref):
synchronize() synchronize()
del STREAM_CAST_BUFFERS[offload_stream] del STREAM_CAST_BUFFERS[offload_stream]
del cast_buffer del cast_buffer
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
soft_empty_cache() soft_empty_cache()
with wf_context: with wf_context:
cast_buffer = torch.empty((size), dtype=torch.int8, device=device) cast_buffer = torch.empty((size), dtype=torch.int8, device=device)

View File

@ -3,6 +3,7 @@ import os
from transformers import T5TokenizerFast from transformers import T5TokenizerFast
from .spiece_tokenizer import SPieceTokenizer from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.genmo import comfy.text_encoders.genmo
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import torch import torch
import comfy.utils import comfy.utils
import math import math
@ -108,6 +109,22 @@ class LTXAVTEModel(torch.nn.Module):
operations = self.gemma3_12b.operations # TODO operations = self.gemma3_12b.operations # TODO
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=operations,
)
def set_clip_options(self, options): def set_clip_options(self, options):
self.execution_device = options.get("execution_device", self.execution_device) self.execution_device = options.get("execution_device", self.execution_device)
self.gemma3_12b.set_clip_options(options) self.gemma3_12b.set_clip_options(options)
@ -129,6 +146,10 @@ class LTXAVTEModel(torch.nn.Module):
out = out.reshape((out.shape[0], out.shape[1], -1)) out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out) out = self.text_embedding_projection(out)
out = out.float() out = out.float()
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
return out.to(out_device), pooled return out.to(out_device), pooled
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
@ -138,14 +159,14 @@ class LTXAVTEModel(torch.nn.Module):
if "model.layers.47.self_attn.q_norm.weight" in sd: if "model.layers.47.self_attn.q_norm.weight" in sd:
return self.gemma3_12b.load_sd(sd) return self.gemma3_12b.load_sd(sd)
else: else:
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True) sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
if len(sdo) == 0: if len(sdo) == 0:
sdo = sd sdo = sd
missing_all = [] missing_all = []
unexpected_all = [] unexpected_all = []
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection)]: for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)} component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
if component_sd: if component_sd:
missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False)) missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))

View File

@ -1154,7 +1154,7 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
def model_trange(*args, **kwargs): def model_trange(*args, **kwargs):
if not comfy.memory_management.aimdo_enabled: if comfy.memory_management.aimdo_allocator is None:
return trange(*args, **kwargs) return trange(*args, **kwargs)
pbar = trange(*args, **kwargs, smoothing=1.0) pbar = trange(*args, **kwargs, smoothing=1.0)

View File

@ -1237,82 +1237,6 @@ class BoundingBox(ComfyTypeIO):
return d return d
@comfytype(io_type="CURVE")
class Curve(ComfyTypeIO):
Type = list
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: list=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = [[0, 0], [1, 1]]
def as_dict(self):
return super().as_dict()
@comfytype(io_type="RANGE")
class Range(ComfyTypeIO):
Type = dict # {"min": float, "max": float, "midpoint"?: float}
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None,
display: str=None,
gradient_stops: list=None,
show_midpoint: bool=None,
midpoint_scale: str=None,
value_min: float=None,
value_max: float=None,
advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = {"min": 0.0, "max": 1.0}
self.display = display
self.gradient_stops = gradient_stops
self.show_midpoint = show_midpoint
self.midpoint_scale = midpoint_scale
self.value_min = value_min
self.value_max = value_max
def as_dict(self):
return super().as_dict() | prune_dict({
"display": self.display,
"gradient_stops": self.gradient_stops,
"show_midpoint": self.show_midpoint,
"midpoint_scale": self.midpoint_scale,
"value_min": self.value_min,
"value_max": self.value_max,
})
@comfytype(io_type="COLOR_CURVES")
class ColorCurves(ComfyTypeIO):
class ColorCurvesDict(TypedDict):
rgb: list[list[float]]
red: list[list[float]]
green: list[list[float]]
blue: list[list[float]]
Type = ColorCurvesDict
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = {
"rgb": [[0, 0], [1, 1]],
"red": [[0, 0], [1, 1]],
"green": [[0, 0], [1, 1]],
"blue": [[0, 0], [1, 1]]
}
def as_dict(self):
return super().as_dict()
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func DYNAMIC_INPUT_LOOKUP[io_type] = func
@ -2299,7 +2223,5 @@ __all__ = [
"PriceBadgeDepends", "PriceBadgeDepends",
"PriceBadge", "PriceBadge",
"BoundingBox", "BoundingBox",
"Curve",
"ColorCurves",
"NodeReplace", "NodeReplace",
] ]

View File

@ -1,137 +0,0 @@
from typing_extensions import override
import torch
import numpy as np
from comfy_api.latest import ComfyExtension, io, ui
def _monotone_cubic_hermite(xs, ys, x_query):
"""Evaluate monotone cubic Hermite interpolation at x_query points."""
n = len(xs)
if n == 0:
return np.zeros_like(x_query)
if n == 1:
return np.full_like(x_query, ys[0])
# Compute slopes
deltas = np.diff(ys) / np.maximum(np.diff(xs), 1e-10)
# Compute tangents (Fritsch-Carlson)
slopes = np.zeros(n)
slopes[0] = deltas[0]
slopes[-1] = deltas[-1]
for i in range(1, n - 1):
if deltas[i - 1] * deltas[i] <= 0:
slopes[i] = 0
else:
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
# Enforce monotonicity
for i in range(n - 1):
if deltas[i] == 0:
slopes[i] = 0
slopes[i + 1] = 0
else:
alpha = slopes[i] / deltas[i]
beta = slopes[i + 1] / deltas[i]
s = alpha ** 2 + beta ** 2
if s > 9:
t = 3 / np.sqrt(s)
slopes[i] = t * alpha * deltas[i]
slopes[i + 1] = t * beta * deltas[i]
# Evaluate
result = np.zeros_like(x_query, dtype=np.float64)
indices = np.searchsorted(xs, x_query, side='right') - 1
indices = np.clip(indices, 0, n - 2)
for i in range(n - 1):
mask = indices == i
if not np.any(mask):
continue
dx = xs[i + 1] - xs[i]
if dx == 0:
result[mask] = ys[i]
continue
t = (x_query[mask] - xs[i]) / dx
t2 = t * t
t3 = t2 * t
h00 = 2 * t3 - 3 * t2 + 1
h10 = t3 - 2 * t2 + t
h01 = -2 * t3 + 3 * t2
h11 = t3 - t2
result[mask] = h00 * ys[i] + h10 * dx * slopes[i] + h01 * ys[i + 1] + h11 * dx * slopes[i + 1]
# Clamp edges
result[x_query <= xs[0]] = ys[0]
result[x_query >= xs[-1]] = ys[-1]
return result
def _build_lut(points):
"""Build a 256-entry LUT from curve control points in [0,1] space."""
if not points or len(points) < 2:
return np.arange(256, dtype=np.float64) / 255.0
pts = sorted(points, key=lambda p: p[0])
xs = np.array([p[0] for p in pts], dtype=np.float64)
ys = np.array([p[1] for p in pts], dtype=np.float64)
x_query = np.linspace(0, 1, 256)
lut = _monotone_cubic_hermite(xs, ys, x_query)
return np.clip(lut, 0, 1)
class ColorCurvesNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ColorCurves",
display_name="Color Curves",
category="image/adjustment",
inputs=[
io.Image.Input("image"),
io.ColorCurves.Input("settings"),
],
outputs=[
io.Image.Output(),
],
)
@classmethod
def execute(cls, image: torch.Tensor, settings: dict) -> io.NodeOutput:
rgb_pts = settings.get("rgb", [[0, 0], [1, 1]])
red_pts = settings.get("red", [[0, 0], [1, 1]])
green_pts = settings.get("green", [[0, 0], [1, 1]])
blue_pts = settings.get("blue", [[0, 0], [1, 1]])
rgb_lut = _build_lut(rgb_pts)
red_lut = _build_lut(red_pts)
green_lut = _build_lut(green_pts)
blue_lut = _build_lut(blue_pts)
# Convert to numpy for LUT application
img_np = image.cpu().numpy().copy()
# Apply per-channel curves then RGB master curve.
# Index with floor(val * 256) clamped to [0, 255] to match GPU NEAREST
# texture sampling on a 256-wide LUT texture.
for ch, ch_lut in enumerate([red_lut, green_lut, blue_lut]):
indices = np.clip((img_np[..., ch] * 256).astype(np.int32), 0, 255)
img_np[..., ch] = ch_lut[indices]
indices = np.clip((img_np[..., ch] * 256).astype(np.int32), 0, 255)
img_np[..., ch] = rgb_lut[indices]
result = torch.from_numpy(np.clip(img_np, 0, 1)).to(image.device, dtype=image.dtype)
return io.NodeOutput(result, ui=ui.PreviewImage(result))
class ColorCurvesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [ColorCurvesNode]
async def comfy_entrypoint() -> ColorCurvesExtension:
return ColorCurvesExtension()

View File

@ -716,12 +716,12 @@ def _render_shader_batch(
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0) gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
gl.glUseProgram(0) gl.glUseProgram(0)
for tex in input_textures: if input_textures:
gl.glDeleteTextures(tex) gl.glDeleteTextures(len(input_textures), input_textures)
for tex in output_textures: if output_textures:
gl.glDeleteTextures(tex) gl.glDeleteTextures(len(output_textures), output_textures)
for tex in ping_pong_textures: if ping_pong_textures:
gl.glDeleteTextures(tex) gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
if fbo is not None: if fbo is not None:
gl.glDeleteFramebuffers(1, [fbo]) gl.glDeleteFramebuffers(1, [fbo])
for pp_fbo in ping_pong_fbos: for pp_fbo in ping_pong_fbos:

View File

@ -10,7 +10,7 @@ class NAGuidance(io.ComfyNode):
node_id="NAGuidance", node_id="NAGuidance",
display_name="Normalized Attention Guidance", display_name="Normalized Attention Guidance",
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.", description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
category="advanced/guidance", category="",
is_experimental=True, is_experimental=True,
inputs=[ inputs=[
io.Model.Input("model", tooltip="The model to apply NAG to."), io.Model.Input("model", tooltip="The model to apply NAG to."),

View File

@ -1,8 +1,10 @@
import os import os
import importlib.util import importlib.util
from comfy.cli_args import args, PerformanceFeature from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import subprocess import subprocess
import comfy_aimdo.control
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names(): def get_gpu_names():
if os.name == 'nt': if os.name == 'nt':
@ -85,6 +87,10 @@ if not args.cuda_malloc:
except: except:
pass pass
if enables_dynamic_vram() and comfy_aimdo.control.init():
args.cuda_malloc = False
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ""
if args.disable_cuda_malloc: if args.disable_cuda_malloc:
args.cuda_malloc = False args.cuda_malloc = False

View File

@ -9,6 +9,7 @@ import traceback
from enum import Enum from enum import Enum
from typing import List, Literal, NamedTuple, Optional, Union from typing import List, Literal, NamedTuple, Optional, Union
import asyncio import asyncio
from contextlib import nullcontext
import torch import torch
@ -520,14 +521,19 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0) GraphBuilder.set_default_prefix(unique_id, call_index, 0)
try: #Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) #will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
finally: #that we just want to cull out each model run.
if comfy.memory_management.aimdo_enabled: allocator = comfy.memory_management.aimdo_allocator
if args.verbose == "DEBUG": with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
comfy_aimdo.control.analyze() try:
comfy.model_management.reset_cast_buffers() output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
comfy_aimdo.model_vbar.vbars_reset_watermark_limits() finally:
if allocator is not None:
if args.verbose == "DEBUG":
comfy_aimdo.model_vbar.vbars_analyze()
comfy.model_management.reset_cast_buffers()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
if has_pending_tasks: if has_pending_tasks:
pending_async_nodes[unique_id] = output_data pending_async_nodes[unique_id] = output_data

11
main.py
View File

@ -173,10 +173,6 @@ import gc
if 'torch' in sys.modules: if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
import comfy_aimdo.control
if enables_dynamic_vram():
comfy_aimdo.control.init()
import comfy.utils import comfy.utils
@ -192,9 +188,13 @@ import hook_breaker_ac10a0
import comfy.memory_management import comfy.memory_management
import comfy.model_patcher import comfy.model_patcher
import comfy_aimdo.control
import comfy_aimdo.torch
if enables_dynamic_vram(): if enables_dynamic_vram():
if comfy.model_management.torch_version_numeric < (2, 8): if comfy.model_management.torch_version_numeric < (2, 8):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
comfy.memory_management.aimdo_allocator = None
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
if args.verbose == 'DEBUG': if args.verbose == 'DEBUG':
comfy_aimdo.control.set_log_debug() comfy_aimdo.control.set_log_debug()
@ -208,10 +208,11 @@ if enables_dynamic_vram():
comfy_aimdo.control.set_log_info() comfy_aimdo.control.set_log_info()
comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic
comfy.memory_management.aimdo_enabled = True comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator()
logging.info("DynamicVRAM support detected and enabled") logging.info("DynamicVRAM support detected and enabled")
else: else:
logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
comfy.memory_management.aimdo_allocator = None
def cuda_malloc_warning(): def cuda_malloc_warning():

146
nodes.py
View File

@ -2035,144 +2035,6 @@ class ImagePadForOutpaint:
return (new_image, mask.unsqueeze(0)) return (new_image, mask.unsqueeze(0))
class TestCurveWidget:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"curve": ("CURVE", {"default": [[0, 0], [1, 1]]}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("points",)
FUNCTION = "execute"
OUTPUT_NODE = True
CATEGORY = "testing"
def execute(self, curve):
import json
result = json.dumps(curve, indent=2)
print("Curve points:", result)
return {"ui": {"text": [result]}, "result": (result,)}
class TestRangePlain:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"range": ("RANGE", {"default": {"min": 0.0, "max": 1.0}}),
"range_midpoint": ("RANGE", {
"default": {"min": 0.2, "max": 0.8, "midpoint": 0.5},
"show_midpoint": True,
}),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
OUTPUT_NODE = True
CATEGORY = "testing"
def execute(self, **kwargs):
import json
result = json.dumps(kwargs, indent=2)
return {"ui": {"text": [result]}, "result": (result,)}
class TestRangeGradient:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"range": ("RANGE", {
"default": {"min": 0.0, "max": 1.0},
"display": "gradient",
"gradient_stops": [
{"offset": 0.0, "color": [0, 0, 0]},
{"offset": 1.0, "color": [255, 255, 255]}
],
}),
"range_midpoint": ("RANGE", {
"default": {"min": 0.0, "max": 1.0, "midpoint": 0.5},
"display": "gradient",
"gradient_stops": [
{"offset": 0.0, "color": [0, 0, 0]},
{"offset": 1.0, "color": [255, 255, 255]}
],
"show_midpoint": True,
"midpoint_scale": "gamma",
}),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
OUTPUT_NODE = True
CATEGORY = "testing"
def execute(self, **kwargs):
import json
result = json.dumps(kwargs, indent=2)
return {"ui": {"text": [result]}, "result": (result,)}
class TestRangeHistogram:
RANGE_OPTS = {
"display": "histogram",
"show_midpoint": True,
"midpoint_scale": "gamma",
"value_min": 0,
"value_max": 255,
}
@classmethod
def INPUT_TYPES(s):
default = {"min": 0, "max": 255, "midpoint": 0.5}
return {
"required": {
"image": ("IMAGE",),
"rgb": ("RANGE", {"default": {**default}, **s.RANGE_OPTS}),
"red": ("RANGE", {"default": {**default}, **s.RANGE_OPTS}),
"green": ("RANGE", {"default": {**default}, **s.RANGE_OPTS}),
"blue": ("RANGE", {"default": {**default}, **s.RANGE_OPTS}),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
OUTPUT_NODE = True
CATEGORY = "testing"
def execute(self, image, rgb, red, green, blue):
import json
import numpy as np
img = image[0].cpu().numpy() # (H, W, C)
# Per-channel histograms
hist_r, _ = np.histogram(img[:, :, 0].flatten(), bins=256, range=(0.0, 1.0))
hist_g, _ = np.histogram(img[:, :, 1].flatten(), bins=256, range=(0.0, 1.0))
hist_b, _ = np.histogram(img[:, :, 2].flatten(), bins=256, range=(0.0, 1.0))
# Luminance histogram (BT.709)
luminance = 0.2126 * img[:, :, 0] + 0.7152 * img[:, :, 1] + 0.0722 * img[:, :, 2]
hist_rgb, _ = np.histogram(luminance.flatten(), bins=256, range=(0.0, 1.0))
result = json.dumps({"rgb": rgb, "red": red, "green": green, "blue": blue}, indent=2)
return {
"ui": {
"text": [result],
"range_histogram_rgb": hist_rgb.astype(np.uint32).tolist(),
"range_histogram_red": hist_r.astype(np.uint32).tolist(),
"range_histogram_green": hist_g.astype(np.uint32).tolist(),
"range_histogram_blue": hist_b.astype(np.uint32).tolist(),
},
"result": (result,)
}
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"KSampler": KSampler, "KSampler": KSampler,
"CheckpointLoaderSimple": CheckpointLoaderSimple, "CheckpointLoaderSimple": CheckpointLoaderSimple,
@ -2241,10 +2103,6 @@ NODE_CLASS_MAPPINGS = {
"ConditioningZeroOut": ConditioningZeroOut, "ConditioningZeroOut": ConditioningZeroOut,
"ConditioningSetTimestepRange": ConditioningSetTimestepRange, "ConditioningSetTimestepRange": ConditioningSetTimestepRange,
"LoraLoaderModelOnly": LoraLoaderModelOnly, "LoraLoaderModelOnly": LoraLoaderModelOnly,
"TestCurveWidget": TestCurveWidget,
"TestRangePlain": TestRangePlain,
"TestRangeGradient": TestRangeGradient,
"TestRangeHistogram": TestRangeHistogram,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
@ -2313,10 +2171,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
# _for_testing # _for_testing
"VAEDecodeTiled": "VAE Decode (Tiled)", "VAEDecodeTiled": "VAE Decode (Tiled)",
"VAEEncodeTiled": "VAE Encode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)",
"TestCurveWidget": "Test Curve Widget",
"TestRangePlain": "Test Range (Plain)",
"TestRangeGradient": "Test Range (Gradient)",
"TestRangeHistogram": "Test Range (Histogram)",
} }
EXTENSION_WEB_DIRS = {} EXTENSION_WEB_DIRS = {}

View File

@ -22,7 +22,7 @@ alembic
SQLAlchemy SQLAlchemy
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.7 comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.0 comfy-aimdo>=0.1.8
requests requests
#non essential dependencies: #non essential dependencies: