Compare commits

..

6 Commits

Author SHA1 Message Date
146669b3e8 Merge branch 'master' into comfyanonymous-patch-1 2026-06-20 08:02:15 +08:00
69d34f2654 Rename a bunch of nodes (#14547) 2026-06-20 08:01:28 +08:00
cd77c551d6 feat: Context Windows sampling with LTX2 models and IC-LoRa guides (CORE-3) (#13325) 2026-06-20 07:47:31 +08:00
4e716f7c57 Add jobs-namespace cancel endpoints (POST /api/jobs/{job_id}/cancel, POST /api/jobs/cancel) (#14493)
* Add jobs-namespace cancel endpoints

Add two cancel endpoints under the jobs namespace so a job can be
cancelled by id without the caller needing to know whether the job is
running or pending, or branching between /interrupt and /queue.

- POST /api/jobs/{job_id}/cancel cancels one job by id. Idempotent: an
  already-finished or unknown id returns 200 {"cancelled": false} rather
  than an error.
- POST /api/jobs/cancel takes {"job_ids": [...]} and cancels a batch.
  Fail-fast: if any id is unknown the request returns 404 listing the
  unknown ids and cancels nothing (no partial side effects).

Both are state-agnostic and map onto the existing queue mechanics: a
running job is interrupted (same path as /interrupt), a pending job is
dequeued (same path as /queue {"delete": [...]}). The cancel logic lives
in comfy_execution.jobs as pure, unit-tested helpers; the server handlers
are thin wrappers. openapi.yaml documents both routes.

* fix: resolve review feedback on cancel endpoints

- Guard cancel_job() against TOCTOU: when dequeue() returns False the
  pending job left the queue between snapshot and delete; return
  CANCEL_UNKNOWN so callers never report cancelled=True for a remove
  that did not happen.
- Validate each job_ids element in the batch cancel endpoint before
  any queue access; unhashable or non-UUID values now return 400
  instead of raising TypeError (500).
- Update batch HTTP tests to use canonical UUID ids (required now that
  the endpoint validates id format) and add tests for the new guards.

* fix: make job cancel atomic and best-effort

Addresses two cancel races/edges raised in review.

Targeted, atomic interrupt. cancel_job's interrupt callback now takes the
prompt id and returns whether it fired; the single-cancel route backs it
with the new PromptQueue.interrupt_if_running, which checks the running set
and signals the interrupt under the queue mutex. This closes the TOCTOU
where a pending job that starts executing between the snapshot and dequeue
(or a running job that finishes between the snapshot and interrupt) could be
missed or, worse, cause an unrelated prompt to be interrupted. The per-prompt
interrupt-flag reset in execute_async keeps a finished job from leaking the
interrupt onto its successor.

Best-effort batch cancel. POST /api/jobs/cancel no longer fails the whole
batch with 404 when one id is unknown/finished; such ids are treated as
no-ops, so "cancel all" still cancels the in-progress jobs even if some
finished between the client's snapshot and the request. Malformed ids are
still rejected with 400.
2026-06-19 16:39:35 -07:00
2ab3816dcf feat: add Load3DAdvanced node (#14316) 2026-06-20 07:06:55 +08:00
93e3fd4c47 Small anima optimization. 2026-06-19 12:41:51 -07:00
14 changed files with 1267 additions and 99 deletions

View File

@ -8,6 +8,8 @@ from abc import ABC, abstractmethod
import logging
import comfy.model_management
import comfy.patcher_extension
import comfy.utils
import comfy.conds
if TYPE_CHECKING:
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
@ -51,12 +53,18 @@ class ContextHandlerABC(ABC):
class IndexListContextWindow(ContextWindowABC):
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None, context_overlap: int=0):
self.index_list = index_list
self.context_length = len(index_list)
self.context_overlap = context_overlap
self.dim = dim
self.total_frames = total_frames
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow}
self.guide_frames_indices: list[int] = []
self.guide_overlap_info: list[tuple[int, int]] = []
self.guide_kf_local_positions: list[int] = []
self.guide_downscale_factors: list[int] = []
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
if dim is None:
@ -85,6 +93,11 @@ class IndexListContextWindow(ContextWindowABC):
region_idx = int(self.center_ratio * num_regions)
return min(max(region_idx, 0), num_regions - 1)
def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow':
if modality_idx == 0:
return self
return self.modality_windows[modality_idx]
class IndexListCallbacks:
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
@ -148,6 +161,172 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
return cond_value._copy_with(sliced)
def compute_guide_overlap(guide_entries: list[dict], keyframe_idxs: torch.Tensor, temporal_downscale_ratio: int, window_index_list: list[int]):
"""Compute which concatenated guide frames overlap with a context window.
Each guide's latent-space start is derived from its first token's pixel-t-start
in keyframe_idxs (shape (B, [t,h,w], num_tokens, [start, end])), divided by the
model's temporal_downscale_ratio.
Args:
guide_entries: list of guide_attention_entry dicts
keyframe_idxs: per-token pixel coords cond tensor for the modality
temporal_downscale_ratio: model's pixel-to-latent temporal compression ratio
window_index_list: the window's frame indices into the video portion
Returns:
suffix_indices: indices into the guide_frames tensor for frame selection
overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment
kf_local_positions: window-local frame positions for keyframe_idxs regeneration
total_overlap: total number of overlapping guide frames
"""
window_set = set(window_index_list)
window_list = list(window_index_list)
suffix_indices = []
overlap_info = []
kf_local_positions = []
suffix_base = 0
token_offset = 0
for entry_idx, entry in enumerate(guide_entries):
first_t_pixel = int(keyframe_idxs[0, 0, token_offset, 0].item())
latent_start = (first_t_pixel + temporal_downscale_ratio - 1) // temporal_downscale_ratio
guide_len = entry["latent_shape"][0]
entry_overlap = 0
for local_offset in range(guide_len):
video_pos = latent_start + local_offset
if video_pos in window_set:
suffix_indices.append(suffix_base + local_offset)
kf_local_positions.append(window_list.index(video_pos))
entry_overlap += 1
if entry_overlap > 0:
overlap_info.append((entry_idx, entry_overlap))
suffix_base += guide_len
token_offset += entry["pre_filter_count"]
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
@dataclass
class WindowingState:
"""Per-modality context windowing state for each step,
built using IndexListContextHandler._build_window_state().
For non-multimodal models the lists are length 1
"""
latents: list[torch.Tensor] # per-modality working latents (guide frames stripped)
guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents
guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata
keyframe_idxs: list[torch.Tensor | None] # per-modality keyframe_idxs tensor for guide latent_start derivation
latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal)
dim: int = 0 # primary modality temporal dim for context windowing
is_multimodal: bool = False
temporal_downscale_ratio: int = 1 # model's pixel-to-latent temporal compression ratio
def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow:
"""Reformat window for multimodal contexts by deriving per-modality index lists.
Non-multimodal contexts return the input window unchanged.
"""
if not self.is_multimodal:
return window
x = self.latents[0]
primary_total = self.latent_shapes[0][self.dim]
primary_overlap = window.context_overlap
map_shapes = self.latent_shapes
if x.size(self.dim) != primary_total:
map_shapes = list(self.latent_shapes)
video_shape = list(self.latent_shapes[0])
video_shape[self.dim] = x.size(self.dim)
map_shapes[0] = torch.Size(video_shape)
try:
per_modality_indices = model.map_context_window_to_modalities(
window.index_list, map_shapes, self.dim)
except AttributeError:
raise NotImplementedError(
f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.")
modality_windows = {}
for mod_idx in range(1, len(self.latents)):
modality_total_frames = self.latents[mod_idx].shape[self.dim]
ratio = modality_total_frames / primary_total if primary_total > 0 else 1
modality_overlap = max(round(primary_overlap * ratio), 0)
modality_windows[mod_idx] = IndexListContextWindow(
per_modality_indices[mod_idx], dim=self.dim,
total_frames=modality_total_frames,
context_overlap=modality_overlap)
return IndexListContextWindow(
window.index_list, dim=self.dim, total_frames=x.shape[self.dim],
modality_windows=modality_windows, context_overlap=primary_overlap)
def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]:
"""Slice latents for a context window, injecting guide frames where applicable.
For multimodal contexts, uses the modality-specific windows derived in prepare_window().
"""
sliced = []
guide_frame_counts = []
for idx in range(len(self.latents)):
modality_window = window.get_window_for_modality(idx)
retain = retain_index_list if idx == 0 else []
s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain)
if self.guide_entries[idx] is not None:
s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx)
else:
ng = 0
sliced.append(s)
guide_frame_counts.append(ng)
return sliced, guide_frame_counts
def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow):
"""Strip injected guide frames from per-cond, per-modality outputs in place."""
for idx in range(len(self.latents)):
if guide_frame_counts[idx] > 0:
window_len = len(window.get_window_for_modality(idx).index_list)
for ci in range(len(out_per_modality)):
out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len)
def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]:
guide_entries = self.guide_entries[modality_idx]
guide_frames = self.guide_latents[modality_idx]
keyframe_idxs = self.keyframe_idxs[modality_idx]
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(
guide_entries, keyframe_idxs, self.temporal_downscale_ratio, window.index_list)
# Shift keyframe positions to account for causal_window_fix anchor occupying sub-pos 0.
anchor_idx = getattr(window, 'causal_anchor_index', None)
if anchor_idx is not None and anchor_idx >= 0:
kf_local_pos = [p + 1 for p in kf_local_pos]
window.guide_frames_indices = suffix_idx
window.guide_overlap_info = overlap_info
window.guide_kf_local_positions = kf_local_pos
# Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
# guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
guide_downscale_factors = []
if guide_frame_count > 0:
full_H = guide_frames.shape[3]
for entry_idx, _ in overlap_info:
entry_H = guide_entries[entry_idx]["latent_shape"][1]
guide_downscale_factors.append(full_H // entry_H)
window.guide_downscale_factors = guide_downscale_factors
if guide_frame_count > 0:
idx = tuple([slice(None)] * self.dim + [suffix_idx])
return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count
return latent_slice, 0
def patch_latent_shapes(self, sub_conds, new_shapes):
if not self.is_multimodal:
return
for cond_list in sub_conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
if 'latent_shapes' in model_conds:
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
@dataclass
class ContextSchedule:
name: str
@ -162,7 +341,7 @@ ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_co
class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
causal_window_fix: bool=True):
latent_retain_index_list: list[int]=[], causal_window_fix: bool=True):
self.context_schedule = context_schedule
self.fuse_method = fuse_method
self.context_length = context_length
@ -174,17 +353,118 @@ class IndexListContextHandler(ContextHandlerABC):
self.freenoise = freenoise
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
self.split_conds_to_windows = split_conds_to_windows
self.latent_retain_index_list = [int(x.strip()) for x in latent_retain_index_list.split(",")] if latent_retain_index_list else []
self.causal_window_fix = causal_window_fix
self.callbacks = {}
@staticmethod
def _get_latent_shapes(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
if 'latent_shapes' in model_conds:
return model_conds['latent_shapes'].cond
return None
@staticmethod
def _get_guide_entries(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
entries = model_conds.get('guide_attention_entries')
if entries is not None and hasattr(entries, 'cond') and entries.cond:
return entries.cond
return None
@staticmethod
def _get_keyframe_idxs(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
kf = model_conds.get('keyframe_idxs')
if kf is not None and hasattr(kf, 'cond') and kf.cond is not None:
return kf.cond
return None
def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor:
"""Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.
If guide frames are present on the primary modality, only the video portion is shuffled.
"""
guide_entries = self._get_guide_entries(conds)
guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0
latent_shapes = self._get_latent_shapes(conds)
if latent_shapes is not None and len(latent_shapes) > 1:
modalities = comfy.utils.unpack_latents(noise, latent_shapes)
primary_total = latent_shapes[0][self.dim]
primary_video_count = modalities[0].size(self.dim) - guide_count
apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed)
for i in range(1, len(modalities)):
mod_total = latent_shapes[i][self.dim]
ratio = mod_total / primary_total if primary_total > 0 else 1
mod_ctx_len = max(round(self.context_length * ratio), 1)
mod_ctx_overlap = max(round(self.context_overlap * ratio), 0)
modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed)
noise, _ = comfy.utils.pack_latents(modalities)
return noise
video_count = noise.size(self.dim) - guide_count
apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed)
return noise
def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]], model: BaseModel) -> WindowingState:
"""Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds."""
latent_shapes = self._get_latent_shapes(conds)
is_multimodal = latent_shapes is not None and len(latent_shapes) > 1
unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in]
unpacked_latents_list = list(unpacked_latents)
guide_latents_list = [None] * len(unpacked_latents)
guide_entries_list = [None] * len(unpacked_latents)
keyframe_idxs_list = [None] * len(unpacked_latents)
extracted_guide_entries = self._get_guide_entries(conds)
extracted_keyframe_idxs = self._get_keyframe_idxs(conds)
# Strip guide frames (only from first modality for now)
if extracted_guide_entries is not None:
guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries)
if guide_count > 0:
x = unpacked_latents[0]
latent_count = x.size(self.dim) - guide_count
unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count)
guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count)
guide_entries_list[0] = extracted_guide_entries
keyframe_idxs_list[0] = extracted_keyframe_idxs
return WindowingState(
latents=unpacked_latents_list,
guide_latents=guide_latents_list,
guide_entries=guide_entries_list,
keyframe_idxs=keyframe_idxs_list,
latent_shapes=latent_shapes,
dim=self.dim,
is_multimodal=is_multimodal,
temporal_downscale_ratio=model.latent_format.temporal_downscale_ratio)
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
if x_in.size(self.dim) > self.context_length:
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
window_state = self._build_window_state(x_in, conds, model) # build window_state to check frame counts, will be built again in execute
total_frame_count = window_state.latents[0].size(self.dim)
if total_frame_count > self.context_length:
logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.")
if self.cond_retain_index_list:
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
if self.latent_retain_index_list:
logging.info(f"Retaining original latent for indexes: {self.latent_retain_index_list}")
return True
logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).")
return False
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
@ -275,7 +555,9 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
sample_sigmas = model_options["transformer_options"]["sample_sigmas"]
current_timestep = timestep[0].to(sample_sigmas.dtype)
mask = torch.isclose(sample_sigmas, current_timestep, rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
return # substep from multi-step sampler: keep self._step from the last full step
@ -284,54 +566,98 @@ class IndexListContextHandler(ContextHandlerABC):
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model
context_windows = self.context_schedule.func(full_length, self, model_options)
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length, context_overlap=self.context_overlap) for window in context_windows]
return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
self._model = model
self.set_step(timestep, model_options)
context_windows = self.get_context_windows(model, x_in, model_options)
enumerated_context_windows = list(enumerate(context_windows))
conds_final = [torch.zeros_like(x_in) for _ in conds]
window_state = self._build_window_state(x_in, conds, model)
num_modalities = len(window_state.latents)
context_windows = self.get_context_windows(model, window_state.latents[0], model_options)
enumerated_context_windows = list(enumerate(context_windows))
total_windows = len(enumerated_context_windows)
# Initialize per-modality accumulators (length 1 for single-modality)
accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents]
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
else:
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents]
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options)
# accumulate results from each context window
for enum_window in enumerated_context_windows:
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
results = self.evaluate_context_windows(
calc_cond_batch, model, x_in, conds, timestep, [enum_window],
model_options, window_state=window_state, total_windows=total_windows)
for result in results:
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
conds_final, counts_final, biases_final)
# result.sub_conds_out is per-cond, per-modality: list[list[Tensor]]
for mod_idx in range(num_modalities):
mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))]
modality_window = result.window.get_window_for_modality(mod_idx)
self.combine_context_window_results(
window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window,
result.window_idx, total_windows, timestep,
accum[mod_idx], counts[mod_idx], biases[mod_idx])
# fuse accumulated results into final conds
try:
# finalize conds
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
# relative is already normalized, so return as is
del counts_final
return conds_final
else:
# normalize conds via division by context usage counts
for i in range(len(conds_final)):
conds_final[i] /= counts_final[i]
del counts_final
return conds_final
result_out = []
for ci in range(len(conds)):
finalized = []
for mod_idx in range(num_modalities):
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
accum[mod_idx][ci] /= counts[mod_idx][ci]
f = accum[mod_idx][ci]
# if guide frames were injected, append them to the end of the fused latents for the next step
if window_state.guide_latents[mod_idx] is not None:
f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim)
finalized.append(f)
# pack modalities together if needed
if window_state.is_multimodal and len(finalized) > 1:
packed, _ = comfy.utils.pack_latents(finalized)
else:
packed = finalized[0]
result_out.append(packed)
return result_out
finally:
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options)
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
model_options, device=None, first_device=None):
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds,
timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
model_options, window_state: WindowingState, total_windows: int = None,
device=None, first_device=None):
"""Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out
For each window:
1. Builds windows (for each modality if multimodal)
2. Slices window for each modality
3. Injects concatenated latent guide frames where present
4. Packs together if needed and calls model
5. Unpacks and strips any guides from outputs
"""
x = window_state.latents[0]
results: list[ContextResults] = []
for window_idx, window in enumerated_context_windows:
# allow processing to end between context window executions for faster Cancel
comfy.model_management.throw_exception_if_processing_interrupted()
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward
# prepare the window accounting for multimodal windows
window = window_state.prepare_window(window, model)
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward.
# Set anchor before slice_for_window so the latent slice and downstream cond slices both pick it up.
anchor_applied = False
if self.causal_window_fix:
anchor_idx = window.index_list[0] - 1
@ -339,27 +665,46 @@ class IndexListContextHandler(ContextHandlerABC):
window.causal_anchor_index = anchor_idx
anchor_applied = True
# slice the window for each modality, injecting guide frames where applicable
sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.latent_retain_index_list, device)
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
# update exposed params
logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}"
+ (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "")
)
# if multimodal, pack modalities together
if window_state.is_multimodal and len(sliced) > 1:
sub_x, sub_shapes = comfy.utils.pack_latents(sliced)
else:
sub_x, sub_shapes = sliced[0], [sliced[0].shape]
# get resized conds for window
model_options["transformer_options"]["context_window"] = window
# get subsections of x, timestep, conds
sub_x = window.get_tensor(x_in, device)
sub_timestep = window.get_tensor(timestep, device, dim=0)
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
sub_timestep = window.get_tensor(timestep, dim=0)
sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds]
# if multimodal, patch latent_shapes in conds for correct unpacking in model
window_state.patch_latent_shapes(sub_conds, sub_shapes)
# call model on window
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
if device is not None:
for i in range(len(sub_conds_out)):
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
# strip causal_window_fix anchor if applied
# unpack outputs
out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
# strip causal_window_fix anchor from primary modality before guide strip so window_len math stays correct
if anchor_applied:
for i in range(len(sub_conds_out)):
sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1)
for ci in range(len(out_per_modality)):
t = out_per_modality[ci][0]
out_per_modality[ci][0] = t.narrow(self.dim, 1, t.shape[self.dim] - 1)
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
# strip injected guide frames
window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window)
results.append(ContextResults(window_idx, out_per_modality, sub_conds, window))
return results
@ -383,7 +728,7 @@ class IndexListContextHandler(ContextHandlerABC):
biases_final[i][idx] = bias_total + bias
else:
# add conds and counts based on weights of fuse method
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep, context_overlap=window.context_overlap)
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
for i in range(len(sub_conds_out)):
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
@ -393,16 +738,22 @@ class IndexListContextHandler(ContextHandlerABC):
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
# limit noise_shape length to context_length for more accurate vram use estimation
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs):
# Scale noise_shape to a single context window so VRAM estimation budgets per-window.
model_options = kwargs.get("model_options", None)
if model_options is None:
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
handler: IndexListContextHandler = model_options.get("context_handler", None)
if handler is not None:
noise_shape = list(noise_shape)
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
return executor(model, noise_shape, *args, **kwargs)
is_packed = len(noise_shape) == 3 and noise_shape[1] == 1
if is_packed:
# TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a
# per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM.
pass
elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length:
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
return executor(model, noise_shape, conds, *args, **kwargs)
def create_prepare_sampling_wrapper(model: ModelPatcher):
@ -422,11 +773,12 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
if not handler.freenoise:
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
conds = [guider.conds.get('positive', guider.conds.get('negative', []))]
noise = handler._apply_freenoise(noise, conds, extra_args["seed"])
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
def create_sampler_sample_wrapper(model: ModelPatcher):
model.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
@ -434,7 +786,6 @@ def create_sampler_sample_wrapper(model: ModelPatcher):
_sampler_sample_wrapper
)
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
total_dims = len(x_in.shape)
weights_tensor = torch.Tensor(weights).to(device=device)
@ -580,8 +931,9 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
return ContextSchedule(context_schedule, func)
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None, context_overlap: int=None):
context_overlap = handler.context_overlap if context_overlap is None else context_overlap
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs, context_overlap=context_overlap)
def create_weights_flat(length: int, **kwargs) -> list[float]:
@ -599,18 +951,18 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
return weight_sequence
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], context_overlap: int, **kwargs):
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
# only expected overlap is given different weights
weights_torch = torch.ones((length))
# blend left-side on all except first window
if min(idxs) > 0:
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
weights_torch[:handler.context_overlap] = ramp_up
ramp_up = torch.linspace(1e-37, 1, context_overlap)
weights_torch[:context_overlap] = ramp_up
# blend right-side on all except last window
if max(idxs) < full_length-1:
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
weights_torch[-handler.context_overlap:] = ramp_down
ramp_down = torch.linspace(1, 1e-37, context_overlap)
weights_torch[-context_overlap:] = ramp_down
return weights_torch
class ContextFuseMethods:

View File

@ -515,7 +515,7 @@ class Block(nn.Module):
h=H,
w=W,
)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_self_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
def _x_fn(
_x_B_T_H_W_D: torch.Tensor,
@ -548,7 +548,7 @@ class Block(nn.Module):
shift_cross_attn_B_T_1_1_D,
transformer_options=transformer_options,
)
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_cross_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D,
@ -557,7 +557,7 @@ class Block(nn.Module):
shift_mlp_B_T_1_1_D,
)
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_mlp_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
return x_B_T_H_W_D

View File

@ -1085,7 +1085,7 @@ class LTXVModel(LTXBaseModel):
)
grid_mask = None
if keyframe_idxs is not None:
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
additional_args.update({ "orig_patchified_shape": list(x.shape)})
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
@ -1330,7 +1330,7 @@ class LTXVModel(LTXBaseModel):
x = x * (1 + scale) + shift
x = self.proj_out(x)
if keyframe_idxs is not None:
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
grid_mask = kwargs["grid_mask"]
orig_patchified_shape = kwargs["orig_patchified_shape"]
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)

View File

@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
import comfy.ldm.lightricks.av_model
import comfy.ldm.lightricks.symmetric_patchifier
import comfy.context_windows
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC
@ -1204,6 +1205,127 @@ class LTXAV(BaseModel):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim):
result = [primary_indices]
if len(latent_shapes) < 2:
return result
video_total = latent_shapes[0][dim]
for i in range(1, len(latent_shapes)):
mod_total = latent_shapes[i][dim]
# Map each primary index to its proportional range of modality indices and
# concatenate in order. Preserves wrapped/strided geometry so the modality
# attends to the same temporal regions as the primary window.
mod_indices = []
seen = set()
for v_idx in primary_indices:
a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1)
a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total)
if a_end <= a_start:
a_end = a_start + 1
for a in range(a_start, a_end):
if a not in seen:
seen.add(a)
mod_indices.append(a)
result.append(mod_indices)
return result
@staticmethod
def _get_guide_entries(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
entries = model_conds.get('guide_attention_entries')
if entries is not None and hasattr(entries, 'cond') and entries.cond:
return entries.cond
return None
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
# Audio denoise mask — slice using audio modality window
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows:
audio_window = window.modality_windows.get(1)
if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
sliced = audio_window.get_tensor(cond_value.cond, device, dim=2)
return cond_value._copy_with(sliced)
# Video denoise mask — split into video + guide portions, slice each
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
cond_tensor = cond_value.cond
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
if guide_count > 0:
T_video = x_in.size(window.dim)
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
suffix_indices = window.guide_frames_indices
if suffix_indices:
idx = tuple([slice(None)] * window.dim + [suffix_indices])
sliced_guide = guide_mask[idx].to(device)
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
else:
return cond_value._copy_with(sliced_video)
# Keyframe indices — regenerate pixel coords for window, select guide positions
if cond_key == "keyframe_idxs":
kf_local_pos = window.guide_kf_local_positions
if not kf_local_pos:
return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty
H, W = x_in.shape[3], x_in.shape[4]
window_len = len(window.index_list)
# account for causal_window_fix anchor in coord space size
anchor_idx = getattr(window, 'causal_anchor_index', None)
if anchor_idx is not None and anchor_idx >= 0:
window_len += 1
patchifier = self.diffusion_model.patchifier
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
scale_factors = self.diffusion_model.vae_scale_factors
pixel_coords = comfy.ldm.lightricks.symmetric_patchifier.latent_to_pixel_coords(
latent_coords,
scale_factors,
causal_fix=self.diffusion_model.causal_temporal_positioning)
tokens = []
for pos in kf_local_pos:
tokens.extend(range(pos * H * W, (pos + 1) * H * W))
pixel_coords = pixel_coords[:, :, tokens, :]
# Adjust spatial end positions for dilated (downscaled) guides.
# Each guide entry may have a different downscale factor; expand the
# per-entry factor to cover all tokens belonging to that entry.
downscale_factors = window.guide_downscale_factors
overlap_info = window.guide_overlap_info
if downscale_factors:
per_token_factor = []
for (entry_idx, overlap_count), dsf in zip(overlap_info, downscale_factors):
per_token_factor.extend([dsf] * (overlap_count * H * W))
factor_tensor = torch.tensor(per_token_factor, device=pixel_coords.device, dtype=pixel_coords.dtype)
spatial_end_offset = (factor_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1) - 1) * torch.tensor(
scale_factors[1:], device=pixel_coords.device, dtype=pixel_coords.dtype,
).view(1, -1, 1, 1)
pixel_coords[:, 1:, :, 1:] += spatial_end_offset
B = cond_value.cond.shape[0]
if B > 1:
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
return cond_value._copy_with(pixel_coords)
# Guide attention entries — adjust per-guide counts based on window overlap
if cond_key == "guide_attention_entries":
overlap_info = window.guide_overlap_info
H, W = x_in.shape[3], x_in.shape[4]
new_entries = []
for entry_idx, overlap_count in overlap_info:
e = cond_value.cond[entry_idx]
new_entries.append({**e,
"pre_filter_count": overlap_count * H * W,
"latent_shape": [overlap_count, H, W]})
return cond_value._copy_with(new_entries)
return None
class HunyuanVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)

View File

@ -4,11 +4,22 @@ Provides normalization and helper functions for job status tracking.
"""
import uuid
from typing import Optional
from typing import Callable, Optional
from comfy_api.internal import prune_dict
# Result of classifying a job for cancellation.
# 'running' -> job is currently executing (interrupt it)
# 'pending' -> job is queued but not started (dequeue it)
# 'terminal' -> job already finished (present in history); cancel is a no-op
# 'unknown' -> job id is not present anywhere
CANCEL_RUNNING = 'running'
CANCEL_PENDING = 'pending'
CANCEL_TERMINAL = 'terminal'
CANCEL_UNKNOWN = 'unknown'
class JobStatus:
"""Job status constants."""
PENDING = 'pending'
@ -407,3 +418,71 @@ def get_all_jobs(
jobs = jobs[:limit]
return (jobs, total_count)
def classify_job_for_cancel(prompt_id: str, running: list, queued: list, history: dict) -> str:
"""Classify a job id for cancellation.
Returns one of CANCEL_RUNNING, CANCEL_PENDING, CANCEL_TERMINAL, CANCEL_UNKNOWN.
Queue items are tuples whose second element (index 1) is the prompt_id.
History is a dict keyed by prompt_id, so a job present there has already
finished and cancelling it is a no-op.
"""
for item in running:
if item[1] == prompt_id:
return CANCEL_RUNNING
for item in queued:
if item[1] == prompt_id:
return CANCEL_PENDING
if prompt_id in history:
return CANCEL_TERMINAL
return CANCEL_UNKNOWN
def cancel_job(
prompt_id: str,
running: list,
queued: list,
history: dict,
interrupt: Callable[[str], bool],
dequeue: Callable[[str], bool],
) -> str:
"""Cancel a single job by id, regardless of state.
Maps the cancel onto the runtime's existing mechanics:
- a running job is interrupted via ``interrupt``
- a pending job is removed from the queue via ``dequeue``
- a job that already finished (terminal) is a no-op
- an unknown id is a no-op (callers that need fail-fast behaviour should
validate ids up front with ``classify_job_for_cancel``)
Both ``interrupt`` and ``dequeue`` take the prompt id and return whether
they acted on a job that was *actually* in that state, so the value returned
here reflects what truly happened rather than the (possibly stale)
classification. This matters around the narrow TOCTOU windows where a job
changes state between the caller's snapshot and the action:
- a job classified RUNNING may have finished before ``interrupt`` fires:
``interrupt`` returns False and this returns CANCEL_UNKNOWN (no-op).
- a job classified PENDING may have started executing before ``dequeue``
fires: ``dequeue`` returns False, ``interrupt`` then catches the now-
running job and this returns CANCEL_RUNNING. If it had simply finished
instead, both return False and this returns CANCEL_UNKNOWN.
``interrupt`` must be atomic — interrupt the job only if it is still the one
running — so a cancel can never land on an unrelated prompt that started in
the meantime (see ``execution.PromptQueue.interrupt_if_running``).
"""
classification = classify_job_for_cancel(prompt_id, running, queued, history)
if classification == CANCEL_RUNNING:
return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN
if classification == CANCEL_PENDING:
if dequeue(prompt_id):
return CANCEL_PENDING
# Left the pending queue between classification and dequeue: if it
# started executing, interrupt the now-running job; otherwise it has
# already finished and the cancel is a genuine no-op.
return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN
# CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops.
return classification

View File

@ -13,21 +13,22 @@ class ContextWindowsManualNode(io.ComfyNode):
description="Manually set context windows.",
inputs=[
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True),
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True),
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."),
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."),
io.Combo.Input("context_schedule", options=[
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
comfy.context_windows.ContextSchedules.BATCHED,
], tooltip="The stride of the context window."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
], default=comfy.context_windows.ContextSchedules.STATIC_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window. For concat-style I2V models (e.g. Wan I2V, HunyuanVideo I2V, Cosmos I2V, SVD) the encoded start image lives in the c_concat conditioning channels; setting this to '0' will retain that start image content at sub-pos 0 of every window."),
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
io.String.Input("latent_retain_index_list", default="", tooltip="List of latent indices to retain in the noise latent itself for each window. Use for workflows where reference content (e.g. a start image) lives directly in the noise latent rather than in separate conditioning channels (e.g. inplace-style I2V like LTXV, AnimateDiff). Independent of cond_retain_index_list."),
io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."),
],
outputs=[
@ -38,7 +39,7 @@ class ContextWindowsManualNode(io.ComfyNode):
@classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model:
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, latent_retain_index_list: list[int]=[], causal_window_fix: bool=True) -> io.Model:
model = model.clone()
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
@ -51,6 +52,7 @@ class ContextWindowsManualNode(io.ComfyNode):
freenoise=freenoise,
cond_retain_index_list=cond_retain_index_list,
split_conds_to_windows=split_conds_to_windows,
latent_retain_index_list=latent_retain_index_list,
causal_window_fix=causal_window_fix,
)
# make memory usage calculation only take into account the context window latents
@ -65,33 +67,71 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
schema = super().define_schema()
schema.node_id = "WanContextWindowsManual"
schema.display_name = "WAN Context Windows (Manual)"
schema.description = "Manually set context windows for WAN-like models (dim=2)."
schema.display_name = "Wan Context Windows"
schema.description = "Set context windows for Wan-like models."
schema.category="model/patch/wan"
schema.inputs = [
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window.", advanced=True),
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window.", advanced=True),
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window in real frames. Must be 4*n + 1."),
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window in real frames."),
io.Combo.Input("context_schedule", options=[
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
comfy.context_windows.ContextSchedules.BATCHED,
], tooltip="The stride of the context window."),
], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True),
io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first I2V frame in every context window (may help retain initial reference)."),
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True),
]
return schema
@classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
retain_first_frame: bool=False, split_conds_to_windows: bool=False) -> io.Model:
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
context_overlap = max(context_overlap // 4, 0) # at least overlap 0
retain_index_list = "0" if retain_first_frame else ""
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows)
class LTXVContextWindowsNode(ContextWindowsManualNode):
@classmethod
def define_schema(cls) -> io.Schema:
schema = super().define_schema()
schema.node_id = "LTXVContextWindows"
schema.display_name = "LTXV Context Windows"
schema.description = "Set context windows for LTXV-like models."
schema.inputs = [
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=8, default=145, tooltip="The length of the context window in real frames. Must be 8*n + 1."),
io.Int.Input("context_overlap", min=0, step=8, default=40, tooltip="The overlap of the context window in real frames."),
io.Combo.Input("context_schedule", options=[
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
comfy.context_windows.ContextSchedules.BATCHED,
], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True),
io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first latent frame in every context window (may help retain initial reference)."),
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True),
]
return schema
@classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, fuse_method: str, freenoise: bool,
retain_first_frame: bool=False, split_conds_to_windows: bool=False, context_stride: int=1, closed_loop: bool=False) -> io.Model:
context_length = max(((context_length - 1) // 8) + 1, 1) # at least length 1
context_overlap = max(context_overlap // 8, 0) # at least overlap 0
retain_index_list = "0" if retain_first_frame else ""
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise,
cond_retain_index_list=retain_index_list, latent_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows)
class ContextWindowsExtension(ComfyExtension):
@ -99,6 +139,7 @@ class ContextWindowsExtension(ComfyExtension):
return [
ContextWindowsManualNode,
WanContextWindowsManualNode,
LTXVContextWindowsNode,
]
def comfy_entrypoint():

View File

@ -77,7 +77,7 @@ class FrameInterpolate(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="FrameInterpolate",
display_name="Frame Interpolate",
display_name="Run Frame Interpolation Model",
category="video",
search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"],
inputs=[

View File

@ -89,7 +89,8 @@ class SwitchNode(io.ComfyNode):
template = io.MatchType.Template("switch")
return io.Schema(
node_id="ComfySwitchNode",
display_name="Switch",
search_aliases=["if", "then", "switch", "conditional", "branch"],
display_name="If/Else Switch",
category="utilities/logic",
is_experimental=True,
inputs=[

View File

@ -10,12 +10,11 @@ class String(io.ComfyNode):
return io.Schema(
node_id="PrimitiveString",
search_aliases=["text", "string", "text box", "prompt"],
display_name="Text String",
display_name="Text String (DEPRECATED)",
category="utilities/primitive",
inputs=[
io.String.Input("value"),
],
inputs=[io.String.Input("value")],
outputs=[io.String.Output()],
is_deprecated=True
)
@classmethod
@ -29,12 +28,10 @@ class StringMultiline(io.ComfyNode):
return io.Schema(
node_id="PrimitiveStringMultiline",
search_aliases=["text", "string", "text multiline", "string multiline", "text box", "prompt"],
display_name="Text String (Multiline)",
display_name="Input Text",
category="utilities/primitive",
essentials_category="Basics",
inputs=[
io.String.Input("value", multiline=True),
],
inputs=[io.String.Input("value", multiline=True)],
outputs=[io.String.Output()],
)

View File

@ -233,13 +233,8 @@ class VideoSlice(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="Video Slice",
display_name="Video Slice",
search_aliases=[
"trim video duration",
"skip first frames",
"frame load cap",
"start time",
],
display_name="Trim Video",
search_aliases=["trim video duration", "skip first frames", "frame load cap", "start time"],
category="video",
essentials_category="Video Tools",
inputs=[

View File

@ -1308,6 +1308,25 @@ class PromptQueue:
queued = copy.copy(self.queue)
return (running, queued)
def interrupt_if_running(self, prompt_id):
"""Interrupt the running prompt with this id, atomically.
Checks the live running set and signals the interrupt under the queue
mutex, so the worker cannot move the job to done (and start the next
prompt) in between. Returns True if a matching job was running and an
interrupt was signalled, False otherwise. The atomicity is what keeps a
cancel from landing on an unrelated prompt that started after a separate
is-running check: the global interrupt flag is reset at the start of
every prompt (execute_async), so a job that finishes before consuming
the flag cannot leak the interrupt onto its successor.
"""
with self.mutex:
for item in self.currently_running.values():
if item[1] == prompt_id:
nodes.interrupt_processing()
return True
return False
def get_tasks_remaining(self):
with self.mutex:
return len(self.queue) + len(self.currently_running)

111
server.py
View File

@ -8,7 +8,15 @@ import time
import nodes
import folder_paths
import execution
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs, validate_job_id
from comfy_execution.jobs import (
JobStatus,
get_job,
get_all_jobs,
validate_job_id,
cancel_job,
CANCEL_PENDING,
CANCEL_RUNNING,
)
import uuid
import urllib
import json
@ -899,6 +907,107 @@ class PromptServer():
return web.json_response(job)
def _cancel_job_by_id(job_id):
"""Cancel a single job by id using the queue's existing mechanics.
Running jobs are interrupted (same mechanism as /interrupt); pending
jobs are dequeued (same mechanism as /queue {"delete": [...]}).
Already-finished or unknown ids are no-ops. State-agnostic.
Returns True when a cancel was actually dispatched (running or
pending job), False when the call was a no-op (terminal/unknown id).
"""
running, queued = self.prompt_queue.get_current_queue()
history = self.prompt_queue.get_history()
def interrupt(prompt_id):
logging.info(f"Cancelling running prompt {prompt_id}")
# Atomic: only interrupts if the job is still the one running,
# so a cancel can't land on a prompt that started in the gap
# since the snapshot above. Returns whether it actually fired.
return self.prompt_queue.interrupt_if_running(prompt_id)
def dequeue(prompt_id):
logging.info(f"Cancelling pending prompt {prompt_id}")
return self.prompt_queue.delete_queue_item(lambda a: a[1] == prompt_id)
classification = cancel_job(job_id, running, queued, history, interrupt, dequeue)
return classification in (CANCEL_RUNNING, CANCEL_PENDING)
@routes.post("/api/jobs/{job_id}/cancel")
async def cancel_job_by_id(request):
"""Cancel a single job by id, regardless of state.
Idempotent: cancelling a job that has already finished, or an id
that is not known, returns 200 with {"cancelled": false} rather
than an error.
"""
job_id = request.match_info.get("job_id", None)
if not job_id:
return web.json_response(
{"error": "job_id is required"},
status=400
)
cancelled = _cancel_job_by_id(job_id)
return web.json_response({"cancelled": cancelled})
@routes.post("/api/jobs/cancel")
async def cancel_jobs_batch(request):
"""Cancel a batch of jobs by id.
Body: {"job_ids": ["<uuid>", ...]}
Best-effort and idempotent: every well-formed id is cancelled if it
is running or pending; ids that are already finished or unknown are
no-ops, not errors. A batch of all no-ops still returns 200 with
{"cancelled": false}. This matches the single-cancel endpoint and
means "cancel all" still cancels the in-progress jobs even if some
finished between the client's snapshot and the request. Malformed
ids are still rejected up front with 400 (see below).
"""
try:
json_data = await request.json()
except json.JSONDecodeError:
return web.json_response(
{"error": "Request body must be valid JSON"},
status=400
)
job_ids = json_data.get("job_ids") if isinstance(json_data, dict) else None
if not isinstance(job_ids, list):
return web.json_response(
{"error": "job_ids must be a list"},
status=400
)
# Validate that every element is a well-formed job id before doing
# anything else. An unhashable element (e.g. a nested dict or list)
# would cause a TypeError when used as a history dict key; a
# non-string or non-UUID value is never a valid id. Reject early
# with 400 rather than letting the classify loop raise 500.
invalid_ids = []
for jid in job_ids:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid if isinstance(jid, str) else repr(jid))
if invalid_ids:
return web.json_response(
{"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids},
status=400,
)
# Best-effort: cancel each id that is still running/pending; an id
# that has finished or never existed is a no-op rather than a reason
# to fail the whole batch.
cancelled = False
for jid in job_ids:
if _cancel_job_by_id(jid):
cancelled = True
return web.json_response({"cancelled": cancelled})
@routes.get("/history")
async def get_history(request):
max_items = request.rel_url.query.get("max_items", None)

View File

View File

@ -0,0 +1,453 @@
"""Tests for the jobs-namespace cancel endpoints.
Covers both layers:
* the pure cancel helpers in ``comfy_execution.jobs``
(``classify_job_for_cancel`` / ``cancel_job``), which hold the business
logic of mapping a cancel onto interrupt-vs-dequeue, and
* the HTTP contract of ``POST /api/jobs/{job_id}/cancel`` and
``POST /api/jobs/cancel`` (status codes, single-cancel idempotency, and
best-effort batch cancellation that treats unknown/finished ids as no-ops
while still rejecting malformed ids with 400).
The HTTP layer is exercised against a small aiohttp app whose handlers are a
faithful copy of the wiring in ``server.py`` driven by a fake queue that
mirrors ``execution.PromptQueue`` (``get_current_queue`` / ``get_history`` /
``delete_queue_item``). This keeps the test free of the heavy ComfyUI runtime
(torch, nodes, ...) while still testing the real cancel logic.
"""
import json
import pytest
from aiohttp import web
from comfy_execution.jobs import (
CANCEL_PENDING,
CANCEL_RUNNING,
CANCEL_TERMINAL,
CANCEL_UNKNOWN,
cancel_job,
classify_job_for_cancel,
validate_job_id,
)
# Classifications for which a cancel was actually dispatched (vs a no-op).
_CANCELLED = (CANCEL_RUNNING, CANCEL_PENDING)
# Canonical UUID ids for HTTP-layer tests (the batch endpoint validates UUID format).
_UUID_A = "aaaaaaaa-aaaa-4aaa-aaaa-aaaaaaaaaaaa"
_UUID_B = "bbbbbbbb-bbbb-4bbb-bbbb-bbbbbbbbbbbb"
_UUID_C = "cccccccc-cccc-4ccc-cccc-cccccccccccc"
_UUID_D = "dddddddd-dddd-4ddd-dddd-dddddddddddd"
_UUID_MISSING = "ffffffff-ffff-4fff-ffff-ffffffffffff"
def make_queue_item(prompt_id, number=0):
"""Build a queue tuple shaped like the real ones: index 1 is the id."""
return (number, prompt_id, {}, {}, [])
class FakePromptQueue:
"""Minimal stand-in for execution.PromptQueue for the cancel paths.
Tracks interrupts and dequeues so tests can assert side effects.
"""
def __init__(self, running=None, pending=None, history=None):
self._running = list(running or [])
self._pending = list(pending or [])
self._history = dict(history or {})
self.interrupt_count = 0
def get_current_queue(self):
return (list(self._running), list(self._pending))
def get_history(self, prompt_id=None):
if prompt_id is None:
return dict(self._history)
if prompt_id in self._history:
return {prompt_id: self._history[prompt_id]}
return {}
def delete_queue_item(self, function):
for i, item in enumerate(self._pending):
if function(item):
self._pending.pop(i)
return True
return False
def interrupt_if_running(self, prompt_id):
# Mirrors execution.PromptQueue.interrupt_if_running: only signals an
# interrupt when the id is actually in the running set.
if any(item[1] == prompt_id for item in self._running):
self.interrupt_count += 1
return True
return False
def build_app(queue):
"""Build an aiohttp app exposing the cancel routes against ``queue``.
Handler bodies mirror server.py exactly.
"""
def _cancel_job_by_id(job_id):
running, pending = queue.get_current_queue()
history = queue.get_history()
def interrupt(prompt_id):
return queue.interrupt_if_running(prompt_id)
def dequeue(prompt_id):
return queue.delete_queue_item(lambda a: a[1] == prompt_id)
classification = cancel_job(
job_id, running, pending, history, interrupt, dequeue
)
return classification in _CANCELLED
async def cancel_job_by_id(request):
job_id = request.match_info.get("job_id", None)
if not job_id:
return web.json_response({"error": "job_id is required"}, status=400)
cancelled = _cancel_job_by_id(job_id)
return web.json_response({"cancelled": cancelled})
async def cancel_jobs_batch(request):
try:
json_data = await request.json()
except json.JSONDecodeError:
return web.json_response(
{"error": "Request body must be valid JSON"}, status=400
)
job_ids = json_data.get("job_ids") if isinstance(json_data, dict) else None
if not isinstance(job_ids, list):
return web.json_response({"error": "job_ids must be a list"}, status=400)
invalid_ids = []
for jid in job_ids:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid if isinstance(jid, str) else repr(jid))
if invalid_ids:
return web.json_response(
{"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids},
status=400,
)
cancelled = False
for jid in job_ids:
if _cancel_job_by_id(jid):
cancelled = True
return web.json_response({"cancelled": cancelled})
app = web.Application()
app.router.add_post("/api/jobs/{job_id}/cancel", cancel_job_by_id)
app.router.add_post("/api/jobs/cancel", cancel_jobs_batch)
return app
# ---------------------------------------------------------------------------
# Pure helper tests: classification + cancel side effects
# ---------------------------------------------------------------------------
class TestClassifyJobForCancel:
def test_running(self):
running = [make_queue_item("a")]
assert classify_job_for_cancel("a", running, [], {}) == CANCEL_RUNNING
def test_pending(self):
pending = [make_queue_item("b")]
assert classify_job_for_cancel("b", [], pending, {}) == CANCEL_PENDING
def test_terminal(self):
history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}}
assert classify_job_for_cancel("c", [], [], history) == CANCEL_TERMINAL
def test_unknown(self):
assert classify_job_for_cancel("z", [], [], {}) == CANCEL_UNKNOWN
class TestCancelJobHelper:
"""``interrupt`` and ``dequeue`` both take the id and return whether they
actually acted, so cancel_job's return reflects the real outcome."""
def test_running_is_interrupted_not_dequeued(self):
interrupts = []
dequeues = []
result = cancel_job(
"a", [make_queue_item("a")], [], {},
interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: dequeues.append(pid) or True,
)
assert result == CANCEL_RUNNING
assert interrupts == ["a"]
assert dequeues == []
def test_pending_is_dequeued_not_interrupted(self):
interrupts = []
dequeues = []
result = cancel_job(
"b", [], [make_queue_item("b")], {},
interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: dequeues.append(pid) or True,
)
assert result == CANCEL_PENDING
assert dequeues == ["b"]
assert interrupts == []
def test_terminal_is_noop(self):
history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}}
interrupts = []
dequeues = []
result = cancel_job(
"c", [], [], history,
interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: dequeues.append(pid) or True,
)
assert result == CANCEL_TERMINAL
assert interrupts == []
assert dequeues == []
def test_unknown_is_noop(self):
interrupts = []
dequeues = []
result = cancel_job(
"z", [], [], {},
interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: dequeues.append(pid) or True,
)
assert result == CANCEL_UNKNOWN
assert interrupts == []
assert dequeues == []
def test_running_but_finished_before_interrupt_returns_unknown(self):
"""Classified RUNNING from a stale snapshot, but the job finished before
the atomic interrupt fired (interrupt returns False). cancel_job reports
UNKNOWN rather than claiming a cancel that did not happen — and the
atomic interrupt guarantees no unrelated job was hit."""
interrupts = []
result = cancel_job(
"a", [make_queue_item("a")], [], {},
interrupt=lambda pid: interrupts.append(pid) or False,
dequeue=lambda pid: True,
)
assert result == CANCEL_UNKNOWN
assert interrupts == ["a"] # interrupt was attempted atomically
def test_pending_started_running_is_interrupted(self):
"""Pending->running race: the job leaves the queue (dequeue False)
because it started executing. The atomic interrupt catches the now-
running job, so cancel_job interrupts it and reports CANCEL_RUNNING."""
interrupts = []
dequeues = []
result = cancel_job(
"b", [], [make_queue_item("b")], {},
interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: (dequeues.append(pid), False)[1],
)
assert result == CANCEL_RUNNING
assert dequeues == ["b"] # dequeue attempted first
assert interrupts == ["b"] # then the now-running job was interrupted
def test_pending_dequeue_miss_not_running_returns_unknown(self):
"""Dequeue miss where the job is not running anymore (it finished): the
atomic interrupt finds nothing to interrupt and returns False, so
cancel_job is a no-op reporting UNKNOWN — never reporting a cancel that
did not happen, and never interrupting a bystander."""
interrupts = []
dequeues = []
result = cancel_job(
"b", [], [make_queue_item("b")], {},
interrupt=lambda pid: interrupts.append(pid) or False,
dequeue=lambda pid: (dequeues.append(pid), False)[1],
)
assert result == CANCEL_UNKNOWN
assert dequeues == ["b"]
assert interrupts == ["b"] # interrupt attempted, found nothing running
# ---------------------------------------------------------------------------
# HTTP contract tests: POST /api/jobs/{job_id}/cancel
# ---------------------------------------------------------------------------
class TestSingleCancelEndpoint:
@pytest.mark.asyncio
async def test_cancel_running_job_interrupts(self, aiohttp_client):
queue = FakePromptQueue(running=[make_queue_item("a")])
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/a/cancel")
assert resp.status == 200
assert (await resp.json()) == {"cancelled": True}
assert queue.interrupt_count == 1
@pytest.mark.asyncio
async def test_cancel_pending_job_dequeues(self, aiohttp_client):
queue = FakePromptQueue(pending=[make_queue_item("b")])
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/b/cancel")
assert resp.status == 200
assert (await resp.json()) == {"cancelled": True}
# Pending job removed from the queue; nothing interrupted.
assert queue.get_current_queue()[1] == []
assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_cancel_terminal_job_is_idempotent_noop(self, aiohttp_client):
history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}}
queue = FakePromptQueue(history=history)
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/c/cancel")
# Already-finished job: 200 no-op (cancelled=false), not an error.
assert resp.status == 200
assert (await resp.json()) == {"cancelled": False}
assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_cancel_unknown_id_is_200_noop(self, aiohttp_client):
queue = FakePromptQueue()
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/does-not-exist/cancel")
# Single-cancel of an unknown id is treated as an idempotent no-op.
assert resp.status == 200
assert (await resp.json()) == {"cancelled": False}
assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_cancel_pending_that_started_running_interrupts(self, aiohttp_client):
"""Pending->running race end to end: the job is pending at snapshot time
but starts executing by the time we dequeue (delete misses). The live
re-check sees it running and interrupts it, so the cancel is not dropped
and the caller still gets cancelled=True."""
class RacingQueue(FakePromptQueue):
def delete_queue_item(self, function):
# The worker picked the job up just before we removed it: it
# leaves the pending queue (delete misses) and is now running.
self._running = list(self._pending)
self._pending = []
return False
queue = RacingQueue(pending=[make_queue_item("b")])
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/b/cancel")
assert resp.status == 200
assert (await resp.json()) == {"cancelled": True}
assert queue.interrupt_count == 1
# ---------------------------------------------------------------------------
# HTTP contract tests: POST /api/jobs/cancel (batch)
# ---------------------------------------------------------------------------
class TestBatchCancelEndpoint:
@pytest.mark.asyncio
async def test_batch_happy_path(self, aiohttp_client):
queue = FakePromptQueue(
running=[make_queue_item(_UUID_A)],
pending=[make_queue_item(_UUID_B, number=1)],
)
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_B]})
assert resp.status == 200
assert (await resp.json()) == {"cancelled": True}
assert queue.interrupt_count == 1 # running job interrupted
assert queue.get_current_queue()[1] == [] # pending job dequeued
@pytest.mark.asyncio
async def test_batch_best_effort_skips_unknown_id(self, aiohttp_client):
"""An unknown id in the batch is a no-op, not a reason to abort: the
running and pending jobs are still cancelled (200, cancelled=true). This
is the "cancel all as a job finishes" case from review."""
queue = FakePromptQueue(
running=[make_queue_item(_UUID_A)],
pending=[make_queue_item(_UUID_B, number=1)],
)
client = await aiohttp_client(build_app(queue))
resp = await client.post(
"/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_MISSING, _UUID_B]}
)
assert resp.status == 200
assert (await resp.json()) == {"cancelled": True}
assert queue.interrupt_count == 1 # running job interrupted
assert queue.get_current_queue()[1] == [] # pending job dequeued
@pytest.mark.asyncio
async def test_batch_all_terminal_is_idempotent_noop(self, aiohttp_client):
history = {
_UUID_C: {"prompt": make_queue_item(_UUID_C), "outputs": {}, "status": {}},
_UUID_D: {"prompt": make_queue_item(_UUID_D), "outputs": {}, "status": {}},
}
queue = FakePromptQueue(history=history)
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_C, _UUID_D]})
# All known but terminal: 200 with cancelled=false, nothing dispatched.
assert resp.status == 200
assert (await resp.json()) == {"cancelled": False}
assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_batch_missing_job_ids_is_400(self, aiohttp_client):
queue = FakePromptQueue()
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/cancel", json={})
assert resp.status == 400
@pytest.mark.asyncio
async def test_batch_unhashable_element_is_400_not_500(self, aiohttp_client):
"""An unhashable element such as a dict or list must yield 400, not 500.
Previously, passing e.g. {"job_ids": [{}]} would reach the classify
loop where ``prompt_id in history`` raises TypeError on an unhashable
type, resulting in an unhandled 500. The input-validation guard must
catch this before any queue or history access.
"""
queue = FakePromptQueue()
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/cancel", json={"job_ids": [{}]})
assert resp.status == 400
body = await resp.json()
assert "invalid_ids" in body
# No queue side effects.
assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_batch_non_uuid_string_element_is_400(self, aiohttp_client):
"""A string that is not a valid UUID must be rejected with 400."""
queue = FakePromptQueue()
client = await aiohttp_client(build_app(queue))
resp = await client.post(
"/api/jobs/cancel", json={"job_ids": ["not-a-uuid"]}
)
assert resp.status == 400
body = await resp.json()
assert "invalid_ids" in body