mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-20 08:26:42 +08:00
Compare commits
4 Commits
feature/cu
...
deepme987/
| Author | SHA1 | Date | |
|---|---|---|---|
| 63784baed5 | |||
| fc9820ebb9 | |||
| fd89498eac | |||
| 74cfcaa318 |
226
app/prompt_metadata.py
Normal file
226
app/prompt_metadata.py
Normal file
@ -0,0 +1,226 @@
|
||||
"""Per-prompt metadata envelope shared between submission and outbound events.
|
||||
|
||||
The metadata envelope is a small flat ``dict[str, str]`` (e.g.
|
||||
``{"workflow_id": ...}``) attached to a prompt at submission and injected
|
||||
by the server into every outbound execution event that carries a
|
||||
``prompt_id``. It lets consumers scope state by tags they care about
|
||||
(workflow, trace, tenant) without the execution layer ever needing to
|
||||
know those tags exist.
|
||||
|
||||
This module is intentionally pure — no imports from ``server`` or
|
||||
``execution`` — so ``PromptServer`` can own a ``PromptMetadataStore``
|
||||
instance and the helpers can be unit-tested without the rest of the app.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
# Bounds. The envelope is forwarded to every WebSocket client connected to
|
||||
# the server on every execution event for the prompt — bounding key count,
|
||||
# key length, value length, and refusing nested structures keeps a
|
||||
# malicious or buggy client from inflating the broadcast volume.
|
||||
MAX_ENVELOPE_KEYS = 16
|
||||
MAX_ENVELOPE_KEY_LEN = 64
|
||||
MAX_ENVELOPE_VALUE_LEN = 256
|
||||
|
||||
# Cap on concurrently registered prompt envelopes. Acts as a backstop if
|
||||
# the cleanup hook is ever bypassed; FIFO eviction so the oldest stale
|
||||
# entry goes first.
|
||||
DEFAULT_STORE_CAPACITY = 4096
|
||||
|
||||
|
||||
def _sanitize_envelope(envelope: Any) -> Optional[dict]:
|
||||
"""Validate and copy a candidate envelope.
|
||||
|
||||
Enforces the ``dict[str, str]`` contract that downstream consumers
|
||||
(cloud projections, frontend zod schemas, OpenAPI docs) rely on:
|
||||
|
||||
- must be a non-empty ``dict``
|
||||
- at most ``MAX_ENVELOPE_KEYS`` entries
|
||||
- every key and value must be a ``str``
|
||||
- keys at most ``MAX_ENVELOPE_KEY_LEN`` chars
|
||||
- values at most ``MAX_ENVELOPE_VALUE_LEN`` chars
|
||||
|
||||
Returns a defensive shallow copy on success, ``None`` on any
|
||||
violation. Logs a warning on violation so abuse is visible.
|
||||
"""
|
||||
if not isinstance(envelope, dict) or not envelope:
|
||||
return None
|
||||
if len(envelope) > MAX_ENVELOPE_KEYS:
|
||||
logging.warning(
|
||||
"prompt metadata envelope rejected: %d keys exceeds limit %d",
|
||||
len(envelope), MAX_ENVELOPE_KEYS,
|
||||
)
|
||||
return None
|
||||
sanitized: dict[str, str] = {}
|
||||
for key, value in envelope.items():
|
||||
if not isinstance(key, str) or not isinstance(value, str):
|
||||
logging.warning(
|
||||
"prompt metadata envelope rejected: non-string key/value (%s=%s)",
|
||||
type(key).__name__, type(value).__name__,
|
||||
)
|
||||
return None
|
||||
if len(key) > MAX_ENVELOPE_KEY_LEN or len(value) > MAX_ENVELOPE_VALUE_LEN:
|
||||
logging.warning(
|
||||
"prompt metadata envelope rejected: key or value exceeds length limit",
|
||||
)
|
||||
return None
|
||||
sanitized[key] = value
|
||||
return sanitized
|
||||
|
||||
|
||||
def extract_envelope_from_extra_data(extra_data: Any) -> Optional[dict]:
|
||||
"""Pull the per-prompt metadata envelope out of a submitted prompt's
|
||||
``extra_data``.
|
||||
|
||||
Two sources, in order:
|
||||
|
||||
1. Explicit ``extra_data["metadata"]`` — sanitized via
|
||||
``_sanitize_envelope``. Oversized or wrong-typed envelopes are
|
||||
rejected (a warning is logged) rather than truncated, so the
|
||||
contract stays strict at the boundary.
|
||||
2. ``extra_data["extra_pnginfo"]["workflow"]["id"]`` — backward-
|
||||
compatibility fallback. Frontends that already stamp the workflow
|
||||
id into ``extra_pnginfo`` keep working; the synthesized envelope
|
||||
is ``{"workflow_id": <id>}``. A debug log fires so the legacy path
|
||||
remains observable.
|
||||
|
||||
Returns ``None`` when neither source yields a usable envelope.
|
||||
"""
|
||||
if not isinstance(extra_data, dict):
|
||||
return None
|
||||
|
||||
if "metadata" in extra_data:
|
||||
sanitized = _sanitize_envelope(extra_data["metadata"])
|
||||
if sanitized is not None:
|
||||
return sanitized
|
||||
# Explicit metadata was supplied but rejected — do not fall
|
||||
# through to the legacy path; the caller asked for something
|
||||
# specific and got it wrong.
|
||||
if isinstance(extra_data["metadata"], dict) and extra_data["metadata"]:
|
||||
return None
|
||||
|
||||
extra_pnginfo = extra_data.get("extra_pnginfo")
|
||||
if isinstance(extra_pnginfo, dict):
|
||||
workflow = extra_pnginfo.get("workflow")
|
||||
if isinstance(workflow, dict):
|
||||
workflow_id = workflow.get("id")
|
||||
if (
|
||||
isinstance(workflow_id, str)
|
||||
and workflow_id
|
||||
and len(workflow_id) <= MAX_ENVELOPE_VALUE_LEN
|
||||
):
|
||||
logging.debug(
|
||||
"prompt metadata envelope synthesized from extra_pnginfo.workflow.id"
|
||||
)
|
||||
return {"workflow_id": workflow_id}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def inject_envelope(
|
||||
data: Any,
|
||||
envelope_lookup: Callable[[str], Optional[dict]],
|
||||
) -> Any:
|
||||
"""Return ``data`` with the per-prompt envelope's keys spread onto it.
|
||||
|
||||
``envelope_lookup`` is called with the payload's ``prompt_id`` and is
|
||||
expected to return the registered envelope or ``None``. This keeps
|
||||
the function pure and avoids depending on any specific storage.
|
||||
|
||||
The envelope's keys are merged onto the payload at the top level so
|
||||
consumers can read them directly (e.g. ``event.workflow_id``) —
|
||||
matching the wire shape of the prior workflow-id-on-events work and
|
||||
avoiding an extra nesting hop for clients. Server-emitted fields on
|
||||
the payload always win on collision (``{**envelope, **d}``); a
|
||||
misbehaving client cannot shadow ``prompt_id``, ``node``, etc.
|
||||
|
||||
Two payload shapes are handled:
|
||||
|
||||
- **dict** carrying ``prompt_id``. A shallow copy is returned with
|
||||
the envelope's keys merged onto it.
|
||||
- **(preview_image, metadata_dict) tuple** — the format used by
|
||||
``PREVIEW_IMAGE_WITH_METADATA``. Only the inner dict is augmented;
|
||||
the binary preview is passed through by reference.
|
||||
|
||||
No-op for payloads without a ``prompt_id``, prompts with no
|
||||
registered envelope, or any other payload shape.
|
||||
"""
|
||||
def inject(d: dict) -> dict:
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
prompt_id = d.get("prompt_id")
|
||||
if not prompt_id:
|
||||
return d
|
||||
envelope = envelope_lookup(prompt_id)
|
||||
if envelope is None:
|
||||
return d
|
||||
return {**envelope, **d}
|
||||
|
||||
if isinstance(data, dict):
|
||||
return inject(data)
|
||||
if isinstance(data, tuple) and len(data) == 2 and isinstance(data[1], dict):
|
||||
injected = inject(data[1])
|
||||
if injected is data[1]:
|
||||
return data
|
||||
return (data[0], injected)
|
||||
return data
|
||||
|
||||
|
||||
class PromptMetadataStore:
|
||||
"""Bounded ``prompt_id -> envelope`` map.
|
||||
|
||||
Owned by ``PromptServer``. Populated at submission, drained when the
|
||||
prompt finishes, wiped on queue cancel/delete. The FIFO cap is a
|
||||
backstop: if any cleanup hook is ever skipped, the store sheds the
|
||||
oldest entry instead of growing without bound.
|
||||
|
||||
Access is serialized through a ``threading.Lock``. ``register`` runs
|
||||
on the aiohttp event-loop thread, ``unregister`` runs on the
|
||||
``prompt_worker`` thread, and ``inject`` runs on whichever thread
|
||||
fires ``send_sync`` (event loop, worker, asset seeder). Individual
|
||||
``dict`` ops are GIL-atomic, but ``register``'s
|
||||
``len() -> pop -> __setitem__`` and ``inject``'s ``get -> {**a, **b}``
|
||||
are multi-step compounds whose interleaving without a lock is
|
||||
racy. The lock is uncontended in steady state (sub-microsecond
|
||||
critical sections) so the cost is negligible.
|
||||
"""
|
||||
|
||||
def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY):
|
||||
self._envelopes: dict[str, dict] = {}
|
||||
self._capacity = capacity
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def register(self, prompt_id: str, extra_data: Any) -> None:
|
||||
envelope = extract_envelope_from_extra_data(extra_data)
|
||||
if envelope is None:
|
||||
return
|
||||
with self._lock:
|
||||
if len(self._envelopes) >= self._capacity:
|
||||
self._envelopes.pop(next(iter(self._envelopes)))
|
||||
self._envelopes[prompt_id] = envelope
|
||||
|
||||
def unregister(self, prompt_id: str) -> None:
|
||||
with self._lock:
|
||||
self._envelopes.pop(prompt_id, None)
|
||||
|
||||
def inject(self, data: Any) -> Any:
|
||||
# Snapshot the envelope under the lock so the spread in
|
||||
# ``inject_envelope`` runs against a consistent view even if a
|
||||
# concurrent ``register``/``unregister`` is mutating the map.
|
||||
def locked_lookup(prompt_id: str) -> Optional[dict]:
|
||||
with self._lock:
|
||||
return self._envelopes.get(prompt_id)
|
||||
return inject_envelope(data, locked_lookup)
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._envelopes)
|
||||
|
||||
def __contains__(self, prompt_id: str) -> bool:
|
||||
with self._lock:
|
||||
return prompt_id in self._envelopes
|
||||
@ -22,25 +22,26 @@ class CompressedTimestep:
|
||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int, per_frame: bool = False):
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
"""
|
||||
tensor: [batch, num_tokens, feature_dim] (per-token, default) or
|
||||
[batch, num_frames, feature_dim] (per_frame=True, already compressed).
|
||||
patches_per_frame: spatial patches per frame; pass None to disable compression.
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
|
||||
"""
|
||||
self.batch_size, n, self.feature_dim = tensor.shape
|
||||
if per_frame:
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = n
|
||||
self.data = tensor
|
||||
elif patches_per_frame is not None and n >= patches_per_frame and n % patches_per_frame == 0:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = n // patches_per_frame
|
||||
# All patches in a frame are identical — keep only the first.
|
||||
self.data = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)[:, :, 0, :].contiguous()
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
|
||||
# All patches in a frame are identical, so we only keep the first one
|
||||
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
|
||||
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
|
||||
else:
|
||||
# Not divisible or too small - store directly without compression
|
||||
self.patches_per_frame = 1
|
||||
self.num_frames = n
|
||||
self.num_frames = num_tokens
|
||||
self.data = tensor
|
||||
|
||||
def expand(self):
|
||||
@ -715,35 +716,32 @@ class LTXAVModel(LTXVModel):
|
||||
|
||||
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||
"""Prepare timestep embeddings."""
|
||||
# TODO: some code reuse is needed here.
|
||||
grid_mask = kwargs.get("grid_mask", None)
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||
v_patches_per_frame = None
|
||||
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
# Used by compute_prompt_timestep and the audio cross-attention paths.
|
||||
timestep_scaled = (timestep[:, grid_mask] if grid_mask is not None else timestep) * self.timestep_scale_multiplier
|
||||
|
||||
# When patches in a frame share a timestep (no spatial mask), project one row per frame instead of one per token
|
||||
per_frame_path = v_patches_per_frame is not None and (timestep.numel() // batch_size) % v_patches_per_frame == 0
|
||||
if per_frame_path:
|
||||
per_frame = timestep.reshape(batch_size, -1, v_patches_per_frame)[:, :, 0]
|
||||
if grid_mask is not None:
|
||||
# All-or-nothing per frame when has_spatial_mask=False.
|
||||
per_frame = per_frame[:, grid_mask[::v_patches_per_frame]]
|
||||
ts_input = per_frame * self.timestep_scale_multiplier
|
||||
else:
|
||||
ts_input = timestep_scaled
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
|
||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||
ts_input.flatten(),
|
||||
timestep_scaled.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
|
||||
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||
v_patches_per_frame = None
|
||||
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
# Reshape to [batch_size, num_tokens, dim] and compress for storage
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||
|
||||
v_prompt_timestep = compute_prompt_timestep(
|
||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||
|
||||
@ -358,61 +358,6 @@ def apply_split_rotary_emb(input_tensor, cos, sin):
|
||||
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
|
||||
|
||||
|
||||
class GuideAttentionMask:
|
||||
"""Holds the two per-group masks for LTXV guide self-attention.
|
||||
_attention_with_guide_mask splits queries into noisy and tracked-guide
|
||||
groups, so the largest mask is (1, 1, tracked_count, T).
|
||||
"""
|
||||
__slots__ = ("guide_start", "tracked_count", "noisy_mask", "tracked_mask")
|
||||
|
||||
def __init__(self, total_tokens, guide_start, tracked_count, tracked_weights):
|
||||
device = tracked_weights.device
|
||||
dtype = tracked_weights.dtype
|
||||
finfo = torch.finfo(dtype)
|
||||
|
||||
pos = tracked_weights > 0
|
||||
log_w = torch.full_like(tracked_weights, finfo.min)
|
||||
log_w[pos] = torch.log(tracked_weights[pos].clamp(min=finfo.tiny))
|
||||
|
||||
self.guide_start = guide_start
|
||||
self.tracked_count = tracked_count
|
||||
|
||||
self.noisy_mask = torch.zeros((1, 1, 1, total_tokens), device=device, dtype=dtype)
|
||||
self.noisy_mask[:, :, :, guide_start:guide_start + tracked_count] = log_w.view(1, 1, 1, -1)
|
||||
|
||||
self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype)
|
||||
self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||
|
||||
|
||||
def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options):
|
||||
"""Apply the guide mask by partitioning Q into noisy and tracked-guide
|
||||
groups, so each group needs only its own sub-mask. Avoids materializing
|
||||
the (1,1,T,T) dense mask.
|
||||
"""
|
||||
guide_start = guide_mask.guide_start
|
||||
tracked_end = guide_start + guide_mask.tracked_count
|
||||
|
||||
out = torch.empty_like(q)
|
||||
|
||||
if guide_start > 0: # In practice currently guides are always after noise, guard for safety if this changes.
|
||||
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
low_precision_attention=False, # sageattn mask support is unreliable
|
||||
)
|
||||
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
if tracked_end < q.shape[1]: # Every guide token is tracked, and nothing comes after them, guard for safety if this changes.
|
||||
out[:, tracked_end:, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, tracked_end:, :], k, v, heads,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -467,10 +412,8 @@ class CrossAttention(nn.Module):
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
elif isinstance(mask, GuideAttentionMask):
|
||||
out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
|
||||
# Apply per-head gating if enabled
|
||||
if self.to_gate_logits is not None:
|
||||
@ -1120,9 +1063,7 @@ class LTXVModel(LTXBaseModel):
|
||||
additional_args["resolved_guide_entries"] = resolved_entries
|
||||
|
||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||
|
||||
if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||
|
||||
# Total surviving guide tokens (all guides)
|
||||
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
|
||||
@ -1158,12 +1099,12 @@ class LTXVModel(LTXBaseModel):
|
||||
if not resolved_entries:
|
||||
return None
|
||||
|
||||
# strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention.
|
||||
needs_mask = any(
|
||||
e["strength"] != 1.0 or e.get("pixel_mask") is not None
|
||||
# Check if any attenuation is actually needed
|
||||
needs_attenuation = any(
|
||||
e["strength"] < 1.0 or e.get("pixel_mask") is not None
|
||||
for e in resolved_entries
|
||||
)
|
||||
if not needs_mask:
|
||||
if not needs_attenuation:
|
||||
return None
|
||||
|
||||
# Build per-guide-token weights for all tracked guide tokens.
|
||||
@ -1218,11 +1159,16 @@ class LTXVModel(LTXBaseModel):
|
||||
# Concatenate per-token weights for all tracked guides
|
||||
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
|
||||
|
||||
# Skip when every weight is exactly 1.0 (additive bias would be 0).
|
||||
if (tracked_weights == 1.0).all():
|
||||
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
|
||||
if (tracked_weights >= 1.0).all():
|
||||
return None
|
||||
|
||||
return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights)
|
||||
# Build the mask: guide tokens are at the end of the sequence.
|
||||
# Tracked guides come first (in order), untracked follow.
|
||||
return self._build_self_attention_mask(
|
||||
total_tokens, num_guide_tokens, total_tracked,
|
||||
tracked_weights, guide_start, device, dtype,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
|
||||
@ -1288,6 +1234,45 @@ class LTXVModel(LTXBaseModel):
|
||||
|
||||
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
|
||||
|
||||
@staticmethod
|
||||
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
|
||||
tracked_weights, guide_start, device, dtype):
|
||||
"""Build a log-space additive self-attention bias mask.
|
||||
|
||||
Attenuates attention between noisy tokens and tracked guide tokens.
|
||||
Untracked guide tokens (at the end of the guide portion) keep full attention.
|
||||
|
||||
Args:
|
||||
total_tokens: Total sequence length.
|
||||
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
|
||||
tracked_count: Number of tracked guide tokens (first in the guide portion).
|
||||
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
|
||||
guide_start: Index where guide tokens begin in the sequence.
|
||||
device: Target device.
|
||||
dtype: Target dtype.
|
||||
|
||||
Returns:
|
||||
(1, 1, total_tokens, total_tokens) additive bias mask.
|
||||
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
|
||||
"""
|
||||
finfo = torch.finfo(dtype)
|
||||
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
|
||||
tracked_end = guide_start + tracked_count
|
||||
|
||||
# Convert weights to log-space bias
|
||||
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
|
||||
log_w = torch.full_like(w, finfo.min)
|
||||
positive_mask = w > 0
|
||||
if positive_mask.any():
|
||||
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
|
||||
|
||||
# noisy → tracked guides: each noisy row gets the same per-guide weight
|
||||
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
|
||||
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
|
||||
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||
|
||||
return mask
|
||||
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
|
||||
"""Process transformer blocks for LTXV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
@ -136,7 +136,7 @@ class ImageFromBatch(IO.ComfyNode):
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("batch_index", default=0, min=-MAX_RESOLUTION, max=MAX_RESOLUTION),
|
||||
IO.Int.Input("batch_index", default=0, min=0, max=4095),
|
||||
IO.Int.Input("length", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
@ -145,9 +145,7 @@ class ImageFromBatch(IO.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, image, batch_index, length) -> IO.NodeOutput:
|
||||
s_in = image
|
||||
if batch_index < 0:
|
||||
batch_index += s_in.shape[0]
|
||||
batch_index = max(0, min(s_in.shape[0] - 1, batch_index))
|
||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||
length = min(s_in.shape[0] - batch_index, length)
|
||||
s = s_in[batch_index:batch_index + length].clone()
|
||||
return IO.NodeOutput(s)
|
||||
|
||||
@ -219,7 +219,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
|
||||
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
|
||||
),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
@ -298,7 +298,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
else:
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
@ -318,7 +318,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, cond_length, 1, 1),
|
||||
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
|
||||
@ -1296,7 +1296,10 @@ class PromptQueue:
|
||||
|
||||
def wipe_queue(self):
|
||||
with self.mutex:
|
||||
dropped_prompt_ids = [item[1] for item in self.queue]
|
||||
self.queue = []
|
||||
for prompt_id in dropped_prompt_ids:
|
||||
self.server.unregister_prompt_metadata(prompt_id)
|
||||
self.server.queue_updated()
|
||||
|
||||
def delete_queue_item(self, function):
|
||||
@ -1306,8 +1309,9 @@ class PromptQueue:
|
||||
if len(self.queue) == 1:
|
||||
self.wipe_queue()
|
||||
else:
|
||||
self.queue.pop(x)
|
||||
deleted = self.queue.pop(x)
|
||||
heapq.heapify(self.queue)
|
||||
self.server.unregister_prompt_metadata(deleted[1])
|
||||
self.server.queue_updated()
|
||||
return True
|
||||
return False
|
||||
|
||||
38
main.py
38
main.py
@ -27,7 +27,6 @@ from utils.mime_types import init_mime_types
|
||||
import faulthandler
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_api import feature_flags
|
||||
@ -149,14 +148,6 @@ def execute_prestartup_script():
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
||||
from nodes import record_node_startup_error
|
||||
record_node_startup_error(
|
||||
module_path=os.path.dirname(script_path),
|
||||
source="custom_nodes",
|
||||
phase="prestartup",
|
||||
error=e,
|
||||
tb=traceback.format_exc(),
|
||||
)
|
||||
return False
|
||||
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
@ -327,19 +318,26 @@ def prompt_worker(q, server_instance):
|
||||
extra_data[k] = sensitive[k]
|
||||
|
||||
asset_seeder.pause()
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
try:
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
|
||||
need_gc = True
|
||||
need_gc = True
|
||||
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
finally:
|
||||
# Always drop the metadata envelope. If e.execute() raises
|
||||
# hard before its own error handling kicks in, the
|
||||
# registered envelope would otherwise leak for the
|
||||
# lifetime of the process.
|
||||
server_instance.unregister_prompt_metadata(prompt_id)
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
|
||||
89
nodes.py
89
nodes.py
@ -1221,7 +1221,7 @@ class LatentFromBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"batch_index": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION}),
|
||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
|
||||
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
@ -1232,9 +1232,7 @@ class LatentFromBatch:
|
||||
def frombatch(self, samples, batch_index, length):
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
if batch_index < 0:
|
||||
batch_index += s_in.shape[0]
|
||||
batch_index = max(0, min(s_in.shape[0] - 1, batch_index))
|
||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||
length = min(s_in.shape[0] - batch_index, length)
|
||||
s["samples"] = s_in[batch_index:batch_index + length].clone()
|
||||
if "noise_mask" in samples:
|
||||
@ -2154,71 +2152,6 @@ EXTENSION_WEB_DIRS = {}
|
||||
# Dictionary of successfully loaded module names and associated directories.
|
||||
LOADED_MODULE_DIRS = {}
|
||||
|
||||
# Dictionary of custom node startup errors, keyed by "<source>:<module_name>"
|
||||
# so that name collisions across custom_nodes / comfy_extras / comfy_api_nodes
|
||||
# do not overwrite each other. Each value contains: source, module_name,
|
||||
# module_path, error, traceback, phase.
|
||||
#
|
||||
# `source` is the same string as the internal `module_parent` used at load
|
||||
# time (e.g. "custom_nodes", "comfy_extras", "comfy_api_nodes"). It is
|
||||
# intentionally a free-form string rather than a fixed enum so the contract
|
||||
# survives node-source layouts evolving (e.g. comfy_api_nodes eventually
|
||||
# moving out of core). Consumers should treat any new value as a new bucket
|
||||
# rather than rejecting it.
|
||||
NODE_STARTUP_ERRORS: dict[str, dict] = {}
|
||||
|
||||
|
||||
def _read_pyproject_metadata(module_path: str) -> dict | None:
|
||||
"""Best-effort extraction of node-pack identity from pyproject.toml.
|
||||
|
||||
Returns a dict with the Comfy Registry-style identity (pack_id,
|
||||
display_name, publisher_id, version, repository) when the module
|
||||
directory contains a pyproject.toml. Returns None when no toml is
|
||||
present or parsing fails for any reason — startup-error tracking
|
||||
must never itself raise.
|
||||
"""
|
||||
if not module_path or not os.path.isdir(module_path):
|
||||
return None
|
||||
toml_path = os.path.join(module_path, "pyproject.toml")
|
||||
if not os.path.isfile(toml_path):
|
||||
return None
|
||||
try:
|
||||
from comfy_config import config_parser
|
||||
|
||||
cfg = config_parser.extract_node_configuration(module_path)
|
||||
if cfg is None:
|
||||
return None
|
||||
meta = {
|
||||
"pack_id": cfg.project.name or None,
|
||||
"display_name": cfg.tool_comfy.display_name or None,
|
||||
"publisher_id": cfg.tool_comfy.publisher_id or None,
|
||||
"version": cfg.project.version or None,
|
||||
"repository": cfg.project.urls.repository or None,
|
||||
}
|
||||
# Drop empty fields so the API payload stays compact.
|
||||
return {k: v for k, v in meta.items() if v}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def record_node_startup_error(
|
||||
*, module_path: str, source: str, phase: str, error: BaseException, tb: str
|
||||
) -> None:
|
||||
"""Record a startup error for a node module so it can be exposed via the API."""
|
||||
module_name = get_module_name(module_path)
|
||||
entry = {
|
||||
"source": source,
|
||||
"module_name": module_name,
|
||||
"module_path": module_path,
|
||||
"error": str(error),
|
||||
"traceback": tb,
|
||||
"phase": phase,
|
||||
}
|
||||
pyproject = _read_pyproject_metadata(module_path)
|
||||
if pyproject:
|
||||
entry["pyproject"] = pyproject
|
||||
NODE_STARTUP_ERRORS[f"{source}:{module_name}"] = entry
|
||||
|
||||
|
||||
def get_module_name(module_path: str) -> str:
|
||||
"""
|
||||
@ -2328,30 +2261,14 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
|
||||
return True
|
||||
except Exception as e:
|
||||
tb = traceback.format_exc()
|
||||
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
|
||||
record_node_startup_error(
|
||||
module_path=module_path,
|
||||
source=module_parent,
|
||||
phase="entrypoint",
|
||||
error=e,
|
||||
tb=tb,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or comfy_entrypoint (need one).")
|
||||
return False
|
||||
except Exception as e:
|
||||
tb = traceback.format_exc()
|
||||
logging.warning(tb)
|
||||
logging.warning(traceback.format_exc())
|
||||
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
|
||||
record_node_startup_error(
|
||||
module_path=module_path,
|
||||
source=module_parent,
|
||||
phase="import",
|
||||
error=e,
|
||||
tb=tb,
|
||||
)
|
||||
return False
|
||||
|
||||
async def init_external_custom_nodes():
|
||||
|
||||
50
server.py
50
server.py
@ -44,6 +44,7 @@ from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from app.subgraph_manager import SubgraphManager
|
||||
from app.node_replace_manager import NodeReplaceManager
|
||||
from app.prompt_metadata import PromptMetadataStore
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from protocol import BinaryEventTypes
|
||||
@ -250,8 +251,14 @@ class PromptServer():
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
self.last_prompt_id = None
|
||||
self.client_id = None
|
||||
|
||||
# Bounded prompt_id -> envelope store. Populated at submission,
|
||||
# drained on completion/cancel. Keeps workflow scope (and other
|
||||
# client-supplied tags) out of the execution layer.
|
||||
self._prompt_metadata = PromptMetadataStore()
|
||||
|
||||
self.on_prompt_handlers = []
|
||||
|
||||
@routes.get('/ws')
|
||||
@ -275,7 +282,12 @@ class PromptServer():
|
||||
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
|
||||
# On reconnect if we are the currently executing client send the current node
|
||||
if self.client_id == sid and self.last_node_id is not None:
|
||||
await self.send("executing", { "node": self.last_node_id }, sid)
|
||||
payload: dict = {"node": self.last_node_id}
|
||||
if self.last_prompt_id is not None:
|
||||
payload["prompt_id"] = self.last_prompt_id
|
||||
await self.send(
|
||||
"executing", self._inject_prompt_metadata(payload), sid
|
||||
)
|
||||
|
||||
# Flag to track if we've received the first message
|
||||
first_message = True
|
||||
@ -765,26 +777,6 @@ class PromptServer():
|
||||
out[node_class] = node_info(node_class)
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/node_startup_errors")
|
||||
async def get_node_startup_errors(request):
|
||||
# Group errors by source so the frontend/Manager can render them
|
||||
# in distinct sections. `source` is the same string as the
|
||||
# module_parent used at load time (e.g. "custom_nodes",
|
||||
# "comfy_extras", "comfy_api_nodes") and is left as a free-form
|
||||
# string so the contract survives node-source layouts evolving.
|
||||
# The response only contains source buckets that actually had a
|
||||
# failure; consumers should not assume any particular set of keys
|
||||
# is always present.
|
||||
#
|
||||
# `module_path` is stripped because the absolute on-disk path is
|
||||
# internal detail that the frontend has no use for.
|
||||
grouped: dict[str, dict[str, dict]] = {}
|
||||
for entry in nodes.NODE_STARTUP_ERRORS.values():
|
||||
source = entry.get("source", "custom_nodes")
|
||||
public_entry = {k: v for k, v in entry.items() if k != "module_path"}
|
||||
grouped.setdefault(source, {})[entry["module_name"]] = public_entry
|
||||
return web.json_response(grouped)
|
||||
|
||||
@routes.get("/api/jobs")
|
||||
async def get_jobs(request):
|
||||
"""List all jobs with filtering, sorting, and pagination.
|
||||
@ -975,6 +967,7 @@ class PromptServer():
|
||||
if sensitive_val in extra_data:
|
||||
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
|
||||
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
|
||||
self.register_prompt_metadata(prompt_id, extra_data)
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
|
||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||
return web.json_response(response)
|
||||
@ -1236,7 +1229,22 @@ class PromptServer():
|
||||
elif sid in self.sockets:
|
||||
await send_socket_catch_exception(self.sockets[sid].send_json, message)
|
||||
|
||||
def register_prompt_metadata(self, prompt_id: str, extra_data) -> None:
|
||||
"""Record per-prompt metadata for injection into outbound execution
|
||||
events. Called at submission, before the prompt is queued."""
|
||||
self._prompt_metadata.register(prompt_id, extra_data)
|
||||
|
||||
def unregister_prompt_metadata(self, prompt_id: str) -> None:
|
||||
"""Drop the per-prompt metadata envelope. Called after the prompt
|
||||
has finished executing and its terminal events have been queued,
|
||||
or when the prompt is cancelled before reaching the worker."""
|
||||
self._prompt_metadata.unregister(prompt_id)
|
||||
|
||||
def _inject_prompt_metadata(self, data):
|
||||
return self._prompt_metadata.inject(data)
|
||||
|
||||
def send_sync(self, event, data, sid=None):
|
||||
data = self._inject_prompt_metadata(data)
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.messages.put_nowait, (event, data, sid))
|
||||
|
||||
|
||||
412
tests-unit/app_test/test_prompt_metadata.py
Normal file
412
tests-unit/app_test/test_prompt_metadata.py
Normal file
@ -0,0 +1,412 @@
|
||||
"""Unit tests for the metadata-envelope module in ``app.prompt_metadata``.
|
||||
|
||||
Covers the two pure helpers (``extract_envelope_from_extra_data`` and
|
||||
``inject_envelope``) and the ``PromptMetadataStore`` integration class
|
||||
that ``PromptServer`` owns.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.prompt_metadata import (
|
||||
MAX_ENVELOPE_KEYS,
|
||||
MAX_ENVELOPE_KEY_LEN,
|
||||
MAX_ENVELOPE_VALUE_LEN,
|
||||
PromptMetadataStore,
|
||||
extract_envelope_from_extra_data,
|
||||
inject_envelope,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractEnvelopeFromExtraData:
|
||||
def test_explicit_metadata_dict_is_used_as_is(self):
|
||||
extra_data = {"metadata": {"workflow_id": "wf-1", "trace_id": "t-9"}}
|
||||
assert extract_envelope_from_extra_data(extra_data) == {
|
||||
"workflow_id": "wf-1",
|
||||
"trace_id": "t-9",
|
||||
}
|
||||
|
||||
def test_explicit_metadata_takes_precedence_over_extra_pnginfo(self):
|
||||
extra_data = {
|
||||
"metadata": {"workflow_id": "explicit"},
|
||||
"extra_pnginfo": {"workflow": {"id": "fallback"}},
|
||||
}
|
||||
assert extract_envelope_from_extra_data(extra_data) == {
|
||||
"workflow_id": "explicit"
|
||||
}
|
||||
|
||||
def test_falls_back_to_extra_pnginfo_workflow_id(self):
|
||||
extra_data = {"extra_pnginfo": {"workflow": {"id": "wf-legacy"}}}
|
||||
assert extract_envelope_from_extra_data(extra_data) == {
|
||||
"workflow_id": "wf-legacy"
|
||||
}
|
||||
|
||||
def test_returns_none_when_no_metadata_and_no_workflow_id(self):
|
||||
assert extract_envelope_from_extra_data({}) is None
|
||||
assert (
|
||||
extract_envelope_from_extra_data({"extra_pnginfo": {"workflow": {}}})
|
||||
is None
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("bad", ["", 123, None, [], {}])
|
||||
def test_rejects_non_string_or_empty_workflow_id(self, bad):
|
||||
extra_data = {"extra_pnginfo": {"workflow": {"id": bad}}}
|
||||
assert extract_envelope_from_extra_data(extra_data) is None
|
||||
|
||||
def test_rejects_non_dict_inputs_at_each_level(self):
|
||||
assert extract_envelope_from_extra_data(None) is None
|
||||
assert extract_envelope_from_extra_data("not-a-dict") is None
|
||||
assert (
|
||||
extract_envelope_from_extra_data({"extra_pnginfo": "not-a-dict"})
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
extract_envelope_from_extra_data(
|
||||
{"extra_pnginfo": {"workflow": "not-a-dict"}}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_empty_explicit_metadata_falls_through_to_workflow_id(self):
|
||||
extra_data = {
|
||||
"metadata": {},
|
||||
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
|
||||
}
|
||||
assert extract_envelope_from_extra_data(extra_data) == {
|
||||
"workflow_id": "wf-legacy"
|
||||
}
|
||||
|
||||
def test_returned_envelope_is_copy_not_reference(self):
|
||||
original = {"workflow_id": "wf-1"}
|
||||
result = extract_envelope_from_extra_data({"metadata": original})
|
||||
assert result is not None
|
||||
result["new_key"] = "x"
|
||||
assert "new_key" not in original
|
||||
|
||||
def test_non_dict_explicit_metadata_falls_through_to_workflow_id(self):
|
||||
extra_data = {
|
||||
"metadata": "not-a-dict",
|
||||
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
|
||||
}
|
||||
assert extract_envelope_from_extra_data(extra_data) == {
|
||||
"workflow_id": "wf-legacy"
|
||||
}
|
||||
|
||||
|
||||
class TestEnvelopeSanitization:
|
||||
"""The wire contract is ``dict[str, str]`` with bounded size. A bad
|
||||
envelope is dropped (and a warning is logged) rather than truncated,
|
||||
so the boundary stays strict."""
|
||||
|
||||
def test_rejects_too_many_keys(self, caplog):
|
||||
envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}
|
||||
with caplog.at_level("WARNING"):
|
||||
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
|
||||
assert any("exceeds limit" in r.message for r in caplog.records)
|
||||
|
||||
def test_accepts_max_keys_exactly(self):
|
||||
envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS)}
|
||||
assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope
|
||||
|
||||
def test_rejects_non_string_keys(self, caplog):
|
||||
with caplog.at_level("WARNING"):
|
||||
assert (
|
||||
extract_envelope_from_extra_data({"metadata": {42: "v"}})
|
||||
is None
|
||||
)
|
||||
assert any("non-string" in r.message for r in caplog.records)
|
||||
|
||||
def test_rejects_non_string_values(self, caplog):
|
||||
for bad_value in [42, None, ["x"], {"nested": "dict"}, b"bytes"]:
|
||||
with caplog.at_level("WARNING"):
|
||||
assert (
|
||||
extract_envelope_from_extra_data(
|
||||
{"metadata": {"k": bad_value}}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_rejects_oversized_key(self):
|
||||
envelope = {"x" * (MAX_ENVELOPE_KEY_LEN + 1): "v"}
|
||||
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
|
||||
|
||||
def test_rejects_oversized_value(self):
|
||||
envelope = {"k": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)}
|
||||
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
|
||||
|
||||
def test_accepts_max_lengths_exactly(self):
|
||||
envelope = {
|
||||
"x" * MAX_ENVELOPE_KEY_LEN: "y" * MAX_ENVELOPE_VALUE_LEN
|
||||
}
|
||||
assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope
|
||||
|
||||
def test_oversized_workflow_id_in_pnginfo_rejected(self):
|
||||
"""The legacy synthesized path also respects the value bound."""
|
||||
extra_data = {
|
||||
"extra_pnginfo": {
|
||||
"workflow": {"id": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)}
|
||||
}
|
||||
}
|
||||
assert extract_envelope_from_extra_data(extra_data) is None
|
||||
|
||||
def test_invalid_explicit_metadata_does_not_fall_through(self):
|
||||
"""An explicit but invalid metadata dict means the caller asked
|
||||
for something specific and got it wrong; the synthesized
|
||||
fallback must not silently substitute."""
|
||||
extra_data = {
|
||||
"metadata": {"k": 42}, # non-string value
|
||||
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
|
||||
}
|
||||
assert extract_envelope_from_extra_data(extra_data) is None
|
||||
|
||||
|
||||
class TestInjectEnvelope:
|
||||
@staticmethod
|
||||
def _lookup(table):
|
||||
return table.get
|
||||
|
||||
def test_spreads_envelope_keys_onto_payload(self):
|
||||
"""Envelope keys are merged at the top level so consumers can
|
||||
read them directly (e.g. ``event.workflow_id``)."""
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-1", "trace_id": "t-9"}})
|
||||
assert inject_envelope({"node": "5", "prompt_id": "p1"}, lookup) == {
|
||||
"node": "5",
|
||||
"prompt_id": "p1",
|
||||
"workflow_id": "wf-1",
|
||||
"trace_id": "t-9",
|
||||
}
|
||||
|
||||
def test_passthrough_when_prompt_id_not_registered(self):
|
||||
lookup = self._lookup({})
|
||||
data = {"node": "5", "prompt_id": "unknown"}
|
||||
assert inject_envelope(data, lookup) == data
|
||||
|
||||
def test_passthrough_when_payload_lacks_prompt_id(self):
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
|
||||
data = {"status": "ok"}
|
||||
assert inject_envelope(data, lookup) == data
|
||||
|
||||
def test_server_keys_win_on_collision_with_envelope(self):
|
||||
"""A misbehaving client cannot shadow server-emitted fields by
|
||||
stamping the same key in their submission envelope."""
|
||||
lookup = self._lookup({
|
||||
"p1": {"prompt_id": "client-claimed", "node": "spoofed", "workflow_id": "wf-1"}
|
||||
})
|
||||
result = inject_envelope({"prompt_id": "p1", "node": "5"}, lookup)
|
||||
assert result["prompt_id"] == "p1"
|
||||
assert result["node"] == "5"
|
||||
assert result["workflow_id"] == "wf-1"
|
||||
|
||||
def test_does_not_mutate_input_dict(self):
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
|
||||
original = {"node": "5", "prompt_id": "p1"}
|
||||
inject_envelope(original, lookup)
|
||||
assert "workflow_id" not in original
|
||||
|
||||
def test_does_not_mutate_envelope_dict(self):
|
||||
envelope = {"workflow_id": "wf-1"}
|
||||
lookup = self._lookup({"p1": envelope})
|
||||
inject_envelope({"prompt_id": "p1", "node": "5"}, lookup)
|
||||
assert envelope == {"workflow_id": "wf-1"}
|
||||
|
||||
def test_injects_into_inner_dict_of_preview_metadata_tuple(self):
|
||||
"""``PREVIEW_IMAGE_WITH_METADATA`` payloads arrive as
|
||||
``(preview_image, metadata_dict)``; the inner dict is the only
|
||||
place the envelope can attach."""
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
|
||||
preview_image = ("PNG", object(), 256)
|
||||
inner = {"node_id": "5", "prompt_id": "p1"}
|
||||
result = inject_envelope((preview_image, inner), lookup)
|
||||
assert isinstance(result, tuple)
|
||||
assert result[0] is preview_image
|
||||
assert result[1] == {
|
||||
"node_id": "5",
|
||||
"prompt_id": "p1",
|
||||
"workflow_id": "wf-1",
|
||||
}
|
||||
assert "workflow_id" not in inner
|
||||
|
||||
def test_preview_tuple_passthrough_when_no_envelope_registered(self):
|
||||
lookup = self._lookup({})
|
||||
preview_image = ("PNG", object(), 256)
|
||||
inner = {"node_id": "5", "prompt_id": "unknown"}
|
||||
result = inject_envelope((preview_image, inner), lookup)
|
||||
assert result == (preview_image, inner)
|
||||
|
||||
@pytest.mark.parametrize("payload", [b"raw-bytes", None, 42])
|
||||
def test_non_dict_non_tuple_payloads_passthrough(self, payload):
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
|
||||
assert inject_envelope(payload, lookup) == payload
|
||||
|
||||
def test_tuple_of_wrong_arity_passthrough(self):
|
||||
"""Only the 2-tuple ``(preview, metadata_dict)`` shape is
|
||||
special-cased. Other tuples must not be touched."""
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
|
||||
triple = (1, {"prompt_id": "p1"}, 3)
|
||||
assert inject_envelope(triple, lookup) is triple
|
||||
|
||||
def test_envelope_lookup_called_per_invocation(self):
|
||||
"""The lookup runs each time the function is called, so changes
|
||||
to the backing store are immediately visible."""
|
||||
store = {"p1": {"workflow_id": "wf-1"}}
|
||||
first = inject_envelope({"prompt_id": "p1"}, store.get)
|
||||
store["p1"] = {"workflow_id": "wf-2"}
|
||||
second = inject_envelope({"prompt_id": "p1"}, store.get)
|
||||
del store["p1"]
|
||||
third = inject_envelope({"prompt_id": "p1"}, store.get)
|
||||
assert first["workflow_id"] == "wf-1"
|
||||
assert second["workflow_id"] == "wf-2"
|
||||
assert "workflow_id" not in third
|
||||
|
||||
|
||||
class TestPromptMetadataStore:
|
||||
"""End-to-end wiring tests that exercise the full register/inject/
|
||||
unregister cycle the way ``PromptServer`` does."""
|
||||
|
||||
def test_register_inject_unregister_cycle(self):
|
||||
store = PromptMetadataStore()
|
||||
store.register(
|
||||
"p1", {"extra_pnginfo": {"workflow": {"id": "wf-1"}}}
|
||||
)
|
||||
injected = store.inject({"node": "5", "prompt_id": "p1"})
|
||||
assert injected == {
|
||||
"node": "5",
|
||||
"prompt_id": "p1",
|
||||
"workflow_id": "wf-1",
|
||||
}
|
||||
store.unregister("p1")
|
||||
passthrough = store.inject({"node": "5", "prompt_id": "p1"})
|
||||
assert "workflow_id" not in passthrough
|
||||
|
||||
def test_register_with_no_derivable_envelope_is_noop(self):
|
||||
store = PromptMetadataStore()
|
||||
store.register("p1", {})
|
||||
assert "p1" not in store
|
||||
data = {"prompt_id": "p1"}
|
||||
assert store.inject(data) == data
|
||||
|
||||
def test_register_with_oversized_envelope_is_noop(self):
|
||||
"""Sanitization rejection means nothing is registered — the
|
||||
store stays empty and inject is a passthrough."""
|
||||
store = PromptMetadataStore()
|
||||
store.register(
|
||||
"p1",
|
||||
{"metadata": {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}},
|
||||
)
|
||||
assert "p1" not in store
|
||||
|
||||
def test_unregister_unknown_prompt_is_silent(self):
|
||||
store = PromptMetadataStore()
|
||||
store.unregister("does-not-exist")
|
||||
|
||||
def test_fifo_eviction_when_capacity_exceeded(self):
|
||||
"""If cleanup hooks are ever bypassed, the store must shed the
|
||||
oldest entry rather than grow without bound."""
|
||||
store = PromptMetadataStore(capacity=3)
|
||||
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
|
||||
store.register("p2", {"metadata": {"workflow_id": "wf-2"}})
|
||||
store.register("p3", {"metadata": {"workflow_id": "wf-3"}})
|
||||
assert len(store) == 3
|
||||
|
||||
store.register("p4", {"metadata": {"workflow_id": "wf-4"}})
|
||||
assert len(store) == 3
|
||||
assert "p1" not in store
|
||||
assert "p4" in store
|
||||
|
||||
# The newer entries are still injectable.
|
||||
assert store.inject({"prompt_id": "p4"})["workflow_id"] == "wf-4"
|
||||
# The evicted one is gone.
|
||||
assert "workflow_id" not in store.inject({"prompt_id": "p1"})
|
||||
|
||||
def test_register_after_unregister_does_not_count_against_capacity(self):
|
||||
"""Normal lifecycle: register, unregister, register many — the
|
||||
store should not silently evict valid entries because of stale
|
||||
accounting."""
|
||||
store = PromptMetadataStore(capacity=2)
|
||||
for i in range(10):
|
||||
store.register(f"p{i}", {"metadata": {"workflow_id": f"wf-{i}"}})
|
||||
store.unregister(f"p{i}")
|
||||
assert len(store) == 0
|
||||
|
||||
def test_re_register_overwrites(self):
|
||||
store = PromptMetadataStore()
|
||||
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
|
||||
store.register("p1", {"metadata": {"workflow_id": "wf-2"}})
|
||||
assert store.inject({"prompt_id": "p1"})["workflow_id"] == "wf-2"
|
||||
|
||||
def test_inject_with_no_registrations_is_passthrough(self):
|
||||
store = PromptMetadataStore()
|
||||
data = {"prompt_id": "p1", "node": "5"}
|
||||
assert store.inject(data) == data
|
||||
|
||||
def test_inject_into_preview_tuple(self):
|
||||
store = PromptMetadataStore()
|
||||
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
|
||||
result = store.inject((b"image-bytes", {"prompt_id": "p1"}))
|
||||
assert result == (b"image-bytes", {
|
||||
"prompt_id": "p1",
|
||||
"workflow_id": "wf-1",
|
||||
})
|
||||
|
||||
def test_concurrent_access_does_not_corrupt_or_raise(self):
|
||||
"""Smoke test for the store's lock. ``register`` is called from
|
||||
the aiohttp event-loop thread, ``unregister`` from the worker
|
||||
thread, and ``inject`` fires on every ``send_sync`` from
|
||||
whichever thread emits the event. Run all three concurrently
|
||||
and assert no exception escapes and the store stays internally
|
||||
consistent (the FIFO cap is never exceeded)."""
|
||||
import threading
|
||||
|
||||
store = PromptMetadataStore(capacity=64)
|
||||
stop = threading.Event()
|
||||
errors: list[BaseException] = []
|
||||
|
||||
def registrar():
|
||||
i = 0
|
||||
try:
|
||||
while not stop.is_set():
|
||||
store.register(
|
||||
f"p{i % 100}",
|
||||
{"metadata": {"workflow_id": f"wf-{i}"}},
|
||||
)
|
||||
i += 1
|
||||
except BaseException as e:
|
||||
errors.append(e)
|
||||
|
||||
def canceller():
|
||||
i = 0
|
||||
try:
|
||||
while not stop.is_set():
|
||||
store.unregister(f"p{i % 100}")
|
||||
i += 1
|
||||
except BaseException as e:
|
||||
errors.append(e)
|
||||
|
||||
def injector():
|
||||
i = 0
|
||||
try:
|
||||
while not stop.is_set():
|
||||
store.inject({"prompt_id": f"p{i % 100}", "node": "5"})
|
||||
i += 1
|
||||
except BaseException as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=registrar),
|
||||
threading.Thread(target=registrar),
|
||||
threading.Thread(target=canceller),
|
||||
threading.Thread(target=injector),
|
||||
threading.Thread(target=injector),
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
# Brief burst — long enough to interleave many ops, short enough
|
||||
# not to slow CI.
|
||||
threading.Event().wait(0.1)
|
||||
stop.set()
|
||||
for t in threads:
|
||||
t.join(timeout=2.0)
|
||||
|
||||
assert errors == [], f"concurrent access raised: {errors[:3]}"
|
||||
assert len(store) <= 64, "FIFO cap was breached under contention"
|
||||
@ -1,146 +0,0 @@
|
||||
"""Tests for the custom node startup error tracking introduced for
|
||||
Comfy-Org/ComfyUI-Launcher#303.
|
||||
|
||||
Covers:
|
||||
- load_custom_node populates NODE_STARTUP_ERRORS with the correct source
|
||||
for each module_parent (custom_nodes / comfy_extras / comfy_api_nodes).
|
||||
- Composite keying prevents collisions between modules with the same name
|
||||
in different sources.
|
||||
- record_node_startup_error stores the expected fields.
|
||||
- pyproject.toml metadata is attached when present and omitted when absent.
|
||||
"""
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
|
||||
import nodes
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_startup_errors():
|
||||
nodes.NODE_STARTUP_ERRORS.clear()
|
||||
yield
|
||||
nodes.NODE_STARTUP_ERRORS.clear()
|
||||
|
||||
|
||||
def _write_broken_module(tmp_path, name: str) -> str:
|
||||
path = tmp_path / f"{name}.py"
|
||||
path.write_text(textwrap.dedent("""\
|
||||
# Deliberately broken module to exercise startup-error tracking.
|
||||
raise RuntimeError("boom from " + __name__)
|
||||
"""))
|
||||
return str(path)
|
||||
|
||||
|
||||
def test_record_node_startup_error_fields(tmp_path):
|
||||
err = ValueError("kaboom")
|
||||
nodes.record_node_startup_error(
|
||||
module_path=str(tmp_path / "my_pack"),
|
||||
source="custom_nodes",
|
||||
phase="import",
|
||||
error=err,
|
||||
tb="traceback-text",
|
||||
)
|
||||
assert "custom_nodes:my_pack" in nodes.NODE_STARTUP_ERRORS
|
||||
entry = nodes.NODE_STARTUP_ERRORS["custom_nodes:my_pack"]
|
||||
assert entry["source"] == "custom_nodes"
|
||||
assert entry["module_name"] == "my_pack"
|
||||
assert entry["phase"] == "import"
|
||||
assert entry["error"] == "kaboom"
|
||||
assert entry["traceback"] == "traceback-text"
|
||||
assert entry["module_path"].endswith("my_pack")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"module_parent",
|
||||
["custom_nodes", "comfy_extras", "comfy_api_nodes"],
|
||||
)
|
||||
async def test_load_custom_node_records_source(tmp_path, module_parent):
|
||||
# `source` in the entry should be the same string as `module_parent`.
|
||||
module_path = _write_broken_module(tmp_path, "broken_pack")
|
||||
|
||||
success = await nodes.load_custom_node(module_path, module_parent=module_parent)
|
||||
assert success is False
|
||||
|
||||
key = f"{module_parent}:broken_pack"
|
||||
assert key in nodes.NODE_STARTUP_ERRORS, nodes.NODE_STARTUP_ERRORS
|
||||
entry = nodes.NODE_STARTUP_ERRORS[key]
|
||||
assert entry["source"] == module_parent
|
||||
assert entry["module_name"] == "broken_pack"
|
||||
assert entry["phase"] == "import"
|
||||
assert "boom from" in entry["error"]
|
||||
assert "RuntimeError" in entry["traceback"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_node_collision_across_sources(tmp_path):
|
||||
# Same module name registered as both a custom node and a comfy_extra;
|
||||
# composite keying should keep both entries.
|
||||
cn_dir = tmp_path / "cn"
|
||||
extras_dir = tmp_path / "extras"
|
||||
cn_dir.mkdir()
|
||||
extras_dir.mkdir()
|
||||
cn_path = _write_broken_module(cn_dir, "nodes_audio")
|
||||
extras_path = _write_broken_module(extras_dir, "nodes_audio")
|
||||
|
||||
assert await nodes.load_custom_node(cn_path, module_parent="custom_nodes") is False
|
||||
assert await nodes.load_custom_node(extras_path, module_parent="comfy_extras") is False
|
||||
|
||||
assert "custom_nodes:nodes_audio" in nodes.NODE_STARTUP_ERRORS
|
||||
assert "comfy_extras:nodes_audio" in nodes.NODE_STARTUP_ERRORS
|
||||
assert (
|
||||
nodes.NODE_STARTUP_ERRORS["custom_nodes:nodes_audio"]["module_path"]
|
||||
!= nodes.NODE_STARTUP_ERRORS["comfy_extras:nodes_audio"]["module_path"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_node_attaches_pyproject_metadata(tmp_path):
|
||||
pack_dir = tmp_path / "MyCoolPack"
|
||||
pack_dir.mkdir()
|
||||
(pack_dir / "__init__.py").write_text("raise RuntimeError('boom')\n")
|
||||
(pack_dir / "pyproject.toml").write_text(textwrap.dedent("""\
|
||||
[project]
|
||||
name = "comfyui-mycoolpack"
|
||||
version = "1.2.3"
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://github.com/example/comfyui-mycoolpack"
|
||||
|
||||
[tool.comfy]
|
||||
PublisherId = "example"
|
||||
DisplayName = "My Cool Pack"
|
||||
"""))
|
||||
|
||||
success = await nodes.load_custom_node(str(pack_dir), module_parent="custom_nodes")
|
||||
assert success is False
|
||||
|
||||
entry = nodes.NODE_STARTUP_ERRORS["custom_nodes:MyCoolPack"]
|
||||
assert "pyproject" in entry, entry
|
||||
py = entry["pyproject"]
|
||||
assert py["pack_id"] == "comfyui-mycoolpack"
|
||||
assert py["display_name"] == "My Cool Pack"
|
||||
assert py["publisher_id"] == "example"
|
||||
assert py["version"] == "1.2.3"
|
||||
assert py["repository"] == "https://github.com/example/comfyui-mycoolpack"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_node_no_pyproject_skips_metadata(tmp_path):
|
||||
# Single-file extras-style module: no pyproject.toml exists alongside it,
|
||||
# so the entry must not contain a 'pyproject' key.
|
||||
module_path = _write_broken_module(tmp_path, "lonely")
|
||||
assert await nodes.load_custom_node(module_path, module_parent="comfy_extras") is False
|
||||
entry = nodes.NODE_STARTUP_ERRORS["comfy_extras:lonely"]
|
||||
assert "pyproject" not in entry
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_node_arbitrary_module_parent_passes_through(tmp_path):
|
||||
# `source` is a free-form string — an unknown module_parent (e.g. a future
|
||||
# node-source bucket) should be recorded as-is, not coerced or rejected.
|
||||
module_path = _write_broken_module(tmp_path, "future_pack")
|
||||
assert await nodes.load_custom_node(module_path, module_parent="future_source") is False
|
||||
entry = nodes.NODE_STARTUP_ERRORS["future_source:future_pack"]
|
||||
assert entry["source"] == "future_source"
|
||||
Reference in New Issue
Block a user