mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-21 13:57:43 +08:00
Compare commits
1 Commits
comfyanony
...
comfyui-wi
| Author | SHA1 | Date | |
|---|---|---|---|
| 4e36f45820 |
@ -8,8 +8,6 @@ from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
import comfy.conds
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
@ -53,18 +51,12 @@ class ContextHandlerABC(ABC):
|
||||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None, context_overlap: int=0):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.context_overlap = context_overlap
|
||||
self.dim = dim
|
||||
self.total_frames = total_frames
|
||||
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||
self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow}
|
||||
self.guide_frames_indices: list[int] = []
|
||||
self.guide_overlap_info: list[tuple[int, int]] = []
|
||||
self.guide_kf_local_positions: list[int] = []
|
||||
self.guide_downscale_factors: list[int] = []
|
||||
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||
if dim is None:
|
||||
@ -93,11 +85,6 @@ class IndexListContextWindow(ContextWindowABC):
|
||||
region_idx = int(self.center_ratio * num_regions)
|
||||
return min(max(region_idx, 0), num_regions - 1)
|
||||
|
||||
def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow':
|
||||
if modality_idx == 0:
|
||||
return self
|
||||
return self.modality_windows[modality_idx]
|
||||
|
||||
|
||||
class IndexListCallbacks:
|
||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||
@ -161,172 +148,6 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
|
||||
def compute_guide_overlap(guide_entries: list[dict], keyframe_idxs: torch.Tensor, temporal_downscale_ratio: int, window_index_list: list[int]):
|
||||
"""Compute which concatenated guide frames overlap with a context window.
|
||||
|
||||
Each guide's latent-space start is derived from its first token's pixel-t-start
|
||||
in keyframe_idxs (shape (B, [t,h,w], num_tokens, [start, end])), divided by the
|
||||
model's temporal_downscale_ratio.
|
||||
|
||||
Args:
|
||||
guide_entries: list of guide_attention_entry dicts
|
||||
keyframe_idxs: per-token pixel coords cond tensor for the modality
|
||||
temporal_downscale_ratio: model's pixel-to-latent temporal compression ratio
|
||||
window_index_list: the window's frame indices into the video portion
|
||||
|
||||
Returns:
|
||||
suffix_indices: indices into the guide_frames tensor for frame selection
|
||||
overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment
|
||||
kf_local_positions: window-local frame positions for keyframe_idxs regeneration
|
||||
total_overlap: total number of overlapping guide frames
|
||||
"""
|
||||
window_set = set(window_index_list)
|
||||
window_list = list(window_index_list)
|
||||
suffix_indices = []
|
||||
overlap_info = []
|
||||
kf_local_positions = []
|
||||
suffix_base = 0
|
||||
token_offset = 0
|
||||
|
||||
for entry_idx, entry in enumerate(guide_entries):
|
||||
first_t_pixel = int(keyframe_idxs[0, 0, token_offset, 0].item())
|
||||
latent_start = (first_t_pixel + temporal_downscale_ratio - 1) // temporal_downscale_ratio
|
||||
guide_len = entry["latent_shape"][0]
|
||||
entry_overlap = 0
|
||||
|
||||
for local_offset in range(guide_len):
|
||||
video_pos = latent_start + local_offset
|
||||
if video_pos in window_set:
|
||||
suffix_indices.append(suffix_base + local_offset)
|
||||
kf_local_positions.append(window_list.index(video_pos))
|
||||
entry_overlap += 1
|
||||
|
||||
if entry_overlap > 0:
|
||||
overlap_info.append((entry_idx, entry_overlap))
|
||||
suffix_base += guide_len
|
||||
token_offset += entry["pre_filter_count"]
|
||||
|
||||
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WindowingState:
|
||||
"""Per-modality context windowing state for each step,
|
||||
built using IndexListContextHandler._build_window_state().
|
||||
For non-multimodal models the lists are length 1
|
||||
"""
|
||||
latents: list[torch.Tensor] # per-modality working latents (guide frames stripped)
|
||||
guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents
|
||||
guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata
|
||||
keyframe_idxs: list[torch.Tensor | None] # per-modality keyframe_idxs tensor for guide latent_start derivation
|
||||
latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal)
|
||||
dim: int = 0 # primary modality temporal dim for context windowing
|
||||
is_multimodal: bool = False
|
||||
temporal_downscale_ratio: int = 1 # model's pixel-to-latent temporal compression ratio
|
||||
|
||||
def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow:
|
||||
"""Reformat window for multimodal contexts by deriving per-modality index lists.
|
||||
Non-multimodal contexts return the input window unchanged.
|
||||
"""
|
||||
if not self.is_multimodal:
|
||||
return window
|
||||
|
||||
x = self.latents[0]
|
||||
primary_total = self.latent_shapes[0][self.dim]
|
||||
primary_overlap = window.context_overlap
|
||||
map_shapes = self.latent_shapes
|
||||
if x.size(self.dim) != primary_total:
|
||||
map_shapes = list(self.latent_shapes)
|
||||
video_shape = list(self.latent_shapes[0])
|
||||
video_shape[self.dim] = x.size(self.dim)
|
||||
map_shapes[0] = torch.Size(video_shape)
|
||||
try:
|
||||
per_modality_indices = model.map_context_window_to_modalities(
|
||||
window.index_list, map_shapes, self.dim)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(
|
||||
f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.")
|
||||
modality_windows = {}
|
||||
for mod_idx in range(1, len(self.latents)):
|
||||
modality_total_frames = self.latents[mod_idx].shape[self.dim]
|
||||
ratio = modality_total_frames / primary_total if primary_total > 0 else 1
|
||||
modality_overlap = max(round(primary_overlap * ratio), 0)
|
||||
modality_windows[mod_idx] = IndexListContextWindow(
|
||||
per_modality_indices[mod_idx], dim=self.dim,
|
||||
total_frames=modality_total_frames,
|
||||
context_overlap=modality_overlap)
|
||||
return IndexListContextWindow(
|
||||
window.index_list, dim=self.dim, total_frames=x.shape[self.dim],
|
||||
modality_windows=modality_windows, context_overlap=primary_overlap)
|
||||
|
||||
def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]:
|
||||
"""Slice latents for a context window, injecting guide frames where applicable.
|
||||
For multimodal contexts, uses the modality-specific windows derived in prepare_window().
|
||||
"""
|
||||
sliced = []
|
||||
guide_frame_counts = []
|
||||
for idx in range(len(self.latents)):
|
||||
modality_window = window.get_window_for_modality(idx)
|
||||
retain = retain_index_list if idx == 0 else []
|
||||
s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain)
|
||||
if self.guide_entries[idx] is not None:
|
||||
s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx)
|
||||
else:
|
||||
ng = 0
|
||||
sliced.append(s)
|
||||
guide_frame_counts.append(ng)
|
||||
return sliced, guide_frame_counts
|
||||
|
||||
def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow):
|
||||
"""Strip injected guide frames from per-cond, per-modality outputs in place."""
|
||||
for idx in range(len(self.latents)):
|
||||
if guide_frame_counts[idx] > 0:
|
||||
window_len = len(window.get_window_for_modality(idx).index_list)
|
||||
for ci in range(len(out_per_modality)):
|
||||
out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len)
|
||||
|
||||
def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]:
|
||||
guide_entries = self.guide_entries[modality_idx]
|
||||
guide_frames = self.guide_latents[modality_idx]
|
||||
keyframe_idxs = self.keyframe_idxs[modality_idx]
|
||||
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(
|
||||
guide_entries, keyframe_idxs, self.temporal_downscale_ratio, window.index_list)
|
||||
# Shift keyframe positions to account for causal_window_fix anchor occupying sub-pos 0.
|
||||
anchor_idx = getattr(window, 'causal_anchor_index', None)
|
||||
if anchor_idx is not None and anchor_idx >= 0:
|
||||
kf_local_pos = [p + 1 for p in kf_local_pos]
|
||||
window.guide_frames_indices = suffix_idx
|
||||
window.guide_overlap_info = overlap_info
|
||||
window.guide_kf_local_positions = kf_local_pos
|
||||
|
||||
# Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
|
||||
# guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
|
||||
guide_downscale_factors = []
|
||||
if guide_frame_count > 0:
|
||||
full_H = guide_frames.shape[3]
|
||||
for entry_idx, _ in overlap_info:
|
||||
entry_H = guide_entries[entry_idx]["latent_shape"][1]
|
||||
guide_downscale_factors.append(full_H // entry_H)
|
||||
window.guide_downscale_factors = guide_downscale_factors
|
||||
|
||||
if guide_frame_count > 0:
|
||||
idx = tuple([slice(None)] * self.dim + [suffix_idx])
|
||||
return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count
|
||||
return latent_slice, 0
|
||||
|
||||
def patch_latent_shapes(self, sub_conds, new_shapes):
|
||||
if not self.is_multimodal:
|
||||
return
|
||||
|
||||
for cond_list in sub_conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
if 'latent_shapes' in model_conds:
|
||||
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
name: str
|
||||
@ -341,7 +162,7 @@ ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_co
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
|
||||
latent_retain_index_list: list[int]=[], causal_window_fix: bool=True):
|
||||
causal_window_fix: bool=True):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
@ -353,118 +174,17 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
self.freenoise = freenoise
|
||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||
self.split_conds_to_windows = split_conds_to_windows
|
||||
self.latent_retain_index_list = [int(x.strip()) for x in latent_retain_index_list.split(",")] if latent_retain_index_list else []
|
||||
self.causal_window_fix = causal_window_fix
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
@staticmethod
|
||||
def _get_latent_shapes(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
if 'latent_shapes' in model_conds:
|
||||
return model_conds['latent_shapes'].cond
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_guide_entries(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
entries = model_conds.get('guide_attention_entries')
|
||||
if entries is not None and hasattr(entries, 'cond') and entries.cond:
|
||||
return entries.cond
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_keyframe_idxs(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
kf = model_conds.get('keyframe_idxs')
|
||||
if kf is not None and hasattr(kf, 'cond') and kf.cond is not None:
|
||||
return kf.cond
|
||||
return None
|
||||
|
||||
def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor:
|
||||
"""Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.
|
||||
If guide frames are present on the primary modality, only the video portion is shuffled.
|
||||
"""
|
||||
guide_entries = self._get_guide_entries(conds)
|
||||
guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0
|
||||
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
if latent_shapes is not None and len(latent_shapes) > 1:
|
||||
modalities = comfy.utils.unpack_latents(noise, latent_shapes)
|
||||
primary_total = latent_shapes[0][self.dim]
|
||||
primary_video_count = modalities[0].size(self.dim) - guide_count
|
||||
apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed)
|
||||
for i in range(1, len(modalities)):
|
||||
mod_total = latent_shapes[i][self.dim]
|
||||
ratio = mod_total / primary_total if primary_total > 0 else 1
|
||||
mod_ctx_len = max(round(self.context_length * ratio), 1)
|
||||
mod_ctx_overlap = max(round(self.context_overlap * ratio), 0)
|
||||
modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed)
|
||||
noise, _ = comfy.utils.pack_latents(modalities)
|
||||
return noise
|
||||
video_count = noise.size(self.dim) - guide_count
|
||||
apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed)
|
||||
return noise
|
||||
|
||||
def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]], model: BaseModel) -> WindowingState:
|
||||
"""Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds."""
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
is_multimodal = latent_shapes is not None and len(latent_shapes) > 1
|
||||
unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in]
|
||||
|
||||
unpacked_latents_list = list(unpacked_latents)
|
||||
guide_latents_list = [None] * len(unpacked_latents)
|
||||
guide_entries_list = [None] * len(unpacked_latents)
|
||||
keyframe_idxs_list = [None] * len(unpacked_latents)
|
||||
|
||||
extracted_guide_entries = self._get_guide_entries(conds)
|
||||
extracted_keyframe_idxs = self._get_keyframe_idxs(conds)
|
||||
|
||||
# Strip guide frames (only from first modality for now)
|
||||
if extracted_guide_entries is not None:
|
||||
guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries)
|
||||
if guide_count > 0:
|
||||
x = unpacked_latents[0]
|
||||
latent_count = x.size(self.dim) - guide_count
|
||||
unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count)
|
||||
guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count)
|
||||
guide_entries_list[0] = extracted_guide_entries
|
||||
keyframe_idxs_list[0] = extracted_keyframe_idxs
|
||||
|
||||
|
||||
return WindowingState(
|
||||
latents=unpacked_latents_list,
|
||||
guide_latents=guide_latents_list,
|
||||
guide_entries=guide_entries_list,
|
||||
keyframe_idxs=keyframe_idxs_list,
|
||||
latent_shapes=latent_shapes,
|
||||
dim=self.dim,
|
||||
is_multimodal=is_multimodal,
|
||||
temporal_downscale_ratio=model.latent_format.temporal_downscale_ratio)
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
window_state = self._build_window_state(x_in, conds, model) # build window_state to check frame counts, will be built again in execute
|
||||
total_frame_count = window_state.latents[0].size(self.dim)
|
||||
if total_frame_count > self.context_length:
|
||||
logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.")
|
||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||
if x_in.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
||||
if self.cond_retain_index_list:
|
||||
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||
if self.latent_retain_index_list:
|
||||
logging.info(f"Retaining original latent for indexes: {self.latent_retain_index_list}")
|
||||
return True
|
||||
logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).")
|
||||
return False
|
||||
|
||||
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
|
||||
@ -555,9 +275,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return resized_cond
|
||||
|
||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||
sample_sigmas = model_options["transformer_options"]["sample_sigmas"]
|
||||
current_timestep = timestep[0].to(sample_sigmas.dtype)
|
||||
mask = torch.isclose(sample_sigmas, current_timestep, rtol=0.0001)
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
return # substep from multi-step sampler: keep self._step from the last full step
|
||||
@ -566,98 +284,54 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length, context_overlap=self.context_overlap) for window in context_windows]
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
self._model = model
|
||||
self.set_step(timestep, model_options)
|
||||
|
||||
window_state = self._build_window_state(x_in, conds, model)
|
||||
num_modalities = len(window_state.latents)
|
||||
|
||||
context_windows = self.get_context_windows(model, window_state.latents[0], model_options)
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
total_windows = len(enumerated_context_windows)
|
||||
|
||||
# Initialize per-modality accumulators (length 1 for single-modality)
|
||||
accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents]
|
||||
conds_final = [torch.zeros_like(x_in) for _ in conds]
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
|
||||
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
else:
|
||||
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
|
||||
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents]
|
||||
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
# accumulate results from each context window
|
||||
for enum_window in enumerated_context_windows:
|
||||
results = self.evaluate_context_windows(
|
||||
calc_cond_batch, model, x_in, conds, timestep, [enum_window],
|
||||
model_options, window_state=window_state, total_windows=total_windows)
|
||||
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
||||
for result in results:
|
||||
# result.sub_conds_out is per-cond, per-modality: list[list[Tensor]]
|
||||
for mod_idx in range(num_modalities):
|
||||
mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))]
|
||||
modality_window = result.window.get_window_for_modality(mod_idx)
|
||||
self.combine_context_window_results(
|
||||
window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window,
|
||||
result.window_idx, total_windows, timestep,
|
||||
accum[mod_idx], counts[mod_idx], biases[mod_idx])
|
||||
|
||||
# fuse accumulated results into final conds
|
||||
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
||||
conds_final, counts_final, biases_final)
|
||||
try:
|
||||
result_out = []
|
||||
for ci in range(len(conds)):
|
||||
finalized = []
|
||||
for mod_idx in range(num_modalities):
|
||||
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
|
||||
accum[mod_idx][ci] /= counts[mod_idx][ci]
|
||||
f = accum[mod_idx][ci]
|
||||
|
||||
# if guide frames were injected, append them to the end of the fused latents for the next step
|
||||
if window_state.guide_latents[mod_idx] is not None:
|
||||
f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim)
|
||||
finalized.append(f)
|
||||
|
||||
# pack modalities together if needed
|
||||
if window_state.is_multimodal and len(finalized) > 1:
|
||||
packed, _ = comfy.utils.pack_latents(finalized)
|
||||
else:
|
||||
packed = finalized[0]
|
||||
|
||||
result_out.append(packed)
|
||||
return result_out
|
||||
# finalize conds
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
# relative is already normalized, so return as is
|
||||
del counts_final
|
||||
return conds_final
|
||||
else:
|
||||
# normalize conds via division by context usage counts
|
||||
for i in range(len(conds_final)):
|
||||
conds_final[i] /= counts_final[i]
|
||||
del counts_final
|
||||
return conds_final
|
||||
finally:
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds,
|
||||
timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||
model_options, window_state: WindowingState, total_windows: int = None,
|
||||
device=None, first_device=None):
|
||||
"""Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out
|
||||
|
||||
For each window:
|
||||
1. Builds windows (for each modality if multimodal)
|
||||
2. Slices window for each modality
|
||||
3. Injects concatenated latent guide frames where present
|
||||
4. Packs together if needed and calls model
|
||||
5. Unpacks and strips any guides from outputs
|
||||
"""
|
||||
x = window_state.latents[0]
|
||||
|
||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||
model_options, device=None, first_device=None):
|
||||
results: list[ContextResults] = []
|
||||
for window_idx, window in enumerated_context_windows:
|
||||
# allow processing to end between context window executions for faster Cancel
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
# prepare the window accounting for multimodal windows
|
||||
window = window_state.prepare_window(window, model)
|
||||
|
||||
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward.
|
||||
# Set anchor before slice_for_window so the latent slice and downstream cond slices both pick it up.
|
||||
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward
|
||||
anchor_applied = False
|
||||
if self.causal_window_fix:
|
||||
anchor_idx = window.index_list[0] - 1
|
||||
@ -665,46 +339,27 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
window.causal_anchor_index = anchor_idx
|
||||
anchor_applied = True
|
||||
|
||||
# slice the window for each modality, injecting guide frames where applicable
|
||||
sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.latent_retain_index_list, device)
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||
|
||||
logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}"
|
||||
+ (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "")
|
||||
)
|
||||
|
||||
# if multimodal, pack modalities together
|
||||
if window_state.is_multimodal and len(sliced) > 1:
|
||||
sub_x, sub_shapes = comfy.utils.pack_latents(sliced)
|
||||
else:
|
||||
sub_x, sub_shapes = sliced[0], [sliced[0].shape]
|
||||
|
||||
# get resized conds for window
|
||||
# update exposed params
|
||||
model_options["transformer_options"]["context_window"] = window
|
||||
sub_timestep = window.get_tensor(timestep, dim=0)
|
||||
sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds]
|
||||
# get subsections of x, timestep, conds
|
||||
sub_x = window.get_tensor(x_in, device)
|
||||
sub_timestep = window.get_tensor(timestep, device, dim=0)
|
||||
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
|
||||
|
||||
# if multimodal, patch latent_shapes in conds for correct unpacking in model
|
||||
window_state.patch_latent_shapes(sub_conds, sub_shapes)
|
||||
|
||||
# call model on window
|
||||
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
||||
if device is not None:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||
|
||||
# unpack outputs
|
||||
out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
|
||||
|
||||
# strip causal_window_fix anchor from primary modality before guide strip so window_len math stays correct
|
||||
# strip causal_window_fix anchor if applied
|
||||
if anchor_applied:
|
||||
for ci in range(len(out_per_modality)):
|
||||
t = out_per_modality[ci][0]
|
||||
out_per_modality[ci][0] = t.narrow(self.dim, 1, t.shape[self.dim] - 1)
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1)
|
||||
|
||||
# strip injected guide frames
|
||||
window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window)
|
||||
|
||||
results.append(ContextResults(window_idx, out_per_modality, sub_conds, window))
|
||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||
return results
|
||||
|
||||
|
||||
@ -728,7 +383,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
biases_final[i][idx] = bias_total + bias
|
||||
else:
|
||||
# add conds and counts based on weights of fuse method
|
||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep, context_overlap=window.context_overlap)
|
||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
|
||||
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
|
||||
for i in range(len(sub_conds_out)):
|
||||
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
|
||||
@ -738,22 +393,16 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
|
||||
|
||||
|
||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs):
|
||||
# Scale noise_shape to a single context window so VRAM estimation budgets per-window.
|
||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
|
||||
# limit noise_shape length to context_length for more accurate vram use estimation
|
||||
model_options = kwargs.get("model_options", None)
|
||||
if model_options is None:
|
||||
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is not None:
|
||||
noise_shape = list(noise_shape)
|
||||
is_packed = len(noise_shape) == 3 and noise_shape[1] == 1
|
||||
if is_packed:
|
||||
# TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a
|
||||
# per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM.
|
||||
pass
|
||||
elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length:
|
||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||
return executor(model, noise_shape, conds, *args, **kwargs)
|
||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||
return executor(model, noise_shape, *args, **kwargs)
|
||||
|
||||
|
||||
def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||
@ -773,12 +422,11 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois
|
||||
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||
if not handler.freenoise:
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
|
||||
conds = [guider.conds.get('positive', guider.conds.get('negative', []))]
|
||||
noise = handler._apply_freenoise(noise, conds, extra_args["seed"])
|
||||
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
||||
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
|
||||
|
||||
def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
||||
@ -786,6 +434,7 @@ def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||
_sampler_sample_wrapper
|
||||
)
|
||||
|
||||
|
||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||
total_dims = len(x_in.shape)
|
||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||
@ -931,9 +580,8 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
|
||||
return ContextSchedule(context_schedule, func)
|
||||
|
||||
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None, context_overlap: int=None):
|
||||
context_overlap = handler.context_overlap if context_overlap is None else context_overlap
|
||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs, context_overlap=context_overlap)
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
|
||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
|
||||
|
||||
|
||||
def create_weights_flat(length: int, **kwargs) -> list[float]:
|
||||
@ -951,18 +599,18 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
|
||||
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
|
||||
return weight_sequence
|
||||
|
||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], context_overlap: int, **kwargs):
|
||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
|
||||
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
|
||||
# only expected overlap is given different weights
|
||||
weights_torch = torch.ones((length))
|
||||
# blend left-side on all except first window
|
||||
if min(idxs) > 0:
|
||||
ramp_up = torch.linspace(1e-37, 1, context_overlap)
|
||||
weights_torch[:context_overlap] = ramp_up
|
||||
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
|
||||
weights_torch[:handler.context_overlap] = ramp_up
|
||||
# blend right-side on all except last window
|
||||
if max(idxs) < full_length-1:
|
||||
ramp_down = torch.linspace(1, 1e-37, context_overlap)
|
||||
weights_torch[-context_overlap:] = ramp_down
|
||||
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
|
||||
weights_torch[-handler.context_overlap:] = ramp_down
|
||||
return weights_torch
|
||||
|
||||
class ContextFuseMethods:
|
||||
|
||||
@ -515,7 +515,7 @@ class Block(nn.Module):
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_self_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
|
||||
def _x_fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
@ -548,7 +548,7 @@ class Block(nn.Module):
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_cross_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
|
||||
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
@ -557,7 +557,7 @@ class Block(nn.Module):
|
||||
shift_mlp_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
||||
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_mlp_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
return x_B_T_H_W_D
|
||||
|
||||
|
||||
|
||||
@ -1085,7 +1085,7 @@ class LTXVModel(LTXBaseModel):
|
||||
)
|
||||
|
||||
grid_mask = None
|
||||
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
|
||||
if keyframe_idxs is not None:
|
||||
additional_args.update({ "orig_patchified_shape": list(x.shape)})
|
||||
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
|
||||
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
|
||||
@ -1330,7 +1330,7 @@ class LTXVModel(LTXBaseModel):
|
||||
x = x * (1 + scale) + shift
|
||||
x = self.proj_out(x)
|
||||
|
||||
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
|
||||
if keyframe_idxs is not None:
|
||||
grid_mask = kwargs["grid_mask"]
|
||||
orig_patchified_shape = kwargs["orig_patchified_shape"]
|
||||
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
|
||||
|
||||
@ -21,7 +21,6 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
import comfy.ldm.lightricks.av_model
|
||||
import comfy.ldm.lightricks.symmetric_patchifier
|
||||
import comfy.context_windows
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
@ -1205,127 +1204,6 @@ class LTXAV(BaseModel):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim):
|
||||
result = [primary_indices]
|
||||
if len(latent_shapes) < 2:
|
||||
return result
|
||||
|
||||
video_total = latent_shapes[0][dim]
|
||||
|
||||
for i in range(1, len(latent_shapes)):
|
||||
mod_total = latent_shapes[i][dim]
|
||||
# Map each primary index to its proportional range of modality indices and
|
||||
# concatenate in order. Preserves wrapped/strided geometry so the modality
|
||||
# attends to the same temporal regions as the primary window.
|
||||
mod_indices = []
|
||||
seen = set()
|
||||
for v_idx in primary_indices:
|
||||
a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1)
|
||||
a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total)
|
||||
if a_end <= a_start:
|
||||
a_end = a_start + 1
|
||||
for a in range(a_start, a_end):
|
||||
if a not in seen:
|
||||
seen.add(a)
|
||||
mod_indices.append(a)
|
||||
result.append(mod_indices)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_guide_entries(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
entries = model_conds.get('guide_attention_entries')
|
||||
if entries is not None and hasattr(entries, 'cond') and entries.cond:
|
||||
return entries.cond
|
||||
return None
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
# Audio denoise mask — slice using audio modality window
|
||||
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows:
|
||||
audio_window = window.modality_windows.get(1)
|
||||
if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
sliced = audio_window.get_tensor(cond_value.cond, device, dim=2)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
# Video denoise mask — split into video + guide portions, slice each
|
||||
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
cond_tensor = cond_value.cond
|
||||
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
|
||||
if guide_count > 0:
|
||||
T_video = x_in.size(window.dim)
|
||||
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
||||
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
||||
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
||||
suffix_indices = window.guide_frames_indices
|
||||
if suffix_indices:
|
||||
idx = tuple([slice(None)] * window.dim + [suffix_indices])
|
||||
sliced_guide = guide_mask[idx].to(device)
|
||||
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
||||
else:
|
||||
return cond_value._copy_with(sliced_video)
|
||||
|
||||
# Keyframe indices — regenerate pixel coords for window, select guide positions
|
||||
if cond_key == "keyframe_idxs":
|
||||
kf_local_pos = window.guide_kf_local_positions
|
||||
if not kf_local_pos:
|
||||
return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
window_len = len(window.index_list)
|
||||
# account for causal_window_fix anchor in coord space size
|
||||
anchor_idx = getattr(window, 'causal_anchor_index', None)
|
||||
if anchor_idx is not None and anchor_idx >= 0:
|
||||
window_len += 1
|
||||
patchifier = self.diffusion_model.patchifier
|
||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||
scale_factors = self.diffusion_model.vae_scale_factors
|
||||
pixel_coords = comfy.ldm.lightricks.symmetric_patchifier.latent_to_pixel_coords(
|
||||
latent_coords,
|
||||
scale_factors,
|
||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||
tokens = []
|
||||
for pos in kf_local_pos:
|
||||
tokens.extend(range(pos * H * W, (pos + 1) * H * W))
|
||||
pixel_coords = pixel_coords[:, :, tokens, :]
|
||||
|
||||
# Adjust spatial end positions for dilated (downscaled) guides.
|
||||
# Each guide entry may have a different downscale factor; expand the
|
||||
# per-entry factor to cover all tokens belonging to that entry.
|
||||
downscale_factors = window.guide_downscale_factors
|
||||
overlap_info = window.guide_overlap_info
|
||||
if downscale_factors:
|
||||
per_token_factor = []
|
||||
for (entry_idx, overlap_count), dsf in zip(overlap_info, downscale_factors):
|
||||
per_token_factor.extend([dsf] * (overlap_count * H * W))
|
||||
factor_tensor = torch.tensor(per_token_factor, device=pixel_coords.device, dtype=pixel_coords.dtype)
|
||||
spatial_end_offset = (factor_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1) - 1) * torch.tensor(
|
||||
scale_factors[1:], device=pixel_coords.device, dtype=pixel_coords.dtype,
|
||||
).view(1, -1, 1, 1)
|
||||
pixel_coords[:, 1:, :, 1:] += spatial_end_offset
|
||||
|
||||
B = cond_value.cond.shape[0]
|
||||
if B > 1:
|
||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||
return cond_value._copy_with(pixel_coords)
|
||||
|
||||
# Guide attention entries — adjust per-guide counts based on window overlap
|
||||
if cond_key == "guide_attention_entries":
|
||||
overlap_info = window.guide_overlap_info
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
new_entries = []
|
||||
for entry_idx, overlap_count in overlap_info:
|
||||
e = cond_value.cond[entry_idx]
|
||||
new_entries.append({**e,
|
||||
"pre_filter_count": overlap_count * H * W,
|
||||
"latent_shape": [overlap_count, H, W]})
|
||||
return cond_value._copy_with(new_entries)
|
||||
|
||||
return None
|
||||
|
||||
class HunyuanVideo(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||
|
||||
@ -10,7 +10,6 @@ from pydantic import BaseModel, Field, confloat
|
||||
class LumaIO:
|
||||
LUMA_REF = "LUMA_REF"
|
||||
LUMA_CONCEPTS = "LUMA_CONCEPTS"
|
||||
LUMA_RAY32_KEYFRAME = "LUMA_RAY32_KEYFRAME"
|
||||
|
||||
|
||||
class LumaReference:
|
||||
@ -21,14 +20,13 @@ class LumaReference:
|
||||
def create_api_model(self, download_url: str):
|
||||
return LumaImageRef(url=download_url, weight=self.weight)
|
||||
|
||||
|
||||
class LumaReferenceChain:
|
||||
def __init__(self, first_ref: LumaReference = None):
|
||||
def __init__(self, first_ref: LumaReference=None):
|
||||
self.refs: list[LumaReference] = []
|
||||
if first_ref:
|
||||
self.refs.append(first_ref)
|
||||
|
||||
def add(self, luma_ref: LumaReference = None):
|
||||
def add(self, luma_ref: LumaReference=None):
|
||||
self.refs.append(luma_ref)
|
||||
|
||||
def create_api_model(self, download_urls: list[str], max_refs=4):
|
||||
@ -126,7 +124,7 @@ def get_luma_concepts(include_none=False):
|
||||
"pull_out",
|
||||
"aerial",
|
||||
"crane_up",
|
||||
"eye_level",
|
||||
"eye_level"
|
||||
]
|
||||
|
||||
|
||||
@ -164,8 +162,8 @@ class LumaVideoModelOutputDuration(str, Enum):
|
||||
|
||||
|
||||
class LumaGenerationType(str, Enum):
|
||||
video = "video"
|
||||
image = "image"
|
||||
video = 'video'
|
||||
image = 'image'
|
||||
|
||||
|
||||
class LumaState(str, Enum):
|
||||
@ -176,109 +174,86 @@ class LumaState(str, Enum):
|
||||
|
||||
|
||||
class LumaAssets(BaseModel):
|
||||
video: Optional[str] = Field(None, description="The URL of the video")
|
||||
image: Optional[str] = Field(None, description="The URL of the image")
|
||||
progress_video: Optional[str] = Field(None, description="The URL of the progress video")
|
||||
video: Optional[str] = Field(None, description='The URL of the video')
|
||||
image: Optional[str] = Field(None, description='The URL of the image')
|
||||
progress_video: Optional[str] = Field(None, description='The URL of the progress video')
|
||||
|
||||
|
||||
class LumaImageRef(BaseModel):
|
||||
"""Used for image gen"""
|
||||
|
||||
url: str = Field(..., description="The URL of the image reference")
|
||||
weight: confloat(ge=0.0, le=1.0) = Field(..., description="The weight of the image reference")
|
||||
url: str = Field(..., description='The URL of the image reference')
|
||||
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
|
||||
|
||||
|
||||
class LumaImageReference(BaseModel):
|
||||
"""Used for video gen"""
|
||||
|
||||
type: Optional[str] = Field("image", description="Input type, defaults to image")
|
||||
url: str = Field(..., description="The URL of the image")
|
||||
type: Optional[str] = Field('image', description='Input type, defaults to image')
|
||||
url: str = Field(..., description='The URL of the image')
|
||||
|
||||
|
||||
class LumaModifyImageRef(BaseModel):
|
||||
url: str = Field(..., description="The URL of the image reference")
|
||||
weight: confloat(ge=0.0, le=1.0) = Field(..., description="The weight of the image reference")
|
||||
url: str = Field(..., description='The URL of the image reference')
|
||||
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
|
||||
|
||||
|
||||
class LumaCharacterRef(BaseModel):
|
||||
identity0: LumaImageIdentity = Field(..., description="The image identity object")
|
||||
identity0: LumaImageIdentity = Field(..., description='The image identity object')
|
||||
|
||||
|
||||
class LumaImageIdentity(BaseModel):
|
||||
images: list[str] = Field(..., description="The URLs of the image identity")
|
||||
images: list[str] = Field(..., description='The URLs of the image identity')
|
||||
|
||||
|
||||
class LumaGenerationReference(BaseModel):
|
||||
type: str = Field("generation", description="Input type, defaults to generation")
|
||||
id: str = Field(..., description="The ID of the generation")
|
||||
type: str = Field('generation', description='Input type, defaults to generation')
|
||||
id: str = Field(..., description='The ID of the generation')
|
||||
|
||||
|
||||
class LumaKeyframes(BaseModel):
|
||||
frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description="")
|
||||
frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description="")
|
||||
frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
|
||||
frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
|
||||
|
||||
|
||||
class LumaConceptObject(BaseModel):
|
||||
key: str = Field(..., description="Camera Concept name")
|
||||
key: str = Field(..., description='Camera Concept name')
|
||||
|
||||
|
||||
class LumaImageGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description="The prompt of the generation")
|
||||
model: LumaImageModel = Field(LumaImageModel.photon_1, description="The image model used for the generation")
|
||||
aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9)
|
||||
image_ref: Optional[list[LumaImageRef]] = Field(None, description="List of image reference objects")
|
||||
style_ref: Optional[list[LumaImageRef]] = Field(None, description="List of style reference objects")
|
||||
character_ref: Optional[LumaCharacterRef] = Field(None, description="The image identity object")
|
||||
modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description="The modify image reference object")
|
||||
prompt: str = Field(..., description='The prompt of the generation')
|
||||
model: LumaImageModel = Field(LumaImageModel.photon_1, description='The image model used for the generation')
|
||||
aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9, description='The aspect ratio of the generation')
|
||||
image_ref: Optional[list[LumaImageRef]] = Field(None, description='List of image reference objects')
|
||||
style_ref: Optional[list[LumaImageRef]] = Field(None, description='List of style reference objects')
|
||||
character_ref: Optional[LumaCharacterRef] = Field(None, description='The image identity object')
|
||||
modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description='The modify image reference object')
|
||||
|
||||
|
||||
class LumaGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description="The prompt of the generation")
|
||||
model: LumaVideoModel = Field(LumaVideoModel.ray_2, description="The video model used for the generation")
|
||||
duration: Optional[LumaVideoModelOutputDuration] = Field(None, description="The duration of the generation")
|
||||
aspect_ratio: Optional[LumaAspectRatio] = Field(None, description="The aspect ratio of the generation")
|
||||
resolution: Optional[LumaVideoOutputResolution] = Field(None, description="The resolution of the generation")
|
||||
loop: Optional[bool] = Field(None, description="Whether to loop the video")
|
||||
keyframes: Optional[LumaKeyframes] = Field(None, description="The keyframes of the generation")
|
||||
concepts: Optional[list[LumaConceptObject]] = Field(None, description="Camera Concepts to apply to generation")
|
||||
prompt: str = Field(..., description='The prompt of the generation')
|
||||
model: LumaVideoModel = Field(LumaVideoModel.ray_2, description='The video model used for the generation')
|
||||
duration: Optional[LumaVideoModelOutputDuration] = Field(None, description='The duration of the generation')
|
||||
aspect_ratio: Optional[LumaAspectRatio] = Field(None, description='The aspect ratio of the generation')
|
||||
resolution: Optional[LumaVideoOutputResolution] = Field(None, description='The resolution of the generation')
|
||||
loop: Optional[bool] = Field(None, description='Whether to loop the video')
|
||||
keyframes: Optional[LumaKeyframes] = Field(None, description='The keyframes of the generation')
|
||||
concepts: Optional[list[LumaConceptObject]] = Field(None, description='Camera Concepts to apply to generation')
|
||||
|
||||
|
||||
class LumaGeneration(BaseModel):
|
||||
id: str = Field(..., description="The ID of the generation")
|
||||
generation_type: LumaGenerationType = Field(..., description="Generation type, image or video")
|
||||
state: LumaState = Field(..., description="The state of the generation")
|
||||
failure_reason: Optional[str] = Field(None, description="The reason for the state of the generation")
|
||||
created_at: str = Field(..., description="The date and time when the generation was created")
|
||||
assets: Optional[LumaAssets] = Field(None, description="The assets of the generation")
|
||||
model: str = Field(..., description="The model used for the generation")
|
||||
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(...)
|
||||
id: str = Field(..., description='The ID of the generation')
|
||||
generation_type: LumaGenerationType = Field(..., description='Generation type, image or video')
|
||||
state: LumaState = Field(..., description='The state of the generation')
|
||||
failure_reason: Optional[str] = Field(None, description='The reason for the state of the generation')
|
||||
created_at: str = Field(..., description='The date and time when the generation was created')
|
||||
assets: Optional[LumaAssets] = Field(None, description='The assets of the generation')
|
||||
model: str = Field(..., description='The model used for the generation')
|
||||
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation")
|
||||
|
||||
|
||||
class Luma2ImageRef(BaseModel):
|
||||
url: str | None = None
|
||||
data: str | None = None
|
||||
media_type: str | None = None
|
||||
generation_id: str | None = Field(None, description="reference a prior generation (extend / source reuse)")
|
||||
|
||||
|
||||
class Luma2VideoEdit(BaseModel):
|
||||
"""Edit controls for Ray 3.2 ``video_edit`` generations."""
|
||||
|
||||
auto_controls: bool | None = Field(None, description="derive a conditioning schedule from the source (recommended)")
|
||||
strength: str | None = Field(None, description="'adhere_1' .. 'reimagine_3'; constrained by IO.Combo")
|
||||
|
||||
|
||||
class Luma2VideoOptions(BaseModel):
|
||||
"""Ray 3.2 ``video`` output settings (text / image / keyframe / edit / extend)."""
|
||||
|
||||
resolution: str | None = Field(None, description="360p | 540p | 720p | 1080p")
|
||||
duration: str | None = Field(None, description="5s | 10s")
|
||||
loop: bool | None = Field(None)
|
||||
start_frame: Luma2ImageRef | None = Field(None)
|
||||
end_frame: Luma2ImageRef | None = Field(None)
|
||||
keyframes: list[Luma2ImageRef] | None = Field(None)
|
||||
keyframe_indexes: list[int] | None = Field(None)
|
||||
edit: Luma2VideoEdit | None = Field(None)
|
||||
|
||||
|
||||
class Luma2GenerationRequest(BaseModel):
|
||||
@ -291,7 +266,6 @@ class Luma2GenerationRequest(BaseModel):
|
||||
web_search: bool | None = None
|
||||
image_ref: list[Luma2ImageRef] | None = None
|
||||
source: Luma2ImageRef | None = None
|
||||
video: Luma2VideoOptions | None = Field(None)
|
||||
|
||||
|
||||
class Luma2Generation(BaseModel):
|
||||
@ -303,31 +277,3 @@ class Luma2Generation(BaseModel):
|
||||
output: list[LumaImageReference] | None = None
|
||||
failure_reason: str | None = None
|
||||
failure_code: str | None = None
|
||||
|
||||
|
||||
# --- Ray 3.2 multi-keyframe chain ---
|
||||
|
||||
LUMA_KEYFRAME_MODE_FRACTION = "fraction" # value in [0.0, 1.0] of the output video duration
|
||||
LUMA_KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the output
|
||||
|
||||
|
||||
class LumaRay32KeyframeItem:
|
||||
"""One guide image anchored at a position on the Ray 3.2 output timeline."""
|
||||
|
||||
def __init__(self, image: torch.Tensor, mode: str, value: float):
|
||||
self.image = image
|
||||
self.mode = mode # LUMA_KEYFRAME_MODE_FRACTION | LUMA_KEYFRAME_MODE_SECONDS
|
||||
self.value = value
|
||||
|
||||
|
||||
class LumaRay32KeyframeChain:
|
||||
def __init__(self):
|
||||
self.items: list[LumaRay32KeyframeItem] = []
|
||||
|
||||
def add(self, item: LumaRay32KeyframeItem) -> None:
|
||||
self.items.append(item)
|
||||
|
||||
def clone(self) -> "LumaRay32KeyframeChain":
|
||||
c = LumaRay32KeyframeChain()
|
||||
c.items = list(self.items)
|
||||
return c
|
||||
|
||||
@ -3,13 +3,9 @@ from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.luma import (
|
||||
LUMA_KEYFRAME_MODE_FRACTION,
|
||||
LUMA_KEYFRAME_MODE_SECONDS,
|
||||
Luma2Generation,
|
||||
Luma2GenerationRequest,
|
||||
Luma2ImageRef,
|
||||
Luma2VideoEdit,
|
||||
Luma2VideoOptions,
|
||||
LumaAspectRatio,
|
||||
LumaCharacterRef,
|
||||
LumaConceptChain,
|
||||
@ -22,8 +18,6 @@ from comfy_api_nodes.apis.luma import (
|
||||
LumaIO,
|
||||
LumaKeyframes,
|
||||
LumaModifyImageRef,
|
||||
LumaRay32KeyframeChain,
|
||||
LumaRay32KeyframeItem,
|
||||
LumaReference,
|
||||
LumaReferenceChain,
|
||||
LumaVideoModel,
|
||||
@ -39,7 +33,6 @@ from comfy_api_nodes.util import (
|
||||
sync_op,
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
@ -699,10 +692,7 @@ async def _luma2_upload_image_refs(
|
||||
async def _luma2_submit_and_poll(
|
||||
cls: type[IO.ComfyNode],
|
||||
request: Luma2GenerationRequest,
|
||||
*,
|
||||
estimated_duration: int | None = None,
|
||||
) -> Luma2Generation:
|
||||
"""Submit a Luma Agents generation and poll until done; returns the completed generation."""
|
||||
) -> Input.Image:
|
||||
initial = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/luma_2/generations", method="POST"),
|
||||
@ -710,21 +700,21 @@ async def _luma2_submit_and_poll(
|
||||
data=request,
|
||||
)
|
||||
if not initial.id:
|
||||
raise RuntimeError("Luma API did not return a generation id.")
|
||||
raise RuntimeError("Luma 2 API did not return a generation id.")
|
||||
final = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/luma_2/generations/{initial.id}", method="GET"),
|
||||
response_model=Luma2Generation,
|
||||
status_extractor=lambda r: r.state,
|
||||
progress_extractor=lambda r: None,
|
||||
estimated_duration=estimated_duration,
|
||||
)
|
||||
if not final.output or not final.output[0].url:
|
||||
if not final.output:
|
||||
msg = final.failure_reason or "no output returned"
|
||||
if final.failure_code:
|
||||
msg = f"{msg} [{final.failure_code}]"
|
||||
raise RuntimeError(f"Luma generation failed: {msg}")
|
||||
return final
|
||||
raise RuntimeError(f"Luma 2 generation failed: {msg}")
|
||||
url = final.output[0].url
|
||||
if not url:
|
||||
raise RuntimeError("Luma 2 generation completed without an output URL.")
|
||||
return await download_url_to_image_tensor(url)
|
||||
|
||||
|
||||
class LumaImageNode(IO.ComfyNode):
|
||||
@ -853,8 +843,7 @@ class LumaImageNode(IO.ComfyNode):
|
||||
web_search=model["web_search"],
|
||||
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=9),
|
||||
)
|
||||
final = await _luma2_submit_and_poll(cls, request)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final.output[0].url))
|
||||
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request))
|
||||
|
||||
|
||||
class LumaImageEditNode(IO.ComfyNode):
|
||||
@ -940,533 +929,7 @@ class LumaImageEditNode(IO.ComfyNode):
|
||||
web_search=model["web_search"],
|
||||
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=8),
|
||||
)
|
||||
final = await _luma2_submit_and_poll(cls, request)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final.output[0].url))
|
||||
|
||||
|
||||
_BADGE_RAY32_VIDEO = IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["resolution", "duration"]),
|
||||
expr="""
|
||||
(
|
||||
$p := {
|
||||
"360p": {"5s": 0.06, "10s": 0.18},
|
||||
"540p": {"5s": 0.15, "10s": 0.45},
|
||||
"720p": {"5s": 0.3, "10s": 0.9},
|
||||
"1080p": {"5s": 1.2, "10s": 3.6}
|
||||
};
|
||||
{"type": "usd", "usd": $lookup($lookup($p, widgets.resolution), widgets.duration)}
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
_BADGE_RAY32_VIDEO_5S = IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$p := {"360p": 0.06, "540p": 0.15, "720p": 0.3, "1080p": 1.2};
|
||||
{"type": "usd", "usd": $lookup($p, widgets.resolution)}
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
_BADGE_RAY32_EDIT = IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$p := {
|
||||
"360p": {"min": 0.54, "max": 1.08},
|
||||
"540p": {"min": 0.72, "max": 1.44},
|
||||
"720p": {"min": 1.08, "max": 2.16},
|
||||
"1080p": {"min": 2.16, "max": 4.32}
|
||||
};
|
||||
$r := $lookup($p, widgets.resolution);
|
||||
{"type": "range_usd", "min_usd": $r.min, "max_usd": $r.max, "format": {"note": "(by source length)"}}
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
_BADGE_RAY32_REFRAME = IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$p := {"360p": 0.03, "540p": 0.06, "720p": 0.12, "1080p": 0.36};
|
||||
{"type": "usd", "usd": $lookup($p, widgets.resolution), "format": {"suffix": "/second"}}
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def _ray32_seed_input() -> IO.Input:
|
||||
return IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; results are nondeterministic regardless of seed.",
|
||||
)
|
||||
|
||||
|
||||
async def _ray32_generate(cls: type[IO.ComfyNode], request: Luma2GenerationRequest) -> IO.NodeOutput:
|
||||
"""Run a ray-3.2 generation and return (video, generation_id)."""
|
||||
final = await _luma2_submit_and_poll(cls, request, estimated_duration=120)
|
||||
video = await download_url_to_video_output(final.output[0].url)
|
||||
return IO.NodeOutput(video, final.id or "")
|
||||
|
||||
|
||||
class LumaRay32TextToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaRay32TextToVideoNode",
|
||||
display_name="Luma Ray 3.2 Text to Video",
|
||||
category="partner/video/Luma",
|
||||
description="Generate a video from a text prompt using Luma's Ray 3.2 model.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "21:9"]),
|
||||
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
|
||||
IO.Combo.Input("duration", options=["5s", "10s"]),
|
||||
IO.Boolean.Input(
|
||||
"loop",
|
||||
default=False,
|
||||
tooltip="Make the video loop seamlessly. Only available with 5s duration.",
|
||||
),
|
||||
_ray32_seed_input(),
|
||||
],
|
||||
outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=_BADGE_RAY32_VIDEO,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, prompt: str, aspect_ratio: str, resolution: str, duration: str, loop: bool, seed: int
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
|
||||
if loop and duration == "10s":
|
||||
raise ValueError("Looping is only available with 5s duration on Ray 3.2.")
|
||||
request = Luma2GenerationRequest(
|
||||
prompt=prompt,
|
||||
model="ray-3.2",
|
||||
type="video",
|
||||
aspect_ratio=aspect_ratio,
|
||||
video=Luma2VideoOptions(resolution=resolution, duration=duration, loop=loop or None),
|
||||
)
|
||||
return await _ray32_generate(cls, request)
|
||||
|
||||
|
||||
class LumaRay32ImageToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaRay32ImageToVideoNode",
|
||||
display_name="Luma Ray 3.2 Image to Video",
|
||||
category="partner/video/Luma",
|
||||
description="Generate a video from a start and/or end frame using Luma's Ray 3.2 model. "
|
||||
"Image-anchored generations are always 5 seconds.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."),
|
||||
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
|
||||
IO.Boolean.Input(
|
||||
"loop",
|
||||
default=False,
|
||||
tooltip="Make the video loop seamlessly. Not available when an end_frame is set.",
|
||||
),
|
||||
_ray32_seed_input(),
|
||||
IO.Image.Input("start_frame", optional=True, tooltip="First frame of the generated video."),
|
||||
IO.Image.Input("end_frame", optional=True, tooltip="Last frame of the generated video."),
|
||||
],
|
||||
outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=_BADGE_RAY32_VIDEO_5S,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
resolution: str,
|
||||
loop: bool,
|
||||
seed: int,
|
||||
start_frame: torch.Tensor | None = None,
|
||||
end_frame: torch.Tensor | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
|
||||
if start_frame is None and end_frame is None:
|
||||
raise ValueError("Provide at least one of start_frame / end_frame.")
|
||||
if loop and end_frame is not None:
|
||||
raise ValueError("Looping is not available when an end_frame is set.")
|
||||
video = Luma2VideoOptions(resolution=resolution, duration="5s", loop=loop or None)
|
||||
if start_frame is not None:
|
||||
url = await upload_image_to_comfyapi(cls, start_frame, mime_type="image/png")
|
||||
video.start_frame = Luma2ImageRef(url=url)
|
||||
if end_frame is not None:
|
||||
url = await upload_image_to_comfyapi(cls, end_frame, mime_type="image/png")
|
||||
video.end_frame = Luma2ImageRef(url=url)
|
||||
request = Luma2GenerationRequest(prompt=prompt, model="ray-3.2", type="video", video=video)
|
||||
return await _ray32_generate(cls, request)
|
||||
|
||||
|
||||
class LumaRay32KeyframeNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaRay32KeyframeNode",
|
||||
display_name="Luma Ray 3.2 Keyframe",
|
||||
category="partner/video/Luma",
|
||||
description="Anchor a guide image to a position on the Ray 3.2 output video timeline. Connect this to "
|
||||
"the 'keyframes' input of the Luma Ray 3.2 Keyframes to Video node; chain several together via the "
|
||||
"optional 'keyframes' input below.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="Guide image to place at the chosen moment of the output video."),
|
||||
IO.DynamicCombo.Input(
|
||||
"position",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"Fraction of duration (0.0-1.0)",
|
||||
[
|
||||
IO.Float.Input(
|
||||
"fraction",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Where in the output video this image applies " "(0.0 = start, 1.0 = end).",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Absolute time (seconds)",
|
||||
[
|
||||
IO.Float.Input(
|
||||
"seconds",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=10.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Time in seconds from the start of the output video where this "
|
||||
"image applies.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="How to place this image on the output video's timeline.",
|
||||
),
|
||||
IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Input(
|
||||
"keyframes",
|
||||
optional=True,
|
||||
tooltip="Optional earlier keyframes to chain with this one.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Output(display_name="keyframes")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
position: dict,
|
||||
keyframes: LumaRay32KeyframeChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
chain = keyframes.clone() if keyframes is not None else LumaRay32KeyframeChain()
|
||||
if position["position"] == "Absolute time (seconds)":
|
||||
mode, value = LUMA_KEYFRAME_MODE_SECONDS, float(position["seconds"])
|
||||
else:
|
||||
mode, value = LUMA_KEYFRAME_MODE_FRACTION, float(position["fraction"])
|
||||
chain.add(LumaRay32KeyframeItem(image=image, mode=mode, value=value))
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class LumaRay32KeyframesToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaRay32KeyframesToVideoNode",
|
||||
display_name="Luma Ray 3.2 Keyframes to Video",
|
||||
category="partner/video/Luma",
|
||||
description="Generate a video that interpolates through a sequence of guide images, each anchored to a "
|
||||
"position on the timeline, using Luma Ray 3.2. Build the sequence with Luma Ray 3.2 Keyframe nodes "
|
||||
"(at least 2).",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."),
|
||||
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
|
||||
IO.Combo.Input("duration", options=["5s", "10s"]),
|
||||
_ray32_seed_input(),
|
||||
IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Input(
|
||||
"keyframes",
|
||||
tooltip="Keyframe sequence from Luma Ray 3.2 Keyframe nodes (at least 2).",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=_BADGE_RAY32_VIDEO,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
resolution: str,
|
||||
duration: str,
|
||||
seed: int,
|
||||
keyframes: LumaRay32KeyframeChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
|
||||
items = keyframes.items if keyframes is not None else []
|
||||
if len(items) < 2:
|
||||
raise ValueError(
|
||||
"Connect at least 2 Luma Ray 3.2 Keyframe nodes "
|
||||
"(use Luma Ray 3.2 Image to Video for a single start/end frame)."
|
||||
)
|
||||
if len(items) > 64:
|
||||
raise ValueError(f"Ray 3.2 supports at most 64 keyframes; got {len(items)}.")
|
||||
maxframe = 120 if duration == "5s" else 240
|
||||
duration_seconds = maxframe / 24 # 5.0 or 10.0
|
||||
# Resolve each keyframe to an output-frame index, then order by position
|
||||
# (so the user can chain keyframes in any order — the position is what places them)
|
||||
placed: list[tuple[int, torch.Tensor]] = []
|
||||
for item in items:
|
||||
if item.mode == LUMA_KEYFRAME_MODE_SECONDS:
|
||||
if item.value > duration_seconds:
|
||||
raise ValueError(
|
||||
f"Keyframe position {item.value:g}s is past the end of the {duration} video; "
|
||||
f"use 0-{duration_seconds:g}s (or switch the keyframe to fraction mode)."
|
||||
)
|
||||
idx = round(item.value * 24)
|
||||
else:
|
||||
idx = round(item.value * maxframe)
|
||||
placed.append((max(0, min(maxframe, idx)), item.image))
|
||||
placed.sort(key=lambda p: p[0])
|
||||
indexes = [idx for idx, _ in placed]
|
||||
for a, b in zip(indexes, indexes[1:]):
|
||||
if a == b:
|
||||
raise ValueError(
|
||||
f"Two keyframes resolve to the same output frame ({a}) for a {duration} video "
|
||||
f"(valid range 0-{maxframe}); give each keyframe a distinct position."
|
||||
)
|
||||
refs: list[Luma2ImageRef] = []
|
||||
for _, image in placed:
|
||||
url = await upload_image_to_comfyapi(cls, image, mime_type="image/png")
|
||||
refs.append(Luma2ImageRef(url=url))
|
||||
request = Luma2GenerationRequest(
|
||||
prompt=prompt,
|
||||
model="ray-3.2",
|
||||
type="video",
|
||||
video=Luma2VideoOptions(resolution=resolution, duration=duration, keyframes=refs, keyframe_indexes=indexes),
|
||||
)
|
||||
return await _ray32_generate(cls, request)
|
||||
|
||||
|
||||
class LumaRay32VideoEditNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaRay32VideoEditNode",
|
||||
display_name="Luma Ray 3.2 Video Edit",
|
||||
category="partner/video/Luma",
|
||||
description="Re-render an existing video under a new prompt using Luma Ray 3.2 (restyle, relight, add "
|
||||
"or remove elements) while keeping the original motion. Source video up to 18 seconds; the edited "
|
||||
"video keeps the source's length.",
|
||||
inputs=[
|
||||
IO.Video.Input("video", tooltip="Source video to edit. Up to 18 seconds."),
|
||||
IO.String.Input("prompt", multiline=True, default="", tooltip="Describes the desired edit."),
|
||||
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
|
||||
IO.Combo.Input(
|
||||
"strength",
|
||||
options=[
|
||||
"auto",
|
||||
"adhere_1",
|
||||
"adhere_2",
|
||||
"adhere_3",
|
||||
"flex_1",
|
||||
"flex_2",
|
||||
"flex_3",
|
||||
"reimagine_1",
|
||||
"reimagine_2",
|
||||
"reimagine_3",
|
||||
],
|
||||
default="auto",
|
||||
tooltip="How strongly to preserve vs. reimagine the source. 'auto' lets Ray 3.2 choose; "
|
||||
"adhere_* preserves the most, flex_* is balanced, reimagine_* changes the most.",
|
||||
),
|
||||
_ray32_seed_input(),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
IO.String.Output(display_name="generation_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=_BADGE_RAY32_EDIT,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, video: Input.Video, prompt: str, resolution: str, strength: str, seed: int
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
|
||||
try:
|
||||
duration = "5s" if video.get_duration() <= 5.0 else "10s"
|
||||
except Exception:
|
||||
duration = "10s"
|
||||
source_url = await upload_video_to_comfyapi(cls, video, max_duration=18)
|
||||
edit = Luma2VideoEdit(auto_controls=True) if strength == "auto" else Luma2VideoEdit(strength=strength)
|
||||
request = Luma2GenerationRequest(
|
||||
prompt=prompt,
|
||||
model="ray-3.2",
|
||||
type="video_edit",
|
||||
source=Luma2ImageRef(url=source_url, media_type="video/mp4"),
|
||||
video=Luma2VideoOptions(resolution=resolution, duration=duration, edit=edit),
|
||||
)
|
||||
return await _ray32_generate(cls, request)
|
||||
|
||||
|
||||
class LumaRay32VideoReframeNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaRay32VideoReframeNode",
|
||||
display_name="Luma Ray 3.2 Video Reframe",
|
||||
category="partner/video/Luma",
|
||||
description="Change the aspect ratio of an existing video, using Luma Ray 3.2 to fill the newly "
|
||||
"exposed canvas areas. Source video up to 30 seconds. Billed per second of output.",
|
||||
inputs=[
|
||||
IO.Video.Input("video", tooltip="Source video to reframe. Up to 30 seconds."),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Describes how the newly exposed canvas areas should be filled.",
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "21:9"]),
|
||||
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
|
||||
_ray32_seed_input(),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
IO.String.Output(display_name="generation_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=_BADGE_RAY32_REFRAME,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, video: Input.Video, prompt: str, aspect_ratio: str, resolution: str, seed: int
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=6000)
|
||||
if resolution == "1080p" and aspect_ratio in {"9:16", "3:4"}:
|
||||
raise ValueError("1080p is not available for vertical aspect ratios (9:16, 3:4) when reframing.")
|
||||
source_url = await upload_video_to_comfyapi(cls, video, max_duration=30)
|
||||
request = Luma2GenerationRequest(
|
||||
prompt=prompt,
|
||||
model="ray-3.2",
|
||||
type="video_reframe",
|
||||
aspect_ratio=aspect_ratio,
|
||||
source=Luma2ImageRef(url=source_url, media_type="video/mp4"),
|
||||
video=Luma2VideoOptions(resolution=resolution),
|
||||
)
|
||||
return await _ray32_generate(cls, request)
|
||||
|
||||
|
||||
class LumaRay32ExtendVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaRay32ExtendVideoNode",
|
||||
display_name="Luma Ray 3.2 Extend Video",
|
||||
category="partner/video/Luma",
|
||||
description="Extend a previous Ray 3.2 generation forward (continue after it) or backward (lead-in "
|
||||
"before it). Connect the generation_id output of a prior Luma Ray 3.2 node."
|
||||
" Extensions are always 5 seconds.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"source_generation_id",
|
||||
default="",
|
||||
tooltip="generation_id of the prior Ray 3.2 video to extend."
|
||||
" Connect the generation_id output of another Luma Ray 3.2 node.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"direction",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"Forward (continue after)",
|
||||
[
|
||||
IO.Boolean.Input(
|
||||
"loop",
|
||||
default=False,
|
||||
tooltip="Loop the extended video seamlessly (forward extend only).",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("Backward (lead-in before)", []),
|
||||
],
|
||||
tooltip="Forward continues after the prior clip; backward is prepended before it.",
|
||||
),
|
||||
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the new content."),
|
||||
IO.Combo.Input("resolution", options=["540p", "720p", "1080p"], default="720p"),
|
||||
_ray32_seed_input(),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
IO.String.Output(display_name="generation_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=_BADGE_RAY32_VIDEO_5S,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, source_generation_id: str, direction: dict, prompt: str, resolution: str, seed: int
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=6000)
|
||||
gen_id = (source_generation_id or "").strip()
|
||||
if not gen_id:
|
||||
raise ValueError(
|
||||
"source_generation_id is required (connect the generation_id output of a prior Luma Ray 3.2 node)."
|
||||
)
|
||||
video = Luma2VideoOptions(resolution=resolution, duration="5s")
|
||||
ref = Luma2ImageRef(generation_id=gen_id)
|
||||
if direction["direction"] == "Forward (continue after)":
|
||||
video.start_frame = ref
|
||||
if direction.get("loop"):
|
||||
video.loop = True
|
||||
else:
|
||||
video.end_frame = ref
|
||||
request = Luma2GenerationRequest(prompt=prompt, model="ray-3.2", type="video", video=video)
|
||||
return await _ray32_generate(cls, request)
|
||||
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request))
|
||||
|
||||
|
||||
class LumaExtension(ComfyExtension):
|
||||
@ -1481,13 +944,6 @@ class LumaExtension(ComfyExtension):
|
||||
LumaConceptsNode,
|
||||
LumaImageNode,
|
||||
LumaImageEditNode,
|
||||
LumaRay32TextToVideoNode,
|
||||
LumaRay32ImageToVideoNode,
|
||||
LumaRay32KeyframeNode,
|
||||
LumaRay32KeyframesToVideoNode,
|
||||
LumaRay32VideoEditNode,
|
||||
LumaRay32VideoReframeNode,
|
||||
LumaRay32ExtendVideoNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -4,8 +4,6 @@ import os
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from email.utils import parsedate_to_datetime
|
||||
from io import BytesIO
|
||||
|
||||
from yarl import URL
|
||||
@ -93,32 +91,6 @@ async def sleep_with_interrupt(
|
||||
await asyncio.sleep(min(1.0, end - now))
|
||||
|
||||
|
||||
def _retry_after_wait(value: str | None, fallback: float, max_wait: float) -> float:
|
||||
"""Delay before the next retry, honoring a server ``Retry-After`` header."""
|
||||
|
||||
seconds: float | None = None
|
||||
if value is not None:
|
||||
value = value.strip()
|
||||
if value.isascii() and value.isdigit():
|
||||
# delay-seconds form. The ASCII-digit guard keeps exotic Unicode "digit" characters away from float()
|
||||
# an all-digit string always converts (huge values become inf, never raising).
|
||||
seconds = float(value)
|
||||
elif value:
|
||||
# HTTP-date form. parsedate_to_datetime raises OverflowError (not a ValueError) on absurd years/offsets
|
||||
try:
|
||||
parsed = parsedate_to_datetime(value)
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
parsed = None
|
||||
if parsed is not None:
|
||||
if parsed.tzinfo is None: # naive datetime: HTTP-date is UTC
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
delta = (parsed - datetime.now(timezone.utc)).total_seconds()
|
||||
seconds = delta if delta > 0 else 0.0
|
||||
if seconds is None:
|
||||
return fallback
|
||||
return min(seconds, max_wait)
|
||||
|
||||
|
||||
def mimetype_to_extension(mime_type: str) -> str:
|
||||
"""Converts a MIME type to a file extension."""
|
||||
return mime_type.split("/")[-1].lower()
|
||||
|
||||
@ -21,7 +21,6 @@ from server import PromptServer
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
_retry_after_wait,
|
||||
default_base_url,
|
||||
get_comfy_api_headers,
|
||||
get_node_id,
|
||||
@ -83,7 +82,6 @@ class _PollUIState:
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||
_MAX_RETRY_AFTER_WAIT = 150.0 # Cap a server Retry-After at this many seconds so a large hint can't block execution
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait", "in_queue"]
|
||||
@ -749,7 +747,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
should_retry = True
|
||||
|
||||
if should_retry:
|
||||
wait_time = _retry_after_wait(resp.headers.get("Retry-After"), wait_time, _MAX_RETRY_AFTER_WAIT)
|
||||
logging.warning(
|
||||
"HTTP %s %s -> %s. Waiting %.2fs (%s).",
|
||||
method,
|
||||
|
||||
@ -4,22 +4,11 @@ Provides normalization and helper functions for job status tracking.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
from comfy_api.internal import prune_dict
|
||||
|
||||
|
||||
# Result of classifying a job for cancellation.
|
||||
# 'running' -> job is currently executing (interrupt it)
|
||||
# 'pending' -> job is queued but not started (dequeue it)
|
||||
# 'terminal' -> job already finished (present in history); cancel is a no-op
|
||||
# 'unknown' -> job id is not present anywhere
|
||||
CANCEL_RUNNING = 'running'
|
||||
CANCEL_PENDING = 'pending'
|
||||
CANCEL_TERMINAL = 'terminal'
|
||||
CANCEL_UNKNOWN = 'unknown'
|
||||
|
||||
|
||||
class JobStatus:
|
||||
"""Job status constants."""
|
||||
PENDING = 'pending'
|
||||
@ -418,71 +407,3 @@ def get_all_jobs(
|
||||
jobs = jobs[:limit]
|
||||
|
||||
return (jobs, total_count)
|
||||
|
||||
|
||||
def classify_job_for_cancel(prompt_id: str, running: list, queued: list, history: dict) -> str:
|
||||
"""Classify a job id for cancellation.
|
||||
|
||||
Returns one of CANCEL_RUNNING, CANCEL_PENDING, CANCEL_TERMINAL, CANCEL_UNKNOWN.
|
||||
|
||||
Queue items are tuples whose second element (index 1) is the prompt_id.
|
||||
History is a dict keyed by prompt_id, so a job present there has already
|
||||
finished and cancelling it is a no-op.
|
||||
"""
|
||||
for item in running:
|
||||
if item[1] == prompt_id:
|
||||
return CANCEL_RUNNING
|
||||
for item in queued:
|
||||
if item[1] == prompt_id:
|
||||
return CANCEL_PENDING
|
||||
if prompt_id in history:
|
||||
return CANCEL_TERMINAL
|
||||
return CANCEL_UNKNOWN
|
||||
|
||||
|
||||
def cancel_job(
|
||||
prompt_id: str,
|
||||
running: list,
|
||||
queued: list,
|
||||
history: dict,
|
||||
interrupt: Callable[[str], bool],
|
||||
dequeue: Callable[[str], bool],
|
||||
) -> str:
|
||||
"""Cancel a single job by id, regardless of state.
|
||||
|
||||
Maps the cancel onto the runtime's existing mechanics:
|
||||
- a running job is interrupted via ``interrupt``
|
||||
- a pending job is removed from the queue via ``dequeue``
|
||||
- a job that already finished (terminal) is a no-op
|
||||
- an unknown id is a no-op (callers that need fail-fast behaviour should
|
||||
validate ids up front with ``classify_job_for_cancel``)
|
||||
|
||||
Both ``interrupt`` and ``dequeue`` take the prompt id and return whether
|
||||
they acted on a job that was *actually* in that state, so the value returned
|
||||
here reflects what truly happened rather than the (possibly stale)
|
||||
classification. This matters around the narrow TOCTOU windows where a job
|
||||
changes state between the caller's snapshot and the action:
|
||||
|
||||
- a job classified RUNNING may have finished before ``interrupt`` fires:
|
||||
``interrupt`` returns False and this returns CANCEL_UNKNOWN (no-op).
|
||||
- a job classified PENDING may have started executing before ``dequeue``
|
||||
fires: ``dequeue`` returns False, ``interrupt`` then catches the now-
|
||||
running job and this returns CANCEL_RUNNING. If it had simply finished
|
||||
instead, both return False and this returns CANCEL_UNKNOWN.
|
||||
|
||||
``interrupt`` must be atomic — interrupt the job only if it is still the one
|
||||
running — so a cancel can never land on an unrelated prompt that started in
|
||||
the meantime (see ``execution.PromptQueue.interrupt_if_running``).
|
||||
"""
|
||||
classification = classify_job_for_cancel(prompt_id, running, queued, history)
|
||||
if classification == CANCEL_RUNNING:
|
||||
return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN
|
||||
if classification == CANCEL_PENDING:
|
||||
if dequeue(prompt_id):
|
||||
return CANCEL_PENDING
|
||||
# Left the pending queue between classification and dequeue: if it
|
||||
# started executing, interrupt the now-running job; otherwise it has
|
||||
# already finished and the cancel is a genuine no-op.
|
||||
return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN
|
||||
# CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops.
|
||||
return classification
|
||||
|
||||
@ -13,22 +13,21 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
description="Manually set context windows.",
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."),
|
||||
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."),
|
||||
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True),
|
||||
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True),
|
||||
io.Combo.Input("context_schedule", options=[
|
||||
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
||||
comfy.context_windows.ContextSchedules.BATCHED,
|
||||
], default=comfy.context_windows.ContextSchedules.STATIC_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
|
||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
||||
], tooltip="The stride of the context window."),
|
||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window. For concat-style I2V models (e.g. Wan I2V, HunyuanVideo I2V, Cosmos I2V, SVD) the encoded start image lives in the c_concat conditioning channels; setting this to '0' will retain that start image content at sub-pos 0 of every window."),
|
||||
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
io.String.Input("latent_retain_index_list", default="", tooltip="List of latent indices to retain in the noise latent itself for each window. Use for workflows where reference content (e.g. a start image) lives directly in the noise latent rather than in separate conditioning channels (e.g. inplace-style I2V like LTXV, AnimateDiff). Independent of cond_retain_index_list."),
|
||||
io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."),
|
||||
],
|
||||
outputs=[
|
||||
@ -39,7 +38,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, latent_retain_index_list: list[int]=[], causal_window_fix: bool=True) -> io.Model:
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model:
|
||||
model = model.clone()
|
||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||
@ -52,7 +51,6 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
freenoise=freenoise,
|
||||
cond_retain_index_list=cond_retain_index_list,
|
||||
split_conds_to_windows=split_conds_to_windows,
|
||||
latent_retain_index_list=latent_retain_index_list,
|
||||
causal_window_fix=causal_window_fix,
|
||||
)
|
||||
# make memory usage calculation only take into account the context window latents
|
||||
@ -67,71 +65,33 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
|
||||
schema = super().define_schema()
|
||||
schema.node_id = "WanContextWindowsManual"
|
||||
schema.display_name = "WAN Context Windows (Manual)"
|
||||
schema.display_name = "Wan Context Windows"
|
||||
schema.description = "Set context windows for Wan-like models."
|
||||
schema.description = "Manually set context windows for WAN-like models (dim=2)."
|
||||
schema.category="model/patch/wan"
|
||||
schema.inputs = [
|
||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window in real frames. Must be 4*n + 1."),
|
||||
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window in real frames."),
|
||||
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window.", advanced=True),
|
||||
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window.", advanced=True),
|
||||
io.Combo.Input("context_schedule", options=[
|
||||
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
||||
comfy.context_windows.ContextSchedules.BATCHED,
|
||||
], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
|
||||
], tooltip="The stride of the context window."),
|
||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True),
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True),
|
||||
io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first I2V frame in every context window (may help retain initial reference)."),
|
||||
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True),
|
||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
]
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
|
||||
retain_first_frame: bool=False, split_conds_to_windows: bool=False) -> io.Model:
|
||||
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
||||
context_overlap = max(context_overlap // 4, 0) # at least overlap 0
|
||||
retain_index_list = "0" if retain_first_frame else ""
|
||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
||||
|
||||
|
||||
class LTXVContextWindowsNode(ContextWindowsManualNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
schema = super().define_schema()
|
||||
schema.node_id = "LTXVContextWindows"
|
||||
schema.display_name = "LTXV Context Windows"
|
||||
schema.description = "Set context windows for LTXV-like models."
|
||||
schema.inputs = [
|
||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=8, default=145, tooltip="The length of the context window in real frames. Must be 8*n + 1."),
|
||||
io.Int.Input("context_overlap", min=0, step=8, default=40, tooltip="The overlap of the context window in real frames."),
|
||||
io.Combo.Input("context_schedule", options=[
|
||||
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
||||
comfy.context_windows.ContextSchedules.BATCHED,
|
||||
], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
|
||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True),
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True),
|
||||
io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first latent frame in every context window (may help retain initial reference)."),
|
||||
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True),
|
||||
]
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, fuse_method: str, freenoise: bool,
|
||||
retain_first_frame: bool=False, split_conds_to_windows: bool=False, context_stride: int=1, closed_loop: bool=False) -> io.Model:
|
||||
context_length = max(((context_length - 1) // 8) + 1, 1) # at least length 1
|
||||
context_overlap = max(context_overlap // 8, 0) # at least overlap 0
|
||||
retain_index_list = "0" if retain_first_frame else ""
|
||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise,
|
||||
cond_retain_index_list=retain_index_list, latent_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
||||
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
||||
|
||||
|
||||
class ContextWindowsExtension(ComfyExtension):
|
||||
@ -139,7 +99,6 @@ class ContextWindowsExtension(ComfyExtension):
|
||||
return [
|
||||
ContextWindowsManualNode,
|
||||
WanContextWindowsManualNode,
|
||||
LTXVContextWindowsNode,
|
||||
]
|
||||
|
||||
def comfy_entrypoint():
|
||||
|
||||
@ -1583,7 +1583,7 @@ class LoadTrainingDataset(io.ComfyNode):
|
||||
shard_path = os.path.join(dataset_dir, shard_file)
|
||||
|
||||
with open(shard_path, "rb") as f:
|
||||
shard_data = torch.load(f, weights_only=True)
|
||||
shard_data = torch.load(f)
|
||||
|
||||
all_latents.extend(shard_data["latents"])
|
||||
all_conditioning.extend(shard_data["conditioning"])
|
||||
|
||||
@ -77,7 +77,7 @@ class FrameInterpolate(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="FrameInterpolate",
|
||||
display_name="Run Frame Interpolation Model",
|
||||
display_name="Frame Interpolate",
|
||||
category="video",
|
||||
search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"],
|
||||
inputs=[
|
||||
|
||||
@ -317,74 +317,11 @@ class PreviewPointCloud(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
MESH_EXTENSIONS = {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
|
||||
|
||||
|
||||
class Load3DAdvanced(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
|
||||
input_path = Path(input_dir)
|
||||
base_path = Path(folder_paths.get_input_directory())
|
||||
|
||||
files = [
|
||||
normalize_path(str(file_path.relative_to(base_path)))
|
||||
for file_path in input_path.rglob("*")
|
||||
if file_path.suffix.lower() in MESH_EXTENSIONS
|
||||
]
|
||||
return IO.Schema(
|
||||
node_id="Load3DAdvanced",
|
||||
display_name="Load 3D (Advanced)",
|
||||
category="3d",
|
||||
search_aliases=[
|
||||
"load mesh",
|
||||
"load gltf",
|
||||
"load glb",
|
||||
"load obj",
|
||||
"load fbx",
|
||||
"load stl",
|
||||
],
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
IO.Combo.Input("model_file", options=["none"] + sorted(files), upload=IO.UploadType.model),
|
||||
IO.Load3D.Input("viewport_state"),
|
||||
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
||||
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DAny.Output(display_name="model_3d"),
|
||||
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
|
||||
IO.Load3DCamera.Output(display_name="camera_info"),
|
||||
IO.Int.Output(display_name="width"),
|
||||
IO.Int.Output(display_name="height"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, model_file, **kwargs) -> bool | str:
|
||||
if not model_file or model_file == "none":
|
||||
return True
|
||||
if not folder_paths.exists_annotated_filepath(model_file):
|
||||
return f"Invalid 3D model file: {model_file}"
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_file, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput:
|
||||
file_3d = None
|
||||
if model_file and model_file != "none":
|
||||
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
|
||||
model_3d_info = viewport_state.get('model_3d_info', [])
|
||||
return IO.NodeOutput(file_3d, model_3d_info, viewport_state['camera_info'], width, height)
|
||||
|
||||
|
||||
class Load3DExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
Load3D,
|
||||
Load3DAdvanced,
|
||||
Preview3D,
|
||||
Preview3DAdvanced,
|
||||
PreviewGaussianSplat,
|
||||
|
||||
@ -89,8 +89,7 @@ class SwitchNode(io.ComfyNode):
|
||||
template = io.MatchType.Template("switch")
|
||||
return io.Schema(
|
||||
node_id="ComfySwitchNode",
|
||||
search_aliases=["if", "then", "switch", "conditional", "branch"],
|
||||
display_name="If/Else Switch",
|
||||
display_name="Switch",
|
||||
category="utilities/logic",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
|
||||
@ -10,11 +10,12 @@ class String(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="PrimitiveString",
|
||||
search_aliases=["text", "string", "text box", "prompt"],
|
||||
display_name="Text String (DEPRECATED)",
|
||||
display_name="Text String",
|
||||
category="utilities/primitive",
|
||||
inputs=[io.String.Input("value")],
|
||||
inputs=[
|
||||
io.String.Input("value"),
|
||||
],
|
||||
outputs=[io.String.Output()],
|
||||
is_deprecated=True
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -28,10 +29,12 @@ class StringMultiline(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="PrimitiveStringMultiline",
|
||||
search_aliases=["text", "string", "text multiline", "string multiline", "text box", "prompt"],
|
||||
display_name="Input Text",
|
||||
display_name="Text String (Multiline)",
|
||||
category="utilities/primitive",
|
||||
essentials_category="Basics",
|
||||
inputs=[io.String.Input("value", multiline=True)],
|
||||
inputs=[
|
||||
io.String.Input("value", multiline=True),
|
||||
],
|
||||
outputs=[io.String.Output()],
|
||||
)
|
||||
|
||||
|
||||
@ -233,8 +233,13 @@ class VideoSlice(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Video Slice",
|
||||
display_name="Trim Video",
|
||||
search_aliases=["trim video duration", "skip first frames", "frame load cap", "start time"],
|
||||
display_name="Video Slice",
|
||||
search_aliases=[
|
||||
"trim video duration",
|
||||
"skip first frames",
|
||||
"frame load cap",
|
||||
"start time",
|
||||
],
|
||||
category="video",
|
||||
essentials_category="Video Tools",
|
||||
inputs=[
|
||||
|
||||
19
execution.py
19
execution.py
@ -1308,25 +1308,6 @@ class PromptQueue:
|
||||
queued = copy.copy(self.queue)
|
||||
return (running, queued)
|
||||
|
||||
def interrupt_if_running(self, prompt_id):
|
||||
"""Interrupt the running prompt with this id, atomically.
|
||||
|
||||
Checks the live running set and signals the interrupt under the queue
|
||||
mutex, so the worker cannot move the job to done (and start the next
|
||||
prompt) in between. Returns True if a matching job was running and an
|
||||
interrupt was signalled, False otherwise. The atomicity is what keeps a
|
||||
cancel from landing on an unrelated prompt that started after a separate
|
||||
is-running check: the global interrupt flag is reset at the start of
|
||||
every prompt (execute_async), so a job that finishes before consuming
|
||||
the flag cannot leak the interrupt onto its successor.
|
||||
"""
|
||||
with self.mutex:
|
||||
for item in self.currently_running.values():
|
||||
if item[1] == prompt_id:
|
||||
nodes.interrupt_processing()
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_tasks_remaining(self):
|
||||
with self.mutex:
|
||||
return len(self.queue) + len(self.currently_running)
|
||||
|
||||
5
nodes.py
5
nodes.py
@ -20,6 +20,8 @@ from PIL.PngImagePlugin import PngInfo
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||
|
||||
import comfy.diffusers_load
|
||||
import comfy.samplers
|
||||
import comfy.sample
|
||||
@ -2297,9 +2299,6 @@ async def init_external_custom_nodes():
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# TODO: remove at some point when custom nodes don't break.
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||
|
||||
base_node_names = set(NODE_CLASS_MAPPINGS.keys())
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
node_import_times = []
|
||||
|
||||
195
openapi.yaml
195
openapi.yaml
@ -673,35 +673,6 @@ components:
|
||||
- created_at
|
||||
- updated_at
|
||||
type: object
|
||||
JobsCancelRequest:
|
||||
additionalProperties: false
|
||||
description: Request to cancel multiple jobs by ID.
|
||||
properties:
|
||||
job_ids:
|
||||
description: Job identifiers (UUIDs) to cancel.
|
||||
items:
|
||||
format: uuid
|
||||
type: string
|
||||
maxItems: 100
|
||||
minItems: 1
|
||||
type: array
|
||||
required:
|
||||
- job_ids
|
||||
type: object
|
||||
JobsCancelResponse:
|
||||
description: Response for POST /api/jobs/cancel.
|
||||
properties:
|
||||
cancelled:
|
||||
description: |
|
||||
Job IDs for which a cancel event was successfully dispatched by this
|
||||
call. Jobs already in a terminal or cancelling state are idempotently
|
||||
skipped and will not appear here.
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
required:
|
||||
- cancelled
|
||||
type: object
|
||||
JobsListResponse:
|
||||
description: Paginated list of jobs for the authenticated user.
|
||||
properties:
|
||||
@ -1035,7 +1006,7 @@ components:
|
||||
description: If true, clear all pending jobs from the queue
|
||||
type: boolean
|
||||
delete:
|
||||
description: Array of job IDs to cancel; pending and running jobs transition to cancelled
|
||||
description: Array of PENDING job IDs to cancel
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
@ -1851,83 +1822,6 @@ paths:
|
||||
summary: Update asset metadata
|
||||
tags:
|
||||
- file
|
||||
/api/assets/{id}/content:
|
||||
get:
|
||||
description: |
|
||||
Returns the binary content of an asset by ID.
|
||||
|
||||
The contract is the same across runtimes — "GET this path and you
|
||||
receive the asset's bytes" — but the mechanism differs:
|
||||
- **Local ComfyUI** streams the bytes directly (`200`,
|
||||
`application/octet-stream`).
|
||||
- **Cloud** does not proxy large files; it responds `302` with a
|
||||
`Location` redirect to a short-lived signed storage URL. Clients that
|
||||
follow redirects (browsers, `fetch`/XHR, `<img>`/`<video>`) receive
|
||||
the bytes transparently.
|
||||
|
||||
Prefer this over the filename-addressed `/api/view` when you have an
|
||||
asset ID.
|
||||
operationId: getAssetContent
|
||||
parameters:
|
||||
- description: Asset ID
|
||||
in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- description: |
|
||||
Content-Disposition for the response: `attachment` (download) or
|
||||
`inline` (render in browser). Defaults to `attachment`.
|
||||
in: query
|
||||
name: disposition
|
||||
schema:
|
||||
default: attachment
|
||||
enum:
|
||||
- inline
|
||||
- attachment
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/octet-stream:
|
||||
schema:
|
||||
format: binary
|
||||
type: string
|
||||
description: Asset content stream (local runtime streams the bytes directly)
|
||||
"302":
|
||||
description: Redirect to a signed storage URL (cloud runtime)
|
||||
headers:
|
||||
Cache-Control:
|
||||
description: Private caching directive scoped to the signed URL lifetime
|
||||
schema:
|
||||
type: string
|
||||
Location:
|
||||
description: Short-lived signed URL to the asset content in storage
|
||||
schema:
|
||||
type: string
|
||||
Vary:
|
||||
description: Partitions any cached redirect by auth credentials so a private redirect is not reused across users
|
||||
schema:
|
||||
type: string
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Asset not found
|
||||
"500":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Internal server error
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
- BearerAuth: []
|
||||
- CookieAuth: []
|
||||
summary: Get asset content
|
||||
tags:
|
||||
- file
|
||||
/api/assets/{id}/tags:
|
||||
delete:
|
||||
description: Removes one or more tags from an existing asset
|
||||
@ -2781,20 +2675,14 @@ paths:
|
||||
summary: Get internationalisation translation strings
|
||||
/api/interrupt:
|
||||
post:
|
||||
deprecated: true
|
||||
description: |
|
||||
Deprecated. Prefer the jobs-namespace cancel endpoints:
|
||||
POST /api/jobs/{job_id}/cancel for a single job, or
|
||||
POST /api/jobs/cancel to cancel jobs by ID.
|
||||
|
||||
Cancels the first active job for the authenticated user (the currently
|
||||
running job if there is one, otherwise the next pending job). Takes no
|
||||
body and cannot target a specific job — use the jobs-namespace endpoints
|
||||
for that.
|
||||
Cancel all currently RUNNING jobs for the authenticated user.
|
||||
This will interrupt any job that is currently in 'in_progress' status.
|
||||
Note: This endpoint only affects running jobs. To cancel pending jobs, use /api/queue.
|
||||
operationId: interruptJob
|
||||
responses:
|
||||
"200":
|
||||
description: Success - first active job cancelled, or no active job found
|
||||
description: Success - Job interrupted or no running job found
|
||||
"401":
|
||||
content:
|
||||
application/json:
|
||||
@ -2807,7 +2695,7 @@ paths:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Internal server error
|
||||
summary: Interrupt the first active job
|
||||
summary: Interrupt currently running jobs
|
||||
tags:
|
||||
- queue
|
||||
/api/job/{job_id}/status:
|
||||
@ -3066,64 +2954,6 @@ paths:
|
||||
summary: Cancel a job
|
||||
tags:
|
||||
- workflow
|
||||
/api/jobs/cancel:
|
||||
post:
|
||||
description: |
|
||||
Cancel one or more jobs for the authenticated user in a single request.
|
||||
|
||||
State-agnostic: cancels both pending and running jobs (both transition to
|
||||
the cancelled state via the same mechanism as the single-job endpoint).
|
||||
|
||||
Idempotent per job: a job already in a terminal or cancelling state is a
|
||||
no-op and simply will not appear in the returned `cancelled` list.
|
||||
|
||||
Fail-fast on unknown IDs: if any provided job ID does not exist for this
|
||||
user, the request returns 404 and no jobs are cancelled. This surfaces
|
||||
bad IDs to the caller rather than silently dropping them.
|
||||
|
||||
This is the canonical batch-cancel endpoint. The delete operation on
|
||||
POST /api/queue is deprecated in favour of this.
|
||||
operationId: cancelJobs
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/JobsCancelRequest'
|
||||
required: true
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/JobsCancelResponse'
|
||||
description: Success - cancel requests dispatched (or jobs were already terminal)
|
||||
"400":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Bad Request - job_ids is missing, empty, exceeds the maximum count, or contains an invalid UUID
|
||||
"401":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Unauthorized - Authentication required
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: One or more job IDs not found for this user (no jobs cancelled)
|
||||
"500":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Internal server error - cancellation failed
|
||||
summary: Cancel multiple jobs
|
||||
tags:
|
||||
- workflow
|
||||
/api/node_replacements:
|
||||
get:
|
||||
description: |
|
||||
@ -3274,18 +3104,9 @@ paths:
|
||||
tags:
|
||||
- queue
|
||||
post:
|
||||
deprecated: true
|
||||
description: |
|
||||
Deprecated. Prefer the jobs-namespace cancel endpoints:
|
||||
POST /api/jobs/cancel for cancelling jobs by ID, and
|
||||
POST /api/jobs/{job_id}/cancel for a single job.
|
||||
|
||||
Cancel specific jobs by ID (the `delete` field) or clear all pending
|
||||
jobs in the queue (the `clear` field). Despite the `delete` naming, this
|
||||
does not delete anything — listed jobs transition to the cancelled state,
|
||||
and `delete` cancels both pending and running jobs (not pending-only as
|
||||
previously documented). Job-by-ID cancellation is superseded by
|
||||
POST /api/jobs/cancel; `clear` has no jobs-namespace replacement yet.
|
||||
Cancel specific PENDING jobs by ID or clear all pending jobs in the queue.
|
||||
Note: This endpoint only affects pending jobs. To cancel running jobs, use /api/interrupt.
|
||||
operationId: manageQueue
|
||||
requestBody:
|
||||
content:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.45.19
|
||||
comfyui-frontend-package==1.45.15
|
||||
comfyui-workflow-templates==0.10.0
|
||||
comfyui-embedded-docs==0.5.4
|
||||
torch
|
||||
|
||||
111
server.py
111
server.py
@ -8,15 +8,7 @@ import time
|
||||
import nodes
|
||||
import folder_paths
|
||||
import execution
|
||||
from comfy_execution.jobs import (
|
||||
JobStatus,
|
||||
get_job,
|
||||
get_all_jobs,
|
||||
validate_job_id,
|
||||
cancel_job,
|
||||
CANCEL_PENDING,
|
||||
CANCEL_RUNNING,
|
||||
)
|
||||
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs, validate_job_id
|
||||
import uuid
|
||||
import urllib
|
||||
import json
|
||||
@ -907,107 +899,6 @@ class PromptServer():
|
||||
|
||||
return web.json_response(job)
|
||||
|
||||
def _cancel_job_by_id(job_id):
|
||||
"""Cancel a single job by id using the queue's existing mechanics.
|
||||
|
||||
Running jobs are interrupted (same mechanism as /interrupt); pending
|
||||
jobs are dequeued (same mechanism as /queue {"delete": [...]}).
|
||||
Already-finished or unknown ids are no-ops. State-agnostic.
|
||||
|
||||
Returns True when a cancel was actually dispatched (running or
|
||||
pending job), False when the call was a no-op (terminal/unknown id).
|
||||
"""
|
||||
running, queued = self.prompt_queue.get_current_queue()
|
||||
history = self.prompt_queue.get_history()
|
||||
|
||||
def interrupt(prompt_id):
|
||||
logging.info(f"Cancelling running prompt {prompt_id}")
|
||||
# Atomic: only interrupts if the job is still the one running,
|
||||
# so a cancel can't land on a prompt that started in the gap
|
||||
# since the snapshot above. Returns whether it actually fired.
|
||||
return self.prompt_queue.interrupt_if_running(prompt_id)
|
||||
|
||||
def dequeue(prompt_id):
|
||||
logging.info(f"Cancelling pending prompt {prompt_id}")
|
||||
return self.prompt_queue.delete_queue_item(lambda a: a[1] == prompt_id)
|
||||
|
||||
classification = cancel_job(job_id, running, queued, history, interrupt, dequeue)
|
||||
return classification in (CANCEL_RUNNING, CANCEL_PENDING)
|
||||
|
||||
@routes.post("/api/jobs/{job_id}/cancel")
|
||||
async def cancel_job_by_id(request):
|
||||
"""Cancel a single job by id, regardless of state.
|
||||
|
||||
Idempotent: cancelling a job that has already finished, or an id
|
||||
that is not known, returns 200 with {"cancelled": false} rather
|
||||
than an error.
|
||||
"""
|
||||
job_id = request.match_info.get("job_id", None)
|
||||
if not job_id:
|
||||
return web.json_response(
|
||||
{"error": "job_id is required"},
|
||||
status=400
|
||||
)
|
||||
|
||||
cancelled = _cancel_job_by_id(job_id)
|
||||
return web.json_response({"cancelled": cancelled})
|
||||
|
||||
@routes.post("/api/jobs/cancel")
|
||||
async def cancel_jobs_batch(request):
|
||||
"""Cancel a batch of jobs by id.
|
||||
|
||||
Body: {"job_ids": ["<uuid>", ...]}
|
||||
|
||||
Best-effort and idempotent: every well-formed id is cancelled if it
|
||||
is running or pending; ids that are already finished or unknown are
|
||||
no-ops, not errors. A batch of all no-ops still returns 200 with
|
||||
{"cancelled": false}. This matches the single-cancel endpoint and
|
||||
means "cancel all" still cancels the in-progress jobs even if some
|
||||
finished between the client's snapshot and the request. Malformed
|
||||
ids are still rejected up front with 400 (see below).
|
||||
"""
|
||||
try:
|
||||
json_data = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return web.json_response(
|
||||
{"error": "Request body must be valid JSON"},
|
||||
status=400
|
||||
)
|
||||
|
||||
job_ids = json_data.get("job_ids") if isinstance(json_data, dict) else None
|
||||
if not isinstance(job_ids, list):
|
||||
return web.json_response(
|
||||
{"error": "job_ids must be a list"},
|
||||
status=400
|
||||
)
|
||||
|
||||
# Validate that every element is a well-formed job id before doing
|
||||
# anything else. An unhashable element (e.g. a nested dict or list)
|
||||
# would cause a TypeError when used as a history dict key; a
|
||||
# non-string or non-UUID value is never a valid id. Reject early
|
||||
# with 400 rather than letting the classify loop raise 500.
|
||||
invalid_ids = []
|
||||
for jid in job_ids:
|
||||
try:
|
||||
validate_job_id(jid)
|
||||
except (ValueError, AttributeError):
|
||||
invalid_ids.append(jid if isinstance(jid, str) else repr(jid))
|
||||
if invalid_ids:
|
||||
return web.json_response(
|
||||
{"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids},
|
||||
status=400,
|
||||
)
|
||||
|
||||
# Best-effort: cancel each id that is still running/pending; an id
|
||||
# that has finished or never existed is a no-op rather than a reason
|
||||
# to fail the whole batch.
|
||||
cancelled = False
|
||||
for jid in job_ids:
|
||||
if _cancel_job_by_id(jid):
|
||||
cancelled = True
|
||||
|
||||
return web.json_response({"cancelled": cancelled})
|
||||
|
||||
@routes.get("/history")
|
||||
async def get_history(request):
|
||||
max_items = request.rel_url.query.get("max_items", None)
|
||||
|
||||
@ -1,453 +0,0 @@
|
||||
"""Tests for the jobs-namespace cancel endpoints.
|
||||
|
||||
Covers both layers:
|
||||
|
||||
* the pure cancel helpers in ``comfy_execution.jobs``
|
||||
(``classify_job_for_cancel`` / ``cancel_job``), which hold the business
|
||||
logic of mapping a cancel onto interrupt-vs-dequeue, and
|
||||
|
||||
* the HTTP contract of ``POST /api/jobs/{job_id}/cancel`` and
|
||||
``POST /api/jobs/cancel`` (status codes, single-cancel idempotency, and
|
||||
best-effort batch cancellation that treats unknown/finished ids as no-ops
|
||||
while still rejecting malformed ids with 400).
|
||||
|
||||
The HTTP layer is exercised against a small aiohttp app whose handlers are a
|
||||
faithful copy of the wiring in ``server.py`` driven by a fake queue that
|
||||
mirrors ``execution.PromptQueue`` (``get_current_queue`` / ``get_history`` /
|
||||
``delete_queue_item``). This keeps the test free of the heavy ComfyUI runtime
|
||||
(torch, nodes, ...) while still testing the real cancel logic.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from comfy_execution.jobs import (
|
||||
CANCEL_PENDING,
|
||||
CANCEL_RUNNING,
|
||||
CANCEL_TERMINAL,
|
||||
CANCEL_UNKNOWN,
|
||||
cancel_job,
|
||||
classify_job_for_cancel,
|
||||
validate_job_id,
|
||||
)
|
||||
|
||||
# Classifications for which a cancel was actually dispatched (vs a no-op).
|
||||
_CANCELLED = (CANCEL_RUNNING, CANCEL_PENDING)
|
||||
|
||||
# Canonical UUID ids for HTTP-layer tests (the batch endpoint validates UUID format).
|
||||
_UUID_A = "aaaaaaaa-aaaa-4aaa-aaaa-aaaaaaaaaaaa"
|
||||
_UUID_B = "bbbbbbbb-bbbb-4bbb-bbbb-bbbbbbbbbbbb"
|
||||
_UUID_C = "cccccccc-cccc-4ccc-cccc-cccccccccccc"
|
||||
_UUID_D = "dddddddd-dddd-4ddd-dddd-dddddddddddd"
|
||||
_UUID_MISSING = "ffffffff-ffff-4fff-ffff-ffffffffffff"
|
||||
|
||||
|
||||
def make_queue_item(prompt_id, number=0):
|
||||
"""Build a queue tuple shaped like the real ones: index 1 is the id."""
|
||||
return (number, prompt_id, {}, {}, [])
|
||||
|
||||
|
||||
class FakePromptQueue:
|
||||
"""Minimal stand-in for execution.PromptQueue for the cancel paths.
|
||||
|
||||
Tracks interrupts and dequeues so tests can assert side effects.
|
||||
"""
|
||||
|
||||
def __init__(self, running=None, pending=None, history=None):
|
||||
self._running = list(running or [])
|
||||
self._pending = list(pending or [])
|
||||
self._history = dict(history or {})
|
||||
self.interrupt_count = 0
|
||||
|
||||
def get_current_queue(self):
|
||||
return (list(self._running), list(self._pending))
|
||||
|
||||
def get_history(self, prompt_id=None):
|
||||
if prompt_id is None:
|
||||
return dict(self._history)
|
||||
if prompt_id in self._history:
|
||||
return {prompt_id: self._history[prompt_id]}
|
||||
return {}
|
||||
|
||||
def delete_queue_item(self, function):
|
||||
for i, item in enumerate(self._pending):
|
||||
if function(item):
|
||||
self._pending.pop(i)
|
||||
return True
|
||||
return False
|
||||
|
||||
def interrupt_if_running(self, prompt_id):
|
||||
# Mirrors execution.PromptQueue.interrupt_if_running: only signals an
|
||||
# interrupt when the id is actually in the running set.
|
||||
if any(item[1] == prompt_id for item in self._running):
|
||||
self.interrupt_count += 1
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def build_app(queue):
|
||||
"""Build an aiohttp app exposing the cancel routes against ``queue``.
|
||||
|
||||
Handler bodies mirror server.py exactly.
|
||||
"""
|
||||
|
||||
def _cancel_job_by_id(job_id):
|
||||
running, pending = queue.get_current_queue()
|
||||
history = queue.get_history()
|
||||
|
||||
def interrupt(prompt_id):
|
||||
return queue.interrupt_if_running(prompt_id)
|
||||
|
||||
def dequeue(prompt_id):
|
||||
return queue.delete_queue_item(lambda a: a[1] == prompt_id)
|
||||
|
||||
classification = cancel_job(
|
||||
job_id, running, pending, history, interrupt, dequeue
|
||||
)
|
||||
return classification in _CANCELLED
|
||||
|
||||
async def cancel_job_by_id(request):
|
||||
job_id = request.match_info.get("job_id", None)
|
||||
if not job_id:
|
||||
return web.json_response({"error": "job_id is required"}, status=400)
|
||||
cancelled = _cancel_job_by_id(job_id)
|
||||
return web.json_response({"cancelled": cancelled})
|
||||
|
||||
async def cancel_jobs_batch(request):
|
||||
try:
|
||||
json_data = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return web.json_response(
|
||||
{"error": "Request body must be valid JSON"}, status=400
|
||||
)
|
||||
|
||||
job_ids = json_data.get("job_ids") if isinstance(json_data, dict) else None
|
||||
if not isinstance(job_ids, list):
|
||||
return web.json_response({"error": "job_ids must be a list"}, status=400)
|
||||
|
||||
invalid_ids = []
|
||||
for jid in job_ids:
|
||||
try:
|
||||
validate_job_id(jid)
|
||||
except (ValueError, AttributeError):
|
||||
invalid_ids.append(jid if isinstance(jid, str) else repr(jid))
|
||||
if invalid_ids:
|
||||
return web.json_response(
|
||||
{"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids},
|
||||
status=400,
|
||||
)
|
||||
|
||||
cancelled = False
|
||||
for jid in job_ids:
|
||||
if _cancel_job_by_id(jid):
|
||||
cancelled = True
|
||||
return web.json_response({"cancelled": cancelled})
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_post("/api/jobs/{job_id}/cancel", cancel_job_by_id)
|
||||
app.router.add_post("/api/jobs/cancel", cancel_jobs_batch)
|
||||
return app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure helper tests: classification + cancel side effects
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClassifyJobForCancel:
|
||||
def test_running(self):
|
||||
running = [make_queue_item("a")]
|
||||
assert classify_job_for_cancel("a", running, [], {}) == CANCEL_RUNNING
|
||||
|
||||
def test_pending(self):
|
||||
pending = [make_queue_item("b")]
|
||||
assert classify_job_for_cancel("b", [], pending, {}) == CANCEL_PENDING
|
||||
|
||||
def test_terminal(self):
|
||||
history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}}
|
||||
assert classify_job_for_cancel("c", [], [], history) == CANCEL_TERMINAL
|
||||
|
||||
def test_unknown(self):
|
||||
assert classify_job_for_cancel("z", [], [], {}) == CANCEL_UNKNOWN
|
||||
|
||||
|
||||
class TestCancelJobHelper:
|
||||
"""``interrupt`` and ``dequeue`` both take the id and return whether they
|
||||
actually acted, so cancel_job's return reflects the real outcome."""
|
||||
|
||||
def test_running_is_interrupted_not_dequeued(self):
|
||||
interrupts = []
|
||||
dequeues = []
|
||||
result = cancel_job(
|
||||
"a", [make_queue_item("a")], [], {},
|
||||
interrupt=lambda pid: interrupts.append(pid) or True,
|
||||
dequeue=lambda pid: dequeues.append(pid) or True,
|
||||
)
|
||||
assert result == CANCEL_RUNNING
|
||||
assert interrupts == ["a"]
|
||||
assert dequeues == []
|
||||
|
||||
def test_pending_is_dequeued_not_interrupted(self):
|
||||
interrupts = []
|
||||
dequeues = []
|
||||
result = cancel_job(
|
||||
"b", [], [make_queue_item("b")], {},
|
||||
interrupt=lambda pid: interrupts.append(pid) or True,
|
||||
dequeue=lambda pid: dequeues.append(pid) or True,
|
||||
)
|
||||
assert result == CANCEL_PENDING
|
||||
assert dequeues == ["b"]
|
||||
assert interrupts == []
|
||||
|
||||
def test_terminal_is_noop(self):
|
||||
history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}}
|
||||
interrupts = []
|
||||
dequeues = []
|
||||
result = cancel_job(
|
||||
"c", [], [], history,
|
||||
interrupt=lambda pid: interrupts.append(pid) or True,
|
||||
dequeue=lambda pid: dequeues.append(pid) or True,
|
||||
)
|
||||
assert result == CANCEL_TERMINAL
|
||||
assert interrupts == []
|
||||
assert dequeues == []
|
||||
|
||||
def test_unknown_is_noop(self):
|
||||
interrupts = []
|
||||
dequeues = []
|
||||
result = cancel_job(
|
||||
"z", [], [], {},
|
||||
interrupt=lambda pid: interrupts.append(pid) or True,
|
||||
dequeue=lambda pid: dequeues.append(pid) or True,
|
||||
)
|
||||
assert result == CANCEL_UNKNOWN
|
||||
assert interrupts == []
|
||||
assert dequeues == []
|
||||
|
||||
def test_running_but_finished_before_interrupt_returns_unknown(self):
|
||||
"""Classified RUNNING from a stale snapshot, but the job finished before
|
||||
the atomic interrupt fired (interrupt returns False). cancel_job reports
|
||||
UNKNOWN rather than claiming a cancel that did not happen — and the
|
||||
atomic interrupt guarantees no unrelated job was hit."""
|
||||
interrupts = []
|
||||
result = cancel_job(
|
||||
"a", [make_queue_item("a")], [], {},
|
||||
interrupt=lambda pid: interrupts.append(pid) or False,
|
||||
dequeue=lambda pid: True,
|
||||
)
|
||||
assert result == CANCEL_UNKNOWN
|
||||
assert interrupts == ["a"] # interrupt was attempted atomically
|
||||
|
||||
def test_pending_started_running_is_interrupted(self):
|
||||
"""Pending->running race: the job leaves the queue (dequeue False)
|
||||
because it started executing. The atomic interrupt catches the now-
|
||||
running job, so cancel_job interrupts it and reports CANCEL_RUNNING."""
|
||||
interrupts = []
|
||||
dequeues = []
|
||||
result = cancel_job(
|
||||
"b", [], [make_queue_item("b")], {},
|
||||
interrupt=lambda pid: interrupts.append(pid) or True,
|
||||
dequeue=lambda pid: (dequeues.append(pid), False)[1],
|
||||
)
|
||||
assert result == CANCEL_RUNNING
|
||||
assert dequeues == ["b"] # dequeue attempted first
|
||||
assert interrupts == ["b"] # then the now-running job was interrupted
|
||||
|
||||
def test_pending_dequeue_miss_not_running_returns_unknown(self):
|
||||
"""Dequeue miss where the job is not running anymore (it finished): the
|
||||
atomic interrupt finds nothing to interrupt and returns False, so
|
||||
cancel_job is a no-op reporting UNKNOWN — never reporting a cancel that
|
||||
did not happen, and never interrupting a bystander."""
|
||||
interrupts = []
|
||||
dequeues = []
|
||||
result = cancel_job(
|
||||
"b", [], [make_queue_item("b")], {},
|
||||
interrupt=lambda pid: interrupts.append(pid) or False,
|
||||
dequeue=lambda pid: (dequeues.append(pid), False)[1],
|
||||
)
|
||||
assert result == CANCEL_UNKNOWN
|
||||
assert dequeues == ["b"]
|
||||
assert interrupts == ["b"] # interrupt attempted, found nothing running
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP contract tests: POST /api/jobs/{job_id}/cancel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSingleCancelEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_running_job_interrupts(self, aiohttp_client):
|
||||
queue = FakePromptQueue(running=[make_queue_item("a")])
|
||||
client = await aiohttp_client(build_app(queue))
|
||||
|
||||
resp = await client.post("/api/jobs/a/cancel")
|
||||
|
||||
assert resp.status == 200
|
||||
assert (await resp.json()) == {"cancelled": True}
|
||||
assert queue.interrupt_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_pending_job_dequeues(self, aiohttp_client):
|
||||
queue = FakePromptQueue(pending=[make_queue_item("b")])
|
||||
client = await aiohttp_client(build_app(queue))
|
||||
|
||||
resp = await client.post("/api/jobs/b/cancel")
|
||||
|
||||
assert resp.status == 200
|
||||
assert (await resp.json()) == {"cancelled": True}
|
||||
# Pending job removed from the queue; nothing interrupted.
|
||||
assert queue.get_current_queue()[1] == []
|
||||
assert queue.interrupt_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_terminal_job_is_idempotent_noop(self, aiohttp_client):
|
||||
history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}}
|
||||
queue = FakePromptQueue(history=history)
|
||||
client = await aiohttp_client(build_app(queue))
|
||||
|
||||
resp = await client.post("/api/jobs/c/cancel")
|
||||
|
||||
# Already-finished job: 200 no-op (cancelled=false), not an error.
|
||||
assert resp.status == 200
|
||||
assert (await resp.json()) == {"cancelled": False}
|
||||
assert queue.interrupt_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_unknown_id_is_200_noop(self, aiohttp_client):
|
||||
queue = FakePromptQueue()
|
||||
client = await aiohttp_client(build_app(queue))
|
||||
|
||||
resp = await client.post("/api/jobs/does-not-exist/cancel")
|
||||
|
||||
# Single-cancel of an unknown id is treated as an idempotent no-op.
|
||||
assert resp.status == 200
|
||||
assert (await resp.json()) == {"cancelled": False}
|
||||
assert queue.interrupt_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_pending_that_started_running_interrupts(self, aiohttp_client):
|
||||
"""Pending->running race end to end: the job is pending at snapshot time
|
||||
but starts executing by the time we dequeue (delete misses). The live
|
||||
re-check sees it running and interrupts it, so the cancel is not dropped
|
||||
and the caller still gets cancelled=True."""
|
||||
|
||||
class RacingQueue(FakePromptQueue):
|
||||
def delete_queue_item(self, function):
|
||||
# The worker picked the job up just before we removed it: it
|
||||
# leaves the pending queue (delete misses) and is now running.
|
||||
self._running = list(self._pending)
|
||||
self._pending = []
|
||||
return False
|
||||
|
||||
queue = RacingQueue(pending=[make_queue_item("b")])
|
||||
client = await aiohttp_client(build_app(queue))
|
||||
|
||||
resp = await client.post("/api/jobs/b/cancel")
|
||||
|
||||
assert resp.status == 200
|
||||
assert (await resp.json()) == {"cancelled": True}
|
||||
assert queue.interrupt_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP contract tests: POST /api/jobs/cancel (batch)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchCancelEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_happy_path(self, aiohttp_client):
|
||||
queue = FakePromptQueue(
|
||||
running=[make_queue_item(_UUID_A)],
|
||||
pending=[make_queue_item(_UUID_B, number=1)],
|
||||
)
|
||||
client = await aiohttp_client(build_app(queue))
|
||||
|
||||
resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_B]})
|
||||
|
||||
assert resp.status == 200
|
||||
assert (await resp.json()) == {"cancelled": True}
|
||||
assert queue.interrupt_count == 1 # running job interrupted
|
||||
assert queue.get_current_queue()[1] == [] # pending job dequeued
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_best_effort_skips_unknown_id(self, aiohttp_client):
|
||||
"""An unknown id in the batch is a no-op, not a reason to abort: the
|
||||
running and pending jobs are still cancelled (200, cancelled=true). This
|
||||
is the "cancel all as a job finishes" case from review."""
|
||||
queue = FakePromptQueue(
|
||||
running=[make_queue_item(_UUID_A)],
|
||||
pending=[make_queue_item(_UUID_B, number=1)],
|
||||
)
|
||||
client = await aiohttp_client(build_app(queue))
|
||||
|
||||
resp = await client.post(
|
||||
"/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_MISSING, _UUID_B]}
|
||||
)
|
||||
|
||||
assert resp.status == 200
|
||||
assert (await resp.json()) == {"cancelled": True}
|
||||
assert queue.interrupt_count == 1 # running job interrupted
|
||||
assert queue.get_current_queue()[1] == [] # pending job dequeued
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_all_terminal_is_idempotent_noop(self, aiohttp_client):
|
||||
history = {
|
||||
_UUID_C: {"prompt": make_queue_item(_UUID_C), "outputs": {}, "status": {}},
|
||||
_UUID_D: {"prompt": make_queue_item(_UUID_D), "outputs": {}, "status": {}},
|
||||
}
|
||||
queue = FakePromptQueue(history=history)
|
||||
client = await aiohttp_client(build_app(queue))
|
||||
|
||||
resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_C, _UUID_D]})
|
||||
|
||||
# All known but terminal: 200 with cancelled=false, nothing dispatched.
|
||||
assert resp.status == 200
|
||||
assert (await resp.json()) == {"cancelled": False}
|
||||
assert queue.interrupt_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_missing_job_ids_is_400(self, aiohttp_client):
|
||||
queue = FakePromptQueue()
|
||||
client = await aiohttp_client(build_app(queue))
|
||||
|
||||
resp = await client.post("/api/jobs/cancel", json={})
|
||||
|
||||
assert resp.status == 400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_unhashable_element_is_400_not_500(self, aiohttp_client):
|
||||
"""An unhashable element such as a dict or list must yield 400, not 500.
|
||||
|
||||
Previously, passing e.g. {"job_ids": [{}]} would reach the classify
|
||||
loop where ``prompt_id in history`` raises TypeError on an unhashable
|
||||
type, resulting in an unhandled 500. The input-validation guard must
|
||||
catch this before any queue or history access.
|
||||
"""
|
||||
queue = FakePromptQueue()
|
||||
client = await aiohttp_client(build_app(queue))
|
||||
|
||||
resp = await client.post("/api/jobs/cancel", json={"job_ids": [{}]})
|
||||
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert "invalid_ids" in body
|
||||
# No queue side effects.
|
||||
assert queue.interrupt_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_non_uuid_string_element_is_400(self, aiohttp_client):
|
||||
"""A string that is not a valid UUID must be rejected with 400."""
|
||||
queue = FakePromptQueue()
|
||||
client = await aiohttp_client(build_app(queue))
|
||||
|
||||
resp = await client.post(
|
||||
"/api/jobs/cancel", json={"job_ids": ["not-a-uuid"]}
|
||||
)
|
||||
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert "invalid_ids" in body
|
||||
Reference in New Issue
Block a user