Compare commits

..

3 Commits

Author SHA1 Message Date
465d1a95da Merge branch 'master' into fix/validate-node-executable 2026-07-03 09:28:39 +08:00
bf00c39705 Don't instantiate nodes during validation
Addresses review feedback: the V1 executability check fell back to
constructing the node (class_def()) when the FUNCTION method wasn't found on
the class. That runs __init__ during validation, so a constructor's side
effects or failure could be misreported as invalid_node_definition for an
otherwise valid node.

Inspect only the class. No core/extra node defines its FUNCTION method on the
instance, so this loses no real coverage while removing the side-effect risk.

Replace the instance-fallback test with one asserting a node with a raising
__init__ but a valid class-level method still passes validation (i.e. it is
never instantiated).
2026-06-26 16:04:29 -07:00
82c954bd2a Validate that a node is executable before running the prompt
A node whose FUNCTION points at a method that does not exist (e.g. a typo in
a custom node), or a V3 node missing its execute override, was only detected
once that node ran -- after every upstream node had already executed. In a
multi-node workflow the user waited for the whole graph to run up to the
broken node before seeing the error.

validate_prompt already walks every node before execution; add an
executability check there so the error is reported up front and attributed
to the offending node (returned in node_errors), and nothing runs.

The check resolves the V1 FUNCTION method on the class (the common case) and
falls back to an instance, since the runtime invokes it on an instance and a
node may define FUNCTION or its method in __init__. V3 nodes are checked via
their existing VALIDATE_CLASS.

Add tests for V1 typo, V3 typo, good nodes, and a node whose method is
defined in __init__ (must not be falsely rejected).
2026-06-26 15:53:34 -07:00
23 changed files with 1292 additions and 1203 deletions

View File

@ -171,30 +171,16 @@
- Reuse existing model classes, blocks, ops, and helper modules when appropriate.
Before implementing a new version of a model component, search the existing
model code for a class or helper that already provides the behavior.
- Model detection code that inspects linear weight shapes should only use the
first dimension. The second dimension may be half the original size for
NVFP4 or other 4-bit quantized models.
- Avoid adding `einops` usage in core inference code. Use native torch tensor
ops such as `reshape`, `view`, `permute`, `transpose`, `flatten`, `unflatten`,
`unsqueeze`, and `squeeze` instead.
- Do not use tensors as general-purpose Python data structures. Keep metadata,
bookkeeping, counters, flags, shape math, padding math, index planning, memory
estimates, and control-flow decisions in plain Python values unless the data
must participate directly in tensor computation. Do not create tensors for
structural metadata that is only used for Python-side control flow. Sequence
lengths, cumulative offsets, split indices, window counts, slice boundaries,
and repeat counts should be kept as Python ints/lists from the point they are
computed. Do not build them as CPU/GPU tensors and then cast, move, validate,
or convert them back to Python for `split`, `tensor_split`, indexing plans,
loops, or cache keys. Avoid creating temporary tensors just to use tensor
methods for scalar or structural calculations.
must participate directly in tensor computation. Avoid creating temporary
tensors just to use tensor methods for scalar or structural calculations.
- Avoid unnecessary casts and transfers. Preserve the intended compute dtype,
storage dtype, bias dtype, and original tensor shape metadata.
- Keep model-native latent layout handling inside the model or latent-format
owner, not in helper nodes. Do not collapse, expand, pack, or unpack latent
dimensions in nodes or other caller-side adapters just to satisfy a model
forward; the model path should consume and return the native latent shape for
that model family.
- Assume inputs to the main model forward are already in the compute dtype by
default, except integer inputs such as some model timestep tensors. Do not add
defensive or convenience casts in model code; it is better for invalid dtype
@ -258,14 +244,6 @@
- Model implementations should add the minimal number of ComfyUI nodes required
to run the model. Reuse existing nodes as much as possible; adapting the model
to work with existing nodes is strongly preferred over creating new nodes.
- Nodes should output only values they own. Do not add pass-through outputs for
workflow convenience unless the node is explicitly an output node. Existing
models, latents, conditioning, or other inputs should flow directly to the
next consumer instead of being re-emitted unchanged.
- Nodes should expose only inputs they actually read to produce current
behavior. Do not add placeholder, pass-through, compatibility, or
workflow-shaping inputs that are ignored or could flow directly to another
node.
- Node-level code must not patch model code directly. Any node behavior that
modifies, wraps, hooks, or changes model behavior must go through the model
patcher class instead of reaching into model internals.

View File

@ -1 +0,0 @@
AGENTS.md

View File

@ -306,15 +306,12 @@ async def download_asset_content(request: web.Request) -> web.Response:
404, "FILE_NOT_FOUND", "Underlying file not found on disk."
)
# User-controlled asset content must never render inline in the app origin
# (stored XSS via SVG/HTML/XML). Force dangerous types to download and
# override any requested inline disposition. Centralised through
# folder_paths.is_dangerous_content_type so this can't drift from /view and
# /userdata (the previous inline set here omitted image/svg+xml and missed
# the charset/casing/+xml-dialect bypasses).
if folder_paths.is_dangerous_content_type(content_type):
_DANGEROUS_MIME_TYPES = {
"text/html", "text/html-sandboxed", "application/xhtml+xml",
"text/javascript", "text/css",
}
if content_type in _DANGEROUS_MIME_TYPES:
content_type = "application/octet-stream"
disposition = "attachment"
safe_name = (filename or "").replace("\r", "").replace("\n", "")
encoded = urllib.parse.quote(safe_name)

View File

@ -50,45 +50,21 @@ class ModelFileManager:
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
async def get_model_preview(request):
folder_name = request.match_info.get("folder", None)
path_index = int(request.match_info.get("path_index", None))
filename = request.match_info.get("filename", None)
if folder_name not in folder_paths.folder_names_and_paths:
return web.Response(status=404)
# The "{filename:.*}" capture also matches the empty string, which
# would resolve to the folder itself; reject it explicitly.
if not filename:
return web.Response(status=400)
try:
path_index = int(request.match_info.get("path_index", None))
except (TypeError, ValueError):
return web.Response(status=400)
folders = folder_paths.folder_names_and_paths[folder_name]
if path_index < 0 or path_index >= len(folders[0]):
return web.Response(status=404)
folder = folders[0][path_index]
full_filename = os.path.normpath(os.path.join(folder, filename))
# Prevent path traversal: the requested file must stay within the
# configured model folder. `filename` is an unrestricted ".*" capture,
# so values like "../../../../etc/passwd" would otherwise escape it.
if not folder_paths.is_within_directory(folder, full_filename):
return web.Response(status=403)
full_filename = os.path.join(folder, filename)
previews = self.get_model_previews(full_filename)
default_preview = previews[0] if len(previews) > 0 else None
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404)
# The preview is selected by a glob inside get_model_previews, so a
# companion file (e.g. "model.preview.png") could itself be a symlink
# resolving outside the model folder. Re-validate the file actually
# opened: is_within_directory realpaths it, catching symlink escape.
if isinstance(default_preview, str) and not folder_paths.is_within_directory(folder, default_preview):
return web.Response(status=403)
try:
with Image.open(default_preview) as img:
img_bytes = BytesIO()

View File

@ -6,7 +6,6 @@ import glob
import shutil
import logging
import tempfile
import mimetypes
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
@ -337,20 +336,7 @@ class UserManager():
if not isinstance(path, str):
return path
# User data files are arbitrary user-supplied content and are never
# meant to render inline. Disable MIME sniffing and force a download
# so uploaded markup/scripts can't execute in the app origin (stored
# XSS). Content-Disposition: attachment is the load-bearing guard;
# the content-type override and nosniff are defence in depth.
content_type = mimetypes.guess_type(path)[0] or 'application/octet-stream'
if folder_paths.is_dangerous_content_type(content_type):
content_type = 'application/octet-stream'
return web.FileResponse(path, headers={
"Content-Type": content_type,
"X-Content-Type-Options": "nosniff",
"Content-Disposition": "attachment",
})
return web.FileResponse(path)
@routes.post("/userdata/{file}")
async def post_userdata(request):

View File

@ -1,4 +1,4 @@
from typing import Any, Literal
from typing import Literal
from pydantic import BaseModel, Field
@ -316,36 +316,3 @@ VIDEO_TASKS_EXECUTION_TIME = {
"1080p": 150,
},
}
class SeedAudioConfig(BaseModel):
format: str = Field(default="mp3")
sample_rate: int = Field(default=24000)
speech_rate: int = Field(default=0)
loudness_rate: int = Field(default=0)
pitch_rate: int = Field(default=0)
class SeedAudioReference(BaseModel):
speaker: str | None = Field(default=None)
audio_data: str | None = Field(default=None)
audio_url: str | None = Field(default=None)
image_data: str | None = Field(default=None)
image_url: str | None = Field(default=None)
class SeedAudioRequest(BaseModel):
model: str = Field(default="seed-audio-1.0")
text_prompt: str = Field(...)
references: list[SeedAudioReference] | None = Field(default=None)
audio_config: SeedAudioConfig = Field(default_factory=SeedAudioConfig)
watermark: dict[str, Any] = Field(default_factory=dict)
class SeedAudioResponse(BaseModel):
audio: str | None = Field(default=None)
url: str | None = Field(default=None)
duration: float | None = Field(default=None)
original_duration: float | None = Field(default=None)
code: int | None = Field(default=None)
message: str | None = Field(default=None)

View File

@ -0,0 +1,147 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field, confloat
class StabilityFormat(str, Enum):
png = 'png'
jpeg = 'jpeg'
webp = 'webp'
class StabilityAspectRatio(str, Enum):
ratio_1_1 = "1:1"
ratio_16_9 = "16:9"
ratio_9_16 = "9:16"
ratio_3_2 = "3:2"
ratio_2_3 = "2:3"
ratio_5_4 = "5:4"
ratio_4_5 = "4:5"
ratio_21_9 = "21:9"
ratio_9_21 = "9:21"
def get_stability_style_presets(include_none=True):
presets = []
if include_none:
presets.append("None")
return presets + [x.value for x in StabilityStylePreset]
class StabilityStylePreset(str, Enum):
_3d_model = "3d-model"
analog_film = "analog-film"
anime = "anime"
cinematic = "cinematic"
comic_book = "comic-book"
digital_art = "digital-art"
enhance = "enhance"
fantasy_art = "fantasy-art"
isometric = "isometric"
line_art = "line-art"
low_poly = "low-poly"
modeling_compound = "modeling-compound"
neon_punk = "neon-punk"
origami = "origami"
photographic = "photographic"
pixel_art = "pixel-art"
tile_texture = "tile-texture"
class Stability_SD3_5_Model(str, Enum):
sd3_5_large = "sd3.5-large"
# sd3_5_large_turbo = "sd3.5-large-turbo"
sd3_5_medium = "sd3.5-medium"
class Stability_SD3_5_GenerationMode(str, Enum):
text_to_image = "text-to-image"
image_to_image = "image-to-image"
class StabilityStable3_5Request(BaseModel):
model: str = Field(...)
mode: str = Field(...)
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
aspect_ratio: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
output_format: Optional[str] = Field(StabilityFormat.png.value)
image: Optional[str] = Field(None)
style_preset: Optional[str] = Field(None)
cfg_scale: float = Field(...)
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
class StabilityUpscaleConservativeRequest(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
output_format: Optional[str] = Field(StabilityFormat.png.value)
image: Optional[str] = Field(None)
creativity: Optional[confloat(ge=0.2, le=0.5)] = Field(None)
class StabilityUpscaleCreativeRequest(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
output_format: Optional[str] = Field(StabilityFormat.png.value)
image: Optional[str] = Field(None)
creativity: Optional[confloat(ge=0.1, le=0.5)] = Field(None)
style_preset: Optional[str] = Field(None)
class StabilityStableUltraRequest(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
aspect_ratio: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
output_format: Optional[str] = Field(StabilityFormat.png.value)
image: Optional[str] = Field(None)
style_preset: Optional[str] = Field(None)
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
class StabilityStableUltraResponse(BaseModel):
image: Optional[str] = Field(None)
finish_reason: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class StabilityResultsGetResponse(BaseModel):
image: Optional[str] = Field(None)
finish_reason: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
id: Optional[str] = Field(None)
name: Optional[str] = Field(None)
errors: Optional[list[str]] = Field(None)
status: Optional[str] = Field(None)
result: Optional[str] = Field(None)
class StabilityAsyncResponse(BaseModel):
id: Optional[str] = Field(None)
class StabilityTextToAudioRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
duration: int = Field(190, ge=1, le=190)
seed: int = Field(0, ge=0, le=4294967294)
steps: int = Field(8, ge=4, le=8)
output_format: str = Field("wav")
class StabilityAudioToAudioRequest(StabilityTextToAudioRequest):
strength: float = Field(0.01, ge=0.01, le=1.0)
class StabilityAudioInpaintRequest(StabilityTextToAudioRequest):
mask_start: int = Field(30, ge=0, le=190)
mask_end: int = Field(190, ge=0, le=190)
class StabilityAudioResponse(BaseModel):
audio: Optional[str] = Field(None)

View File

@ -1,4 +1,3 @@
import base64
import hashlib
import logging
import math
@ -21,10 +20,6 @@ from comfy_api_nodes.apis.bytedance import (
GetAssetResponse,
Image2VideoTaskCreationRequest,
ImageTaskCreationResponse,
SeedAudioConfig,
SeedAudioReference,
SeedAudioRequest,
SeedAudioResponse,
Seedance2TaskCreationRequest,
SeedanceCreateAssetRequest,
SeedanceCreateAssetResponse,
@ -48,8 +43,6 @@ from comfy_api_nodes.apis.bytedance import (
)
from comfy_api_nodes.util import (
ApiEndpoint,
audio_bytes_to_audio_input,
audio_input_to_mp3,
download_url_to_image_tensor,
download_url_to_video_output,
downscale_image_tensor_by_max_side,
@ -58,14 +51,11 @@ from comfy_api_nodes.util import (
image_tensor_pair_to_batch,
poll_op,
sync_op,
tensor_to_base64_string,
upload_audio_to_comfyapi,
upload_image_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
upscale_image_tensor_to_min_pixels,
upscale_video_to_min_pixels,
validate_audio_duration,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_string,
@ -2484,311 +2474,6 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode):
return IO.NodeOutput(asset_id, resolved_group)
MODE_TEXT = "text only"
MODE_AUDIO = "audio reference"
MODE_IMAGE = "image reference"
MODE_SPEAKER = "preset voice"
# (speaker_id, display_label) for built-in TTS 2.0 voices; resolvable ids are account-scoped.
SEED_AUDIO_PRESET_VOICES: list[tuple[str, str]] = [
("zh_female_vv_uranus_bigtts", "Vivi (Female, multilingual)"),
("zh_female_xiaohe_uranus_bigtts", "Mindy (Female, multilingual)"),
("en_female_stokie_uranus_bigtts", "Stokie (Female, English)"),
("en_female_dacey_uranus_bigtts", "Dacey (Female, English)"),
("en_male_tim_uranus_bigtts", "Tim (Male, English)"),
("zh_male_m191_uranus_bigtts", "Kian (Male, multilingual)"),
("zh_male_taocheng_uranus_bigtts", "Cedric (Male, multilingual)"),
("zh_male_sophie_uranus_bigtts", "Sophie (Female, multilingual)"),
("zh_female_yingyujiaoxue_uranus_bigtts", "Jean (Female, multilingual)"),
("zh_male_dayi_uranus_bigtts", "Magnus (Male, multilingual)"),
("zh_female_mizai_uranus_bigtts", "Mabel (Female, multilingual)"),
("zh_female_jitangnv_uranus_bigtts", "Nadia (Female, multilingual)"),
("zh_female_meilinvyou_uranus_bigtts", "Opal (Female, multilingual)"),
("zh_female_liuchangnv_uranus_bigtts", "Pearl (Female, multilingual)"),
("zh_male_ruyayichen_uranus_bigtts", "Quentin (Male, multilingual)"),
("zh_female_vivo_uranus_bigtts", "Vienna (Female, multilingual)"),
("zh_female_xiaoai_uranus_bigtts", "Alina (Female, multilingual)"),
("zh_female_cancan_uranus_bigtts", "Corinne (Female, multilingual)"),
("zh_female_tianmeixiaoyuan_uranus_bigtts", "Esther (Female, multilingual)"),
("zh_female_tianmeitaozi_uranus_bigtts", "Freya (Female, multilingual)"),
("zh_female_shuangkuaisisi_uranus_bigtts", "Gigi (Female, multilingual)"),
("zh_female_peiqi_uranus_bigtts", "Holly (Female, multilingual)"),
("zh_female_xiaoxue_uranus_bigtts", "Lyla (Female, multilingual)"),
("zh_female_yuanqi_uranus_bigtts", "Daisy (Female, multilingual)"),
("zh_female_kefunvsheng_uranus_bigtts", "Tracy (Female, multilingual)"),
("zh_male_shaonianzixin_uranus_bigtts", "Jess (Male, multilingual)"),
("zh_female_linjianvhai_uranus_bigtts", "Pinky (Female, multilingual)"),
("zh_female_kiwi_uranus_bigtts", "Sweety (Female, multilingual)"),
("zh_female_sajiaoxuemei_uranus_bigtts", "Sandy (Female, multilingual)"),
("de_male_seven_uranus_bigtts", "Sven (Male, German)"),
("jp_female_minimi_uranus_bigtts", "Minimi (Female, Japanese)"),
("fr_male_usseau_uranus_bigtts", "Usseau (Male, French)"),
("es_male_felipe_uranus_bigtts", "Felipe (Male, Spanish)"),
("id_male_han_uranus_bigtts", "Han (Male, Indonesian)"),
("pt_male_martins_uranus_bigtts", "Martins (Male, Portuguese)"),
("it_male_enzo_uranus_bigtts", "Enzo (Male, Italian)"),
("kr_male_shane_uranus_bigtts", "Shane (Male, Korean)"),
("zh_male_liufei_uranus_bigtts", "Felix (Male, Chinese)"),
("zh_female_qingxinnvsheng_uranus_bigtts", "Celeste (Female, Chinese)"),
("zh_male_sunwukong_uranus_bigtts", "Monkey King (Male, Chinese)"),
]
SEED_AUDIO_VOICE_OPTIONS = [label for _, label in SEED_AUDIO_PRESET_VOICES]
SEED_AUDIO_VOICE_MAP = {label: speaker_id for speaker_id, label in SEED_AUDIO_PRESET_VOICES}
_AUDIO_TAG_RE = re.compile(r"@Audio(\d+)", re.IGNORECASE)
def max_audio_tag(prompt: str) -> int:
"""Highest N referenced as @AudioN in the prompt (0 if none)."""
nums = [int(m) for m in _AUDIO_TAG_RE.findall(prompt or "")]
return max(nums) if nums else 0
def connected_audio_indices(reference_mode: dict) -> list[int]:
"""Indices (1-based) of connected reference_audio sockets, in order."""
return [
i
for i in range(1, 3 + 1)
if reference_mode.get(f"reference_audio_{i}") is not None
]
def validate_seed_audio_inputs(
text_prompt: str,
mode: str,
audio_indices: list[int],
has_image: bool,
preset_voice: str | None = None,
) -> None:
validate_string(text_prompt, field_name="text_prompt", min_length=1, max_length=3000)
max_tag = max_audio_tag(text_prompt)
if mode == MODE_TEXT:
if max_tag:
raise ValueError(
f"The prompt references @Audio{max_tag}, but reference mode is '{MODE_TEXT}'. "
f"Switch to '{MODE_AUDIO}' and connect the reference clip(s)."
)
elif mode == MODE_AUDIO:
if not audio_indices:
raise ValueError(
f"Reference mode '{MODE_AUDIO}' requires at least one reference_audio input "
f"(or switch to '{MODE_TEXT}')."
)
if audio_indices != list(range(1, len(audio_indices) + 1)):
raise ValueError(
"Connect reference_audio inputs in order without gaps: reference_audio_1, then _2, then _3."
)
if max_tag > len(audio_indices):
raise ValueError(
f"The prompt references @Audio{max_tag}, but only {len(audio_indices)} "
f"reference audio(s) are connected."
)
elif mode == MODE_IMAGE:
if not has_image:
raise ValueError(f"Reference mode '{MODE_IMAGE}' requires a reference_image input.")
if max_tag:
raise ValueError(
f"@AudioN tags are not used in '{MODE_IMAGE}' mode; the prompt should contain "
f"only the text to synthesize."
)
elif mode == MODE_SPEAKER:
if not preset_voice or preset_voice not in SEED_AUDIO_VOICE_MAP:
raise ValueError(f"Reference mode '{MODE_SPEAKER}' requires selecting a preset voice.")
if max_tag > 1:
raise ValueError(
f"'{MODE_SPEAKER}' mode uses a single voice, so @Audio{max_tag} is out of range. "
f"Remove the @AudioN tags — the whole prompt is read in the selected voice."
)
else:
raise ValueError(f"Unknown reference mode: {mode!r}")
class ByteDanceSeedAudioNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="ByteDanceSeedAudio",
display_name="ByteDance Seed Audio 1.0",
category="partner/audio/ByteDance",
description=(
"Generate speech, music, sound effects and multi-speaker dialogue from a single prompt "
"with ByteDance Seed Audio 1.0. Describe the voice(s), emotion, ambience, background music "
"and sound effects in the prompt, and include the lines to speak. Optionally pick a built-in "
"preset voice, clone voices from up to 3 reference clips (tagged @Audio1-3 in the prompt), "
"or derive a voice from a character image. Up to 2 minutes of audio per run."
),
inputs=[
IO.String.Input(
"text_prompt",
multiline=True,
default="",
tooltip=(
"Describe the voice(s), emotion, pacing, ambience, background music and sound "
"effects, and include the lines to speak (name characters inline for dialogue). "
"In 'audio reference' mode, refer to connected clips by order as @Audio1, @Audio2, "
"@Audio3. Maximum 3000 characters."
),
),
IO.DynamicCombo.Input(
"reference_mode",
options=[
IO.DynamicCombo.Option(MODE_TEXT, []),
IO.DynamicCombo.Option(
MODE_AUDIO,
[
IO.Audio.Input(
"reference_audio_1",
optional=True,
tooltip="Reference clip for voice cloning, tagged @Audio1 in the prompt. "
"Up to 30s.",
),
IO.Audio.Input(
"reference_audio_2",
optional=True,
tooltip="Reference clip tagged @Audio2 in the prompt. Up to 30s.",
),
IO.Audio.Input(
"reference_audio_3",
optional=True,
tooltip="Reference clip tagged @Audio3 in the prompt. Up to 30s.",
),
],
),
IO.DynamicCombo.Option(
MODE_IMAGE,
[
IO.Image.Input(
"reference_image",
optional=True,
tooltip="A single character image; the model derives a voice from it. "
"Cannot be combined with reference audio.",
),
],
),
IO.DynamicCombo.Option(
MODE_SPEAKER,
[
IO.Combo.Input(
"preset_voice",
options=SEED_AUDIO_VOICE_OPTIONS,
default=SEED_AUDIO_VOICE_OPTIONS[0],
tooltip="A built-in TTS 2.0 voice that reads the prompt. No reference "
"clip needed, and @AudioN tags are not used in this mode.",
),
],
),
],
tooltip=(
"How to condition the voice: 'text only' (describe everything in the prompt), "
"'audio reference' (clone up to 3 voices, tagged @Audio1-3), 'image reference' "
"(derive a voice from one character image), or 'preset voice' (pick a built-in "
"named voice that reads the prompt)."
),
),
IO.Combo.Input(
"sample_rate",
options=["8000", "16000", "24000", "32000", "44100", "48000"],
default="24000",
tooltip="Output sample rate in Hz.",
),
IO.Int.Input(
"speech_rate",
default=0,
min=-50,
max=100,
tooltip="Speaking speed. 0 = normal, 100 = 2.0x, -50 = 0.5x.",
),
IO.Int.Input(
"loudness_rate",
default=0,
min=-50,
max=100,
tooltip="Loudness. 0 = normal, 100 = 2.0x, -50 = 0.5x.",
),
IO.Int.Input(
"pitch_rate",
default=0,
min=-12,
max=12,
tooltip="Pitch shift in semitones (-12 to 12).",
),
IO.Int.Input(
"seed",
default=42,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Audio.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd": 0.2145, "format":{"suffix":"/minute","approximate":true}}""",
),
)
@classmethod
async def execute(
cls,
text_prompt: str,
reference_mode: dict,
sample_rate: str,
speech_rate: int,
loudness_rate: int,
pitch_rate: int,
seed: int,
) -> IO.NodeOutput:
mode = reference_mode["reference_mode"]
audio_indices = connected_audio_indices(reference_mode)
image = reference_mode.get("reference_image")
preset_voice = reference_mode.get("preset_voice")
validate_seed_audio_inputs(text_prompt, mode, audio_indices, image is not None, preset_voice)
references: list[SeedAudioReference] | None = None
if mode == MODE_AUDIO:
references = []
for i in audio_indices:
clip = reference_mode[f"reference_audio_{i}"]
validate_audio_duration(clip, max_duration=30.0)
mp3_bytes = audio_input_to_mp3(clip).getvalue()
references.append(SeedAudioReference(audio_data=base64.b64encode(mp3_bytes).decode("utf-8")))
elif mode == MODE_IMAGE:
image = upscale_image_tensor_to_min_pixels(image, 160_000)
references = [SeedAudioReference(image_data=tensor_to_base64_string(image, mime_type="image/png"))]
elif mode == MODE_SPEAKER:
references = [SeedAudioReference(speaker=SEED_AUDIO_VOICE_MAP[preset_voice])]
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/byteplus/api/v3/tts/create", method="POST"),
response_model=SeedAudioResponse,
data=SeedAudioRequest(
text_prompt=text_prompt,
references=references,
audio_config=SeedAudioConfig(
sample_rate=int(sample_rate),
speech_rate=speech_rate,
loudness_rate=loudness_rate,
pitch_rate=pitch_rate,
),
),
)
if not response.audio:
raise Exception(
f"Seed Audio returned no audio (code={response.code}): {response.message}"
)
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response.audio)))
class ByteDanceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -2805,7 +2490,6 @@ class ByteDanceExtension(ComfyExtension):
ByteDance2ReferenceNode,
ByteDanceCreateImageAsset,
ByteDanceCreateVideoAsset,
ByteDanceSeedAudioNode,
]

View File

@ -0,0 +1,932 @@
from inspect import cleandoc
from typing import Optional
from typing_extensions import override
from comfy_api.latest import ComfyExtension, Input, IO
from comfy_api_nodes.apis.stability import (
StabilityUpscaleConservativeRequest,
StabilityUpscaleCreativeRequest,
StabilityAsyncResponse,
StabilityResultsGetResponse,
StabilityStable3_5Request,
StabilityStableUltraRequest,
StabilityStableUltraResponse,
StabilityAspectRatio,
Stability_SD3_5_Model,
Stability_SD3_5_GenerationMode,
get_stability_style_presets,
StabilityTextToAudioRequest,
StabilityAudioToAudioRequest,
StabilityAudioInpaintRequest,
StabilityAudioResponse,
)
from comfy_api_nodes.util import (
validate_audio_duration,
validate_string,
audio_input_to_mp3,
bytesio_to_image_tensor,
tensor_to_bytesio,
audio_bytes_to_audio_input,
sync_op,
poll_op,
ApiEndpoint,
)
import torch
import base64
from io import BytesIO
from enum import Enum
class StabilityPollStatus(str, Enum):
finished = "finished"
in_progress = "in_progress"
failed = "failed"
def get_async_dummy_status(x: StabilityResultsGetResponse):
if x.name is not None or x.errors is not None:
return StabilityPollStatus.failed
elif x.finish_reason is not None:
return StabilityPollStatus.finished
return StabilityPollStatus.in_progress
class StabilityStableImageUltraNode(IO.ComfyNode):
"""
Generates images synchronously based on prompt and resolution.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityStableImageUltraNode",
display_name="Stability AI Stable Image Ultra",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
"elements, colors, and subjects will lead to better results. " +
"To control the weight of a given word use the format `(word:weight)`," +
"where `word` is the word you'd like to control the weight of and `weight`" +
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
"would convey a sky that was blue and green, but more green than blue.",
),
IO.Combo.Input(
"aspect_ratio",
options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.",
),
IO.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
advanced=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.Image.Input(
"image",
optional=True,
),
IO.String.Input(
"negative_prompt",
default="",
tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
advanced=True,
),
IO.Float.Input(
"image_denoise",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.08}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
aspect_ratio: str,
style_preset: str,
seed: int,
image: Optional[torch.Tensor] = None,
negative_prompt: str = "",
image_denoise: Optional[float] = 0.5,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present
image_binary = None
if image is not None:
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
else:
image_denoise = None
if not negative_prompt:
negative_prompt = None
if style_preset == "None":
style_preset = None
files = {
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityStableUltraRequest(
prompt=prompt,
negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio,
seed=seed,
strength=image_denoise,
style_preset=style_preset,
),
files=files,
content_type="multipart/form-data",
)
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return IO.NodeOutput(returned_image)
class StabilityStableImageSD_3_5Node(IO.ComfyNode):
"""
Generates images synchronously based on prompt and resolution.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityStableImageSD_3_5Node",
display_name="Stability AI Stable Diffusion 3.5 Image",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
IO.Combo.Input(
"model",
options=Stability_SD3_5_Model,
),
IO.Combo.Input(
"aspect_ratio",
options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.",
),
IO.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
advanced=True,
),
IO.Float.Input(
"cfg_scale",
default=4.0,
min=1.0,
max=10.0,
step=0.1,
tooltip="How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.Image.Input(
"image",
optional=True,
),
IO.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
advanced=True,
),
IO.Float.Input(
"image_denoise",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$contains(widgets.model,"large")
? {"type":"usd","usd":0.065}
: {"type":"usd","usd":0.035}
)
""",
),
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
aspect_ratio: str,
style_preset: str,
seed: int,
cfg_scale: float,
image: Optional[torch.Tensor] = None,
negative_prompt: str = "",
image_denoise: Optional[float] = 0.5,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present
image_binary = None
mode = Stability_SD3_5_GenerationMode.text_to_image
if image is not None:
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
mode = Stability_SD3_5_GenerationMode.image_to_image
aspect_ratio = None
else:
image_denoise = None
if not negative_prompt:
negative_prompt = None
if style_preset == "None":
style_preset = None
files = {
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityStable3_5Request(
prompt=prompt,
negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio,
seed=seed,
strength=image_denoise,
style_preset=style_preset,
cfg_scale=cfg_scale,
model=model,
mode=mode,
),
files=files,
content_type="multipart/form-data",
)
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return IO.NodeOutput(returned_image)
class StabilityUpscaleConservativeNode(IO.ComfyNode):
"""
Upscale image with minimal alterations to 4K resolution.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityUpscaleConservativeNode",
display_name="Stability AI Upscale Conservative",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
IO.Float.Input(
"creativity",
default=0.35,
min=0.2,
max=0.5,
step=0.01,
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
advanced=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
prompt: str,
creativity: float,
seed: int,
negative_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
if not negative_prompt:
negative_prompt = None
files = {
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityUpscaleConservativeRequest(
prompt=prompt,
negative_prompt=negative_prompt,
creativity=round(creativity,2),
seed=seed,
),
files=files,
content_type="multipart/form-data",
)
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return IO.NodeOutput(returned_image)
class StabilityUpscaleCreativeNode(IO.ComfyNode):
"""
Upscale image with minimal alterations to 4K resolution.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityUpscaleCreativeNode",
display_name="Stability AI Upscale Creative",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
IO.Float.Input(
"creativity",
default=0.3,
min=0.1,
max=0.5,
step=0.01,
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
),
IO.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
advanced=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
advanced=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.6}""",
),
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
prompt: str,
creativity: float,
style_preset: str,
seed: int,
negative_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
if not negative_prompt:
negative_prompt = None
if style_preset == "None":
style_preset = None
files = {
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
response_model=StabilityAsyncResponse,
data=StabilityUpscaleCreativeRequest(
prompt=prompt,
negative_prompt=negative_prompt,
creativity=round(creativity,2),
style_preset=style_preset,
seed=seed,
),
files=files,
content_type="multipart/form-data",
)
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
response_model=StabilityResultsGetResponse,
poll_interval=3,
status_extractor=lambda x: get_async_dummy_status(x),
)
if response_poll.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
image_data = base64.b64decode(response_poll.result)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return IO.NodeOutput(returned_image)
class StabilityUpscaleFastNode(IO.ComfyNode):
"""
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityUpscaleFastNode",
display_name="Stability AI Upscale Fast",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.02}""",
),
)
@classmethod
async def execute(cls, image: torch.Tensor) -> IO.NodeOutput:
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
files = {
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
response_model=StabilityStableUltraResponse,
files=files,
content_type="multipart/form-data",
)
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return IO.NodeOutput(returned_image)
class StabilityTextToAudio(IO.ComfyNode):
"""Generates high-quality music and sound effects from text descriptions."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityTextToAudio",
display_name="Stability AI Text To Audio",
category="partner/audio/Stability AI",
essentials_category="Audio",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
IO.String.Input("prompt", multiline=True, default=""),
IO.Int.Input(
"duration",
default=190,
min=1,
max=190,
step=1,
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
IO.Int.Input(
"steps",
default=8,
min=4,
max=8,
step=1,
tooltip="Controls the number of sampling steps.",
optional=True,
advanced=True,
),
],
outputs=[
IO.Audio.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.2}""",
),
)
@classmethod
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
validate_string(prompt, max_length=10000)
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
)
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityAudioToAudio(IO.ComfyNode):
"""Transforms existing audio samples into new high-quality compositions using text instructions."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityAudioToAudio",
display_name="Stability AI Audio To Audio",
category="partner/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
IO.String.Input("prompt", multiline=True, default=""),
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
IO.Int.Input(
"duration",
default=190,
min=1,
max=190,
step=1,
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
IO.Int.Input(
"steps",
default=8,
min=4,
max=8,
step=1,
tooltip="Controls the number of sampling steps.",
optional=True,
advanced=True,
),
IO.Float.Input(
"strength",
default=1,
min=0.01,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Parameter controls how much influence the audio parameter has on the generated audio.",
optional=True,
),
],
outputs=[
IO.Audio.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.2}""",
),
)
@classmethod
async def execute(
cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float
) -> IO.NodeOutput:
validate_string(prompt, max_length=10000)
validate_audio_duration(audio, 6, 190)
payload = StabilityAudioToAudioRequest(
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
)
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
)
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityAudioInpaint(IO.ComfyNode):
"""Transforms part of existing audio sample using text instructions."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityAudioInpaint",
display_name="Stability AI Audio Inpaint",
category="partner/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
IO.String.Input("prompt", multiline=True, default=""),
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
IO.Int.Input(
"duration",
default=190,
min=1,
max=190,
step=1,
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
IO.Int.Input(
"steps",
default=8,
min=4,
max=8,
step=1,
tooltip="Controls the number of sampling steps.",
optional=True,
advanced=True,
),
IO.Int.Input(
"mask_start",
default=30,
min=0,
max=190,
step=1,
optional=True,
advanced=True,
),
IO.Int.Input(
"mask_end",
default=190,
min=0,
max=190,
step=1,
optional=True,
advanced=True,
),
],
outputs=[
IO.Audio.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.2}""",
),
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
audio: Input.Audio,
duration: int,
seed: int,
steps: int,
mask_start: int,
mask_end: int,
) -> IO.NodeOutput:
validate_string(prompt, max_length=10000)
if mask_end <= mask_start:
raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})")
validate_audio_duration(audio, 6, 190)
payload = StabilityAudioInpaintRequest(
prompt=prompt,
model=model,
duration=duration,
seed=seed,
steps=steps,
mask_start=mask_start,
mask_end=mask_end,
)
response_api = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
)
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
StabilityStableImageUltraNode,
StabilityStableImageSD_3_5Node,
StabilityUpscaleConservativeNode,
StabilityUpscaleCreativeNode,
StabilityUpscaleFastNode,
StabilityTextToAudio,
StabilityAudioToAudio,
StabilityAudioInpaint,
]
async def comfy_entrypoint() -> StabilityExtension:
return StabilityExtension()

View File

@ -26,7 +26,6 @@ from .conversions import (
text_filepath_to_base64_string,
text_filepath_to_data_uri,
trim_video,
upscale_image_tensor_to_min_pixels,
upscale_video_to_min_pixels,
video_to_base64_string,
)
@ -100,7 +99,6 @@ __all__ = [
"text_filepath_to_base64_string",
"text_filepath_to_data_uri",
"trim_video",
"upscale_image_tensor_to_min_pixels",
"upscale_video_to_min_pixels",
"video_to_base64_string",
# Validation utilities

View File

@ -448,15 +448,6 @@ def _compute_upscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[in
return new_w, new_h
def upscale_image_tensor_to_min_pixels(image: torch.Tensor, total_pixels: int) -> torch.Tensor:
samples = image.movedim(-1, 1)
dims = _compute_upscale_dims(samples.shape[3], samples.shape[2], int(total_pixels))
if dims is None:
return image
new_w, new_h = dims
return common_upscale(samples, new_w, new_h, "lanczos", "disabled").movedim(1, -1)
def upscale_video_to_min_pixels(video: Input.Video, min_pixels: int) -> Input.Video:
"""Upscale a video to meet at least ``min_pixels`` (w * h), preserving aspect ratio.

View File

@ -1113,6 +1113,32 @@ def full_type_name(klass):
return klass.__qualname__
return module + '.' + klass.__qualname__
def node_not_executable_reason(class_def, class_type):
"""Return a human-readable reason the node cannot be executed, or None if it's fine.
Catches a node whose declared entry point doesn't resolve to a real method
(e.g. a V1 ``FUNCTION = "invert"`` where the method is misspelled, or a V3 node
missing its ``execute`` override). Running this during validation surfaces the
problem before execution starts, instead of after upstream nodes have run.
Only the class is inspected; the node is never instantiated here, so a node's
``__init__`` side effects cannot run (or fail) during validation.
"""
try:
if issubclass(class_def, _ComfyNodeInternal):
# V3: validates that execute()/define_schema() overrides exist.
class_def.VALIDATE_CLASS()
return None
# V1: FUNCTION names the method to call; it must exist on the class.
function_name = getattr(class_def, "FUNCTION", None)
if function_name is None:
return f"'{class_type}' does not define FUNCTION"
if not callable(getattr(class_def, function_name, None)):
return f"'{class_type}' has no method '{function_name}' (declared in FUNCTION)"
return None
except Exception as ex:
return str(ex)
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
outputs = set()
for x in prompt:
@ -1148,6 +1174,35 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[
}
return (False, error, [], {})
# Make sure the node is actually executable (its FUNCTION/execute entry
# point resolves to a real method) before we touch any schema-derived
# attributes below or start execution. Catches code typos up front and
# attributes the error to the offending node.
not_executable = node_not_executable_reason(class_, class_type)
if not_executable is not None:
node_title = prompt[x].get('_meta', {}).get('title', class_type)
error = {
"type": "invalid_node_definition",
"message": "Node is not executable",
"details": f"{not_executable} (Node ID '#{x}')",
"extra_info": {
"node_id": x,
"class_type": class_type,
"node_title": node_title,
}
}
node_errors = {x: {
"errors": [{
"type": "invalid_node_definition",
"message": "Node is not executable",
"details": not_executable,
"extra_info": {},
}],
"dependent_outputs": [],
"class_type": class_type,
}}
return (False, error, [], node_errors)
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
if partial_execution_list is None or x in partial_execution_list:
outputs.add(x)

View File

@ -264,59 +264,6 @@ def annotated_filepath(name: str) -> tuple[str, str | None]:
return name, base_dir
# Content types a browser may execute or render inline. File endpoints that
# serve user-controlled content must force these to download (and ideally set
# Content-Disposition: attachment) to avoid stored XSS. Centralised here so the
# /view and /userdata handlers can't drift apart. mimetypes.guess_type may
# return either the text/* or application/* spelling depending on platform, so
# both are listed.
DANGEROUS_CONTENT_TYPES = {
'text/html', 'text/html-sandboxed', 'application/xhtml+xml',
'text/javascript', 'application/javascript', 'application/x-javascript',
'application/ecmascript', 'text/css',
'image/svg+xml', 'application/xml', 'text/xml',
# message/rfc822 (.mht/.mhtml) can carry script in some browsers.
'message/rfc822',
}
def is_dangerous_content_type(content_type: str | None) -> bool:
"""Return True if a browser may execute or render `content_type` inline.
Normalises before matching so the check can't be slipped past with a
charset/boundary parameter (``text/html; charset=utf-8``) or casing
(``TEXT/HTML``). Any XML dialect (``*+xml`` or ``*/xml``) is treated as
dangerous because XML can carry inline script via stylesheet/entity tricks,
which also covers the ``application/{xslt,rss,atom,rdf}+xml`` family without
enumerating each one. Endpoints serving user-controlled content should route
a dangerous type to ``application/octet-stream`` + ``Content-Disposition:
attachment`` + ``X-Content-Type-Options: nosniff``.
"""
if not content_type:
return False
normalized = content_type.split(';', 1)[0].strip().lower()
if normalized in DANGEROUS_CONTENT_TYPES:
return True
return normalized.endswith('+xml') or normalized.endswith('/xml')
def is_within_directory(directory: str, target: str) -> bool:
"""Return True if `target` resolves to a path inside `directory`.
Uses realpath on both operands so that a symlink placed inside `directory`
that points elsewhere cannot escape the containment check at open time.
"""
try:
directory = os.path.realpath(directory)
target = os.path.realpath(target)
return os.path.commonpath((directory, target)) == directory
except ValueError:
# ValueError is raised by realpath() on a path with an embedded null
# byte, and by commonpath() on Windows when the paths are on different
# drives. In either case the target is not safely within the directory.
return False
def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
name, base_dir = annotated_filepath(name)
@ -326,12 +273,7 @@ def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
else:
base_dir = get_input_directory() # fallback path
filepath = os.path.abspath(os.path.join(base_dir, name))
# Prevent path traversal: the resolved path must stay within base_dir.
# repr() the name in the message so a crafted value can't inject log lines.
if not is_within_directory(base_dir, filepath):
raise ValueError("Invalid file path: {!r}".format(name))
return filepath
return os.path.join(base_dir, name)
def exists_annotated_filepath(name) -> bool:
@ -340,10 +282,7 @@ def exists_annotated_filepath(name) -> bool:
if base_dir is None:
base_dir = get_input_directory() # fallback path
filepath = os.path.abspath(os.path.join(base_dir, name))
# Treat traversal attempts as non-existent rather than probing the filesystem.
if not is_within_directory(base_dir, filepath):
return False
filepath = os.path.join(base_dir, name)
return os.path.exists(filepath)

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.45.20
comfyui-workflow-templates==0.11.2
comfyui-workflow-templates==0.11.1
comfyui-embedded-docs==0.5.6
torch
torchsde

View File

@ -127,7 +127,6 @@ def create_cors_middleware(allowed_origin: str):
return cors_middleware
def is_loopback(host):
if host is None:
return False
@ -617,30 +616,15 @@ class PromptServer():
or 'application/octet-stream'
)
# For security, force renderable/active types (HTML, JS,
# CSS, SVG, XML — anything that can carry inline <script>
# and execute in the page origin) to download instead of
# displaying inline, preventing stored XSS. The
# attachment disposition is the load-bearing guard: a
# bare filename= hint does not force a download per
# RFC 6266, so we only attach it on the dangerous branch
# to avoid breaking inline display of legitimate images.
# Escape backslash/quote per RFC 6266 quoted-string so a
# filename containing a double quote (which passes the
# ".."/leading-slash filter above) can't break out of the
# header's quoted-string and malform the disposition.
safe_filename = filename.replace("\\", "\\\\").replace('"', '\\"')
disposition = f"filename=\"{safe_filename}\""
if folder_paths.is_dangerous_content_type(content_type):
content_type = 'application/octet-stream'
disposition = f"attachment; filename=\"{safe_filename}\""
# For security, force certain mimetypes to download instead of display
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
content_type = 'application/octet-stream' # Forces download
return web.FileResponse(
file,
headers={
"Content-Disposition": disposition,
"Content-Type": content_type,
"X-Content-Type-Options": "nosniff"
"Content-Disposition": f"filename=\"{filename}\"",
"Content-Type": content_type
}
)

View File

@ -1,5 +1,3 @@
import contextlib
import json
import time
import uuid
from datetime import datetime
@ -11,40 +9,6 @@ import requests
from helpers import get_asset_filename, trigger_sync_seed_assets
def test_download_svg_forced_to_attachment(http: requests.Session, api_base: str):
"""GHSA-779p-m5rp-r4h4 CISA-5 (sibling route): an uploaded SVG must never be
served inline from GET /api/assets/{id}/content, or an inline <script> runs
in the app origin (stored XSS). Even with disposition=inline requested, a
dangerous content type must be forced to application/octet-stream +
Content-Disposition: attachment + nosniff. Regression guard for the stale
inline blocklist that previously omitted image/svg+xml and ignored the
centralized folder_paths.is_dangerous_content_type check.
"""
svg = b'<svg xmlns="http://www.w3.org/2000/svg"><script>alert(1)</script></svg>'
files = {"file": ("evil.svg", svg, "image/svg+xml")}
form_data = {
"tags": json.dumps(["models", "checkpoints", "unit-tests", "svgxss"]),
"name": "evil.svg",
}
up = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
body = up.json()
assert up.status_code in (200, 201), body
aid = body["id"]
try:
r = http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline", timeout=120)
r.content
assert r.status_code == 200
ct = r.headers.get("Content-Type", "").lower()
cd = r.headers.get("Content-Disposition", "").lower()
assert "svg" not in ct, f"SVG served with a renderable content type: {ct!r}"
assert ct.startswith("application/octet-stream"), f"expected octet-stream, got {ct!r}"
assert "attachment" in cd, f"inline disposition not overridden to attachment: {cd!r}"
assert r.headers.get("X-Content-Type-Options", "").lower() == "nosniff"
finally:
with contextlib.suppress(Exception):
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]

View File

@ -53,11 +53,8 @@ def test_annotated_filepath():
def test_get_annotated_filepath():
default_dir = "/default/dir"
# get_annotated_filepath now normalizes with os.path.abspath (part of the
# GHSA-779p traversal hardening), so compare against the normalized form —
# on Windows abspath also prepends the current drive letter.
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.abspath(os.path.join(default_dir, "test.txt"))
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.abspath(os.path.join(folder_paths.get_output_directory(), "test.txt"))
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
def test_add_model_folder_path_append(clear_folder_paths):
folder_paths.add_model_folder_path("test_folder", "/default/path", is_default=True)

View File

@ -0,0 +1,137 @@
"""Tests for pre-execution validation that a node is actually executable.
validate_prompt rejects a node whose declared entry point does not resolve to a
real method (a V1 FUNCTION typo, or a V3 node missing its execute override) before
any node runs, attributing the error to the offending node.
"""
import asyncio
import nodes
from comfy_api.latest import io
from execution import node_not_executable_reason, validate_prompt
class _GoodV1Node:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "run"
OUTPUT_NODE = True
CATEGORY = "Test"
def run(self):
return (None,)
class _TypoV1Node:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "invert" # method below is misspelled
OUTPUT_NODE = True
CATEGORY = "Test"
def invvert(self):
return (None,)
class _SideEffectInitV1Node:
"""Valid class-level method, but a constructor that must never run in validation."""
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "run"
OUTPUT_NODE = True
CATEGORY = "Test"
def __init__(self):
raise RuntimeError("__init__ must not run during validation")
def run(self):
return (None,)
def _v3_schema(node_id):
return io.Schema(
node_id=node_id,
display_name=node_id,
category="Test",
inputs=[],
outputs=[io.Image.Output()],
is_output_node=True,
)
class _GoodV3Node(io.ComfyNode):
@classmethod
def define_schema(cls):
return _v3_schema("GoodV3Node")
@classmethod
def execute(cls):
return io.NodeOutput(None)
class _TypoV3Node(io.ComfyNode):
@classmethod
def define_schema(cls):
return _v3_schema("TypoV3Node")
@classmethod
def exicute(cls): # typo: should be "execute"
return io.NodeOutput(None)
def _register(class_type, class_def):
nodes.NODE_CLASS_MAPPINGS[class_type] = class_def
def _validate(class_type):
prompt = {"1": {"class_type": class_type, "inputs": {}}}
return asyncio.run(validate_prompt("pid", prompt, None))
def test_good_node_passes():
_register("GoodV1Node", _GoodV1Node)
assert node_not_executable_reason(_GoodV1Node, "GoodV1Node") is None
valid, _, _, _ = _validate("GoodV1Node")
assert valid is True
def test_typo_node_rejected_with_node_error():
_register("TypoV1Node", _TypoV1Node)
valid, error, _, node_errors = _validate("TypoV1Node")
assert valid is False
assert error["type"] == "invalid_node_definition"
assert node_errors["1"]["class_type"] == "TypoV1Node"
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"
assert "invert" in node_errors["1"]["errors"][0]["details"]
def test_validation_does_not_instantiate_node():
"""A valid node is not constructed during validation, so __init__ never runs."""
_register("SideEffectInitV1Node", _SideEffectInitV1Node)
assert node_not_executable_reason(_SideEffectInitV1Node, "SideEffectInitV1Node") is None
valid, _, _, _ = _validate("SideEffectInitV1Node")
assert valid is True
def test_good_v3_node_passes():
_register("GoodV3Node", _GoodV3Node)
assert node_not_executable_reason(_GoodV3Node, "GoodV3Node") is None
valid, _, _, _ = _validate("GoodV3Node")
assert valid is True
def test_typo_v3_node_rejected_with_node_error():
_register("TypoV3Node", _TypoV3Node)
valid, error, _, node_errors = _validate("TypoV3Node")
assert valid is False
assert error["type"] == "invalid_node_definition"
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"

View File

@ -1,192 +0,0 @@
"""CI unit tests for FIX #2 of GHSA-779p-m5rp-r4h4.
Path traversal / hardening in app/model_manager.py get_model_preview
(route /experiment/models/preview/{folder}/{path_index}/{filename:.*}).
Reference: https://github.com/Comfy-Org/ComfyUI/security/advisories/GHSA-779p-m5rp-r4h4
"""
import pytest
import yarl
from io import BytesIO
from PIL import Image
from aiohttp import web
from unittest.mock import patch
from app.model_manager import ModelFileManager
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def model_manager():
return ModelFileManager()
@pytest.fixture
def app(model_manager):
app = web.Application()
routes = web.RouteTableDef()
model_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_legit_preview_returns_200(aiohttp_client, app, tmp_path):
"""Sanity: a real preview PNG inside the model folder is served as webp 200."""
img = Image.new('RGB', (16, 16), color=(255, 0, 128))
img.save(tmp_path / "test_model.png", format='PNG')
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/0/test_model.png')
assert response.status == 200
assert response.content_type == 'image/webp'
img_bytes = BytesIO(await response.read())
served = Image.open(img_bytes)
assert served.format
assert served.format.lower() == 'webp'
served.close()
async def test_non_integer_path_index_returns_400(aiohttp_client, app, tmp_path):
"""A non-integer path_index segment must be rejected with 400."""
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/abc/test_model.png')
assert response.status == 400
async def test_out_of_range_path_index_returns_404(aiohttp_client, app, tmp_path):
"""A path_index beyond the configured folder list must return 404."""
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/99/test_model.png')
assert response.status == 404
async def test_empty_filename_returns_400(aiohttp_client, app, tmp_path):
"""The "{filename:.*}" capture also matches the empty string (trailing
slash). It would resolve to the folder itself and must be rejected with 400."""
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/0/')
assert response.status == 400
async def test_path_traversal_in_filename_returns_403(aiohttp_client, app, tmp_path):
"""Path traversal in {filename} must be rejected with 403 and must NOT read
a file outside the configured model directory.
GOTCHA: aiohttp/yarl collapses literal ``../`` dot-segments out of the URL
path before it reaches the handler, which would make this test vacuously
pass (the request would hit a different/non-existent route). We percent-encode
the dots and slashes (``%2e%2e%2f``) and send the URL with
``yarl.URL(..., encoded=True)`` so the bytes survive client-side normalization
untouched; aiohttp's router then percent-decodes them into ``match_info``,
delivering the literal ``../`` traversal to the handler's ``{filename:.*}``
capture.
Without the fix the handler computes
``os.path.normpath(os.path.join(folder, "../../../../etc/hosts"))``, which
escapes ``tmp_path`` and would be passed straight to get_model_previews ->
Image.open, serving bytes from outside the model dir (200/served bytes). The
is_within_directory() containment check is the load-bearing fix that turns
that escape into a 403.
"""
# Sanity-anchor: a legit preview exists inside tmp_path, so a 200 path is
# genuinely reachable — proving the 403 below is the containment check
# firing, not an unrelated 404.
img = Image.new('RGB', (16, 16), color=(255, 0, 128))
img.save(tmp_path / "test_model.png", format='PNG')
# Percent-encoded "../../../../etc/hosts" so yarl does not collapse the
# dot-segments before the request leaves the client.
encoded_traversal = '%2e%2e%2f' * 4 + 'etc%2fhosts'
raw_path = '/experiment/models/preview/test_folder/0/' + encoded_traversal
url = yarl.URL(raw_path, encoded=True)
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get(url)
# Confirm the traversal actually reached the handler intact: a 200 here
# would mean either normalization stripped the ``../`` (vacuous pass) or
# the containment check failed open and served outside-dir bytes.
assert response.status == 403, (
f"expected 403 from is_within_directory() containment check, "
f"got {response.status}; traversal may have been normalized away "
f"or the fix failed open"
)
body = await response.read()
assert body == b"", "403 response must not carry any file bytes"
async def test_symlink_companion_preview_returns_403(aiohttp_client, app, tmp_path):
"""A companion preview file is selected by a glob inside get_model_previews
and then opened. If that companion is a symlink whose path is in-dir but
whose target escapes the model folder, it must be rejected with 403 — not
served. The requested path itself stays in-dir (so the first containment
check passes); the load-bearing fix is the SECOND is_within_directory check
on the file actually opened.
"""
model_dir = tmp_path / "models"
model_dir.mkdir()
secret_dir = tmp_path / "secret"
secret_dir.mkdir()
# A real image OUTSIDE the model dir — valid, so without the fix Image.open
# would succeed and its bytes would be served (200).
secret = secret_dir / "secret.png"
Image.new('RGB', (8, 8), color=(0, 0, 0)).save(secret, format='PNG')
# Companion preview, in-dir by name but a symlink escaping the model dir.
# (No real model file is needed — get_model_previews globs companions by
# basename, and omitting a .safetensors avoids the metadata-header read.)
companion = model_dir / "model.preview.png"
try:
companion.symlink_to(secret)
except (OSError, NotImplementedError):
pytest.skip("symlinks not supported on this platform/filesystem")
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(model_dir)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/0/model.safetensors')
assert response.status == 403, (
f"expected 403 — the globbed companion preview is a symlink resolving "
f"outside the model dir and must not be served; got {response.status}"
)
assert await response.read() == b""
async def test_null_byte_in_filename_no_500(aiohttp_client, app, tmp_path):
"""A NUL byte in the filename must yield a clean client rejection, not a 500
from an uncaught ValueError in is_within_directory's realpath() call."""
raw_path = '/experiment/models/preview/test_folder/0/' + 'a%00b'
url = yarl.URL(raw_path, encoded=True)
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get(url)
assert response.status != 500, (
f"NUL byte produced a 500 (uncaught ValueError); expected a clean "
f"4xx rejection, got {response.status}"
)
assert 400 <= response.status < 500

View File

@ -1,165 +0,0 @@
"""Security tests for GHSA-779p-m5rp-r4h4 — FIX #3.
Path traversal in folder_paths.get_annotated_filepath / exists_annotated_filepath,
plus the shared is_within_directory() containment helper.
These are pure-function tests (no running server). The input/output/temp
directories are pointed at tmp_path via the folder_paths setters, so a crafted
name containing `../`, an absolute path, or a symlink that escapes the base
directory must be rejected.
Reference: https://github.com/Comfy-Org/ComfyUI/security/advisories/GHSA-779p-m5rp-r4h4
"""
import os
import pytest
import folder_paths
from comfy.options import enable_args_parsing
enable_args_parsing()
@pytest.fixture
def sandbox(tmp_path):
"""Point folder_paths' input/output/temp dirs at a real temp sandbox.
Yields the realpath'd base, input, output and temp directories. The original
directory values are restored afterward so tests stay isolated.
"""
base = os.path.realpath(str(tmp_path))
input_dir = os.path.join(base, "input")
output_dir = os.path.join(base, "output")
temp_dir = os.path.join(base, "temp")
for d in (input_dir, output_dir, temp_dir):
os.makedirs(d, exist_ok=True)
orig_input = folder_paths.get_input_directory()
orig_output = folder_paths.get_output_directory()
orig_temp = folder_paths.get_temp_directory()
folder_paths.set_input_directory(input_dir)
folder_paths.set_output_directory(output_dir)
folder_paths.set_temp_directory(temp_dir)
yield {
"base": base,
"input": input_dir,
"output": output_dir,
"temp": temp_dir,
}
folder_paths.set_input_directory(orig_input)
folder_paths.set_output_directory(orig_output)
folder_paths.set_temp_directory(orig_temp)
# ---------------------------------------------------------------------------
# is_within_directory() — the shared containment helper
# ---------------------------------------------------------------------------
def test_is_within_directory_legit_child(sandbox):
base = sandbox["input"]
child = os.path.join(base, "sub", "image.png")
assert folder_paths.is_within_directory(base, child) is True
def test_is_within_directory_dotdot_escape(sandbox):
base = sandbox["input"]
escape = os.path.join(base, "..", "..", "etc", "passwd")
assert folder_paths.is_within_directory(base, escape) is False
def test_is_within_directory_symlink_escape(sandbox):
"""A symlink created INSIDE base that points OUTSIDE base must not pass.
This is the key new hardening: is_within_directory realpath()s both operands,
so a symlink planted in the base directory can't be used to read files
elsewhere. We create a real on-disk symlink and a real secret target to
verify the check actually resolves the link.
"""
base = sandbox["input"]
# A directory living outside the base, holding a secret file.
outside = os.path.join(sandbox["base"], "outside_secret_dir")
os.makedirs(outside, exist_ok=True)
secret = os.path.join(outside, "secret.txt")
with open(secret, "w") as f:
f.write("top secret")
# Plant a symlink inside base that points at the outside directory.
# symlink creation can require elevated privileges / Developer Mode on
# Windows, so skip cleanly where it isn't available (same guard as the
# sibling test in test_ghsa_779p_02_preview_traversal.py).
link = os.path.join(base, "escape_link")
try:
os.symlink(outside, link)
except (OSError, NotImplementedError):
pytest.skip("symlinks not supported on this platform/filesystem")
# Accessing the secret "through" the in-base symlink must be rejected.
target_via_link = os.path.join(link, "secret.txt")
assert folder_paths.is_within_directory(base, target_via_link) is False
# ---------------------------------------------------------------------------
# get_annotated_filepath()
# ---------------------------------------------------------------------------
def test_get_annotated_filepath_legit_name(sandbox):
result = folder_paths.get_annotated_filepath("image.png")
assert result == os.path.join(sandbox["input"], "image.png")
assert folder_paths.is_within_directory(sandbox["input"], result)
def test_get_annotated_filepath_input_annotation(sandbox):
result = folder_paths.get_annotated_filepath("image.png [input]")
assert result == os.path.join(sandbox["input"], "image.png")
def test_get_annotated_filepath_output_annotation(sandbox):
result = folder_paths.get_annotated_filepath("image.png [output]")
assert result == os.path.join(sandbox["output"], "image.png")
def test_get_annotated_filepath_temp_annotation(sandbox):
result = folder_paths.get_annotated_filepath("image.png [temp]")
assert result == os.path.join(sandbox["temp"], "image.png")
def test_get_annotated_filepath_dotdot_raises(sandbox):
with pytest.raises(ValueError):
folder_paths.get_annotated_filepath("../etc/passwd")
def test_get_annotated_filepath_dotdot_with_annotation_raises(sandbox):
with pytest.raises(ValueError):
folder_paths.get_annotated_filepath("../../etc/passwd [output]")
def test_get_annotated_filepath_absolute_escape_raises(sandbox):
with pytest.raises(ValueError):
folder_paths.get_annotated_filepath("/etc/passwd")
# ---------------------------------------------------------------------------
# exists_annotated_filepath()
# ---------------------------------------------------------------------------
def test_exists_annotated_filepath_existing_legit_file(sandbox):
real = os.path.join(sandbox["input"], "real.png")
with open(real, "w") as f:
f.write("data")
assert folder_paths.exists_annotated_filepath("real.png") is True
def test_exists_annotated_filepath_traversal_returns_false(sandbox):
"""A traversal name must return False without raising and without probing
outside the base directory (must never reach os.path.exists for the escape).
"""
# /etc/passwd exists on POSIX; the function must still report False because
# the resolved path escapes the input directory.
assert folder_paths.exists_annotated_filepath("../../../../../../etc/passwd") is False
def test_exists_annotated_filepath_absolute_returns_false(sandbox):
assert folder_paths.exists_annotated_filepath("/etc/passwd") is False

View File

@ -1,147 +0,0 @@
"""
CI unit tests for FIX #4 of GHSA-779p-m5rp-r4h4.
Stored-XSS hardening on GET /userdata/{file} in app/user_manager.py.
User data files are arbitrary user-supplied content and must never render
inline in the app origin. The getuserdata handler:
- forces Content-Type to application/octet-stream for any type in
folder_paths.DANGEROUS_CONTENT_TYPES (text/html, image/svg+xml,
text/javascript, ...),
- sets X-Content-Type-Options: nosniff,
- sets Content-Disposition: attachment.
These tests pre-create files in tmp_path and GET them back, asserting the
secure response headers. They mirror the aiohttp_client pattern in
tests-unit/prompt_server_test/user_manager_test.py.
"""
import pytest
import os
from aiohttp import web
from app.user_manager import UserManager
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def user_manager(tmp_path):
um = UserManager()
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
tmp_path, file
) if file else tmp_path
return um
@pytest.fixture
def app(user_manager):
app = web.Application()
routes = web.RouteTableDef()
user_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_html_served_as_octet_stream(aiohttp_client, app, tmp_path):
(tmp_path / "evil.html").write_text(
"<script>console.log('xss-marker-ghsa-779p')</script>"
)
client = await aiohttp_client(app)
resp = await client.get("/userdata/evil.html")
assert resp.status == 200
ct = resp.headers.get("Content-Type", "")
# The load-bearing assertion: a .html file must NOT be served as text/html.
assert "text/html" not in ct.lower(), (
f"Content-Type {ct!r} would let a browser render/execute the file (stored XSS)."
)
assert ct == "application/octet-stream"
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
assert "attachment" in resp.headers.get("Content-Disposition", "")
async def test_svg_served_as_octet_stream(aiohttp_client, app, tmp_path):
(tmp_path / "evil.svg").write_text(
'<?xml version="1.0"?>'
'<svg xmlns="http://www.w3.org/2000/svg">'
'<script>console.log("xss-marker-ghsa-779p")</script>'
"</svg>"
)
client = await aiohttp_client(app)
resp = await client.get("/userdata/evil.svg")
assert resp.status == 200
ct = resp.headers.get("Content-Type", "")
# SVG can carry inline <script>; it must not be served as image/svg+xml.
assert "svg" not in ct.lower(), (
f"Content-Type {ct!r} would let a browser render the SVG and execute embedded scripts."
)
assert ct == "application/octet-stream"
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
assert "attachment" in resp.headers.get("Content-Disposition", "")
async def test_js_served_as_octet_stream(aiohttp_client, app, tmp_path):
(tmp_path / "evil.js").write_text("alert('xss-marker-ghsa-779p')")
client = await aiohttp_client(app)
resp = await client.get("/userdata/evil.js")
assert resp.status == 200
ct = resp.headers.get("Content-Type", "").lower()
# Must not be served as any executable JavaScript content type.
assert "javascript" not in ct, (
f"Content-Type {ct!r} is an executable JS type."
)
assert "ecmascript" not in ct, (
f"Content-Type {ct!r} is an executable JS type."
)
assert ct == "application/octet-stream"
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
assert "attachment" in resp.headers.get("Content-Disposition", "")
async def test_xml_dialect_served_as_octet_stream(aiohttp_client, app, tmp_path):
"""An XML dialect outside the original blocklist (.xslt -> application/xslt+xml)
must still be forced to download. This pins the normalised *+xml family rule
in folder_paths.is_dangerous_content_type(); a plain set-membership test would
have served this inline."""
(tmp_path / "evil.xslt").write_text(
'<?xml version="1.0"?>'
'<xsl:stylesheet version="1.0" '
'xmlns:xsl="http://www.w3.org/1999/XSL/Transform">'
"<!-- xss-marker-ghsa-779p -->"
"</xsl:stylesheet>"
)
client = await aiohttp_client(app)
resp = await client.get("/userdata/evil.xslt")
assert resp.status == 200
ct = resp.headers.get("Content-Type", "")
assert ct == "application/octet-stream", (
f"Content-Type {ct!r}: an *+xml dialect must be forced to octet-stream "
f"(it can carry inline script via stylesheet/entity tricks)."
)
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
assert "attachment" in resp.headers.get("Content-Disposition", "")
async def test_benign_txt_still_served(aiohttp_client, app, tmp_path):
(tmp_path / "note.txt").write_text("just a harmless note")
client = await aiohttp_client(app)
resp = await client.get("/userdata/note.txt")
assert resp.status == 200
assert await resp.text() == "just a harmless note"
ct = resp.headers.get("Content-Type", "")
# text/plain is not in the dangerous set, so it is acceptable here. The
# defence-in-depth headers must still be present regardless.
assert "text/plain" in ct.lower()
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
assert "attachment" in resp.headers.get("Content-Disposition", "")

View File

@ -1,138 +0,0 @@
"""CI unit guard for FIX #5 of GHSA-779p-m5rp-r4h4 — the /view forced-download set.
Vuln #5 was stored XSS via SVG upload: the /view endpoint's Content-Type
blocklist covered text/html, text/javascript, etc. but was missing
image/svg+xml, so an uploaded SVG carrying an inline <script> was served as
image/svg+xml and executed in the page origin when rendered.
The /view forced-download decision lives in the view_image closure registered by
server.PromptServer.add_routes (server.py ~line 596), which calls
`folder_paths.is_dangerous_content_type(content_type)` — a normalising check that
strips charset/boundary parameters and casing and folds in the whole */xml and
*+xml dialect family — rather than a bypassable raw
`content_type in folder_paths.DANGEROUS_CONTENT_TYPES` membership test. On a match
it rewrites the response to application/octet-stream with a
Content-Disposition: attachment header. server.py cannot be imported in a unit
test (importing it spins up the full PromptServer/aiohttp app and its global side
effects), so these tests pin the underlying dangerous-content data
(folder_paths.DANGEROUS_CONTENT_TYPES) and the normalising is_dangerous_content_type()
helper that the closure actually calls.
The end-to-end /view assertion (upload an SVG, GET /view, confirm the response
is not served as image/svg+xml) lives in the live POC at
.security/pocs/test_security_ghsa_779p.py::TestViewSvgContentType, which
requires a running server. This file is the fast, server-free CI guard on the
set contents so the blocklist can't silently regress.
"""
import folder_paths
# Active/renderable content types that must be forced to download. Each of these
# can carry an inline <script> (or otherwise execute) in the page origin if a
# browser renders it. image/svg+xml is the original missing item that caused
# vuln #5.
DANGEROUS = [
'image/svg+xml',
'application/xml',
'text/xml',
'text/html',
'text/html-sandboxed',
'application/xhtml+xml',
'text/javascript',
'application/javascript',
'application/x-javascript',
'application/ecmascript',
'text/css',
]
# Benign image types that browsers display inline and that must keep rendering;
# forcing these to download would break legitimate previews.
BENIGN_INLINE_IMAGES = [
'image/png',
'image/jpeg',
'image/webp',
'image/gif',
]
def test_dangerous_content_types_is_a_set():
assert isinstance(folder_paths.DANGEROUS_CONTENT_TYPES, set)
def test_svg_is_in_the_blocklist():
"""The specific item whose absence caused vuln #5."""
assert 'image/svg+xml' in folder_paths.DANGEROUS_CONTENT_TYPES, (
"image/svg+xml missing from DANGEROUS_CONTENT_TYPES — this is exactly "
"the regression that reopens GHSA-779p-m5rp-r4h4 vuln #5 (stored XSS "
"via SVG upload on /view)."
)
def test_all_dangerous_types_present():
missing = [ct for ct in DANGEROUS if ct not in folder_paths.DANGEROUS_CONTENT_TYPES]
assert not missing, (
f"DANGEROUS_CONTENT_TYPES is missing required active/renderable types: "
f"{missing}. The /view closure only forces a download for content types "
f"in this set; anything missing here is served inline and can execute."
)
def test_benign_inline_image_types_absent():
leaked = [ct for ct in BENIGN_INLINE_IMAGES if ct in folder_paths.DANGEROUS_CONTENT_TYPES]
assert not leaked, (
f"Benign inline-displayable image types found in DANGEROUS_CONTENT_TYPES: "
f"{leaked}. Forcing these to download would break legitimate image "
f"previews in /view — they must keep rendering inline."
)
# ---------------------------------------------------------------------------
# is_dangerous_content_type() — the normalising check the /view and /userdata
# handlers now call instead of a raw `in DANGEROUS_CONTENT_TYPES` membership
# test. An exact-string membership test was bypassable with a charset parameter
# or odd casing, and missed the wider XML dialect family; these tests pin the
# normalisation so that bypass can't reopen.
# ---------------------------------------------------------------------------
def test_function_matches_plain_dangerous_types():
for ct in DANGEROUS:
assert folder_paths.is_dangerous_content_type(ct) is True, ct
def test_function_strips_parameters_and_casing():
"""A charset/boundary parameter or casing must not slip a type past the check.
This is the bypass surfaced by review: the /view blake3 branch can serve an
attacker-controlled, unvalidated asset mime_type like 'text/html; charset=utf-8',
which an exact-string set test missed.
"""
for ct in (
'text/html; charset=utf-8',
'TEXT/HTML',
'Text/HTML; charset=UTF-8',
'image/svg+xml; charset=utf-8',
' text/html ',
):
assert folder_paths.is_dangerous_content_type(ct) is True, ct
def test_function_covers_xml_dialect_family():
"""Any *+xml / */xml dialect is dangerous without enumerating each one."""
for ct in (
'application/xslt+xml',
'application/rss+xml',
'application/atom+xml',
'application/rdf+xml',
'application/mathml+xml',
'message/rfc822',
):
assert folder_paths.is_dangerous_content_type(ct) is True, ct
def test_function_allows_benign_and_empty():
for ct in BENIGN_INLINE_IMAGES + ['application/octet-stream', 'text/plain']:
assert folder_paths.is_dangerous_content_type(ct) is False, ct
# None / empty (mimetypes.guess_type miss) must not be treated as dangerous.
assert folder_paths.is_dangerous_content_type(None) is False
assert folder_paths.is_dangerous_content_type('') is False