Compare commits

..

12 Commits

Author SHA1 Message Date
4d721bff59 Range Editor 2026-02-26 08:22:46 -05:00
de67bb0870 feat: add COLOR_CURVES type and ColorCurves node 2026-02-22 20:42:53 -05:00
9d70c2626b test 2026-02-22 19:24:54 -05:00
cb64b957f2 CURVE type 2026-02-22 15:06:46 -05:00
07ca6852e8 Fix dtype issue in embeddings connector. (#12570) 2026-02-22 03:18:20 -05:00
f266b8d352 Move LTXAV av embedding connectors to diffusion model. (#12569) 2026-02-21 22:29:58 -05:00
b6cb30bab5 chore: tune CodeRabbit config to limit review scope and disable for drafts (#12567)
* chore: tune CodeRabbit config to limit review scope and disable for drafts

- Add tone_instructions to focus only on newly introduced issues
- Add global path_instructions entry to ignore pre-existing issues in moved/reformatted code
- Disable draft PR reviews (drafts: false) and add WIP title keywords
- Disable ruff tool to prevent linter-based outside-diff-range comments

Addresses feedback from maintainers about CodeRabbit flagging pre-existing
issues in code that was merely moved or de-indented (e.g., PR #12557),
which can discourage community contributions and cause scope creep.

Amp-Thread-ID: https://ampcode.com/threads/T-019c82de-0481-7253-ad42-20cb595bb1ba

* chore: add 'DO NOT MERGE' to ignore_title_keywords

Amp-Thread-ID: https://ampcode.com/threads/T-019c82de-0481-7253-ad42-20cb595bb1ba
2026-02-21 18:32:15 -08:00
ee72752162 Add category to Normalized Attention Guidance node (#12565) 2026-02-21 19:51:21 -05:00
7591d781a7 fix: specify UTF-8 encoding when reading subgraph files (#12563)
On Windows, Python defaults to cp1252 encoding when no encoding is
specified. JSON files containing UTF-8 characters (e.g., non-ASCII
characters) cause UnicodeDecodeError when read with cp1252.

This fixes the error that occurs when loading blueprint subgraphs
on Windows systems.

https://claude.ai/code/session_014WHi3SL9Gzsi3U6kbSjbSb

Co-authored-by: Claude <noreply@anthropic.com>
2026-02-21 15:05:00 -08:00
0bfb936ab4 comfy-aimdo 0.2 - Improved pytorch allocator integration (#12557)
Integrate comfy-aimdo 0.2 which takes a different approach to
installing the memory allocator hook. Instead of using the complicated
and buggy pytorch MemPool+CudaPluggableAlloctor, cuda is directly hooked
making the process much more transparent to both comfy and pytorch. As
far as pytorch knows, aimdo doesnt exist anymore, and just operates
behind the scenes.

Remove all the mempool setup stuff for dynamic_vram and bump the
comfy-aimdo version. Remove the allocator object from memory_management
and demote its use as an enablment check to a boolean flag.

Comfy-aimdo 0.2 also support the pytorch cuda async allocator, so
remove the dynamic_vram based force disablement of cuda_malloc and
just go back to the old settings of allocators based on command line
input.
2026-02-21 10:52:57 -08:00
602b2505a4 add support for pyopengl < 3.1.4 where the size parameter does not exist (#12555) 2026-02-21 06:14:57 -08:00
04a55d5019 fix: swap essentials_category from CLIPTextEncode to PrimitiveStringMultiline (#12553)
Remove CLIPTextEncode from Basics essentials category and add
PrimitiveStringMultiline (String Multiline) in its place.

Amp-Thread-ID: https://ampcode.com/threads/T-019c7efb-d916-7244-8c43-77b615ba0622

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-20 23:46:46 -08:00
18 changed files with 435 additions and 68 deletions

View File

@ -1,6 +1,7 @@
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
language: "en-US" language: "en-US"
early_access: false early_access: false
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
reviews: reviews:
profile: "chill" profile: "chill"
@ -35,6 +36,14 @@ reviews:
- "!**/*.bat" - "!**/*.bat"
path_instructions: path_instructions:
- path: "**"
instructions: |
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
de-indented, or reformatted without logic changes. If code appears in the diff
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
treat it as unchanged. Contributors should not feel obligated to address
pre-existing issues outside the scope of their contribution.
- path: "comfy/**" - path: "comfy/**"
instructions: | instructions: |
Core ML/diffusion engine. Focus on: Core ML/diffusion engine. Focus on:
@ -74,7 +83,11 @@ reviews:
auto_review: auto_review:
enabled: true enabled: true
auto_incremental_review: true auto_incremental_review: true
drafts: true drafts: false
ignore_title_keywords:
- "WIP"
- "DO NOT REVIEW"
- "DO NOT MERGE"
finishing_touches: finishing_touches:
docstrings: docstrings:
@ -84,7 +97,7 @@ reviews:
tools: tools:
ruff: ruff:
enabled: true enabled: false
pylint: pylint:
enabled: false enabled: false
flake8: flake8:

View File

@ -53,7 +53,7 @@ class SubgraphManager:
return entry_id, entry return entry_id, entry
async def load_entry_data(self, entry: SubgraphEntry): async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r') as f: with open(entry['path'], 'r', encoding='utf-8') as f:
entry['data'] = f.read() entry['data'] = f.read()
return entry return entry

View File

@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import (
LTXVModel, LTXVModel,
) )
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import comfy.ldm.common_dit import comfy.ldm.common_dit
class CompressedTimestep: class CompressedTimestep:
@ -450,6 +451,29 @@ class LTXAVModel(LTXVModel):
operations=self.operations, operations=self.operations,
) )
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=self.operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=self.operations,
)
def preprocess_text_embeds(self, context):
if context.shape[-1] == self.caption_channels * 2:
return context
out_vid = self.video_embeddings_connector(context)[0]
out_audio = self.audio_embeddings_connector(context)[0]
return torch.concat((out_vid, out_audio), dim=-1)
def _init_transformer_blocks(self, device, dtype, **kwargs): def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks for LTXAV.""" """Initialize transformer blocks for LTXAV."""
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(

View File

@ -234,7 +234,7 @@ class Embeddings1DConnector(nn.Module):
return indices return indices
def precompute_freqs_cis(self, indices_grid, spacing="exp"): def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
dim = self.inner_dim dim = self.inner_dim
n_elem = 2 # 2 because of cos and sin n_elem = 2 # 2 because of cos and sin
freqs = self.precompute_freqs(indices_grid, spacing) freqs = self.precompute_freqs(indices_grid, spacing)
@ -247,7 +247,7 @@ class Embeddings1DConnector(nn.Module):
) )
else: else:
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
def forward( def forward(
self, self,
@ -288,7 +288,7 @@ class Embeddings1DConnector(nn.Module):
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
) )
indices_grid = indices_grid[None, None, :] indices_grid = indices_grid[None, None, :]
freqs_cis = self.precompute_freqs_cis(indices_grid) freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
# 2. Blocks # 2. Blocks
for block_idx, block in enumerate(self.transformer_1d_blocks): for block_idx, block in enumerate(self.transformer_1d_blocks):

View File

@ -78,4 +78,4 @@ def interpret_gathered_like(tensors, gathered):
return dest_views return dest_views
aimdo_allocator = None aimdo_enabled = False

View File

@ -988,10 +988,14 @@ class LTXAV(BaseModel):
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None) attention_mask = kwargs.get("attention_mask", None)
device = kwargs["device"]
if attention_mask is not None: if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None) cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None: if cross_attn is not None:
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))

View File

@ -836,7 +836,7 @@ def unet_inital_load_device(parameters, dtype):
mem_dev = get_free_memory(torch_dev) mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev) mem_cpu = get_free_memory(cpu_dev)
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None: if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled:
return torch_dev return torch_dev
else: else:
return cpu_dev return cpu_dev
@ -1121,7 +1121,6 @@ def get_cast_buffer(offload_stream, device, size, ref):
synchronize() synchronize()
del STREAM_CAST_BUFFERS[offload_stream] del STREAM_CAST_BUFFERS[offload_stream]
del cast_buffer del cast_buffer
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
soft_empty_cache() soft_empty_cache()
with wf_context: with wf_context:
cast_buffer = torch.empty((size), dtype=torch.int8, device=device) cast_buffer = torch.empty((size), dtype=torch.int8, device=device)

View File

@ -3,7 +3,6 @@ import os
from transformers import T5TokenizerFast from transformers import T5TokenizerFast
from .spiece_tokenizer import SPieceTokenizer from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.genmo import comfy.text_encoders.genmo
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import torch import torch
import comfy.utils import comfy.utils
import math import math
@ -109,22 +108,6 @@ class LTXAVTEModel(torch.nn.Module):
operations = self.gemma3_12b.operations # TODO operations = self.gemma3_12b.operations # TODO
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=operations,
)
def set_clip_options(self, options): def set_clip_options(self, options):
self.execution_device = options.get("execution_device", self.execution_device) self.execution_device = options.get("execution_device", self.execution_device)
self.gemma3_12b.set_clip_options(options) self.gemma3_12b.set_clip_options(options)
@ -146,10 +129,6 @@ class LTXAVTEModel(torch.nn.Module):
out = out.reshape((out.shape[0], out.shape[1], -1)) out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out) out = self.text_embedding_projection(out)
out = out.float() out = out.float()
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
return out.to(out_device), pooled return out.to(out_device), pooled
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
@ -159,14 +138,14 @@ class LTXAVTEModel(torch.nn.Module):
if "model.layers.47.self_attn.q_norm.weight" in sd: if "model.layers.47.self_attn.q_norm.weight" in sd:
return self.gemma3_12b.load_sd(sd) return self.gemma3_12b.load_sd(sd)
else: else:
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True) sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True)
if len(sdo) == 0: if len(sdo) == 0:
sdo = sd sdo = sd
missing_all = [] missing_all = []
unexpected_all = [] unexpected_all = []
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]: for prefix, component in [("text_embedding_projection.", self.text_embedding_projection)]:
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)} component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
if component_sd: if component_sd:
missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False)) missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))

View File

@ -1154,7 +1154,7 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
def model_trange(*args, **kwargs): def model_trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None: if not comfy.memory_management.aimdo_enabled:
return trange(*args, **kwargs) return trange(*args, **kwargs)
pbar = trange(*args, **kwargs, smoothing=1.0) pbar = trange(*args, **kwargs, smoothing=1.0)

View File

@ -1237,6 +1237,82 @@ class BoundingBox(ComfyTypeIO):
return d return d
@comfytype(io_type="CURVE")
class Curve(ComfyTypeIO):
Type = list
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: list=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = [[0, 0], [1, 1]]
def as_dict(self):
return super().as_dict()
@comfytype(io_type="RANGE")
class Range(ComfyTypeIO):
Type = dict # {"min": float, "max": float, "midpoint"?: float}
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None,
display: str=None,
gradient_stops: list=None,
show_midpoint: bool=None,
midpoint_scale: str=None,
value_min: float=None,
value_max: float=None,
advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = {"min": 0.0, "max": 1.0}
self.display = display
self.gradient_stops = gradient_stops
self.show_midpoint = show_midpoint
self.midpoint_scale = midpoint_scale
self.value_min = value_min
self.value_max = value_max
def as_dict(self):
return super().as_dict() | prune_dict({
"display": self.display,
"gradient_stops": self.gradient_stops,
"show_midpoint": self.show_midpoint,
"midpoint_scale": self.midpoint_scale,
"value_min": self.value_min,
"value_max": self.value_max,
})
@comfytype(io_type="COLOR_CURVES")
class ColorCurves(ComfyTypeIO):
class ColorCurvesDict(TypedDict):
rgb: list[list[float]]
red: list[list[float]]
green: list[list[float]]
blue: list[list[float]]
Type = ColorCurvesDict
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = {
"rgb": [[0, 0], [1, 1]],
"red": [[0, 0], [1, 1]],
"green": [[0, 0], [1, 1]],
"blue": [[0, 0], [1, 1]]
}
def as_dict(self):
return super().as_dict()
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func DYNAMIC_INPUT_LOOKUP[io_type] = func
@ -2223,5 +2299,7 @@ __all__ = [
"PriceBadgeDepends", "PriceBadgeDepends",
"PriceBadge", "PriceBadge",
"BoundingBox", "BoundingBox",
"Curve",
"ColorCurves",
"NodeReplace", "NodeReplace",
] ]

View File

@ -0,0 +1,137 @@
from typing_extensions import override
import torch
import numpy as np
from comfy_api.latest import ComfyExtension, io, ui
def _monotone_cubic_hermite(xs, ys, x_query):
"""Evaluate monotone cubic Hermite interpolation at x_query points."""
n = len(xs)
if n == 0:
return np.zeros_like(x_query)
if n == 1:
return np.full_like(x_query, ys[0])
# Compute slopes
deltas = np.diff(ys) / np.maximum(np.diff(xs), 1e-10)
# Compute tangents (Fritsch-Carlson)
slopes = np.zeros(n)
slopes[0] = deltas[0]
slopes[-1] = deltas[-1]
for i in range(1, n - 1):
if deltas[i - 1] * deltas[i] <= 0:
slopes[i] = 0
else:
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
# Enforce monotonicity
for i in range(n - 1):
if deltas[i] == 0:
slopes[i] = 0
slopes[i + 1] = 0
else:
alpha = slopes[i] / deltas[i]
beta = slopes[i + 1] / deltas[i]
s = alpha ** 2 + beta ** 2
if s > 9:
t = 3 / np.sqrt(s)
slopes[i] = t * alpha * deltas[i]
slopes[i + 1] = t * beta * deltas[i]
# Evaluate
result = np.zeros_like(x_query, dtype=np.float64)
indices = np.searchsorted(xs, x_query, side='right') - 1
indices = np.clip(indices, 0, n - 2)
for i in range(n - 1):
mask = indices == i
if not np.any(mask):
continue
dx = xs[i + 1] - xs[i]
if dx == 0:
result[mask] = ys[i]
continue
t = (x_query[mask] - xs[i]) / dx
t2 = t * t
t3 = t2 * t
h00 = 2 * t3 - 3 * t2 + 1
h10 = t3 - 2 * t2 + t
h01 = -2 * t3 + 3 * t2
h11 = t3 - t2
result[mask] = h00 * ys[i] + h10 * dx * slopes[i] + h01 * ys[i + 1] + h11 * dx * slopes[i + 1]
# Clamp edges
result[x_query <= xs[0]] = ys[0]
result[x_query >= xs[-1]] = ys[-1]
return result
def _build_lut(points):
"""Build a 256-entry LUT from curve control points in [0,1] space."""
if not points or len(points) < 2:
return np.arange(256, dtype=np.float64) / 255.0
pts = sorted(points, key=lambda p: p[0])
xs = np.array([p[0] for p in pts], dtype=np.float64)
ys = np.array([p[1] for p in pts], dtype=np.float64)
x_query = np.linspace(0, 1, 256)
lut = _monotone_cubic_hermite(xs, ys, x_query)
return np.clip(lut, 0, 1)
class ColorCurvesNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ColorCurves",
display_name="Color Curves",
category="image/adjustment",
inputs=[
io.Image.Input("image"),
io.ColorCurves.Input("settings"),
],
outputs=[
io.Image.Output(),
],
)
@classmethod
def execute(cls, image: torch.Tensor, settings: dict) -> io.NodeOutput:
rgb_pts = settings.get("rgb", [[0, 0], [1, 1]])
red_pts = settings.get("red", [[0, 0], [1, 1]])
green_pts = settings.get("green", [[0, 0], [1, 1]])
blue_pts = settings.get("blue", [[0, 0], [1, 1]])
rgb_lut = _build_lut(rgb_pts)
red_lut = _build_lut(red_pts)
green_lut = _build_lut(green_pts)
blue_lut = _build_lut(blue_pts)
# Convert to numpy for LUT application
img_np = image.cpu().numpy().copy()
# Apply per-channel curves then RGB master curve.
# Index with floor(val * 256) clamped to [0, 255] to match GPU NEAREST
# texture sampling on a 256-wide LUT texture.
for ch, ch_lut in enumerate([red_lut, green_lut, blue_lut]):
indices = np.clip((img_np[..., ch] * 256).astype(np.int32), 0, 255)
img_np[..., ch] = ch_lut[indices]
indices = np.clip((img_np[..., ch] * 256).astype(np.int32), 0, 255)
img_np[..., ch] = rgb_lut[indices]
result = torch.from_numpy(np.clip(img_np, 0, 1)).to(image.device, dtype=image.dtype)
return io.NodeOutput(result, ui=ui.PreviewImage(result))
class ColorCurvesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [ColorCurvesNode]
async def comfy_entrypoint() -> ColorCurvesExtension:
return ColorCurvesExtension()

View File

@ -716,12 +716,12 @@ def _render_shader_batch(
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0) gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
gl.glUseProgram(0) gl.glUseProgram(0)
if input_textures: for tex in input_textures:
gl.glDeleteTextures(len(input_textures), input_textures) gl.glDeleteTextures(tex)
if output_textures: for tex in output_textures:
gl.glDeleteTextures(len(output_textures), output_textures) gl.glDeleteTextures(tex)
if ping_pong_textures: for tex in ping_pong_textures:
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures) gl.glDeleteTextures(tex)
if fbo is not None: if fbo is not None:
gl.glDeleteFramebuffers(1, [fbo]) gl.glDeleteFramebuffers(1, [fbo])
for pp_fbo in ping_pong_fbos: for pp_fbo in ping_pong_fbos:

View File

@ -10,7 +10,7 @@ class NAGuidance(io.ComfyNode):
node_id="NAGuidance", node_id="NAGuidance",
display_name="Normalized Attention Guidance", display_name="Normalized Attention Guidance",
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.", description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
category="", category="advanced/guidance",
is_experimental=True, is_experimental=True,
inputs=[ inputs=[
io.Model.Input("model", tooltip="The model to apply NAG to."), io.Model.Input("model", tooltip="The model to apply NAG to."),

View File

@ -1,10 +1,8 @@
import os import os
import importlib.util import importlib.util
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram from comfy.cli_args import args, PerformanceFeature
import subprocess import subprocess
import comfy_aimdo.control
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names(): def get_gpu_names():
if os.name == 'nt': if os.name == 'nt':
@ -87,10 +85,6 @@ if not args.cuda_malloc:
except: except:
pass pass
if enables_dynamic_vram() and comfy_aimdo.control.init():
args.cuda_malloc = False
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ""
if args.disable_cuda_malloc: if args.disable_cuda_malloc:
args.cuda_malloc = False args.cuda_malloc = False

View File

@ -9,7 +9,6 @@ import traceback
from enum import Enum from enum import Enum
from typing import List, Literal, NamedTuple, Optional, Union from typing import List, Literal, NamedTuple, Optional, Union
import asyncio import asyncio
from contextlib import nullcontext
import torch import torch
@ -521,19 +520,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0) GraphBuilder.set_default_prefix(unique_id, call_index, 0)
#Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows try:
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
#that we just want to cull out each model run. finally:
allocator = comfy.memory_management.aimdo_allocator if comfy.memory_management.aimdo_enabled:
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())): if args.verbose == "DEBUG":
try: comfy_aimdo.control.analyze()
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) comfy.model_management.reset_cast_buffers()
finally: comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
if allocator is not None:
if args.verbose == "DEBUG":
comfy_aimdo.model_vbar.vbars_analyze()
comfy.model_management.reset_cast_buffers()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
if has_pending_tasks: if has_pending_tasks:
pending_async_nodes[unique_id] = output_data pending_async_nodes[unique_id] = output_data

11
main.py
View File

@ -173,6 +173,10 @@ import gc
if 'torch' in sys.modules: if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
import comfy_aimdo.control
if enables_dynamic_vram():
comfy_aimdo.control.init()
import comfy.utils import comfy.utils
@ -188,13 +192,9 @@ import hook_breaker_ac10a0
import comfy.memory_management import comfy.memory_management
import comfy.model_patcher import comfy.model_patcher
import comfy_aimdo.control
import comfy_aimdo.torch
if enables_dynamic_vram(): if enables_dynamic_vram():
if comfy.model_management.torch_version_numeric < (2, 8): if comfy.model_management.torch_version_numeric < (2, 8):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
comfy.memory_management.aimdo_allocator = None
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
if args.verbose == 'DEBUG': if args.verbose == 'DEBUG':
comfy_aimdo.control.set_log_debug() comfy_aimdo.control.set_log_debug()
@ -208,11 +208,10 @@ if enables_dynamic_vram():
comfy_aimdo.control.set_log_info() comfy_aimdo.control.set_log_info()
comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic
comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator() comfy.memory_management.aimdo_enabled = True
logging.info("DynamicVRAM support detected and enabled") logging.info("DynamicVRAM support detected and enabled")
else: else:
logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
comfy.memory_management.aimdo_allocator = None
def cuda_malloc_warning(): def cuda_malloc_warning():

146
nodes.py
View File

@ -2035,6 +2035,144 @@ class ImagePadForOutpaint:
return (new_image, mask.unsqueeze(0)) return (new_image, mask.unsqueeze(0))
class TestCurveWidget:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"curve": ("CURVE", {"default": [[0, 0], [1, 1]]}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("points",)
FUNCTION = "execute"
OUTPUT_NODE = True
CATEGORY = "testing"
def execute(self, curve):
import json
result = json.dumps(curve, indent=2)
print("Curve points:", result)
return {"ui": {"text": [result]}, "result": (result,)}
class TestRangePlain:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"range": ("RANGE", {"default": {"min": 0.0, "max": 1.0}}),
"range_midpoint": ("RANGE", {
"default": {"min": 0.2, "max": 0.8, "midpoint": 0.5},
"show_midpoint": True,
}),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
OUTPUT_NODE = True
CATEGORY = "testing"
def execute(self, **kwargs):
import json
result = json.dumps(kwargs, indent=2)
return {"ui": {"text": [result]}, "result": (result,)}
class TestRangeGradient:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"range": ("RANGE", {
"default": {"min": 0.0, "max": 1.0},
"display": "gradient",
"gradient_stops": [
{"offset": 0.0, "color": [0, 0, 0]},
{"offset": 1.0, "color": [255, 255, 255]}
],
}),
"range_midpoint": ("RANGE", {
"default": {"min": 0.0, "max": 1.0, "midpoint": 0.5},
"display": "gradient",
"gradient_stops": [
{"offset": 0.0, "color": [0, 0, 0]},
{"offset": 1.0, "color": [255, 255, 255]}
],
"show_midpoint": True,
"midpoint_scale": "gamma",
}),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
OUTPUT_NODE = True
CATEGORY = "testing"
def execute(self, **kwargs):
import json
result = json.dumps(kwargs, indent=2)
return {"ui": {"text": [result]}, "result": (result,)}
class TestRangeHistogram:
RANGE_OPTS = {
"display": "histogram",
"show_midpoint": True,
"midpoint_scale": "gamma",
"value_min": 0,
"value_max": 255,
}
@classmethod
def INPUT_TYPES(s):
default = {"min": 0, "max": 255, "midpoint": 0.5}
return {
"required": {
"image": ("IMAGE",),
"rgb": ("RANGE", {"default": {**default}, **s.RANGE_OPTS}),
"red": ("RANGE", {"default": {**default}, **s.RANGE_OPTS}),
"green": ("RANGE", {"default": {**default}, **s.RANGE_OPTS}),
"blue": ("RANGE", {"default": {**default}, **s.RANGE_OPTS}),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
OUTPUT_NODE = True
CATEGORY = "testing"
def execute(self, image, rgb, red, green, blue):
import json
import numpy as np
img = image[0].cpu().numpy() # (H, W, C)
# Per-channel histograms
hist_r, _ = np.histogram(img[:, :, 0].flatten(), bins=256, range=(0.0, 1.0))
hist_g, _ = np.histogram(img[:, :, 1].flatten(), bins=256, range=(0.0, 1.0))
hist_b, _ = np.histogram(img[:, :, 2].flatten(), bins=256, range=(0.0, 1.0))
# Luminance histogram (BT.709)
luminance = 0.2126 * img[:, :, 0] + 0.7152 * img[:, :, 1] + 0.0722 * img[:, :, 2]
hist_rgb, _ = np.histogram(luminance.flatten(), bins=256, range=(0.0, 1.0))
result = json.dumps({"rgb": rgb, "red": red, "green": green, "blue": blue}, indent=2)
return {
"ui": {
"text": [result],
"range_histogram_rgb": hist_rgb.astype(np.uint32).tolist(),
"range_histogram_red": hist_r.astype(np.uint32).tolist(),
"range_histogram_green": hist_g.astype(np.uint32).tolist(),
"range_histogram_blue": hist_b.astype(np.uint32).tolist(),
},
"result": (result,)
}
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"KSampler": KSampler, "KSampler": KSampler,
"CheckpointLoaderSimple": CheckpointLoaderSimple, "CheckpointLoaderSimple": CheckpointLoaderSimple,
@ -2103,6 +2241,10 @@ NODE_CLASS_MAPPINGS = {
"ConditioningZeroOut": ConditioningZeroOut, "ConditioningZeroOut": ConditioningZeroOut,
"ConditioningSetTimestepRange": ConditioningSetTimestepRange, "ConditioningSetTimestepRange": ConditioningSetTimestepRange,
"LoraLoaderModelOnly": LoraLoaderModelOnly, "LoraLoaderModelOnly": LoraLoaderModelOnly,
"TestCurveWidget": TestCurveWidget,
"TestRangePlain": TestRangePlain,
"TestRangeGradient": TestRangeGradient,
"TestRangeHistogram": TestRangeHistogram,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
@ -2171,6 +2313,10 @@ NODE_DISPLAY_NAME_MAPPINGS = {
# _for_testing # _for_testing
"VAEDecodeTiled": "VAE Decode (Tiled)", "VAEDecodeTiled": "VAE Decode (Tiled)",
"VAEEncodeTiled": "VAE Encode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)",
"TestCurveWidget": "Test Curve Widget",
"TestRangePlain": "Test Range (Plain)",
"TestRangeGradient": "Test Range (Gradient)",
"TestRangeHistogram": "Test Range (Histogram)",
} }
EXTENSION_WEB_DIRS = {} EXTENSION_WEB_DIRS = {}

View File

@ -22,7 +22,7 @@ alembic
SQLAlchemy SQLAlchemy
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.7 comfy-kitchen>=0.2.7
comfy-aimdo>=0.1.8 comfy-aimdo>=0.2.0
requests requests
#non essential dependencies: #non essential dependencies: