Compare commits

..

4 Commits

Author SHA1 Message Date
63784baed5 fix(server): serialize PromptMetadataStore access with a lock
Addresses comfyanonymous's review nit on PR #13905. The store is touched
from three threads — the aiohttp event loop (``register`` via
``post_prompt``), the worker thread (``unregister`` via the
``prompt_worker`` try/finally and ``execution_error`` paths), and any
thread that fires ``send_sync`` (``inject``). Individual ``dict``
operations are GIL-atomic but ``register``'s ``len -> pop -> setitem``
and ``inject``'s ``get -> {**a, **b}`` are multi-step compounds whose
interleaving without a lock is racy. A single ``threading.Lock`` keeps
the FIFO cap honest and snapshots the envelope under the lock before
the spread runs.

Adds a stress-test that runs concurrent register/unregister/inject for
100 ms across five threads and asserts no exception escapes and the
capacity bound is held.
2026-05-14 21:26:06 -07:00
fc9820ebb9 refactor(server): spread envelope keys onto payload at top level
Switch the wire shape from nested ``metadata: {workflow_id: ...}`` to
spreading the envelope's keys directly onto each event payload. The
contract on the websocket is now identical to the prior workflow-id-on-
events work — consumers read ``event.workflow_id`` directly — but the
core executor still has no concept of workflow scope; the envelope is
captured at submission and decorated at the server transport layer.

Server-emitted fields always win on collision (``{**envelope, **d}``):
a misbehaving client cannot shadow ``prompt_id``, ``node``, etc. by
stamping the same key in their submission envelope.
2026-05-14 21:18:43 -07:00
fd89498eac fix(server): bound metadata envelope and clean up on cancel paths
Addresses review feedback on the per-prompt metadata envelope:

- Sanitize at the boundary: reject envelopes larger than 16 keys, keys
  over 64 chars, values over 256 chars, or anything that isn't a flat
  ``dict[str, str]``. Logs a warning so abuse is observable. Stops a
  malicious client from inflating broadcast volume by stamping a 10 MB
  metadata blob onto every WS event.
- Cap the in-memory store at 4096 concurrent envelopes with FIFO
  eviction. Acts as a backstop if any cleanup hook is skipped.
- Drop envelopes when prompts are cancelled before reaching the worker:
  ``PromptQueue.wipe_queue`` and ``delete_queue_item`` now call
  ``server.unregister_prompt_metadata`` for every removed item.
- Drop envelopes on hard execution failures: the worker now wraps
  ``e.execute()`` in ``try/finally``, so an uncaught exception in
  execution no longer leaks the envelope.
- Guard the WS reconnect handler: only include ``prompt_id`` in the
  ``executing`` payload when ``last_prompt_id`` is set, so clients
  with strict schemas (zod ``prompt_id: zJobId``) don't reject the
  message with a null id.
- Extract a ``PromptMetadataStore`` class that owns the dict and the
  bounds, so ``PromptServer`` becomes a thin delegating layer and the
  full register/inject/unregister cycle (plus FIFO eviction and
  sanitization) is unit-tested without torch.

44 tests passing; ruff clean on all touched files.
2026-05-14 21:03:38 -07:00
74cfcaa318 feat(server): per-prompt metadata envelope on websocket events
Replaces the workflow_id-on-every-event approach (#13684, reverted in
#13901) with a generic metadata envelope captured at submission and
injected at the server-side send chokepoint.

- POST /prompt accepts an opaque ``extra_data.metadata`` dict (falls
  back to synthesizing ``{"workflow_id": <id>}`` from
  ``extra_pnginfo.workflow.id`` so existing frontends keep working).
- ``PromptServer`` owns a ``prompt_id -> metadata`` map populated at
  submission, drained when the prompt finishes. ``send_sync`` injects
  the envelope into any outbound payload that carries a ``prompt_id``,
  including the ``(preview_image, metadata_dict)`` tuple used by
  ``PREVIEW_IMAGE_WITH_METADATA``. WS reconnect path carries it too.
- Pure helpers live in ``app/prompt_metadata.py`` so the execution
  layer never depends on workflow concepts and the helpers can be
  unit-tested without torch.

Execution layer (``execution.py``, ``comfy_execution/*``) and the jobs
API are unchanged. Backward compatible: existing fields and shapes are
preserved, only an additional ``metadata`` field is attached when
present.
2026-05-14 20:47:00 -07:00
11 changed files with 784 additions and 384 deletions

226
app/prompt_metadata.py Normal file
View File

@ -0,0 +1,226 @@
"""Per-prompt metadata envelope shared between submission and outbound events.
The metadata envelope is a small flat ``dict[str, str]`` (e.g.
``{"workflow_id": ...}``) attached to a prompt at submission and injected
by the server into every outbound execution event that carries a
``prompt_id``. It lets consumers scope state by tags they care about
(workflow, trace, tenant) without the execution layer ever needing to
know those tags exist.
This module is intentionally pure — no imports from ``server`` or
``execution`` — so ``PromptServer`` can own a ``PromptMetadataStore``
instance and the helpers can be unit-tested without the rest of the app.
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Callable, Optional
# Bounds. The envelope is forwarded to every WebSocket client connected to
# the server on every execution event for the prompt — bounding key count,
# key length, value length, and refusing nested structures keeps a
# malicious or buggy client from inflating the broadcast volume.
MAX_ENVELOPE_KEYS = 16
MAX_ENVELOPE_KEY_LEN = 64
MAX_ENVELOPE_VALUE_LEN = 256
# Cap on concurrently registered prompt envelopes. Acts as a backstop if
# the cleanup hook is ever bypassed; FIFO eviction so the oldest stale
# entry goes first.
DEFAULT_STORE_CAPACITY = 4096
def _sanitize_envelope(envelope: Any) -> Optional[dict]:
"""Validate and copy a candidate envelope.
Enforces the ``dict[str, str]`` contract that downstream consumers
(cloud projections, frontend zod schemas, OpenAPI docs) rely on:
- must be a non-empty ``dict``
- at most ``MAX_ENVELOPE_KEYS`` entries
- every key and value must be a ``str``
- keys at most ``MAX_ENVELOPE_KEY_LEN`` chars
- values at most ``MAX_ENVELOPE_VALUE_LEN`` chars
Returns a defensive shallow copy on success, ``None`` on any
violation. Logs a warning on violation so abuse is visible.
"""
if not isinstance(envelope, dict) or not envelope:
return None
if len(envelope) > MAX_ENVELOPE_KEYS:
logging.warning(
"prompt metadata envelope rejected: %d keys exceeds limit %d",
len(envelope), MAX_ENVELOPE_KEYS,
)
return None
sanitized: dict[str, str] = {}
for key, value in envelope.items():
if not isinstance(key, str) or not isinstance(value, str):
logging.warning(
"prompt metadata envelope rejected: non-string key/value (%s=%s)",
type(key).__name__, type(value).__name__,
)
return None
if len(key) > MAX_ENVELOPE_KEY_LEN or len(value) > MAX_ENVELOPE_VALUE_LEN:
logging.warning(
"prompt metadata envelope rejected: key or value exceeds length limit",
)
return None
sanitized[key] = value
return sanitized
def extract_envelope_from_extra_data(extra_data: Any) -> Optional[dict]:
"""Pull the per-prompt metadata envelope out of a submitted prompt's
``extra_data``.
Two sources, in order:
1. Explicit ``extra_data["metadata"]`` — sanitized via
``_sanitize_envelope``. Oversized or wrong-typed envelopes are
rejected (a warning is logged) rather than truncated, so the
contract stays strict at the boundary.
2. ``extra_data["extra_pnginfo"]["workflow"]["id"]`` — backward-
compatibility fallback. Frontends that already stamp the workflow
id into ``extra_pnginfo`` keep working; the synthesized envelope
is ``{"workflow_id": <id>}``. A debug log fires so the legacy path
remains observable.
Returns ``None`` when neither source yields a usable envelope.
"""
if not isinstance(extra_data, dict):
return None
if "metadata" in extra_data:
sanitized = _sanitize_envelope(extra_data["metadata"])
if sanitized is not None:
return sanitized
# Explicit metadata was supplied but rejected — do not fall
# through to the legacy path; the caller asked for something
# specific and got it wrong.
if isinstance(extra_data["metadata"], dict) and extra_data["metadata"]:
return None
extra_pnginfo = extra_data.get("extra_pnginfo")
if isinstance(extra_pnginfo, dict):
workflow = extra_pnginfo.get("workflow")
if isinstance(workflow, dict):
workflow_id = workflow.get("id")
if (
isinstance(workflow_id, str)
and workflow_id
and len(workflow_id) <= MAX_ENVELOPE_VALUE_LEN
):
logging.debug(
"prompt metadata envelope synthesized from extra_pnginfo.workflow.id"
)
return {"workflow_id": workflow_id}
return None
def inject_envelope(
data: Any,
envelope_lookup: Callable[[str], Optional[dict]],
) -> Any:
"""Return ``data`` with the per-prompt envelope's keys spread onto it.
``envelope_lookup`` is called with the payload's ``prompt_id`` and is
expected to return the registered envelope or ``None``. This keeps
the function pure and avoids depending on any specific storage.
The envelope's keys are merged onto the payload at the top level so
consumers can read them directly (e.g. ``event.workflow_id``) —
matching the wire shape of the prior workflow-id-on-events work and
avoiding an extra nesting hop for clients. Server-emitted fields on
the payload always win on collision (``{**envelope, **d}``); a
misbehaving client cannot shadow ``prompt_id``, ``node``, etc.
Two payload shapes are handled:
- **dict** carrying ``prompt_id``. A shallow copy is returned with
the envelope's keys merged onto it.
- **(preview_image, metadata_dict) tuple** — the format used by
``PREVIEW_IMAGE_WITH_METADATA``. Only the inner dict is augmented;
the binary preview is passed through by reference.
No-op for payloads without a ``prompt_id``, prompts with no
registered envelope, or any other payload shape.
"""
def inject(d: dict) -> dict:
if not isinstance(d, dict):
return d
prompt_id = d.get("prompt_id")
if not prompt_id:
return d
envelope = envelope_lookup(prompt_id)
if envelope is None:
return d
return {**envelope, **d}
if isinstance(data, dict):
return inject(data)
if isinstance(data, tuple) and len(data) == 2 and isinstance(data[1], dict):
injected = inject(data[1])
if injected is data[1]:
return data
return (data[0], injected)
return data
class PromptMetadataStore:
"""Bounded ``prompt_id -> envelope`` map.
Owned by ``PromptServer``. Populated at submission, drained when the
prompt finishes, wiped on queue cancel/delete. The FIFO cap is a
backstop: if any cleanup hook is ever skipped, the store sheds the
oldest entry instead of growing without bound.
Access is serialized through a ``threading.Lock``. ``register`` runs
on the aiohttp event-loop thread, ``unregister`` runs on the
``prompt_worker`` thread, and ``inject`` runs on whichever thread
fires ``send_sync`` (event loop, worker, asset seeder). Individual
``dict`` ops are GIL-atomic, but ``register``'s
``len() -> pop -> __setitem__`` and ``inject``'s ``get -> {**a, **b}``
are multi-step compounds whose interleaving without a lock is
racy. The lock is uncontended in steady state (sub-microsecond
critical sections) so the cost is negligible.
"""
def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY):
self._envelopes: dict[str, dict] = {}
self._capacity = capacity
self._lock = threading.Lock()
def register(self, prompt_id: str, extra_data: Any) -> None:
envelope = extract_envelope_from_extra_data(extra_data)
if envelope is None:
return
with self._lock:
if len(self._envelopes) >= self._capacity:
self._envelopes.pop(next(iter(self._envelopes)))
self._envelopes[prompt_id] = envelope
def unregister(self, prompt_id: str) -> None:
with self._lock:
self._envelopes.pop(prompt_id, None)
def inject(self, data: Any) -> Any:
# Snapshot the envelope under the lock so the spread in
# ``inject_envelope`` runs against a consistent view even if a
# concurrent ``register``/``unregister`` is mutating the map.
def locked_lookup(prompt_id: str) -> Optional[dict]:
with self._lock:
return self._envelopes.get(prompt_id)
return inject_envelope(data, locked_lookup)
def __len__(self) -> int:
with self._lock:
return len(self._envelopes)
def __contains__(self, prompt_id: str) -> bool:
with self._lock:
return prompt_id in self._envelopes

View File

@ -22,25 +22,26 @@ class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing."""
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
def __init__(self, tensor: torch.Tensor, patches_per_frame: int, per_frame: bool = False):
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
"""
tensor: [batch, num_tokens, feature_dim] (per-token, default) or
[batch, num_frames, feature_dim] (per_frame=True, already compressed).
patches_per_frame: spatial patches per frame; pass None to disable compression.
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
"""
self.batch_size, n, self.feature_dim = tensor.shape
if per_frame:
self.batch_size, num_tokens, self.feature_dim = tensor.shape
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
self.patches_per_frame = patches_per_frame
self.num_frames = n
self.data = tensor
elif patches_per_frame is not None and n >= patches_per_frame and n % patches_per_frame == 0:
self.patches_per_frame = patches_per_frame
self.num_frames = n // patches_per_frame
# All patches in a frame are identical — keep only the first.
self.data = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)[:, :, 0, :].contiguous()
self.num_frames = num_tokens // patches_per_frame
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
# All patches in a frame are identical, so we only keep the first one
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
else:
# Not divisible or too small - store directly without compression
self.patches_per_frame = 1
self.num_frames = n
self.num_frames = num_tokens
self.data = tensor
def expand(self):
@ -715,35 +716,32 @@ class LTXAVModel(LTXVModel):
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
"""Prepare timestep embeddings."""
# TODO: some code reuse is needed here.
grid_mask = kwargs.get("grid_mask", None)
orig_shape = kwargs.get("orig_shape")
has_spatial_mask = kwargs.get("has_spatial_mask", None)
v_patches_per_frame = None
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
v_patches_per_frame = orig_shape[3] * orig_shape[4]
if grid_mask is not None:
timestep = timestep[:, grid_mask]
# Used by compute_prompt_timestep and the audio cross-attention paths.
timestep_scaled = (timestep[:, grid_mask] if grid_mask is not None else timestep) * self.timestep_scale_multiplier
# When patches in a frame share a timestep (no spatial mask), project one row per frame instead of one per token
per_frame_path = v_patches_per_frame is not None and (timestep.numel() // batch_size) % v_patches_per_frame == 0
if per_frame_path:
per_frame = timestep.reshape(batch_size, -1, v_patches_per_frame)[:, :, 0]
if grid_mask is not None:
# All-or-nothing per frame when has_spatial_mask=False.
per_frame = per_frame[:, grid_mask[::v_patches_per_frame]]
ts_input = per_frame * self.timestep_scale_multiplier
else:
ts_input = timestep_scaled
timestep_scaled = timestep * self.timestep_scale_multiplier
v_timestep, v_embedded_timestep = self.adaln_single(
ts_input.flatten(),
timestep_scaled.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
orig_shape = kwargs.get("orig_shape")
has_spatial_mask = kwargs.get("has_spatial_mask", None)
v_patches_per_frame = None
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
v_patches_per_frame = orig_shape[3] * orig_shape[4]
# Reshape to [batch_size, num_tokens, dim] and compress for storage
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
v_prompt_timestep = compute_prompt_timestep(
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype

View File

@ -358,61 +358,6 @@ def apply_split_rotary_emb(input_tensor, cos, sin):
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
class GuideAttentionMask:
"""Holds the two per-group masks for LTXV guide self-attention.
_attention_with_guide_mask splits queries into noisy and tracked-guide
groups, so the largest mask is (1, 1, tracked_count, T).
"""
__slots__ = ("guide_start", "tracked_count", "noisy_mask", "tracked_mask")
def __init__(self, total_tokens, guide_start, tracked_count, tracked_weights):
device = tracked_weights.device
dtype = tracked_weights.dtype
finfo = torch.finfo(dtype)
pos = tracked_weights > 0
log_w = torch.full_like(tracked_weights, finfo.min)
log_w[pos] = torch.log(tracked_weights[pos].clamp(min=finfo.tiny))
self.guide_start = guide_start
self.tracked_count = tracked_count
self.noisy_mask = torch.zeros((1, 1, 1, total_tokens), device=device, dtype=dtype)
self.noisy_mask[:, :, :, guide_start:guide_start + tracked_count] = log_w.view(1, 1, 1, -1)
self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype)
self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1)
def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options):
"""Apply the guide mask by partitioning Q into noisy and tracked-guide
groups, so each group needs only its own sub-mask. Avoids materializing
the (1,1,T,T) dense mask.
"""
guide_start = guide_mask.guide_start
tracked_end = guide_start + guide_mask.tracked_count
out = torch.empty_like(q)
if guide_start > 0: # In practice currently guides are always after noise, guard for safety if this changes.
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False, # sageattn mask support is unreliable
)
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False,
)
if tracked_end < q.shape[1]: # Every guide token is tracked, and nothing comes after them, guard for safety if this changes.
out[:, tracked_end:, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, tracked_end:, :], k, v, heads,
attn_precision=attn_precision, transformer_options=transformer_options,
)
return out
class CrossAttention(nn.Module):
def __init__(
self,
@ -467,10 +412,8 @@ class CrossAttention(nn.Module):
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
elif isinstance(mask, GuideAttentionMask):
out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
# Apply per-head gating if enabled
if self.to_gate_logits is not None:
@ -1120,9 +1063,7 @@ class LTXVModel(LTXBaseModel):
additional_args["resolved_guide_entries"] = resolved_entries
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
# Total surviving guide tokens (all guides)
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
@ -1158,12 +1099,12 @@ class LTXVModel(LTXBaseModel):
if not resolved_entries:
return None
# strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention.
needs_mask = any(
e["strength"] != 1.0 or e.get("pixel_mask") is not None
# Check if any attenuation is actually needed
needs_attenuation = any(
e["strength"] < 1.0 or e.get("pixel_mask") is not None
for e in resolved_entries
)
if not needs_mask:
if not needs_attenuation:
return None
# Build per-guide-token weights for all tracked guide tokens.
@ -1218,11 +1159,16 @@ class LTXVModel(LTXBaseModel):
# Concatenate per-token weights for all tracked guides
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
# Skip when every weight is exactly 1.0 (additive bias would be 0).
if (tracked_weights == 1.0).all():
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
if (tracked_weights >= 1.0).all():
return None
return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights)
# Build the mask: guide tokens are at the end of the sequence.
# Tracked guides come first (in order), untracked follow.
return self._build_self_attention_mask(
total_tokens, num_guide_tokens, total_tracked,
tracked_weights, guide_start, device, dtype,
)
@staticmethod
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
@ -1288,6 +1234,45 @@ class LTXVModel(LTXBaseModel):
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
@staticmethod
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
tracked_weights, guide_start, device, dtype):
"""Build a log-space additive self-attention bias mask.
Attenuates attention between noisy tokens and tracked guide tokens.
Untracked guide tokens (at the end of the guide portion) keep full attention.
Args:
total_tokens: Total sequence length.
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
tracked_count: Number of tracked guide tokens (first in the guide portion).
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
guide_start: Index where guide tokens begin in the sequence.
device: Target device.
dtype: Target dtype.
Returns:
(1, 1, total_tokens, total_tokens) additive bias mask.
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
"""
finfo = torch.finfo(dtype)
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
tracked_end = guide_start + tracked_count
# Convert weights to log-space bias
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
log_w = torch.full_like(w, finfo.min)
positive_mask = w > 0
if positive_mask.any():
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
# noisy → tracked guides: each noisy row gets the same per-guide weight
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
return mask
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
"""Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {})

View File

@ -136,7 +136,7 @@ class ImageFromBatch(IO.ComfyNode):
category="image/batch",
inputs=[
IO.Image.Input("image"),
IO.Int.Input("batch_index", default=0, min=-MAX_RESOLUTION, max=MAX_RESOLUTION),
IO.Int.Input("batch_index", default=0, min=0, max=4095),
IO.Int.Input("length", default=1, min=1, max=4096),
],
outputs=[IO.Image.Output()],
@ -145,9 +145,7 @@ class ImageFromBatch(IO.ComfyNode):
@classmethod
def execute(cls, image, batch_index, length) -> IO.NodeOutput:
s_in = image
if batch_index < 0:
batch_index += s_in.shape[0]
batch_index = max(0, min(s_in.shape[0] - 1, batch_index))
batch_index = min(s_in.shape[0] - 1, batch_index)
length = min(s_in.shape[0] - batch_index, length)
s = s_in[batch_index:batch_index + length].clone()
return IO.NodeOutput(s)

View File

@ -219,7 +219,7 @@ class LTXVAddGuide(io.ComfyNode):
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
@ -298,7 +298,7 @@ class LTXVAddGuide(io.ComfyNode):
else:
mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
1.0 - strength,
dtype=noise_mask.dtype,
device=noise_mask.device,
)
@ -318,7 +318,7 @@ class LTXVAddGuide(io.ComfyNode):
mask = torch.full(
(noise_mask.shape[0], 1, cond_length, 1, 1),
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
1.0 - strength,
dtype=noise_mask.dtype,
device=noise_mask.device,
)

View File

@ -1296,7 +1296,10 @@ class PromptQueue:
def wipe_queue(self):
with self.mutex:
dropped_prompt_ids = [item[1] for item in self.queue]
self.queue = []
for prompt_id in dropped_prompt_ids:
self.server.unregister_prompt_metadata(prompt_id)
self.server.queue_updated()
def delete_queue_item(self, function):
@ -1306,8 +1309,9 @@ class PromptQueue:
if len(self.queue) == 1:
self.wipe_queue()
else:
self.queue.pop(x)
deleted = self.queue.pop(x)
heapq.heapify(self.queue)
self.server.unregister_prompt_metadata(deleted[1])
self.server.queue_updated()
return True
return False

38
main.py
View File

@ -27,7 +27,6 @@ from utils.mime_types import init_mime_types
import faulthandler
import logging
import sys
import traceback
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags
@ -149,14 +148,6 @@ def execute_prestartup_script():
return True
except Exception as e:
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
from nodes import record_node_startup_error
record_node_startup_error(
module_path=os.path.dirname(script_path),
source="custom_nodes",
phase="prestartup",
error=e,
tb=traceback.format_exc(),
)
return False
node_paths = folder_paths.get_folder_paths("custom_nodes")
@ -327,19 +318,26 @@ def prompt_worker(q, server_instance):
extra_data[k] = sensitive[k]
asset_seeder.pause()
e.execute(item[2], prompt_id, extra_data, item[4])
try:
e.execute(item[2], prompt_id, extra_data, item[4])
need_gc = True
need_gc = True
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
q.task_done(item_id,
e.history_result,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages), process_item=remove_sensitive)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
q.task_done(item_id,
e.history_result,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages), process_item=remove_sensitive)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
finally:
# Always drop the metadata envelope. If e.execute() raises
# hard before its own error handling kicks in, the
# registered envelope would otherwise leak for the
# lifetime of the process.
server_instance.unregister_prompt_metadata(prompt_id)
current_time = time.perf_counter()
execution_time = current_time - execution_start_time

View File

@ -1221,7 +1221,7 @@ class LatentFromBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"batch_index": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION}),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
@ -1232,9 +1232,7 @@ class LatentFromBatch:
def frombatch(self, samples, batch_index, length):
s = samples.copy()
s_in = samples["samples"]
if batch_index < 0:
batch_index += s_in.shape[0]
batch_index = max(0, min(s_in.shape[0] - 1, batch_index))
batch_index = min(s_in.shape[0] - 1, batch_index)
length = min(s_in.shape[0] - batch_index, length)
s["samples"] = s_in[batch_index:batch_index + length].clone()
if "noise_mask" in samples:
@ -2154,71 +2152,6 @@ EXTENSION_WEB_DIRS = {}
# Dictionary of successfully loaded module names and associated directories.
LOADED_MODULE_DIRS = {}
# Dictionary of custom node startup errors, keyed by "<source>:<module_name>"
# so that name collisions across custom_nodes / comfy_extras / comfy_api_nodes
# do not overwrite each other. Each value contains: source, module_name,
# module_path, error, traceback, phase.
#
# `source` is the same string as the internal `module_parent` used at load
# time (e.g. "custom_nodes", "comfy_extras", "comfy_api_nodes"). It is
# intentionally a free-form string rather than a fixed enum so the contract
# survives node-source layouts evolving (e.g. comfy_api_nodes eventually
# moving out of core). Consumers should treat any new value as a new bucket
# rather than rejecting it.
NODE_STARTUP_ERRORS: dict[str, dict] = {}
def _read_pyproject_metadata(module_path: str) -> dict | None:
"""Best-effort extraction of node-pack identity from pyproject.toml.
Returns a dict with the Comfy Registry-style identity (pack_id,
display_name, publisher_id, version, repository) when the module
directory contains a pyproject.toml. Returns None when no toml is
present or parsing fails for any reason — startup-error tracking
must never itself raise.
"""
if not module_path or not os.path.isdir(module_path):
return None
toml_path = os.path.join(module_path, "pyproject.toml")
if not os.path.isfile(toml_path):
return None
try:
from comfy_config import config_parser
cfg = config_parser.extract_node_configuration(module_path)
if cfg is None:
return None
meta = {
"pack_id": cfg.project.name or None,
"display_name": cfg.tool_comfy.display_name or None,
"publisher_id": cfg.tool_comfy.publisher_id or None,
"version": cfg.project.version or None,
"repository": cfg.project.urls.repository or None,
}
# Drop empty fields so the API payload stays compact.
return {k: v for k, v in meta.items() if v}
except Exception:
return None
def record_node_startup_error(
*, module_path: str, source: str, phase: str, error: BaseException, tb: str
) -> None:
"""Record a startup error for a node module so it can be exposed via the API."""
module_name = get_module_name(module_path)
entry = {
"source": source,
"module_name": module_name,
"module_path": module_path,
"error": str(error),
"traceback": tb,
"phase": phase,
}
pyproject = _read_pyproject_metadata(module_path)
if pyproject:
entry["pyproject"] = pyproject
NODE_STARTUP_ERRORS[f"{source}:{module_name}"] = entry
def get_module_name(module_path: str) -> str:
"""
@ -2328,30 +2261,14 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
return True
except Exception as e:
tb = traceback.format_exc()
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
record_node_startup_error(
module_path=module_path,
source=module_parent,
phase="entrypoint",
error=e,
tb=tb,
)
return False
else:
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or comfy_entrypoint (need one).")
return False
except Exception as e:
tb = traceback.format_exc()
logging.warning(tb)
logging.warning(traceback.format_exc())
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
record_node_startup_error(
module_path=module_path,
source=module_parent,
phase="import",
error=e,
tb=tb,
)
return False
async def init_external_custom_nodes():

View File

@ -44,6 +44,7 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from app.subgraph_manager import SubgraphManager
from app.node_replace_manager import NodeReplaceManager
from app.prompt_metadata import PromptMetadataStore
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes
@ -250,8 +251,14 @@ class PromptServer():
routes = web.RouteTableDef()
self.routes = routes
self.last_node_id = None
self.last_prompt_id = None
self.client_id = None
# Bounded prompt_id -> envelope store. Populated at submission,
# drained on completion/cancel. Keeps workflow scope (and other
# client-supplied tags) out of the execution layer.
self._prompt_metadata = PromptMetadataStore()
self.on_prompt_handlers = []
@routes.get('/ws')
@ -275,7 +282,12 @@ class PromptServer():
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
# On reconnect if we are the currently executing client send the current node
if self.client_id == sid and self.last_node_id is not None:
await self.send("executing", { "node": self.last_node_id }, sid)
payload: dict = {"node": self.last_node_id}
if self.last_prompt_id is not None:
payload["prompt_id"] = self.last_prompt_id
await self.send(
"executing", self._inject_prompt_metadata(payload), sid
)
# Flag to track if we've received the first message
first_message = True
@ -765,26 +777,6 @@ class PromptServer():
out[node_class] = node_info(node_class)
return web.json_response(out)
@routes.get("/node_startup_errors")
async def get_node_startup_errors(request):
# Group errors by source so the frontend/Manager can render them
# in distinct sections. `source` is the same string as the
# module_parent used at load time (e.g. "custom_nodes",
# "comfy_extras", "comfy_api_nodes") and is left as a free-form
# string so the contract survives node-source layouts evolving.
# The response only contains source buckets that actually had a
# failure; consumers should not assume any particular set of keys
# is always present.
#
# `module_path` is stripped because the absolute on-disk path is
# internal detail that the frontend has no use for.
grouped: dict[str, dict[str, dict]] = {}
for entry in nodes.NODE_STARTUP_ERRORS.values():
source = entry.get("source", "custom_nodes")
public_entry = {k: v for k, v in entry.items() if k != "module_path"}
grouped.setdefault(source, {})[entry["module_name"]] = public_entry
return web.json_response(grouped)
@routes.get("/api/jobs")
async def get_jobs(request):
"""List all jobs with filtering, sorting, and pagination.
@ -975,6 +967,7 @@ class PromptServer():
if sensitive_val in extra_data:
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
self.register_prompt_metadata(prompt_id, extra_data)
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response)
@ -1236,7 +1229,22 @@ class PromptServer():
elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_json, message)
def register_prompt_metadata(self, prompt_id: str, extra_data) -> None:
"""Record per-prompt metadata for injection into outbound execution
events. Called at submission, before the prompt is queued."""
self._prompt_metadata.register(prompt_id, extra_data)
def unregister_prompt_metadata(self, prompt_id: str) -> None:
"""Drop the per-prompt metadata envelope. Called after the prompt
has finished executing and its terminal events have been queued,
or when the prompt is cancelled before reaching the worker."""
self._prompt_metadata.unregister(prompt_id)
def _inject_prompt_metadata(self, data):
return self._prompt_metadata.inject(data)
def send_sync(self, event, data, sid=None):
data = self._inject_prompt_metadata(data)
self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid))

View File

@ -0,0 +1,412 @@
"""Unit tests for the metadata-envelope module in ``app.prompt_metadata``.
Covers the two pure helpers (``extract_envelope_from_extra_data`` and
``inject_envelope``) and the ``PromptMetadataStore`` integration class
that ``PromptServer`` owns.
"""
from __future__ import annotations
import pytest
from app.prompt_metadata import (
MAX_ENVELOPE_KEYS,
MAX_ENVELOPE_KEY_LEN,
MAX_ENVELOPE_VALUE_LEN,
PromptMetadataStore,
extract_envelope_from_extra_data,
inject_envelope,
)
class TestExtractEnvelopeFromExtraData:
def test_explicit_metadata_dict_is_used_as_is(self):
extra_data = {"metadata": {"workflow_id": "wf-1", "trace_id": "t-9"}}
assert extract_envelope_from_extra_data(extra_data) == {
"workflow_id": "wf-1",
"trace_id": "t-9",
}
def test_explicit_metadata_takes_precedence_over_extra_pnginfo(self):
extra_data = {
"metadata": {"workflow_id": "explicit"},
"extra_pnginfo": {"workflow": {"id": "fallback"}},
}
assert extract_envelope_from_extra_data(extra_data) == {
"workflow_id": "explicit"
}
def test_falls_back_to_extra_pnginfo_workflow_id(self):
extra_data = {"extra_pnginfo": {"workflow": {"id": "wf-legacy"}}}
assert extract_envelope_from_extra_data(extra_data) == {
"workflow_id": "wf-legacy"
}
def test_returns_none_when_no_metadata_and_no_workflow_id(self):
assert extract_envelope_from_extra_data({}) is None
assert (
extract_envelope_from_extra_data({"extra_pnginfo": {"workflow": {}}})
is None
)
@pytest.mark.parametrize("bad", ["", 123, None, [], {}])
def test_rejects_non_string_or_empty_workflow_id(self, bad):
extra_data = {"extra_pnginfo": {"workflow": {"id": bad}}}
assert extract_envelope_from_extra_data(extra_data) is None
def test_rejects_non_dict_inputs_at_each_level(self):
assert extract_envelope_from_extra_data(None) is None
assert extract_envelope_from_extra_data("not-a-dict") is None
assert (
extract_envelope_from_extra_data({"extra_pnginfo": "not-a-dict"})
is None
)
assert (
extract_envelope_from_extra_data(
{"extra_pnginfo": {"workflow": "not-a-dict"}}
)
is None
)
def test_empty_explicit_metadata_falls_through_to_workflow_id(self):
extra_data = {
"metadata": {},
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
}
assert extract_envelope_from_extra_data(extra_data) == {
"workflow_id": "wf-legacy"
}
def test_returned_envelope_is_copy_not_reference(self):
original = {"workflow_id": "wf-1"}
result = extract_envelope_from_extra_data({"metadata": original})
assert result is not None
result["new_key"] = "x"
assert "new_key" not in original
def test_non_dict_explicit_metadata_falls_through_to_workflow_id(self):
extra_data = {
"metadata": "not-a-dict",
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
}
assert extract_envelope_from_extra_data(extra_data) == {
"workflow_id": "wf-legacy"
}
class TestEnvelopeSanitization:
"""The wire contract is ``dict[str, str]`` with bounded size. A bad
envelope is dropped (and a warning is logged) rather than truncated,
so the boundary stays strict."""
def test_rejects_too_many_keys(self, caplog):
envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}
with caplog.at_level("WARNING"):
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
assert any("exceeds limit" in r.message for r in caplog.records)
def test_accepts_max_keys_exactly(self):
envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS)}
assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope
def test_rejects_non_string_keys(self, caplog):
with caplog.at_level("WARNING"):
assert (
extract_envelope_from_extra_data({"metadata": {42: "v"}})
is None
)
assert any("non-string" in r.message for r in caplog.records)
def test_rejects_non_string_values(self, caplog):
for bad_value in [42, None, ["x"], {"nested": "dict"}, b"bytes"]:
with caplog.at_level("WARNING"):
assert (
extract_envelope_from_extra_data(
{"metadata": {"k": bad_value}}
)
is None
)
def test_rejects_oversized_key(self):
envelope = {"x" * (MAX_ENVELOPE_KEY_LEN + 1): "v"}
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
def test_rejects_oversized_value(self):
envelope = {"k": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)}
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
def test_accepts_max_lengths_exactly(self):
envelope = {
"x" * MAX_ENVELOPE_KEY_LEN: "y" * MAX_ENVELOPE_VALUE_LEN
}
assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope
def test_oversized_workflow_id_in_pnginfo_rejected(self):
"""The legacy synthesized path also respects the value bound."""
extra_data = {
"extra_pnginfo": {
"workflow": {"id": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)}
}
}
assert extract_envelope_from_extra_data(extra_data) is None
def test_invalid_explicit_metadata_does_not_fall_through(self):
"""An explicit but invalid metadata dict means the caller asked
for something specific and got it wrong; the synthesized
fallback must not silently substitute."""
extra_data = {
"metadata": {"k": 42}, # non-string value
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
}
assert extract_envelope_from_extra_data(extra_data) is None
class TestInjectEnvelope:
@staticmethod
def _lookup(table):
return table.get
def test_spreads_envelope_keys_onto_payload(self):
"""Envelope keys are merged at the top level so consumers can
read them directly (e.g. ``event.workflow_id``)."""
lookup = self._lookup({"p1": {"workflow_id": "wf-1", "trace_id": "t-9"}})
assert inject_envelope({"node": "5", "prompt_id": "p1"}, lookup) == {
"node": "5",
"prompt_id": "p1",
"workflow_id": "wf-1",
"trace_id": "t-9",
}
def test_passthrough_when_prompt_id_not_registered(self):
lookup = self._lookup({})
data = {"node": "5", "prompt_id": "unknown"}
assert inject_envelope(data, lookup) == data
def test_passthrough_when_payload_lacks_prompt_id(self):
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
data = {"status": "ok"}
assert inject_envelope(data, lookup) == data
def test_server_keys_win_on_collision_with_envelope(self):
"""A misbehaving client cannot shadow server-emitted fields by
stamping the same key in their submission envelope."""
lookup = self._lookup({
"p1": {"prompt_id": "client-claimed", "node": "spoofed", "workflow_id": "wf-1"}
})
result = inject_envelope({"prompt_id": "p1", "node": "5"}, lookup)
assert result["prompt_id"] == "p1"
assert result["node"] == "5"
assert result["workflow_id"] == "wf-1"
def test_does_not_mutate_input_dict(self):
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
original = {"node": "5", "prompt_id": "p1"}
inject_envelope(original, lookup)
assert "workflow_id" not in original
def test_does_not_mutate_envelope_dict(self):
envelope = {"workflow_id": "wf-1"}
lookup = self._lookup({"p1": envelope})
inject_envelope({"prompt_id": "p1", "node": "5"}, lookup)
assert envelope == {"workflow_id": "wf-1"}
def test_injects_into_inner_dict_of_preview_metadata_tuple(self):
"""``PREVIEW_IMAGE_WITH_METADATA`` payloads arrive as
``(preview_image, metadata_dict)``; the inner dict is the only
place the envelope can attach."""
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
preview_image = ("PNG", object(), 256)
inner = {"node_id": "5", "prompt_id": "p1"}
result = inject_envelope((preview_image, inner), lookup)
assert isinstance(result, tuple)
assert result[0] is preview_image
assert result[1] == {
"node_id": "5",
"prompt_id": "p1",
"workflow_id": "wf-1",
}
assert "workflow_id" not in inner
def test_preview_tuple_passthrough_when_no_envelope_registered(self):
lookup = self._lookup({})
preview_image = ("PNG", object(), 256)
inner = {"node_id": "5", "prompt_id": "unknown"}
result = inject_envelope((preview_image, inner), lookup)
assert result == (preview_image, inner)
@pytest.mark.parametrize("payload", [b"raw-bytes", None, 42])
def test_non_dict_non_tuple_payloads_passthrough(self, payload):
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
assert inject_envelope(payload, lookup) == payload
def test_tuple_of_wrong_arity_passthrough(self):
"""Only the 2-tuple ``(preview, metadata_dict)`` shape is
special-cased. Other tuples must not be touched."""
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
triple = (1, {"prompt_id": "p1"}, 3)
assert inject_envelope(triple, lookup) is triple
def test_envelope_lookup_called_per_invocation(self):
"""The lookup runs each time the function is called, so changes
to the backing store are immediately visible."""
store = {"p1": {"workflow_id": "wf-1"}}
first = inject_envelope({"prompt_id": "p1"}, store.get)
store["p1"] = {"workflow_id": "wf-2"}
second = inject_envelope({"prompt_id": "p1"}, store.get)
del store["p1"]
third = inject_envelope({"prompt_id": "p1"}, store.get)
assert first["workflow_id"] == "wf-1"
assert second["workflow_id"] == "wf-2"
assert "workflow_id" not in third
class TestPromptMetadataStore:
"""End-to-end wiring tests that exercise the full register/inject/
unregister cycle the way ``PromptServer`` does."""
def test_register_inject_unregister_cycle(self):
store = PromptMetadataStore()
store.register(
"p1", {"extra_pnginfo": {"workflow": {"id": "wf-1"}}}
)
injected = store.inject({"node": "5", "prompt_id": "p1"})
assert injected == {
"node": "5",
"prompt_id": "p1",
"workflow_id": "wf-1",
}
store.unregister("p1")
passthrough = store.inject({"node": "5", "prompt_id": "p1"})
assert "workflow_id" not in passthrough
def test_register_with_no_derivable_envelope_is_noop(self):
store = PromptMetadataStore()
store.register("p1", {})
assert "p1" not in store
data = {"prompt_id": "p1"}
assert store.inject(data) == data
def test_register_with_oversized_envelope_is_noop(self):
"""Sanitization rejection means nothing is registered — the
store stays empty and inject is a passthrough."""
store = PromptMetadataStore()
store.register(
"p1",
{"metadata": {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}},
)
assert "p1" not in store
def test_unregister_unknown_prompt_is_silent(self):
store = PromptMetadataStore()
store.unregister("does-not-exist")
def test_fifo_eviction_when_capacity_exceeded(self):
"""If cleanup hooks are ever bypassed, the store must shed the
oldest entry rather than grow without bound."""
store = PromptMetadataStore(capacity=3)
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
store.register("p2", {"metadata": {"workflow_id": "wf-2"}})
store.register("p3", {"metadata": {"workflow_id": "wf-3"}})
assert len(store) == 3
store.register("p4", {"metadata": {"workflow_id": "wf-4"}})
assert len(store) == 3
assert "p1" not in store
assert "p4" in store
# The newer entries are still injectable.
assert store.inject({"prompt_id": "p4"})["workflow_id"] == "wf-4"
# The evicted one is gone.
assert "workflow_id" not in store.inject({"prompt_id": "p1"})
def test_register_after_unregister_does_not_count_against_capacity(self):
"""Normal lifecycle: register, unregister, register many — the
store should not silently evict valid entries because of stale
accounting."""
store = PromptMetadataStore(capacity=2)
for i in range(10):
store.register(f"p{i}", {"metadata": {"workflow_id": f"wf-{i}"}})
store.unregister(f"p{i}")
assert len(store) == 0
def test_re_register_overwrites(self):
store = PromptMetadataStore()
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
store.register("p1", {"metadata": {"workflow_id": "wf-2"}})
assert store.inject({"prompt_id": "p1"})["workflow_id"] == "wf-2"
def test_inject_with_no_registrations_is_passthrough(self):
store = PromptMetadataStore()
data = {"prompt_id": "p1", "node": "5"}
assert store.inject(data) == data
def test_inject_into_preview_tuple(self):
store = PromptMetadataStore()
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
result = store.inject((b"image-bytes", {"prompt_id": "p1"}))
assert result == (b"image-bytes", {
"prompt_id": "p1",
"workflow_id": "wf-1",
})
def test_concurrent_access_does_not_corrupt_or_raise(self):
"""Smoke test for the store's lock. ``register`` is called from
the aiohttp event-loop thread, ``unregister`` from the worker
thread, and ``inject`` fires on every ``send_sync`` from
whichever thread emits the event. Run all three concurrently
and assert no exception escapes and the store stays internally
consistent (the FIFO cap is never exceeded)."""
import threading
store = PromptMetadataStore(capacity=64)
stop = threading.Event()
errors: list[BaseException] = []
def registrar():
i = 0
try:
while not stop.is_set():
store.register(
f"p{i % 100}",
{"metadata": {"workflow_id": f"wf-{i}"}},
)
i += 1
except BaseException as e:
errors.append(e)
def canceller():
i = 0
try:
while not stop.is_set():
store.unregister(f"p{i % 100}")
i += 1
except BaseException as e:
errors.append(e)
def injector():
i = 0
try:
while not stop.is_set():
store.inject({"prompt_id": f"p{i % 100}", "node": "5"})
i += 1
except BaseException as e:
errors.append(e)
threads = [
threading.Thread(target=registrar),
threading.Thread(target=registrar),
threading.Thread(target=canceller),
threading.Thread(target=injector),
threading.Thread(target=injector),
]
for t in threads:
t.start()
# Brief burst — long enough to interleave many ops, short enough
# not to slow CI.
threading.Event().wait(0.1)
stop.set()
for t in threads:
t.join(timeout=2.0)
assert errors == [], f"concurrent access raised: {errors[:3]}"
assert len(store) <= 64, "FIFO cap was breached under contention"

View File

@ -1,146 +0,0 @@
"""Tests for the custom node startup error tracking introduced for
Comfy-Org/ComfyUI-Launcher#303.
Covers:
- load_custom_node populates NODE_STARTUP_ERRORS with the correct source
for each module_parent (custom_nodes / comfy_extras / comfy_api_nodes).
- Composite keying prevents collisions between modules with the same name
in different sources.
- record_node_startup_error stores the expected fields.
- pyproject.toml metadata is attached when present and omitted when absent.
"""
import textwrap
import pytest
import nodes
@pytest.fixture(autouse=True)
def _clear_startup_errors():
nodes.NODE_STARTUP_ERRORS.clear()
yield
nodes.NODE_STARTUP_ERRORS.clear()
def _write_broken_module(tmp_path, name: str) -> str:
path = tmp_path / f"{name}.py"
path.write_text(textwrap.dedent("""\
# Deliberately broken module to exercise startup-error tracking.
raise RuntimeError("boom from " + __name__)
"""))
return str(path)
def test_record_node_startup_error_fields(tmp_path):
err = ValueError("kaboom")
nodes.record_node_startup_error(
module_path=str(tmp_path / "my_pack"),
source="custom_nodes",
phase="import",
error=err,
tb="traceback-text",
)
assert "custom_nodes:my_pack" in nodes.NODE_STARTUP_ERRORS
entry = nodes.NODE_STARTUP_ERRORS["custom_nodes:my_pack"]
assert entry["source"] == "custom_nodes"
assert entry["module_name"] == "my_pack"
assert entry["phase"] == "import"
assert entry["error"] == "kaboom"
assert entry["traceback"] == "traceback-text"
assert entry["module_path"].endswith("my_pack")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"module_parent",
["custom_nodes", "comfy_extras", "comfy_api_nodes"],
)
async def test_load_custom_node_records_source(tmp_path, module_parent):
# `source` in the entry should be the same string as `module_parent`.
module_path = _write_broken_module(tmp_path, "broken_pack")
success = await nodes.load_custom_node(module_path, module_parent=module_parent)
assert success is False
key = f"{module_parent}:broken_pack"
assert key in nodes.NODE_STARTUP_ERRORS, nodes.NODE_STARTUP_ERRORS
entry = nodes.NODE_STARTUP_ERRORS[key]
assert entry["source"] == module_parent
assert entry["module_name"] == "broken_pack"
assert entry["phase"] == "import"
assert "boom from" in entry["error"]
assert "RuntimeError" in entry["traceback"]
@pytest.mark.asyncio
async def test_load_custom_node_collision_across_sources(tmp_path):
# Same module name registered as both a custom node and a comfy_extra;
# composite keying should keep both entries.
cn_dir = tmp_path / "cn"
extras_dir = tmp_path / "extras"
cn_dir.mkdir()
extras_dir.mkdir()
cn_path = _write_broken_module(cn_dir, "nodes_audio")
extras_path = _write_broken_module(extras_dir, "nodes_audio")
assert await nodes.load_custom_node(cn_path, module_parent="custom_nodes") is False
assert await nodes.load_custom_node(extras_path, module_parent="comfy_extras") is False
assert "custom_nodes:nodes_audio" in nodes.NODE_STARTUP_ERRORS
assert "comfy_extras:nodes_audio" in nodes.NODE_STARTUP_ERRORS
assert (
nodes.NODE_STARTUP_ERRORS["custom_nodes:nodes_audio"]["module_path"]
!= nodes.NODE_STARTUP_ERRORS["comfy_extras:nodes_audio"]["module_path"]
)
@pytest.mark.asyncio
async def test_load_custom_node_attaches_pyproject_metadata(tmp_path):
pack_dir = tmp_path / "MyCoolPack"
pack_dir.mkdir()
(pack_dir / "__init__.py").write_text("raise RuntimeError('boom')\n")
(pack_dir / "pyproject.toml").write_text(textwrap.dedent("""\
[project]
name = "comfyui-mycoolpack"
version = "1.2.3"
[project.urls]
Repository = "https://github.com/example/comfyui-mycoolpack"
[tool.comfy]
PublisherId = "example"
DisplayName = "My Cool Pack"
"""))
success = await nodes.load_custom_node(str(pack_dir), module_parent="custom_nodes")
assert success is False
entry = nodes.NODE_STARTUP_ERRORS["custom_nodes:MyCoolPack"]
assert "pyproject" in entry, entry
py = entry["pyproject"]
assert py["pack_id"] == "comfyui-mycoolpack"
assert py["display_name"] == "My Cool Pack"
assert py["publisher_id"] == "example"
assert py["version"] == "1.2.3"
assert py["repository"] == "https://github.com/example/comfyui-mycoolpack"
@pytest.mark.asyncio
async def test_load_custom_node_no_pyproject_skips_metadata(tmp_path):
# Single-file extras-style module: no pyproject.toml exists alongside it,
# so the entry must not contain a 'pyproject' key.
module_path = _write_broken_module(tmp_path, "lonely")
assert await nodes.load_custom_node(module_path, module_parent="comfy_extras") is False
entry = nodes.NODE_STARTUP_ERRORS["comfy_extras:lonely"]
assert "pyproject" not in entry
@pytest.mark.asyncio
async def test_load_custom_node_arbitrary_module_parent_passes_through(tmp_path):
# `source` is a free-form string — an unknown module_parent (e.g. a future
# node-source bucket) should be recorded as-is, not coerced or rejected.
module_path = _write_broken_module(tmp_path, "future_pack")
assert await nodes.load_custom_node(module_path, module_parent="future_source") is False
entry = nodes.NODE_STARTUP_ERRORS["future_source:future_pack"]
assert entry["source"] == "future_source"