mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-16 22:57:09 +08:00
Compare commits
1 Commits
feature/cu
...
feat/api-n
| Author | SHA1 | Date | |
|---|---|---|---|
| aa2c6a8492 |
@ -22,25 +22,26 @@ class CompressedTimestep:
|
||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int, per_frame: bool = False):
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
"""
|
||||
tensor: [batch, num_tokens, feature_dim] (per-token, default) or
|
||||
[batch, num_frames, feature_dim] (per_frame=True, already compressed).
|
||||
patches_per_frame: spatial patches per frame; pass None to disable compression.
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
|
||||
"""
|
||||
self.batch_size, n, self.feature_dim = tensor.shape
|
||||
if per_frame:
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = n
|
||||
self.data = tensor
|
||||
elif patches_per_frame is not None and n >= patches_per_frame and n % patches_per_frame == 0:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = n // patches_per_frame
|
||||
# All patches in a frame are identical — keep only the first.
|
||||
self.data = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)[:, :, 0, :].contiguous()
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
|
||||
# All patches in a frame are identical, so we only keep the first one
|
||||
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
|
||||
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
|
||||
else:
|
||||
# Not divisible or too small - store directly without compression
|
||||
self.patches_per_frame = 1
|
||||
self.num_frames = n
|
||||
self.num_frames = num_tokens
|
||||
self.data = tensor
|
||||
|
||||
def expand(self):
|
||||
@ -715,35 +716,32 @@ class LTXAVModel(LTXVModel):
|
||||
|
||||
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||
"""Prepare timestep embeddings."""
|
||||
# TODO: some code reuse is needed here.
|
||||
grid_mask = kwargs.get("grid_mask", None)
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||
v_patches_per_frame = None
|
||||
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
# Used by compute_prompt_timestep and the audio cross-attention paths.
|
||||
timestep_scaled = (timestep[:, grid_mask] if grid_mask is not None else timestep) * self.timestep_scale_multiplier
|
||||
|
||||
# When patches in a frame share a timestep (no spatial mask), project one row per frame instead of one per token
|
||||
per_frame_path = v_patches_per_frame is not None and (timestep.numel() // batch_size) % v_patches_per_frame == 0
|
||||
if per_frame_path:
|
||||
per_frame = timestep.reshape(batch_size, -1, v_patches_per_frame)[:, :, 0]
|
||||
if grid_mask is not None:
|
||||
# All-or-nothing per frame when has_spatial_mask=False.
|
||||
per_frame = per_frame[:, grid_mask[::v_patches_per_frame]]
|
||||
ts_input = per_frame * self.timestep_scale_multiplier
|
||||
else:
|
||||
ts_input = timestep_scaled
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
|
||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||
ts_input.flatten(),
|
||||
timestep_scaled.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
|
||||
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||
v_patches_per_frame = None
|
||||
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
# Reshape to [batch_size, num_tokens, dim] and compress for storage
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||
|
||||
v_prompt_timestep = compute_prompt_timestep(
|
||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||
|
||||
@ -358,61 +358,6 @@ def apply_split_rotary_emb(input_tensor, cos, sin):
|
||||
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
|
||||
|
||||
|
||||
class GuideAttentionMask:
|
||||
"""Holds the two per-group masks for LTXV guide self-attention.
|
||||
_attention_with_guide_mask splits queries into noisy and tracked-guide
|
||||
groups, so the largest mask is (1, 1, tracked_count, T).
|
||||
"""
|
||||
__slots__ = ("guide_start", "tracked_count", "noisy_mask", "tracked_mask")
|
||||
|
||||
def __init__(self, total_tokens, guide_start, tracked_count, tracked_weights):
|
||||
device = tracked_weights.device
|
||||
dtype = tracked_weights.dtype
|
||||
finfo = torch.finfo(dtype)
|
||||
|
||||
pos = tracked_weights > 0
|
||||
log_w = torch.full_like(tracked_weights, finfo.min)
|
||||
log_w[pos] = torch.log(tracked_weights[pos].clamp(min=finfo.tiny))
|
||||
|
||||
self.guide_start = guide_start
|
||||
self.tracked_count = tracked_count
|
||||
|
||||
self.noisy_mask = torch.zeros((1, 1, 1, total_tokens), device=device, dtype=dtype)
|
||||
self.noisy_mask[:, :, :, guide_start:guide_start + tracked_count] = log_w.view(1, 1, 1, -1)
|
||||
|
||||
self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype)
|
||||
self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||
|
||||
|
||||
def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options):
|
||||
"""Apply the guide mask by partitioning Q into noisy and tracked-guide
|
||||
groups, so each group needs only its own sub-mask. Avoids materializing
|
||||
the (1,1,T,T) dense mask.
|
||||
"""
|
||||
guide_start = guide_mask.guide_start
|
||||
tracked_end = guide_start + guide_mask.tracked_count
|
||||
|
||||
out = torch.empty_like(q)
|
||||
|
||||
if guide_start > 0: # In practice currently guides are always after noise, guard for safety if this changes.
|
||||
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
low_precision_attention=False, # sageattn mask support is unreliable
|
||||
)
|
||||
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
if tracked_end < q.shape[1]: # Every guide token is tracked, and nothing comes after them, guard for safety if this changes.
|
||||
out[:, tracked_end:, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, tracked_end:, :], k, v, heads,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -467,10 +412,8 @@ class CrossAttention(nn.Module):
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
elif isinstance(mask, GuideAttentionMask):
|
||||
out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
|
||||
# Apply per-head gating if enabled
|
||||
if self.to_gate_logits is not None:
|
||||
@ -1120,9 +1063,7 @@ class LTXVModel(LTXBaseModel):
|
||||
additional_args["resolved_guide_entries"] = resolved_entries
|
||||
|
||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||
|
||||
if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||
|
||||
# Total surviving guide tokens (all guides)
|
||||
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
|
||||
@ -1158,12 +1099,12 @@ class LTXVModel(LTXBaseModel):
|
||||
if not resolved_entries:
|
||||
return None
|
||||
|
||||
# strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention.
|
||||
needs_mask = any(
|
||||
e["strength"] != 1.0 or e.get("pixel_mask") is not None
|
||||
# Check if any attenuation is actually needed
|
||||
needs_attenuation = any(
|
||||
e["strength"] < 1.0 or e.get("pixel_mask") is not None
|
||||
for e in resolved_entries
|
||||
)
|
||||
if not needs_mask:
|
||||
if not needs_attenuation:
|
||||
return None
|
||||
|
||||
# Build per-guide-token weights for all tracked guide tokens.
|
||||
@ -1218,11 +1159,16 @@ class LTXVModel(LTXBaseModel):
|
||||
# Concatenate per-token weights for all tracked guides
|
||||
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
|
||||
|
||||
# Skip when every weight is exactly 1.0 (additive bias would be 0).
|
||||
if (tracked_weights == 1.0).all():
|
||||
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
|
||||
if (tracked_weights >= 1.0).all():
|
||||
return None
|
||||
|
||||
return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights)
|
||||
# Build the mask: guide tokens are at the end of the sequence.
|
||||
# Tracked guides come first (in order), untracked follow.
|
||||
return self._build_self_attention_mask(
|
||||
total_tokens, num_guide_tokens, total_tracked,
|
||||
tracked_weights, guide_start, device, dtype,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
|
||||
@ -1288,6 +1234,45 @@ class LTXVModel(LTXBaseModel):
|
||||
|
||||
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
|
||||
|
||||
@staticmethod
|
||||
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
|
||||
tracked_weights, guide_start, device, dtype):
|
||||
"""Build a log-space additive self-attention bias mask.
|
||||
|
||||
Attenuates attention between noisy tokens and tracked guide tokens.
|
||||
Untracked guide tokens (at the end of the guide portion) keep full attention.
|
||||
|
||||
Args:
|
||||
total_tokens: Total sequence length.
|
||||
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
|
||||
tracked_count: Number of tracked guide tokens (first in the guide portion).
|
||||
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
|
||||
guide_start: Index where guide tokens begin in the sequence.
|
||||
device: Target device.
|
||||
dtype: Target dtype.
|
||||
|
||||
Returns:
|
||||
(1, 1, total_tokens, total_tokens) additive bias mask.
|
||||
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
|
||||
"""
|
||||
finfo = torch.finfo(dtype)
|
||||
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
|
||||
tracked_end = guide_start + tracked_count
|
||||
|
||||
# Convert weights to log-space bias
|
||||
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
|
||||
log_w = torch.full_like(w, finfo.min)
|
||||
positive_mask = w > 0
|
||||
if positive_mask.any():
|
||||
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
|
||||
|
||||
# noisy → tracked guides: each noisy row gets the same per-guide weight
|
||||
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
|
||||
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
|
||||
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||
|
||||
return mask
|
||||
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
|
||||
"""Process transformer blocks for LTXV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
101
comfy_api_nodes/apis/bytedance_llm.py
Normal file
101
comfy_api_nodes/apis/bytedance_llm.py
Normal file
@ -0,0 +1,101 @@
|
||||
"""Pydantic models for BytePlus ModelArk Responses API.
|
||||
|
||||
See: https://docs.byteplus.com/en/docs/ModelArk/1585128 (request)
|
||||
https://docs.byteplus.com/en/docs/ModelArk/1783703 (response)
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BytePlusInputText(BaseModel):
|
||||
type: Literal["input_text"] = "input_text"
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class BytePlusInputImage(BaseModel):
|
||||
type: Literal["input_image"] = "input_image"
|
||||
image_url: str = Field(..., description="Image URL or `data:image/...;base64,...` payload")
|
||||
detail: str = Field("auto", description="One of high, low, auto")
|
||||
|
||||
|
||||
class BytePlusInputVideo(BaseModel):
|
||||
type: Literal["input_video"] = "input_video"
|
||||
video_url: str = Field(..., description="Video URL or `data:video/...;base64,...` payload")
|
||||
fps: float | None = Field(None, ge=0.2, le=5.0)
|
||||
|
||||
|
||||
BytePlusMessageContent = BytePlusInputText | BytePlusInputImage | BytePlusInputVideo
|
||||
|
||||
|
||||
class BytePlusInputMessage(BaseModel):
|
||||
type: Literal["message"] = "message"
|
||||
role: str = Field(..., description="One of user, system, assistant, developer")
|
||||
content: list[BytePlusMessageContent] = Field(...)
|
||||
|
||||
|
||||
class BytePlusResponseCreateRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: list[BytePlusInputMessage] = Field(...)
|
||||
instructions: str | None = Field(None)
|
||||
max_output_tokens: int | None = Field(None, ge=1)
|
||||
temperature: float | None = Field(None, ge=0.0, le=2.0)
|
||||
store: bool | None = Field(False)
|
||||
stream: bool | None = Field(False)
|
||||
|
||||
|
||||
class BytePlusOutputText(BaseModel):
|
||||
type: Literal["output_text"] = "output_text"
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class BytePlusOutputRefusal(BaseModel):
|
||||
type: Literal["refusal"] = "refusal"
|
||||
refusal: str = Field(...)
|
||||
|
||||
|
||||
class BytePlusOutputContent(BaseModel):
|
||||
type: str = Field(...)
|
||||
text: str | None = Field(None)
|
||||
refusal: str | None = Field(None)
|
||||
|
||||
|
||||
class BytePlusOutputMessage(BaseModel):
|
||||
type: str = Field(...)
|
||||
id: str | None = Field(None)
|
||||
role: str | None = Field(None)
|
||||
status: str | None = Field(None)
|
||||
content: list[BytePlusOutputContent] | None = Field(None)
|
||||
|
||||
|
||||
class BytePlusInputTokensDetails(BaseModel):
|
||||
cached_tokens: int | None = Field(None)
|
||||
|
||||
|
||||
class BytePlusOutputTokensDetails(BaseModel):
|
||||
reasoning_tokens: int | None = Field(None)
|
||||
|
||||
|
||||
class BytePlusResponseUsage(BaseModel):
|
||||
input_tokens: int | None = Field(None)
|
||||
output_tokens: int | None = Field(None)
|
||||
total_tokens: int | None = Field(None)
|
||||
input_tokens_details: BytePlusInputTokensDetails | None = Field(None)
|
||||
output_tokens_details: BytePlusOutputTokensDetails | None = Field(None)
|
||||
|
||||
|
||||
class BytePlusResponseError(BaseModel):
|
||||
code: str = Field(...)
|
||||
message: str = Field(...)
|
||||
|
||||
|
||||
class BytePlusResponseObject(BaseModel):
|
||||
id: str | None = Field(None)
|
||||
object: str | None = Field(None)
|
||||
created_at: int | None = Field(None)
|
||||
model: str | None = Field(None)
|
||||
status: str | None = Field(None)
|
||||
error: BytePlusResponseError | None = Field(None)
|
||||
output: list[BytePlusOutputMessage] | None = Field(None)
|
||||
usage: BytePlusResponseUsage | None = Field(None)
|
||||
271
comfy_api_nodes/nodes_bytedance_llm.py
Normal file
271
comfy_api_nodes/nodes_bytedance_llm.py
Normal file
@ -0,0 +1,271 @@
|
||||
"""API Nodes for ByteDance Seed LLM via the BytePlus ModelArk Responses API.
|
||||
|
||||
See: https://docs.byteplus.com/en/docs/ModelArk/1585128
|
||||
"""
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bytedance_llm import (
|
||||
BytePlusInputImage,
|
||||
BytePlusInputMessage,
|
||||
BytePlusInputText,
|
||||
BytePlusInputVideo,
|
||||
BytePlusMessageContent,
|
||||
BytePlusResponseCreateRequest,
|
||||
BytePlusResponseObject,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
get_number_of_images,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
BYTEPLUS_RESPONSES_ENDPOINT = "/proxy/byteplus/api/v3/responses"
|
||||
SEED_MAX_IMAGES = 20
|
||||
SEED_MAX_VIDEOS = 4
|
||||
|
||||
SEED_MODELS: dict[str, str] = {
|
||||
"Seed 2.0 Pro": "seed-2-0-pro-260328",
|
||||
"Seed 2.0 Lite": "seed-2-0-lite-260228",
|
||||
"Seed 2.0 Mini": "seed-2-0-mini-260215",
|
||||
}
|
||||
|
||||
# USD per 1M tokens: (input, cache_hit_input, output)
|
||||
_SEED_PRICES_PER_MILLION: dict[str, tuple[float, float, float]] = {
|
||||
"seed-2-0-pro-260328": (0.50, 0.10, 3.00),
|
||||
"seed-2-0-lite-260228": (0.25, 0.05, 2.00),
|
||||
"seed-2-0-mini-260215": (0.10, 0.02, 0.40),
|
||||
}
|
||||
|
||||
|
||||
def _seed_model_inputs(max_images: int = SEED_MAX_IMAGES, max_videos: int = SEED_MAX_VIDEOS):
|
||||
return [
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, max_images + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional image(s) to use as context for the model. Up to {max_images} images.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"videos",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Video.Input("video"),
|
||||
names=[f"video_{i}" for i in range(1, max_videos + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional video(s) to use as context for the model. Up to {max_videos} videos.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.01,
|
||||
tooltip="Controls randomness. 0.0 is deterministic, higher values are more random.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _calculate_price(model_id: str, response: BytePlusResponseObject) -> float | None:
|
||||
"""Compute approximate USD price from response usage."""
|
||||
if not response.usage:
|
||||
return None
|
||||
rates = _SEED_PRICES_PER_MILLION.get(model_id)
|
||||
if rates is None:
|
||||
return None
|
||||
input_rate, cache_hit_rate, output_rate = rates
|
||||
input_tokens = response.usage.input_tokens or 0
|
||||
output_tokens = response.usage.output_tokens or 0
|
||||
cached = 0
|
||||
if response.usage.input_tokens_details:
|
||||
cached = response.usage.input_tokens_details.cached_tokens or 0
|
||||
fresh_input = max(0, input_tokens - cached)
|
||||
total = fresh_input * input_rate + cached * cache_hit_rate + output_tokens * output_rate
|
||||
return total / 1_000_000.0
|
||||
|
||||
|
||||
def _get_text_from_response(response: BytePlusResponseObject) -> str:
|
||||
"""Extract concatenated text from all assistant message output_text blocks."""
|
||||
if not response.output:
|
||||
return ""
|
||||
chunks: list[str] = []
|
||||
for item in response.output:
|
||||
if item.type != "message" or not item.content:
|
||||
continue
|
||||
for block in item.content:
|
||||
if block.type == "output_text" and block.text:
|
||||
chunks.append(block.text)
|
||||
elif block.type == "refusal" and block.refusal:
|
||||
raise ValueError(f"Model refused to respond: {block.refusal}")
|
||||
return "\n".join(chunks)
|
||||
|
||||
|
||||
async def _build_image_content_blocks(
|
||||
cls: type[IO.ComfyNode],
|
||||
image_tensors: list[Input.Image],
|
||||
) -> list[BytePlusInputImage]:
|
||||
urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image_tensors,
|
||||
max_images=SEED_MAX_IMAGES,
|
||||
wait_label="Uploading reference images",
|
||||
)
|
||||
return [BytePlusInputImage(image_url=url) for url in urls]
|
||||
|
||||
|
||||
async def _build_video_content_blocks(
|
||||
cls: type[IO.ComfyNode],
|
||||
videos: list[Input.Video],
|
||||
) -> list[BytePlusInputVideo]:
|
||||
blocks: list[BytePlusInputVideo] = []
|
||||
total = len(videos)
|
||||
for idx, video in enumerate(videos):
|
||||
label = "Uploading reference video"
|
||||
if total > 1:
|
||||
label = f"{label} ({idx + 1}/{total})"
|
||||
url = await upload_video_to_comfyapi(cls, video, wait_label=label)
|
||||
blocks.append(BytePlusInputVideo(video_url=url))
|
||||
return blocks
|
||||
|
||||
|
||||
class ByteDanceSeedNode(IO.ComfyNode):
|
||||
"""Generate text responses from a ByteDance Seed 2.0 model."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedNode",
|
||||
display_name="ByteDance Seed",
|
||||
category="api node/text/ByteDance",
|
||||
essentials_category="Text Generation",
|
||||
description="Generate text responses with ByteDance's Seed 2.0 models. "
|
||||
"Provide a text prompt and optionally one or more images or videos for multimodal context.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text input to the model.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[IO.DynamicCombo.Option(label, _seed_model_inputs()) for label in SEED_MODELS],
|
||||
tooltip="The Seed model used to generate the response.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Foundational instructions that dictate the model's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.String.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$contains($m, "mini") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.00025, 0.0009],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "lite") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0003, 0.002],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "pro") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0005, 0.003],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: {"type":"text", "text":"Token-based"}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_label = model["model"]
|
||||
temperature = model["temperature"]
|
||||
model_id = SEED_MODELS[model_label]
|
||||
|
||||
image_tensors: list[Input.Image] = [t for t in (model.get("images") or {}).values() if t is not None]
|
||||
if sum(get_number_of_images(t) for t in image_tensors) > SEED_MAX_IMAGES:
|
||||
raise ValueError(f"Up to {SEED_MAX_IMAGES} images are supported per request.")
|
||||
|
||||
video_inputs: list[Input.Video] = [v for v in (model.get("videos") or {}).values() if v is not None]
|
||||
if len(video_inputs) > SEED_MAX_VIDEOS:
|
||||
raise ValueError(f"Up to {SEED_MAX_VIDEOS} videos are supported per request.")
|
||||
|
||||
content: list[BytePlusMessageContent] = []
|
||||
if image_tensors:
|
||||
content.extend(await _build_image_content_blocks(cls, image_tensors))
|
||||
if video_inputs:
|
||||
content.extend(await _build_video_content_blocks(cls, video_inputs))
|
||||
content.append(BytePlusInputText(text=prompt))
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_RESPONSES_ENDPOINT, method="POST"),
|
||||
response_model=BytePlusResponseObject,
|
||||
data=BytePlusResponseCreateRequest(
|
||||
model=model_id,
|
||||
input=[BytePlusInputMessage(role="user", content=content)],
|
||||
instructions=system_prompt or None,
|
||||
temperature=temperature,
|
||||
store=False,
|
||||
stream=False,
|
||||
),
|
||||
price_extractor=lambda r: _calculate_price(model_id, r),
|
||||
)
|
||||
if response.error:
|
||||
raise ValueError(f"Seed API error ({response.error.code}): {response.error.message}")
|
||||
result = _get_text_from_response(response)
|
||||
if not result:
|
||||
raise ValueError("Empty response from Seed model.")
|
||||
return IO.NodeOutput(result)
|
||||
|
||||
|
||||
class ByteDanceLLMExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [ByteDanceSeedNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ByteDanceLLMExtension:
|
||||
return ByteDanceLLMExtension()
|
||||
@ -219,7 +219,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
|
||||
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
|
||||
),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
@ -298,7 +298,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
else:
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
@ -318,7 +318,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, cond_length, 1, 1),
|
||||
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
|
||||
9
main.py
9
main.py
@ -27,7 +27,6 @@ from utils.mime_types import init_mime_types
|
||||
import faulthandler
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_api import feature_flags
|
||||
@ -149,14 +148,6 @@ def execute_prestartup_script():
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
||||
from nodes import record_node_startup_error
|
||||
record_node_startup_error(
|
||||
module_path=os.path.dirname(script_path),
|
||||
source="custom_nodes",
|
||||
phase="prestartup",
|
||||
error=e,
|
||||
tb=traceback.format_exc(),
|
||||
)
|
||||
return False
|
||||
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
|
||||
83
nodes.py
83
nodes.py
@ -2154,71 +2154,6 @@ EXTENSION_WEB_DIRS = {}
|
||||
# Dictionary of successfully loaded module names and associated directories.
|
||||
LOADED_MODULE_DIRS = {}
|
||||
|
||||
# Dictionary of custom node startup errors, keyed by "<source>:<module_name>"
|
||||
# so that name collisions across custom_nodes / comfy_extras / comfy_api_nodes
|
||||
# do not overwrite each other. Each value contains: source, module_name,
|
||||
# module_path, error, traceback, phase.
|
||||
#
|
||||
# `source` is the same string as the internal `module_parent` used at load
|
||||
# time (e.g. "custom_nodes", "comfy_extras", "comfy_api_nodes"). It is
|
||||
# intentionally a free-form string rather than a fixed enum so the contract
|
||||
# survives node-source layouts evolving (e.g. comfy_api_nodes eventually
|
||||
# moving out of core). Consumers should treat any new value as a new bucket
|
||||
# rather than rejecting it.
|
||||
NODE_STARTUP_ERRORS: dict[str, dict] = {}
|
||||
|
||||
|
||||
def _read_pyproject_metadata(module_path: str) -> dict | None:
|
||||
"""Best-effort extraction of node-pack identity from pyproject.toml.
|
||||
|
||||
Returns a dict with the Comfy Registry-style identity (pack_id,
|
||||
display_name, publisher_id, version, repository) when the module
|
||||
directory contains a pyproject.toml. Returns None when no toml is
|
||||
present or parsing fails for any reason — startup-error tracking
|
||||
must never itself raise.
|
||||
"""
|
||||
if not module_path or not os.path.isdir(module_path):
|
||||
return None
|
||||
toml_path = os.path.join(module_path, "pyproject.toml")
|
||||
if not os.path.isfile(toml_path):
|
||||
return None
|
||||
try:
|
||||
from comfy_config import config_parser
|
||||
|
||||
cfg = config_parser.extract_node_configuration(module_path)
|
||||
if cfg is None:
|
||||
return None
|
||||
meta = {
|
||||
"pack_id": cfg.project.name or None,
|
||||
"display_name": cfg.tool_comfy.display_name or None,
|
||||
"publisher_id": cfg.tool_comfy.publisher_id or None,
|
||||
"version": cfg.project.version or None,
|
||||
"repository": cfg.project.urls.repository or None,
|
||||
}
|
||||
# Drop empty fields so the API payload stays compact.
|
||||
return {k: v for k, v in meta.items() if v}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def record_node_startup_error(
|
||||
*, module_path: str, source: str, phase: str, error: BaseException, tb: str
|
||||
) -> None:
|
||||
"""Record a startup error for a node module so it can be exposed via the API."""
|
||||
module_name = get_module_name(module_path)
|
||||
entry = {
|
||||
"source": source,
|
||||
"module_name": module_name,
|
||||
"module_path": module_path,
|
||||
"error": str(error),
|
||||
"traceback": tb,
|
||||
"phase": phase,
|
||||
}
|
||||
pyproject = _read_pyproject_metadata(module_path)
|
||||
if pyproject:
|
||||
entry["pyproject"] = pyproject
|
||||
NODE_STARTUP_ERRORS[f"{source}:{module_name}"] = entry
|
||||
|
||||
|
||||
def get_module_name(module_path: str) -> str:
|
||||
"""
|
||||
@ -2328,30 +2263,14 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
|
||||
return True
|
||||
except Exception as e:
|
||||
tb = traceback.format_exc()
|
||||
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
|
||||
record_node_startup_error(
|
||||
module_path=module_path,
|
||||
source=module_parent,
|
||||
phase="entrypoint",
|
||||
error=e,
|
||||
tb=tb,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or comfy_entrypoint (need one).")
|
||||
return False
|
||||
except Exception as e:
|
||||
tb = traceback.format_exc()
|
||||
logging.warning(tb)
|
||||
logging.warning(traceback.format_exc())
|
||||
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
|
||||
record_node_startup_error(
|
||||
module_path=module_path,
|
||||
source=module_parent,
|
||||
phase="import",
|
||||
error=e,
|
||||
tb=tb,
|
||||
)
|
||||
return False
|
||||
|
||||
async def init_external_custom_nodes():
|
||||
|
||||
20
server.py
20
server.py
@ -765,26 +765,6 @@ class PromptServer():
|
||||
out[node_class] = node_info(node_class)
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/node_startup_errors")
|
||||
async def get_node_startup_errors(request):
|
||||
# Group errors by source so the frontend/Manager can render them
|
||||
# in distinct sections. `source` is the same string as the
|
||||
# module_parent used at load time (e.g. "custom_nodes",
|
||||
# "comfy_extras", "comfy_api_nodes") and is left as a free-form
|
||||
# string so the contract survives node-source layouts evolving.
|
||||
# The response only contains source buckets that actually had a
|
||||
# failure; consumers should not assume any particular set of keys
|
||||
# is always present.
|
||||
#
|
||||
# `module_path` is stripped because the absolute on-disk path is
|
||||
# internal detail that the frontend has no use for.
|
||||
grouped: dict[str, dict[str, dict]] = {}
|
||||
for entry in nodes.NODE_STARTUP_ERRORS.values():
|
||||
source = entry.get("source", "custom_nodes")
|
||||
public_entry = {k: v for k, v in entry.items() if k != "module_path"}
|
||||
grouped.setdefault(source, {})[entry["module_name"]] = public_entry
|
||||
return web.json_response(grouped)
|
||||
|
||||
@routes.get("/api/jobs")
|
||||
async def get_jobs(request):
|
||||
"""List all jobs with filtering, sorting, and pagination.
|
||||
|
||||
@ -1,146 +0,0 @@
|
||||
"""Tests for the custom node startup error tracking introduced for
|
||||
Comfy-Org/ComfyUI-Launcher#303.
|
||||
|
||||
Covers:
|
||||
- load_custom_node populates NODE_STARTUP_ERRORS with the correct source
|
||||
for each module_parent (custom_nodes / comfy_extras / comfy_api_nodes).
|
||||
- Composite keying prevents collisions between modules with the same name
|
||||
in different sources.
|
||||
- record_node_startup_error stores the expected fields.
|
||||
- pyproject.toml metadata is attached when present and omitted when absent.
|
||||
"""
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
|
||||
import nodes
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_startup_errors():
|
||||
nodes.NODE_STARTUP_ERRORS.clear()
|
||||
yield
|
||||
nodes.NODE_STARTUP_ERRORS.clear()
|
||||
|
||||
|
||||
def _write_broken_module(tmp_path, name: str) -> str:
|
||||
path = tmp_path / f"{name}.py"
|
||||
path.write_text(textwrap.dedent("""\
|
||||
# Deliberately broken module to exercise startup-error tracking.
|
||||
raise RuntimeError("boom from " + __name__)
|
||||
"""))
|
||||
return str(path)
|
||||
|
||||
|
||||
def test_record_node_startup_error_fields(tmp_path):
|
||||
err = ValueError("kaboom")
|
||||
nodes.record_node_startup_error(
|
||||
module_path=str(tmp_path / "my_pack"),
|
||||
source="custom_nodes",
|
||||
phase="import",
|
||||
error=err,
|
||||
tb="traceback-text",
|
||||
)
|
||||
assert "custom_nodes:my_pack" in nodes.NODE_STARTUP_ERRORS
|
||||
entry = nodes.NODE_STARTUP_ERRORS["custom_nodes:my_pack"]
|
||||
assert entry["source"] == "custom_nodes"
|
||||
assert entry["module_name"] == "my_pack"
|
||||
assert entry["phase"] == "import"
|
||||
assert entry["error"] == "kaboom"
|
||||
assert entry["traceback"] == "traceback-text"
|
||||
assert entry["module_path"].endswith("my_pack")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"module_parent",
|
||||
["custom_nodes", "comfy_extras", "comfy_api_nodes"],
|
||||
)
|
||||
async def test_load_custom_node_records_source(tmp_path, module_parent):
|
||||
# `source` in the entry should be the same string as `module_parent`.
|
||||
module_path = _write_broken_module(tmp_path, "broken_pack")
|
||||
|
||||
success = await nodes.load_custom_node(module_path, module_parent=module_parent)
|
||||
assert success is False
|
||||
|
||||
key = f"{module_parent}:broken_pack"
|
||||
assert key in nodes.NODE_STARTUP_ERRORS, nodes.NODE_STARTUP_ERRORS
|
||||
entry = nodes.NODE_STARTUP_ERRORS[key]
|
||||
assert entry["source"] == module_parent
|
||||
assert entry["module_name"] == "broken_pack"
|
||||
assert entry["phase"] == "import"
|
||||
assert "boom from" in entry["error"]
|
||||
assert "RuntimeError" in entry["traceback"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_node_collision_across_sources(tmp_path):
|
||||
# Same module name registered as both a custom node and a comfy_extra;
|
||||
# composite keying should keep both entries.
|
||||
cn_dir = tmp_path / "cn"
|
||||
extras_dir = tmp_path / "extras"
|
||||
cn_dir.mkdir()
|
||||
extras_dir.mkdir()
|
||||
cn_path = _write_broken_module(cn_dir, "nodes_audio")
|
||||
extras_path = _write_broken_module(extras_dir, "nodes_audio")
|
||||
|
||||
assert await nodes.load_custom_node(cn_path, module_parent="custom_nodes") is False
|
||||
assert await nodes.load_custom_node(extras_path, module_parent="comfy_extras") is False
|
||||
|
||||
assert "custom_nodes:nodes_audio" in nodes.NODE_STARTUP_ERRORS
|
||||
assert "comfy_extras:nodes_audio" in nodes.NODE_STARTUP_ERRORS
|
||||
assert (
|
||||
nodes.NODE_STARTUP_ERRORS["custom_nodes:nodes_audio"]["module_path"]
|
||||
!= nodes.NODE_STARTUP_ERRORS["comfy_extras:nodes_audio"]["module_path"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_node_attaches_pyproject_metadata(tmp_path):
|
||||
pack_dir = tmp_path / "MyCoolPack"
|
||||
pack_dir.mkdir()
|
||||
(pack_dir / "__init__.py").write_text("raise RuntimeError('boom')\n")
|
||||
(pack_dir / "pyproject.toml").write_text(textwrap.dedent("""\
|
||||
[project]
|
||||
name = "comfyui-mycoolpack"
|
||||
version = "1.2.3"
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://github.com/example/comfyui-mycoolpack"
|
||||
|
||||
[tool.comfy]
|
||||
PublisherId = "example"
|
||||
DisplayName = "My Cool Pack"
|
||||
"""))
|
||||
|
||||
success = await nodes.load_custom_node(str(pack_dir), module_parent="custom_nodes")
|
||||
assert success is False
|
||||
|
||||
entry = nodes.NODE_STARTUP_ERRORS["custom_nodes:MyCoolPack"]
|
||||
assert "pyproject" in entry, entry
|
||||
py = entry["pyproject"]
|
||||
assert py["pack_id"] == "comfyui-mycoolpack"
|
||||
assert py["display_name"] == "My Cool Pack"
|
||||
assert py["publisher_id"] == "example"
|
||||
assert py["version"] == "1.2.3"
|
||||
assert py["repository"] == "https://github.com/example/comfyui-mycoolpack"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_node_no_pyproject_skips_metadata(tmp_path):
|
||||
# Single-file extras-style module: no pyproject.toml exists alongside it,
|
||||
# so the entry must not contain a 'pyproject' key.
|
||||
module_path = _write_broken_module(tmp_path, "lonely")
|
||||
assert await nodes.load_custom_node(module_path, module_parent="comfy_extras") is False
|
||||
entry = nodes.NODE_STARTUP_ERRORS["comfy_extras:lonely"]
|
||||
assert "pyproject" not in entry
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_node_arbitrary_module_parent_passes_through(tmp_path):
|
||||
# `source` is a free-form string — an unknown module_parent (e.g. a future
|
||||
# node-source bucket) should be recorded as-is, not coerced or rejected.
|
||||
module_path = _write_broken_module(tmp_path, "future_pack")
|
||||
assert await nodes.load_custom_node(module_path, module_parent="future_source") is False
|
||||
entry = nodes.NODE_STARTUP_ERRORS["future_source:future_pack"]
|
||||
assert entry["source"] == "future_source"
|
||||
Reference in New Issue
Block a user