Compare commits

..

1 Commits

Author SHA1 Message Date
aa2c6a8492 [Partner Nodes] add ByteDance Seed LLM node
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-15 21:30:55 +03:00
9 changed files with 462 additions and 363 deletions

View File

@ -22,25 +22,26 @@ class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing."""
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
def __init__(self, tensor: torch.Tensor, patches_per_frame: int, per_frame: bool = False):
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
"""
tensor: [batch, num_tokens, feature_dim] (per-token, default) or
[batch, num_frames, feature_dim] (per_frame=True, already compressed).
patches_per_frame: spatial patches per frame; pass None to disable compression.
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
"""
self.batch_size, n, self.feature_dim = tensor.shape
if per_frame:
self.batch_size, num_tokens, self.feature_dim = tensor.shape
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
self.patches_per_frame = patches_per_frame
self.num_frames = n
self.data = tensor
elif patches_per_frame is not None and n >= patches_per_frame and n % patches_per_frame == 0:
self.patches_per_frame = patches_per_frame
self.num_frames = n // patches_per_frame
# All patches in a frame are identical — keep only the first.
self.data = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)[:, :, 0, :].contiguous()
self.num_frames = num_tokens // patches_per_frame
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
# All patches in a frame are identical, so we only keep the first one
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
else:
# Not divisible or too small - store directly without compression
self.patches_per_frame = 1
self.num_frames = n
self.num_frames = num_tokens
self.data = tensor
def expand(self):
@ -715,35 +716,32 @@ class LTXAVModel(LTXVModel):
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
"""Prepare timestep embeddings."""
# TODO: some code reuse is needed here.
grid_mask = kwargs.get("grid_mask", None)
orig_shape = kwargs.get("orig_shape")
has_spatial_mask = kwargs.get("has_spatial_mask", None)
v_patches_per_frame = None
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
v_patches_per_frame = orig_shape[3] * orig_shape[4]
if grid_mask is not None:
timestep = timestep[:, grid_mask]
# Used by compute_prompt_timestep and the audio cross-attention paths.
timestep_scaled = (timestep[:, grid_mask] if grid_mask is not None else timestep) * self.timestep_scale_multiplier
# When patches in a frame share a timestep (no spatial mask), project one row per frame instead of one per token
per_frame_path = v_patches_per_frame is not None and (timestep.numel() // batch_size) % v_patches_per_frame == 0
if per_frame_path:
per_frame = timestep.reshape(batch_size, -1, v_patches_per_frame)[:, :, 0]
if grid_mask is not None:
# All-or-nothing per frame when has_spatial_mask=False.
per_frame = per_frame[:, grid_mask[::v_patches_per_frame]]
ts_input = per_frame * self.timestep_scale_multiplier
else:
ts_input = timestep_scaled
timestep_scaled = timestep * self.timestep_scale_multiplier
v_timestep, v_embedded_timestep = self.adaln_single(
ts_input.flatten(),
timestep_scaled.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
orig_shape = kwargs.get("orig_shape")
has_spatial_mask = kwargs.get("has_spatial_mask", None)
v_patches_per_frame = None
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
v_patches_per_frame = orig_shape[3] * orig_shape[4]
# Reshape to [batch_size, num_tokens, dim] and compress for storage
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
v_prompt_timestep = compute_prompt_timestep(
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype

View File

@ -358,61 +358,6 @@ def apply_split_rotary_emb(input_tensor, cos, sin):
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
class GuideAttentionMask:
"""Holds the two per-group masks for LTXV guide self-attention.
_attention_with_guide_mask splits queries into noisy and tracked-guide
groups, so the largest mask is (1, 1, tracked_count, T).
"""
__slots__ = ("guide_start", "tracked_count", "noisy_mask", "tracked_mask")
def __init__(self, total_tokens, guide_start, tracked_count, tracked_weights):
device = tracked_weights.device
dtype = tracked_weights.dtype
finfo = torch.finfo(dtype)
pos = tracked_weights > 0
log_w = torch.full_like(tracked_weights, finfo.min)
log_w[pos] = torch.log(tracked_weights[pos].clamp(min=finfo.tiny))
self.guide_start = guide_start
self.tracked_count = tracked_count
self.noisy_mask = torch.zeros((1, 1, 1, total_tokens), device=device, dtype=dtype)
self.noisy_mask[:, :, :, guide_start:guide_start + tracked_count] = log_w.view(1, 1, 1, -1)
self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype)
self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1)
def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options):
"""Apply the guide mask by partitioning Q into noisy and tracked-guide
groups, so each group needs only its own sub-mask. Avoids materializing
the (1,1,T,T) dense mask.
"""
guide_start = guide_mask.guide_start
tracked_end = guide_start + guide_mask.tracked_count
out = torch.empty_like(q)
if guide_start > 0: # In practice currently guides are always after noise, guard for safety if this changes.
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False, # sageattn mask support is unreliable
)
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False,
)
if tracked_end < q.shape[1]: # Every guide token is tracked, and nothing comes after them, guard for safety if this changes.
out[:, tracked_end:, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, tracked_end:, :], k, v, heads,
attn_precision=attn_precision, transformer_options=transformer_options,
)
return out
class CrossAttention(nn.Module):
def __init__(
self,
@ -467,10 +412,8 @@ class CrossAttention(nn.Module):
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
elif isinstance(mask, GuideAttentionMask):
out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
# Apply per-head gating if enabled
if self.to_gate_logits is not None:
@ -1120,9 +1063,7 @@ class LTXVModel(LTXBaseModel):
additional_args["resolved_guide_entries"] = resolved_entries
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
# Total surviving guide tokens (all guides)
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
@ -1158,12 +1099,12 @@ class LTXVModel(LTXBaseModel):
if not resolved_entries:
return None
# strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention.
needs_mask = any(
e["strength"] != 1.0 or e.get("pixel_mask") is not None
# Check if any attenuation is actually needed
needs_attenuation = any(
e["strength"] < 1.0 or e.get("pixel_mask") is not None
for e in resolved_entries
)
if not needs_mask:
if not needs_attenuation:
return None
# Build per-guide-token weights for all tracked guide tokens.
@ -1218,11 +1159,16 @@ class LTXVModel(LTXBaseModel):
# Concatenate per-token weights for all tracked guides
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
# Skip when every weight is exactly 1.0 (additive bias would be 0).
if (tracked_weights == 1.0).all():
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
if (tracked_weights >= 1.0).all():
return None
return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights)
# Build the mask: guide tokens are at the end of the sequence.
# Tracked guides come first (in order), untracked follow.
return self._build_self_attention_mask(
total_tokens, num_guide_tokens, total_tracked,
tracked_weights, guide_start, device, dtype,
)
@staticmethod
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
@ -1288,6 +1234,45 @@ class LTXVModel(LTXBaseModel):
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
@staticmethod
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
tracked_weights, guide_start, device, dtype):
"""Build a log-space additive self-attention bias mask.
Attenuates attention between noisy tokens and tracked guide tokens.
Untracked guide tokens (at the end of the guide portion) keep full attention.
Args:
total_tokens: Total sequence length.
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
tracked_count: Number of tracked guide tokens (first in the guide portion).
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
guide_start: Index where guide tokens begin in the sequence.
device: Target device.
dtype: Target dtype.
Returns:
(1, 1, total_tokens, total_tokens) additive bias mask.
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
"""
finfo = torch.finfo(dtype)
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
tracked_end = guide_start + tracked_count
# Convert weights to log-space bias
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
log_w = torch.full_like(w, finfo.min)
positive_mask = w > 0
if positive_mask.any():
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
# noisy → tracked guides: each noisy row gets the same per-guide weight
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
return mask
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
"""Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {})

View File

@ -0,0 +1,101 @@
"""Pydantic models for BytePlus ModelArk Responses API.
See: https://docs.byteplus.com/en/docs/ModelArk/1585128 (request)
https://docs.byteplus.com/en/docs/ModelArk/1783703 (response)
"""
from typing import Literal
from pydantic import BaseModel, Field
class BytePlusInputText(BaseModel):
type: Literal["input_text"] = "input_text"
text: str = Field(...)
class BytePlusInputImage(BaseModel):
type: Literal["input_image"] = "input_image"
image_url: str = Field(..., description="Image URL or `data:image/...;base64,...` payload")
detail: str = Field("auto", description="One of high, low, auto")
class BytePlusInputVideo(BaseModel):
type: Literal["input_video"] = "input_video"
video_url: str = Field(..., description="Video URL or `data:video/...;base64,...` payload")
fps: float | None = Field(None, ge=0.2, le=5.0)
BytePlusMessageContent = BytePlusInputText | BytePlusInputImage | BytePlusInputVideo
class BytePlusInputMessage(BaseModel):
type: Literal["message"] = "message"
role: str = Field(..., description="One of user, system, assistant, developer")
content: list[BytePlusMessageContent] = Field(...)
class BytePlusResponseCreateRequest(BaseModel):
model: str = Field(...)
input: list[BytePlusInputMessage] = Field(...)
instructions: str | None = Field(None)
max_output_tokens: int | None = Field(None, ge=1)
temperature: float | None = Field(None, ge=0.0, le=2.0)
store: bool | None = Field(False)
stream: bool | None = Field(False)
class BytePlusOutputText(BaseModel):
type: Literal["output_text"] = "output_text"
text: str = Field(...)
class BytePlusOutputRefusal(BaseModel):
type: Literal["refusal"] = "refusal"
refusal: str = Field(...)
class BytePlusOutputContent(BaseModel):
type: str = Field(...)
text: str | None = Field(None)
refusal: str | None = Field(None)
class BytePlusOutputMessage(BaseModel):
type: str = Field(...)
id: str | None = Field(None)
role: str | None = Field(None)
status: str | None = Field(None)
content: list[BytePlusOutputContent] | None = Field(None)
class BytePlusInputTokensDetails(BaseModel):
cached_tokens: int | None = Field(None)
class BytePlusOutputTokensDetails(BaseModel):
reasoning_tokens: int | None = Field(None)
class BytePlusResponseUsage(BaseModel):
input_tokens: int | None = Field(None)
output_tokens: int | None = Field(None)
total_tokens: int | None = Field(None)
input_tokens_details: BytePlusInputTokensDetails | None = Field(None)
output_tokens_details: BytePlusOutputTokensDetails | None = Field(None)
class BytePlusResponseError(BaseModel):
code: str = Field(...)
message: str = Field(...)
class BytePlusResponseObject(BaseModel):
id: str | None = Field(None)
object: str | None = Field(None)
created_at: int | None = Field(None)
model: str | None = Field(None)
status: str | None = Field(None)
error: BytePlusResponseError | None = Field(None)
output: list[BytePlusOutputMessage] | None = Field(None)
usage: BytePlusResponseUsage | None = Field(None)

View File

@ -0,0 +1,271 @@
"""API Nodes for ByteDance Seed LLM via the BytePlus ModelArk Responses API.
See: https://docs.byteplus.com/en/docs/ModelArk/1585128
"""
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bytedance_llm import (
BytePlusInputImage,
BytePlusInputMessage,
BytePlusInputText,
BytePlusInputVideo,
BytePlusMessageContent,
BytePlusResponseCreateRequest,
BytePlusResponseObject,
)
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
sync_op,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
)
BYTEPLUS_RESPONSES_ENDPOINT = "/proxy/byteplus/api/v3/responses"
SEED_MAX_IMAGES = 20
SEED_MAX_VIDEOS = 4
SEED_MODELS: dict[str, str] = {
"Seed 2.0 Pro": "seed-2-0-pro-260328",
"Seed 2.0 Lite": "seed-2-0-lite-260228",
"Seed 2.0 Mini": "seed-2-0-mini-260215",
}
# USD per 1M tokens: (input, cache_hit_input, output)
_SEED_PRICES_PER_MILLION: dict[str, tuple[float, float, float]] = {
"seed-2-0-pro-260328": (0.50, 0.10, 3.00),
"seed-2-0-lite-260228": (0.25, 0.05, 2.00),
"seed-2-0-mini-260215": (0.10, 0.02, 0.40),
}
def _seed_model_inputs(max_images: int = SEED_MAX_IMAGES, max_videos: int = SEED_MAX_VIDEOS):
return [
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, max_images + 1)],
min=0,
),
tooltip=f"Optional image(s) to use as context for the model. Up to {max_images} images.",
),
IO.Autogrow.Input(
"videos",
template=IO.Autogrow.TemplateNames(
IO.Video.Input("video"),
names=[f"video_{i}" for i in range(1, max_videos + 1)],
min=0,
),
tooltip=f"Optional video(s) to use as context for the model. Up to {max_videos} videos.",
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.01,
tooltip="Controls randomness. 0.0 is deterministic, higher values are more random.",
advanced=True,
),
]
def _calculate_price(model_id: str, response: BytePlusResponseObject) -> float | None:
"""Compute approximate USD price from response usage."""
if not response.usage:
return None
rates = _SEED_PRICES_PER_MILLION.get(model_id)
if rates is None:
return None
input_rate, cache_hit_rate, output_rate = rates
input_tokens = response.usage.input_tokens or 0
output_tokens = response.usage.output_tokens or 0
cached = 0
if response.usage.input_tokens_details:
cached = response.usage.input_tokens_details.cached_tokens or 0
fresh_input = max(0, input_tokens - cached)
total = fresh_input * input_rate + cached * cache_hit_rate + output_tokens * output_rate
return total / 1_000_000.0
def _get_text_from_response(response: BytePlusResponseObject) -> str:
"""Extract concatenated text from all assistant message output_text blocks."""
if not response.output:
return ""
chunks: list[str] = []
for item in response.output:
if item.type != "message" or not item.content:
continue
for block in item.content:
if block.type == "output_text" and block.text:
chunks.append(block.text)
elif block.type == "refusal" and block.refusal:
raise ValueError(f"Model refused to respond: {block.refusal}")
return "\n".join(chunks)
async def _build_image_content_blocks(
cls: type[IO.ComfyNode],
image_tensors: list[Input.Image],
) -> list[BytePlusInputImage]:
urls = await upload_images_to_comfyapi(
cls,
image_tensors,
max_images=SEED_MAX_IMAGES,
wait_label="Uploading reference images",
)
return [BytePlusInputImage(image_url=url) for url in urls]
async def _build_video_content_blocks(
cls: type[IO.ComfyNode],
videos: list[Input.Video],
) -> list[BytePlusInputVideo]:
blocks: list[BytePlusInputVideo] = []
total = len(videos)
for idx, video in enumerate(videos):
label = "Uploading reference video"
if total > 1:
label = f"{label} ({idx + 1}/{total})"
url = await upload_video_to_comfyapi(cls, video, wait_label=label)
blocks.append(BytePlusInputVideo(video_url=url))
return blocks
class ByteDanceSeedNode(IO.ComfyNode):
"""Generate text responses from a ByteDance Seed 2.0 model."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ByteDanceSeedNode",
display_name="ByteDance Seed",
category="api node/text/ByteDance",
essentials_category="Text Generation",
description="Generate text responses with ByteDance's Seed 2.0 models. "
"Provide a text prompt and optionally one or more images or videos for multimodal context.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text input to the model.",
),
IO.DynamicCombo.Input(
"model",
options=[IO.DynamicCombo.Option(label, _seed_model_inputs()) for label in SEED_MODELS],
tooltip="The Seed model used to generate the response.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default="",
optional=True,
advanced=True,
tooltip="Foundational instructions that dictate the model's behavior.",
),
],
outputs=[IO.String.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$m := widgets.model;
$contains($m, "mini") ? {
"type": "list_usd",
"usd": [0.00025, 0.0009],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "lite") ? {
"type": "list_usd",
"usd": [0.0003, 0.002],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "pro") ? {
"type": "list_usd",
"usd": [0.0005, 0.003],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: {"type":"text", "text":"Token-based"}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_label = model["model"]
temperature = model["temperature"]
model_id = SEED_MODELS[model_label]
image_tensors: list[Input.Image] = [t for t in (model.get("images") or {}).values() if t is not None]
if sum(get_number_of_images(t) for t in image_tensors) > SEED_MAX_IMAGES:
raise ValueError(f"Up to {SEED_MAX_IMAGES} images are supported per request.")
video_inputs: list[Input.Video] = [v for v in (model.get("videos") or {}).values() if v is not None]
if len(video_inputs) > SEED_MAX_VIDEOS:
raise ValueError(f"Up to {SEED_MAX_VIDEOS} videos are supported per request.")
content: list[BytePlusMessageContent] = []
if image_tensors:
content.extend(await _build_image_content_blocks(cls, image_tensors))
if video_inputs:
content.extend(await _build_video_content_blocks(cls, video_inputs))
content.append(BytePlusInputText(text=prompt))
response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_RESPONSES_ENDPOINT, method="POST"),
response_model=BytePlusResponseObject,
data=BytePlusResponseCreateRequest(
model=model_id,
input=[BytePlusInputMessage(role="user", content=content)],
instructions=system_prompt or None,
temperature=temperature,
store=False,
stream=False,
),
price_extractor=lambda r: _calculate_price(model_id, r),
)
if response.error:
raise ValueError(f"Seed API error ({response.error.code}): {response.error.message}")
result = _get_text_from_response(response)
if not result:
raise ValueError("Empty response from Seed model.")
return IO.NodeOutput(result)
class ByteDanceLLMExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ByteDanceSeedNode]
async def comfy_entrypoint() -> ByteDanceLLMExtension:
return ByteDanceLLMExtension()

View File

@ -219,7 +219,7 @@ class LTXVAddGuide(io.ComfyNode):
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
@ -298,7 +298,7 @@ class LTXVAddGuide(io.ComfyNode):
else:
mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
1.0 - strength,
dtype=noise_mask.dtype,
device=noise_mask.device,
)
@ -318,7 +318,7 @@ class LTXVAddGuide(io.ComfyNode):
mask = torch.full(
(noise_mask.shape[0], 1, cond_length, 1, 1),
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
1.0 - strength,
dtype=noise_mask.dtype,
device=noise_mask.device,
)

View File

@ -27,7 +27,6 @@ from utils.mime_types import init_mime_types
import faulthandler
import logging
import sys
import traceback
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags
@ -149,14 +148,6 @@ def execute_prestartup_script():
return True
except Exception as e:
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
from nodes import record_node_startup_error
record_node_startup_error(
module_path=os.path.dirname(script_path),
source="custom_nodes",
phase="prestartup",
error=e,
tb=traceback.format_exc(),
)
return False
node_paths = folder_paths.get_folder_paths("custom_nodes")

View File

@ -2154,71 +2154,6 @@ EXTENSION_WEB_DIRS = {}
# Dictionary of successfully loaded module names and associated directories.
LOADED_MODULE_DIRS = {}
# Dictionary of custom node startup errors, keyed by "<source>:<module_name>"
# so that name collisions across custom_nodes / comfy_extras / comfy_api_nodes
# do not overwrite each other. Each value contains: source, module_name,
# module_path, error, traceback, phase.
#
# `source` is the same string as the internal `module_parent` used at load
# time (e.g. "custom_nodes", "comfy_extras", "comfy_api_nodes"). It is
# intentionally a free-form string rather than a fixed enum so the contract
# survives node-source layouts evolving (e.g. comfy_api_nodes eventually
# moving out of core). Consumers should treat any new value as a new bucket
# rather than rejecting it.
NODE_STARTUP_ERRORS: dict[str, dict] = {}
def _read_pyproject_metadata(module_path: str) -> dict | None:
"""Best-effort extraction of node-pack identity from pyproject.toml.
Returns a dict with the Comfy Registry-style identity (pack_id,
display_name, publisher_id, version, repository) when the module
directory contains a pyproject.toml. Returns None when no toml is
present or parsing fails for any reason — startup-error tracking
must never itself raise.
"""
if not module_path or not os.path.isdir(module_path):
return None
toml_path = os.path.join(module_path, "pyproject.toml")
if not os.path.isfile(toml_path):
return None
try:
from comfy_config import config_parser
cfg = config_parser.extract_node_configuration(module_path)
if cfg is None:
return None
meta = {
"pack_id": cfg.project.name or None,
"display_name": cfg.tool_comfy.display_name or None,
"publisher_id": cfg.tool_comfy.publisher_id or None,
"version": cfg.project.version or None,
"repository": cfg.project.urls.repository or None,
}
# Drop empty fields so the API payload stays compact.
return {k: v for k, v in meta.items() if v}
except Exception:
return None
def record_node_startup_error(
*, module_path: str, source: str, phase: str, error: BaseException, tb: str
) -> None:
"""Record a startup error for a node module so it can be exposed via the API."""
module_name = get_module_name(module_path)
entry = {
"source": source,
"module_name": module_name,
"module_path": module_path,
"error": str(error),
"traceback": tb,
"phase": phase,
}
pyproject = _read_pyproject_metadata(module_path)
if pyproject:
entry["pyproject"] = pyproject
NODE_STARTUP_ERRORS[f"{source}:{module_name}"] = entry
def get_module_name(module_path: str) -> str:
"""
@ -2328,30 +2263,14 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
return True
except Exception as e:
tb = traceback.format_exc()
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
record_node_startup_error(
module_path=module_path,
source=module_parent,
phase="entrypoint",
error=e,
tb=tb,
)
return False
else:
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or comfy_entrypoint (need one).")
return False
except Exception as e:
tb = traceback.format_exc()
logging.warning(tb)
logging.warning(traceback.format_exc())
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
record_node_startup_error(
module_path=module_path,
source=module_parent,
phase="import",
error=e,
tb=tb,
)
return False
async def init_external_custom_nodes():

View File

@ -765,26 +765,6 @@ class PromptServer():
out[node_class] = node_info(node_class)
return web.json_response(out)
@routes.get("/node_startup_errors")
async def get_node_startup_errors(request):
# Group errors by source so the frontend/Manager can render them
# in distinct sections. `source` is the same string as the
# module_parent used at load time (e.g. "custom_nodes",
# "comfy_extras", "comfy_api_nodes") and is left as a free-form
# string so the contract survives node-source layouts evolving.
# The response only contains source buckets that actually had a
# failure; consumers should not assume any particular set of keys
# is always present.
#
# `module_path` is stripped because the absolute on-disk path is
# internal detail that the frontend has no use for.
grouped: dict[str, dict[str, dict]] = {}
for entry in nodes.NODE_STARTUP_ERRORS.values():
source = entry.get("source", "custom_nodes")
public_entry = {k: v for k, v in entry.items() if k != "module_path"}
grouped.setdefault(source, {})[entry["module_name"]] = public_entry
return web.json_response(grouped)
@routes.get("/api/jobs")
async def get_jobs(request):
"""List all jobs with filtering, sorting, and pagination.

View File

@ -1,146 +0,0 @@
"""Tests for the custom node startup error tracking introduced for
Comfy-Org/ComfyUI-Launcher#303.
Covers:
- load_custom_node populates NODE_STARTUP_ERRORS with the correct source
for each module_parent (custom_nodes / comfy_extras / comfy_api_nodes).
- Composite keying prevents collisions between modules with the same name
in different sources.
- record_node_startup_error stores the expected fields.
- pyproject.toml metadata is attached when present and omitted when absent.
"""
import textwrap
import pytest
import nodes
@pytest.fixture(autouse=True)
def _clear_startup_errors():
nodes.NODE_STARTUP_ERRORS.clear()
yield
nodes.NODE_STARTUP_ERRORS.clear()
def _write_broken_module(tmp_path, name: str) -> str:
path = tmp_path / f"{name}.py"
path.write_text(textwrap.dedent("""\
# Deliberately broken module to exercise startup-error tracking.
raise RuntimeError("boom from " + __name__)
"""))
return str(path)
def test_record_node_startup_error_fields(tmp_path):
err = ValueError("kaboom")
nodes.record_node_startup_error(
module_path=str(tmp_path / "my_pack"),
source="custom_nodes",
phase="import",
error=err,
tb="traceback-text",
)
assert "custom_nodes:my_pack" in nodes.NODE_STARTUP_ERRORS
entry = nodes.NODE_STARTUP_ERRORS["custom_nodes:my_pack"]
assert entry["source"] == "custom_nodes"
assert entry["module_name"] == "my_pack"
assert entry["phase"] == "import"
assert entry["error"] == "kaboom"
assert entry["traceback"] == "traceback-text"
assert entry["module_path"].endswith("my_pack")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"module_parent",
["custom_nodes", "comfy_extras", "comfy_api_nodes"],
)
async def test_load_custom_node_records_source(tmp_path, module_parent):
# `source` in the entry should be the same string as `module_parent`.
module_path = _write_broken_module(tmp_path, "broken_pack")
success = await nodes.load_custom_node(module_path, module_parent=module_parent)
assert success is False
key = f"{module_parent}:broken_pack"
assert key in nodes.NODE_STARTUP_ERRORS, nodes.NODE_STARTUP_ERRORS
entry = nodes.NODE_STARTUP_ERRORS[key]
assert entry["source"] == module_parent
assert entry["module_name"] == "broken_pack"
assert entry["phase"] == "import"
assert "boom from" in entry["error"]
assert "RuntimeError" in entry["traceback"]
@pytest.mark.asyncio
async def test_load_custom_node_collision_across_sources(tmp_path):
# Same module name registered as both a custom node and a comfy_extra;
# composite keying should keep both entries.
cn_dir = tmp_path / "cn"
extras_dir = tmp_path / "extras"
cn_dir.mkdir()
extras_dir.mkdir()
cn_path = _write_broken_module(cn_dir, "nodes_audio")
extras_path = _write_broken_module(extras_dir, "nodes_audio")
assert await nodes.load_custom_node(cn_path, module_parent="custom_nodes") is False
assert await nodes.load_custom_node(extras_path, module_parent="comfy_extras") is False
assert "custom_nodes:nodes_audio" in nodes.NODE_STARTUP_ERRORS
assert "comfy_extras:nodes_audio" in nodes.NODE_STARTUP_ERRORS
assert (
nodes.NODE_STARTUP_ERRORS["custom_nodes:nodes_audio"]["module_path"]
!= nodes.NODE_STARTUP_ERRORS["comfy_extras:nodes_audio"]["module_path"]
)
@pytest.mark.asyncio
async def test_load_custom_node_attaches_pyproject_metadata(tmp_path):
pack_dir = tmp_path / "MyCoolPack"
pack_dir.mkdir()
(pack_dir / "__init__.py").write_text("raise RuntimeError('boom')\n")
(pack_dir / "pyproject.toml").write_text(textwrap.dedent("""\
[project]
name = "comfyui-mycoolpack"
version = "1.2.3"
[project.urls]
Repository = "https://github.com/example/comfyui-mycoolpack"
[tool.comfy]
PublisherId = "example"
DisplayName = "My Cool Pack"
"""))
success = await nodes.load_custom_node(str(pack_dir), module_parent="custom_nodes")
assert success is False
entry = nodes.NODE_STARTUP_ERRORS["custom_nodes:MyCoolPack"]
assert "pyproject" in entry, entry
py = entry["pyproject"]
assert py["pack_id"] == "comfyui-mycoolpack"
assert py["display_name"] == "My Cool Pack"
assert py["publisher_id"] == "example"
assert py["version"] == "1.2.3"
assert py["repository"] == "https://github.com/example/comfyui-mycoolpack"
@pytest.mark.asyncio
async def test_load_custom_node_no_pyproject_skips_metadata(tmp_path):
# Single-file extras-style module: no pyproject.toml exists alongside it,
# so the entry must not contain a 'pyproject' key.
module_path = _write_broken_module(tmp_path, "lonely")
assert await nodes.load_custom_node(module_path, module_parent="comfy_extras") is False
entry = nodes.NODE_STARTUP_ERRORS["comfy_extras:lonely"]
assert "pyproject" not in entry
@pytest.mark.asyncio
async def test_load_custom_node_arbitrary_module_parent_passes_through(tmp_path):
# `source` is a free-form string — an unknown module_parent (e.g. a future
# node-source bucket) should be recorded as-is, not coerced or rejected.
module_path = _write_broken_module(tmp_path, "future_pack")
assert await nodes.load_custom_node(module_path, module_parent="future_source") is False
entry = nodes.NODE_STARTUP_ERRORS["future_source:future_pack"]
assert entry["source"] == "future_source"