mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-04 11:38:33 +08:00
Compare commits
10 Commits
feat/core/
...
temp_pr
| Author | SHA1 | Date | |
|---|---|---|---|
| a78019266f | |||
| f5c4bb1f02 | |||
| 1073a74976 | |||
| de1b8f3e8d | |||
| 77917ed3a6 | |||
| a04ebe05c2 | |||
| 9764381998 | |||
| 1e04ced089 | |||
| 96e0e3585b | |||
| 35c1470935 |
36
AGENTS.md
36
AGENTS.md
@ -8,6 +8,8 @@
|
||||
directly required.
|
||||
- Prefer practical fixes over broad architecture work. Add abstractions only
|
||||
when they remove real repeated logic or match an existing ComfyUI pattern.
|
||||
- Prefer fewer dependencies. Do not add new dependencies to ComfyUI unless they
|
||||
are absolutely necessary.
|
||||
- Delete obsolete code aggressively when newer infrastructure makes it useless.
|
||||
Remove dead fallbacks, migration paths, unused options, debug prints, and
|
||||
compatibility branches that are no longer needed. Do not leave dead branches,
|
||||
@ -111,6 +113,11 @@
|
||||
- Do not add freeze, unfreeze, or trainability toggles to model classes. ComfyUI
|
||||
models are always treated as frozen for inference, so explicit freeze
|
||||
functionality is redundant and should not be added.
|
||||
- Remove training-only behavior such as dropout from inference model code, but
|
||||
preserve checkpoint and state-dict compatibility when doing so. If deleting a
|
||||
module would change state-dict keys, module ordering, or checkpoint loading
|
||||
behavior, replace it with a no-op such as `nn.Identity` instead of removing the
|
||||
slot outright.
|
||||
|
||||
## Python Style
|
||||
|
||||
@ -164,16 +171,30 @@
|
||||
- Reuse existing model classes, blocks, ops, and helper modules when appropriate.
|
||||
Before implementing a new version of a model component, search the existing
|
||||
model code for a class or helper that already provides the behavior.
|
||||
- Model detection code that inspects linear weight shapes should only use the
|
||||
first dimension. The second dimension may be half the original size for
|
||||
NVFP4 or other 4-bit quantized models.
|
||||
- Avoid adding `einops` usage in core inference code. Use native torch tensor
|
||||
ops such as `reshape`, `view`, `permute`, `transpose`, `flatten`, `unflatten`,
|
||||
`unsqueeze`, and `squeeze` instead.
|
||||
- Do not use tensors as general-purpose Python data structures. Keep metadata,
|
||||
bookkeeping, counters, flags, shape math, padding math, index planning, memory
|
||||
estimates, and control-flow decisions in plain Python values unless the data
|
||||
must participate directly in tensor computation. Avoid creating temporary
|
||||
tensors just to use tensor methods for scalar or structural calculations.
|
||||
must participate directly in tensor computation. Do not create tensors for
|
||||
structural metadata that is only used for Python-side control flow. Sequence
|
||||
lengths, cumulative offsets, split indices, window counts, slice boundaries,
|
||||
and repeat counts should be kept as Python ints/lists from the point they are
|
||||
computed. Do not build them as CPU/GPU tensors and then cast, move, validate,
|
||||
or convert them back to Python for `split`, `tensor_split`, indexing plans,
|
||||
loops, or cache keys. Avoid creating temporary tensors just to use tensor
|
||||
methods for scalar or structural calculations.
|
||||
- Avoid unnecessary casts and transfers. Preserve the intended compute dtype,
|
||||
storage dtype, bias dtype, and original tensor shape metadata.
|
||||
- Keep model-native latent layout handling inside the model or latent-format
|
||||
owner, not in helper nodes. Do not collapse, expand, pack, or unpack latent
|
||||
dimensions in nodes or other caller-side adapters just to satisfy a model
|
||||
forward; the model path should consume and return the native latent shape for
|
||||
that model family.
|
||||
- Assume inputs to the main model forward are already in the compute dtype by
|
||||
default, except integer inputs such as some model timestep tensors. Do not add
|
||||
defensive or convenience casts in model code; it is better for invalid dtype
|
||||
@ -234,6 +255,17 @@
|
||||
`CATEGORY`, and registration through the local mapping used by that file.
|
||||
- Keep node changes backward compatible by default. Add inputs with sensible
|
||||
defaults and avoid changing output types unless the request requires it.
|
||||
- Model implementations should add the minimal number of ComfyUI nodes required
|
||||
to run the model. Reuse existing nodes as much as possible; adapting the model
|
||||
to work with existing nodes is strongly preferred over creating new nodes.
|
||||
- Nodes should output only values they own. Do not add pass-through outputs for
|
||||
workflow convenience unless the node is explicitly an output node. Existing
|
||||
models, latents, conditioning, or other inputs should flow directly to the
|
||||
next consumer instead of being re-emitted unchanged.
|
||||
- Nodes should expose only inputs they actually read to produce current
|
||||
behavior. Do not add placeholder, pass-through, compatibility, or
|
||||
workflow-shaping inputs that are ignored or could flow directly to another
|
||||
node.
|
||||
- Node-level code must not patch model code directly. Any node behavior that
|
||||
modifies, wraps, hooks, or changes model behavior must go through the model
|
||||
patcher class instead of reaching into model internals.
|
||||
|
||||
@ -306,12 +306,15 @@ async def download_asset_content(request: web.Request) -> web.Response:
|
||||
404, "FILE_NOT_FOUND", "Underlying file not found on disk."
|
||||
)
|
||||
|
||||
_DANGEROUS_MIME_TYPES = {
|
||||
"text/html", "text/html-sandboxed", "application/xhtml+xml",
|
||||
"text/javascript", "text/css",
|
||||
}
|
||||
if content_type in _DANGEROUS_MIME_TYPES:
|
||||
# User-controlled asset content must never render inline in the app origin
|
||||
# (stored XSS via SVG/HTML/XML). Force dangerous types to download and
|
||||
# override any requested inline disposition. Centralised through
|
||||
# folder_paths.is_dangerous_content_type so this can't drift from /view and
|
||||
# /userdata (the previous inline set here omitted image/svg+xml and missed
|
||||
# the charset/casing/+xml-dialect bypasses).
|
||||
if folder_paths.is_dangerous_content_type(content_type):
|
||||
content_type = "application/octet-stream"
|
||||
disposition = "attachment"
|
||||
|
||||
safe_name = (filename or "").replace("\r", "").replace("\n", "")
|
||||
encoded = urllib.parse.quote(safe_name)
|
||||
|
||||
@ -50,21 +50,45 @@ class ModelFileManager:
|
||||
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
||||
async def get_model_preview(request):
|
||||
folder_name = request.match_info.get("folder", None)
|
||||
path_index = int(request.match_info.get("path_index", None))
|
||||
filename = request.match_info.get("filename", None)
|
||||
|
||||
if folder_name not in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
|
||||
# The "{filename:.*}" capture also matches the empty string, which
|
||||
# would resolve to the folder itself; reject it explicitly.
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
try:
|
||||
path_index = int(request.match_info.get("path_index", None))
|
||||
except (TypeError, ValueError):
|
||||
return web.Response(status=400)
|
||||
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
if path_index < 0 or path_index >= len(folders[0]):
|
||||
return web.Response(status=404)
|
||||
folder = folders[0][path_index]
|
||||
full_filename = os.path.join(folder, filename)
|
||||
full_filename = os.path.normpath(os.path.join(folder, filename))
|
||||
|
||||
# Prevent path traversal: the requested file must stay within the
|
||||
# configured model folder. `filename` is an unrestricted ".*" capture,
|
||||
# so values like "../../../../etc/passwd" would otherwise escape it.
|
||||
if not folder_paths.is_within_directory(folder, full_filename):
|
||||
return web.Response(status=403)
|
||||
|
||||
previews = self.get_model_previews(full_filename)
|
||||
default_preview = previews[0] if len(previews) > 0 else None
|
||||
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
||||
return web.Response(status=404)
|
||||
|
||||
# The preview is selected by a glob inside get_model_previews, so a
|
||||
# companion file (e.g. "model.preview.png") could itself be a symlink
|
||||
# resolving outside the model folder. Re-validate the file actually
|
||||
# opened: is_within_directory realpaths it, catching symlink escape.
|
||||
if isinstance(default_preview, str) and not folder_paths.is_within_directory(folder, default_preview):
|
||||
return web.Response(status=403)
|
||||
|
||||
try:
|
||||
with Image.open(default_preview) as img:
|
||||
img_bytes = BytesIO()
|
||||
|
||||
@ -6,6 +6,7 @@ import glob
|
||||
import shutil
|
||||
import logging
|
||||
import tempfile
|
||||
import mimetypes
|
||||
from aiohttp import web
|
||||
from urllib import parse
|
||||
from comfy.cli_args import args
|
||||
@ -336,7 +337,20 @@ class UserManager():
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
return web.FileResponse(path)
|
||||
# User data files are arbitrary user-supplied content and are never
|
||||
# meant to render inline. Disable MIME sniffing and force a download
|
||||
# so uploaded markup/scripts can't execute in the app origin (stored
|
||||
# XSS). Content-Disposition: attachment is the load-bearing guard;
|
||||
# the content-type override and nosniff are defence in depth.
|
||||
content_type = mimetypes.guess_type(path)[0] or 'application/octet-stream'
|
||||
if folder_paths.is_dangerous_content_type(content_type):
|
||||
content_type = 'application/octet-stream'
|
||||
|
||||
return web.FileResponse(path, headers={
|
||||
"Content-Type": content_type,
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Content-Disposition": "attachment",
|
||||
})
|
||||
|
||||
@routes.post("/userdata/{file}")
|
||||
async def post_userdata(request):
|
||||
|
||||
@ -1651,15 +1651,6 @@ class Schema:
|
||||
Use this for nodes with interactive/operable UI regions that produce intermediate outputs
|
||||
(e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image).
|
||||
"""
|
||||
lazy_outputs: bool=False
|
||||
"""When True, cache will invalidate when output connections change, and expected_outputs will be available.
|
||||
|
||||
Use this for nodes that can skip computing outputs that aren't connected downstream.
|
||||
Check `comfy_execution.utils.is_output_needed(i)` inside execute() - False means output i is definitely unused
|
||||
and safe to skip. Only nodes with this flag receive expected_outputs; all others see None.
|
||||
|
||||
Limitation: consumers must exist before this node runs - a subgraph expansion that
|
||||
hand-builds a link to a pre-existing node's already-skipped output reads a stale value."""
|
||||
|
||||
def validate(self):
|
||||
'''Validate the schema:
|
||||
@ -2117,14 +2108,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls.GET_SCHEMA()
|
||||
return cls._ACCEPT_ALL_INPUTS
|
||||
|
||||
_LAZY_OUTPUTS = None
|
||||
@final
|
||||
@classproperty
|
||||
def LAZY_OUTPUTS(cls): # noqa
|
||||
if cls._LAZY_OUTPUTS is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._LAZY_OUTPUTS
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, dict]:
|
||||
@ -2169,8 +2152,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls._NOT_IDEMPOTENT = schema.not_idempotent
|
||||
if cls._ACCEPT_ALL_INPUTS is None:
|
||||
cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs
|
||||
if cls._LAZY_OUTPUTS is None:
|
||||
cls._LAZY_OUTPUTS = schema.lazy_outputs
|
||||
|
||||
if cls._RETURN_TYPES is None:
|
||||
output = []
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -316,3 +316,36 @@ VIDEO_TASKS_EXECUTION_TIME = {
|
||||
"1080p": 150,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SeedAudioConfig(BaseModel):
|
||||
format: str = Field(default="mp3")
|
||||
sample_rate: int = Field(default=24000)
|
||||
speech_rate: int = Field(default=0)
|
||||
loudness_rate: int = Field(default=0)
|
||||
pitch_rate: int = Field(default=0)
|
||||
|
||||
|
||||
class SeedAudioReference(BaseModel):
|
||||
speaker: str | None = Field(default=None)
|
||||
audio_data: str | None = Field(default=None)
|
||||
audio_url: str | None = Field(default=None)
|
||||
image_data: str | None = Field(default=None)
|
||||
image_url: str | None = Field(default=None)
|
||||
|
||||
|
||||
class SeedAudioRequest(BaseModel):
|
||||
model: str = Field(default="seed-audio-1.0")
|
||||
text_prompt: str = Field(...)
|
||||
references: list[SeedAudioReference] | None = Field(default=None)
|
||||
audio_config: SeedAudioConfig = Field(default_factory=SeedAudioConfig)
|
||||
watermark: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SeedAudioResponse(BaseModel):
|
||||
audio: str | None = Field(default=None)
|
||||
url: str | None = Field(default=None)
|
||||
duration: float | None = Field(default=None)
|
||||
original_duration: float | None = Field(default=None)
|
||||
code: int | None = Field(default=None)
|
||||
message: str | None = Field(default=None)
|
||||
|
||||
@ -1,147 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, confloat
|
||||
|
||||
|
||||
class StabilityFormat(str, Enum):
|
||||
png = 'png'
|
||||
jpeg = 'jpeg'
|
||||
webp = 'webp'
|
||||
|
||||
|
||||
class StabilityAspectRatio(str, Enum):
|
||||
ratio_1_1 = "1:1"
|
||||
ratio_16_9 = "16:9"
|
||||
ratio_9_16 = "9:16"
|
||||
ratio_3_2 = "3:2"
|
||||
ratio_2_3 = "2:3"
|
||||
ratio_5_4 = "5:4"
|
||||
ratio_4_5 = "4:5"
|
||||
ratio_21_9 = "21:9"
|
||||
ratio_9_21 = "9:21"
|
||||
|
||||
|
||||
def get_stability_style_presets(include_none=True):
|
||||
presets = []
|
||||
if include_none:
|
||||
presets.append("None")
|
||||
return presets + [x.value for x in StabilityStylePreset]
|
||||
|
||||
|
||||
class StabilityStylePreset(str, Enum):
|
||||
_3d_model = "3d-model"
|
||||
analog_film = "analog-film"
|
||||
anime = "anime"
|
||||
cinematic = "cinematic"
|
||||
comic_book = "comic-book"
|
||||
digital_art = "digital-art"
|
||||
enhance = "enhance"
|
||||
fantasy_art = "fantasy-art"
|
||||
isometric = "isometric"
|
||||
line_art = "line-art"
|
||||
low_poly = "low-poly"
|
||||
modeling_compound = "modeling-compound"
|
||||
neon_punk = "neon-punk"
|
||||
origami = "origami"
|
||||
photographic = "photographic"
|
||||
pixel_art = "pixel-art"
|
||||
tile_texture = "tile-texture"
|
||||
|
||||
|
||||
class Stability_SD3_5_Model(str, Enum):
|
||||
sd3_5_large = "sd3.5-large"
|
||||
# sd3_5_large_turbo = "sd3.5-large-turbo"
|
||||
sd3_5_medium = "sd3.5-medium"
|
||||
|
||||
|
||||
class Stability_SD3_5_GenerationMode(str, Enum):
|
||||
text_to_image = "text-to-image"
|
||||
image_to_image = "image-to-image"
|
||||
|
||||
|
||||
class StabilityStable3_5Request(BaseModel):
|
||||
model: str = Field(...)
|
||||
mode: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
aspect_ratio: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
style_preset: Optional[str] = Field(None)
|
||||
cfg_scale: float = Field(...)
|
||||
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
|
||||
|
||||
|
||||
class StabilityUpscaleConservativeRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
creativity: Optional[confloat(ge=0.2, le=0.5)] = Field(None)
|
||||
|
||||
|
||||
class StabilityUpscaleCreativeRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
creativity: Optional[confloat(ge=0.1, le=0.5)] = Field(None)
|
||||
style_preset: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class StabilityStableUltraRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
aspect_ratio: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
style_preset: Optional[str] = Field(None)
|
||||
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
|
||||
|
||||
|
||||
class StabilityStableUltraResponse(BaseModel):
|
||||
image: Optional[str] = Field(None)
|
||||
finish_reason: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class StabilityResultsGetResponse(BaseModel):
|
||||
image: Optional[str] = Field(None)
|
||||
finish_reason: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
id: Optional[str] = Field(None)
|
||||
name: Optional[str] = Field(None)
|
||||
errors: Optional[list[str]] = Field(None)
|
||||
status: Optional[str] = Field(None)
|
||||
result: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class StabilityAsyncResponse(BaseModel):
|
||||
id: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class StabilityTextToAudioRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
duration: int = Field(190, ge=1, le=190)
|
||||
seed: int = Field(0, ge=0, le=4294967294)
|
||||
steps: int = Field(8, ge=4, le=8)
|
||||
output_format: str = Field("wav")
|
||||
|
||||
|
||||
class StabilityAudioToAudioRequest(StabilityTextToAudioRequest):
|
||||
strength: float = Field(0.01, ge=0.01, le=1.0)
|
||||
|
||||
|
||||
class StabilityAudioInpaintRequest(StabilityTextToAudioRequest):
|
||||
mask_start: int = Field(30, ge=0, le=190)
|
||||
mask_end: int = Field(190, ge=0, le=190)
|
||||
|
||||
|
||||
class StabilityAudioResponse(BaseModel):
|
||||
audio: Optional[str] = Field(None)
|
||||
@ -1,3 +1,4 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import math
|
||||
@ -20,6 +21,10 @@ from comfy_api_nodes.apis.bytedance import (
|
||||
GetAssetResponse,
|
||||
Image2VideoTaskCreationRequest,
|
||||
ImageTaskCreationResponse,
|
||||
SeedAudioConfig,
|
||||
SeedAudioReference,
|
||||
SeedAudioRequest,
|
||||
SeedAudioResponse,
|
||||
Seedance2TaskCreationRequest,
|
||||
SeedanceCreateAssetRequest,
|
||||
SeedanceCreateAssetResponse,
|
||||
@ -43,6 +48,8 @@ from comfy_api_nodes.apis.bytedance import (
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
audio_bytes_to_audio_input,
|
||||
audio_input_to_mp3,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
downscale_image_tensor_by_max_side,
|
||||
@ -51,11 +58,14 @@ from comfy_api_nodes.util import (
|
||||
image_tensor_pair_to_batch,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
upscale_image_tensor_to_min_pixels,
|
||||
upscale_video_to_min_pixels,
|
||||
validate_audio_duration,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
@ -2474,6 +2484,311 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode):
|
||||
return IO.NodeOutput(asset_id, resolved_group)
|
||||
|
||||
|
||||
MODE_TEXT = "text only"
|
||||
MODE_AUDIO = "audio reference"
|
||||
MODE_IMAGE = "image reference"
|
||||
MODE_SPEAKER = "preset voice"
|
||||
|
||||
# (speaker_id, display_label) for built-in TTS 2.0 voices; resolvable ids are account-scoped.
|
||||
SEED_AUDIO_PRESET_VOICES: list[tuple[str, str]] = [
|
||||
("zh_female_vv_uranus_bigtts", "Vivi (Female, multilingual)"),
|
||||
("zh_female_xiaohe_uranus_bigtts", "Mindy (Female, multilingual)"),
|
||||
("en_female_stokie_uranus_bigtts", "Stokie (Female, English)"),
|
||||
("en_female_dacey_uranus_bigtts", "Dacey (Female, English)"),
|
||||
("en_male_tim_uranus_bigtts", "Tim (Male, English)"),
|
||||
("zh_male_m191_uranus_bigtts", "Kian (Male, multilingual)"),
|
||||
("zh_male_taocheng_uranus_bigtts", "Cedric (Male, multilingual)"),
|
||||
("zh_male_sophie_uranus_bigtts", "Sophie (Female, multilingual)"),
|
||||
("zh_female_yingyujiaoxue_uranus_bigtts", "Jean (Female, multilingual)"),
|
||||
("zh_male_dayi_uranus_bigtts", "Magnus (Male, multilingual)"),
|
||||
("zh_female_mizai_uranus_bigtts", "Mabel (Female, multilingual)"),
|
||||
("zh_female_jitangnv_uranus_bigtts", "Nadia (Female, multilingual)"),
|
||||
("zh_female_meilinvyou_uranus_bigtts", "Opal (Female, multilingual)"),
|
||||
("zh_female_liuchangnv_uranus_bigtts", "Pearl (Female, multilingual)"),
|
||||
("zh_male_ruyayichen_uranus_bigtts", "Quentin (Male, multilingual)"),
|
||||
("zh_female_vivo_uranus_bigtts", "Vienna (Female, multilingual)"),
|
||||
("zh_female_xiaoai_uranus_bigtts", "Alina (Female, multilingual)"),
|
||||
("zh_female_cancan_uranus_bigtts", "Corinne (Female, multilingual)"),
|
||||
("zh_female_tianmeixiaoyuan_uranus_bigtts", "Esther (Female, multilingual)"),
|
||||
("zh_female_tianmeitaozi_uranus_bigtts", "Freya (Female, multilingual)"),
|
||||
("zh_female_shuangkuaisisi_uranus_bigtts", "Gigi (Female, multilingual)"),
|
||||
("zh_female_peiqi_uranus_bigtts", "Holly (Female, multilingual)"),
|
||||
("zh_female_xiaoxue_uranus_bigtts", "Lyla (Female, multilingual)"),
|
||||
("zh_female_yuanqi_uranus_bigtts", "Daisy (Female, multilingual)"),
|
||||
("zh_female_kefunvsheng_uranus_bigtts", "Tracy (Female, multilingual)"),
|
||||
("zh_male_shaonianzixin_uranus_bigtts", "Jess (Male, multilingual)"),
|
||||
("zh_female_linjianvhai_uranus_bigtts", "Pinky (Female, multilingual)"),
|
||||
("zh_female_kiwi_uranus_bigtts", "Sweety (Female, multilingual)"),
|
||||
("zh_female_sajiaoxuemei_uranus_bigtts", "Sandy (Female, multilingual)"),
|
||||
("de_male_seven_uranus_bigtts", "Sven (Male, German)"),
|
||||
("jp_female_minimi_uranus_bigtts", "Minimi (Female, Japanese)"),
|
||||
("fr_male_usseau_uranus_bigtts", "Usseau (Male, French)"),
|
||||
("es_male_felipe_uranus_bigtts", "Felipe (Male, Spanish)"),
|
||||
("id_male_han_uranus_bigtts", "Han (Male, Indonesian)"),
|
||||
("pt_male_martins_uranus_bigtts", "Martins (Male, Portuguese)"),
|
||||
("it_male_enzo_uranus_bigtts", "Enzo (Male, Italian)"),
|
||||
("kr_male_shane_uranus_bigtts", "Shane (Male, Korean)"),
|
||||
("zh_male_liufei_uranus_bigtts", "Felix (Male, Chinese)"),
|
||||
("zh_female_qingxinnvsheng_uranus_bigtts", "Celeste (Female, Chinese)"),
|
||||
("zh_male_sunwukong_uranus_bigtts", "Monkey King (Male, Chinese)"),
|
||||
]
|
||||
SEED_AUDIO_VOICE_OPTIONS = [label for _, label in SEED_AUDIO_PRESET_VOICES]
|
||||
SEED_AUDIO_VOICE_MAP = {label: speaker_id for speaker_id, label in SEED_AUDIO_PRESET_VOICES}
|
||||
|
||||
_AUDIO_TAG_RE = re.compile(r"@Audio(\d+)", re.IGNORECASE)
|
||||
|
||||
|
||||
def max_audio_tag(prompt: str) -> int:
|
||||
"""Highest N referenced as @AudioN in the prompt (0 if none)."""
|
||||
nums = [int(m) for m in _AUDIO_TAG_RE.findall(prompt or "")]
|
||||
return max(nums) if nums else 0
|
||||
|
||||
|
||||
def connected_audio_indices(reference_mode: dict) -> list[int]:
|
||||
"""Indices (1-based) of connected reference_audio sockets, in order."""
|
||||
return [
|
||||
i
|
||||
for i in range(1, 3 + 1)
|
||||
if reference_mode.get(f"reference_audio_{i}") is not None
|
||||
]
|
||||
|
||||
|
||||
def validate_seed_audio_inputs(
|
||||
text_prompt: str,
|
||||
mode: str,
|
||||
audio_indices: list[int],
|
||||
has_image: bool,
|
||||
preset_voice: str | None = None,
|
||||
) -> None:
|
||||
validate_string(text_prompt, field_name="text_prompt", min_length=1, max_length=3000)
|
||||
max_tag = max_audio_tag(text_prompt)
|
||||
|
||||
if mode == MODE_TEXT:
|
||||
if max_tag:
|
||||
raise ValueError(
|
||||
f"The prompt references @Audio{max_tag}, but reference mode is '{MODE_TEXT}'. "
|
||||
f"Switch to '{MODE_AUDIO}' and connect the reference clip(s)."
|
||||
)
|
||||
elif mode == MODE_AUDIO:
|
||||
if not audio_indices:
|
||||
raise ValueError(
|
||||
f"Reference mode '{MODE_AUDIO}' requires at least one reference_audio input "
|
||||
f"(or switch to '{MODE_TEXT}')."
|
||||
)
|
||||
if audio_indices != list(range(1, len(audio_indices) + 1)):
|
||||
raise ValueError(
|
||||
"Connect reference_audio inputs in order without gaps: reference_audio_1, then _2, then _3."
|
||||
)
|
||||
if max_tag > len(audio_indices):
|
||||
raise ValueError(
|
||||
f"The prompt references @Audio{max_tag}, but only {len(audio_indices)} "
|
||||
f"reference audio(s) are connected."
|
||||
)
|
||||
elif mode == MODE_IMAGE:
|
||||
if not has_image:
|
||||
raise ValueError(f"Reference mode '{MODE_IMAGE}' requires a reference_image input.")
|
||||
if max_tag:
|
||||
raise ValueError(
|
||||
f"@AudioN tags are not used in '{MODE_IMAGE}' mode; the prompt should contain "
|
||||
f"only the text to synthesize."
|
||||
)
|
||||
elif mode == MODE_SPEAKER:
|
||||
if not preset_voice or preset_voice not in SEED_AUDIO_VOICE_MAP:
|
||||
raise ValueError(f"Reference mode '{MODE_SPEAKER}' requires selecting a preset voice.")
|
||||
if max_tag > 1:
|
||||
raise ValueError(
|
||||
f"'{MODE_SPEAKER}' mode uses a single voice, so @Audio{max_tag} is out of range. "
|
||||
f"Remove the @AudioN tags — the whole prompt is read in the selected voice."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown reference mode: {mode!r}")
|
||||
|
||||
|
||||
class ByteDanceSeedAudioNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedAudio",
|
||||
display_name="ByteDance Seed Audio 1.0",
|
||||
category="partner/audio/ByteDance",
|
||||
description=(
|
||||
"Generate speech, music, sound effects and multi-speaker dialogue from a single prompt "
|
||||
"with ByteDance Seed Audio 1.0. Describe the voice(s), emotion, ambience, background music "
|
||||
"and sound effects in the prompt, and include the lines to speak. Optionally pick a built-in "
|
||||
"preset voice, clone voices from up to 3 reference clips (tagged @Audio1-3 in the prompt), "
|
||||
"or derive a voice from a character image. Up to 2 minutes of audio per run."
|
||||
),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"text_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip=(
|
||||
"Describe the voice(s), emotion, pacing, ambience, background music and sound "
|
||||
"effects, and include the lines to speak (name characters inline for dialogue). "
|
||||
"In 'audio reference' mode, refer to connected clips by order as @Audio1, @Audio2, "
|
||||
"@Audio3. Maximum 3000 characters."
|
||||
),
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"reference_mode",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(MODE_TEXT, []),
|
||||
IO.DynamicCombo.Option(
|
||||
MODE_AUDIO,
|
||||
[
|
||||
IO.Audio.Input(
|
||||
"reference_audio_1",
|
||||
optional=True,
|
||||
tooltip="Reference clip for voice cloning, tagged @Audio1 in the prompt. "
|
||||
"Up to 30s.",
|
||||
),
|
||||
IO.Audio.Input(
|
||||
"reference_audio_2",
|
||||
optional=True,
|
||||
tooltip="Reference clip tagged @Audio2 in the prompt. Up to 30s.",
|
||||
),
|
||||
IO.Audio.Input(
|
||||
"reference_audio_3",
|
||||
optional=True,
|
||||
tooltip="Reference clip tagged @Audio3 in the prompt. Up to 30s.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
MODE_IMAGE,
|
||||
[
|
||||
IO.Image.Input(
|
||||
"reference_image",
|
||||
optional=True,
|
||||
tooltip="A single character image; the model derives a voice from it. "
|
||||
"Cannot be combined with reference audio.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
MODE_SPEAKER,
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"preset_voice",
|
||||
options=SEED_AUDIO_VOICE_OPTIONS,
|
||||
default=SEED_AUDIO_VOICE_OPTIONS[0],
|
||||
tooltip="A built-in TTS 2.0 voice that reads the prompt. No reference "
|
||||
"clip needed, and @AudioN tags are not used in this mode.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip=(
|
||||
"How to condition the voice: 'text only' (describe everything in the prompt), "
|
||||
"'audio reference' (clone up to 3 voices, tagged @Audio1-3), 'image reference' "
|
||||
"(derive a voice from one character image), or 'preset voice' (pick a built-in "
|
||||
"named voice that reads the prompt)."
|
||||
),
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"sample_rate",
|
||||
options=["8000", "16000", "24000", "32000", "44100", "48000"],
|
||||
default="24000",
|
||||
tooltip="Output sample rate in Hz.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"speech_rate",
|
||||
default=0,
|
||||
min=-50,
|
||||
max=100,
|
||||
tooltip="Speaking speed. 0 = normal, 100 = 2.0x, -50 = 0.5x.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"loudness_rate",
|
||||
default=0,
|
||||
min=-50,
|
||||
max=100,
|
||||
tooltip="Loudness. 0 = normal, 100 = 2.0x, -50 = 0.5x.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"pitch_rate",
|
||||
default=0,
|
||||
min=-12,
|
||||
max=12,
|
||||
tooltip="Pitch shift in semitones (-12 to 12).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.2145, "format":{"suffix":"/minute","approximate":true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
text_prompt: str,
|
||||
reference_mode: dict,
|
||||
sample_rate: str,
|
||||
speech_rate: int,
|
||||
loudness_rate: int,
|
||||
pitch_rate: int,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
mode = reference_mode["reference_mode"]
|
||||
audio_indices = connected_audio_indices(reference_mode)
|
||||
image = reference_mode.get("reference_image")
|
||||
preset_voice = reference_mode.get("preset_voice")
|
||||
validate_seed_audio_inputs(text_prompt, mode, audio_indices, image is not None, preset_voice)
|
||||
|
||||
references: list[SeedAudioReference] | None = None
|
||||
if mode == MODE_AUDIO:
|
||||
references = []
|
||||
for i in audio_indices:
|
||||
clip = reference_mode[f"reference_audio_{i}"]
|
||||
validate_audio_duration(clip, max_duration=30.0)
|
||||
mp3_bytes = audio_input_to_mp3(clip).getvalue()
|
||||
references.append(SeedAudioReference(audio_data=base64.b64encode(mp3_bytes).decode("utf-8")))
|
||||
elif mode == MODE_IMAGE:
|
||||
image = upscale_image_tensor_to_min_pixels(image, 160_000)
|
||||
references = [SeedAudioReference(image_data=tensor_to_base64_string(image, mime_type="image/png"))]
|
||||
elif mode == MODE_SPEAKER:
|
||||
references = [SeedAudioReference(speaker=SEED_AUDIO_VOICE_MAP[preset_voice])]
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/byteplus/api/v3/tts/create", method="POST"),
|
||||
response_model=SeedAudioResponse,
|
||||
data=SeedAudioRequest(
|
||||
text_prompt=text_prompt,
|
||||
references=references,
|
||||
audio_config=SeedAudioConfig(
|
||||
sample_rate=int(sample_rate),
|
||||
speech_rate=speech_rate,
|
||||
loudness_rate=loudness_rate,
|
||||
pitch_rate=pitch_rate,
|
||||
),
|
||||
),
|
||||
)
|
||||
if not response.audio:
|
||||
raise Exception(
|
||||
f"Seed Audio returned no audio (code={response.code}): {response.message}"
|
||||
)
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response.audio)))
|
||||
|
||||
|
||||
class ByteDanceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -2490,6 +2805,7 @@ class ByteDanceExtension(ComfyExtension):
|
||||
ByteDance2ReferenceNode,
|
||||
ByteDanceCreateImageAsset,
|
||||
ByteDanceCreateVideoAsset,
|
||||
ByteDanceSeedAudioNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1,932 +0,0 @@
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, Input, IO
|
||||
from comfy_api_nodes.apis.stability import (
|
||||
StabilityUpscaleConservativeRequest,
|
||||
StabilityUpscaleCreativeRequest,
|
||||
StabilityAsyncResponse,
|
||||
StabilityResultsGetResponse,
|
||||
StabilityStable3_5Request,
|
||||
StabilityStableUltraRequest,
|
||||
StabilityStableUltraResponse,
|
||||
StabilityAspectRatio,
|
||||
Stability_SD3_5_Model,
|
||||
Stability_SD3_5_GenerationMode,
|
||||
get_stability_style_presets,
|
||||
StabilityTextToAudioRequest,
|
||||
StabilityAudioToAudioRequest,
|
||||
StabilityAudioInpaintRequest,
|
||||
StabilityAudioResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
validate_audio_duration,
|
||||
validate_string,
|
||||
audio_input_to_mp3,
|
||||
bytesio_to_image_tensor,
|
||||
tensor_to_bytesio,
|
||||
audio_bytes_to_audio_input,
|
||||
sync_op,
|
||||
poll_op,
|
||||
ApiEndpoint,
|
||||
)
|
||||
|
||||
import torch
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class StabilityPollStatus(str, Enum):
|
||||
finished = "finished"
|
||||
in_progress = "in_progress"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
def get_async_dummy_status(x: StabilityResultsGetResponse):
|
||||
if x.name is not None or x.errors is not None:
|
||||
return StabilityPollStatus.failed
|
||||
elif x.finish_reason is not None:
|
||||
return StabilityPollStatus.finished
|
||||
return StabilityPollStatus.in_progress
|
||||
|
||||
|
||||
class StabilityStableImageUltraNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityStableImageUltraNode",
|
||||
display_name="Stability AI Stable Image Ultra",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
|
||||
"elements, colors, and subjects will lead to better results. " +
|
||||
"To control the weight of a given word use the format `(word:weight)`," +
|
||||
"where `word` is the word you'd like to control the weight of and `weight`" +
|
||||
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
|
||||
"would convey a sky that was blue and green, but more green than blue.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=StabilityAspectRatio,
|
||||
default=StabilityAspectRatio.ratio_1_1,
|
||||
tooltip="Aspect ratio of generated image.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"style_preset",
|
||||
options=get_stability_style_presets(),
|
||||
tooltip="Optional desired style of generated image.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"image_denoise",
|
||||
default=0.5,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.08}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
style_preset: str,
|
||||
seed: int,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
negative_prompt: str = "",
|
||||
image_denoise: Optional[float] = 0.5,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
if image is not None:
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
|
||||
else:
|
||||
image_denoise = None
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
if style_preset == "None":
|
||||
style_preset = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
data=StabilityStableUltraRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
seed=seed,
|
||||
strength=image_denoise,
|
||||
style_preset=style_preset,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityStableImageSD_3_5Node",
|
||||
display_name="Stability AI Stable Diffusion 3.5 Image",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Stability_SD3_5_Model,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=StabilityAspectRatio,
|
||||
default=StabilityAspectRatio.ratio_1_1,
|
||||
tooltip="Aspect ratio of generated image.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"style_preset",
|
||||
options=get_stability_style_presets(),
|
||||
tooltip="Optional desired style of generated image.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"cfg_scale",
|
||||
default=4.0,
|
||||
min=1.0,
|
||||
max=10.0,
|
||||
step=0.1,
|
||||
tooltip="How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"image_denoise",
|
||||
default=0.5,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$contains(widgets.model,"large")
|
||||
? {"type":"usd","usd":0.065}
|
||||
: {"type":"usd","usd":0.035}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
style_preset: str,
|
||||
seed: int,
|
||||
cfg_scale: float,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
negative_prompt: str = "",
|
||||
image_denoise: Optional[float] = 0.5,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
mode = Stability_SD3_5_GenerationMode.text_to_image
|
||||
if image is not None:
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
|
||||
mode = Stability_SD3_5_GenerationMode.image_to_image
|
||||
aspect_ratio = None
|
||||
else:
|
||||
image_denoise = None
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
if style_preset == "None":
|
||||
style_preset = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
data=StabilityStable3_5Request(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
seed=seed,
|
||||
strength=image_denoise,
|
||||
style_preset=style_preset,
|
||||
cfg_scale=cfg_scale,
|
||||
model=model,
|
||||
mode=mode,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
||||
"""
|
||||
Upscale image with minimal alterations to 4K resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleConservativeNode",
|
||||
display_name="Stability AI Upscale Conservative",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"creativity",
|
||||
default=0.35,
|
||||
min=0.2,
|
||||
max=0.5,
|
||||
step=0.01,
|
||||
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.4}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
creativity: float,
|
||||
seed: int,
|
||||
negative_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
data=StabilityUpscaleConservativeRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
creativity=round(creativity,2),
|
||||
seed=seed,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
||||
"""
|
||||
Upscale image with minimal alterations to 4K resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleCreativeNode",
|
||||
display_name="Stability AI Upscale Creative",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"creativity",
|
||||
default=0.3,
|
||||
min=0.1,
|
||||
max=0.5,
|
||||
step=0.01,
|
||||
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"style_preset",
|
||||
options=get_stability_style_presets(),
|
||||
tooltip="Optional desired style of generated image.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.6}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
creativity: float,
|
||||
style_preset: str,
|
||||
seed: int,
|
||||
negative_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
if style_preset == "None":
|
||||
style_preset = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
|
||||
response_model=StabilityAsyncResponse,
|
||||
data=StabilityUpscaleCreativeRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
creativity=round(creativity,2),
|
||||
style_preset=style_preset,
|
||||
seed=seed,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
|
||||
response_model=StabilityResultsGetResponse,
|
||||
poll_interval=3,
|
||||
status_extractor=lambda x: get_async_dummy_status(x),
|
||||
)
|
||||
|
||||
if response_poll.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_poll.result)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityUpscaleFastNode(IO.ComfyNode):
|
||||
"""
|
||||
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleFastNode",
|
||||
display_name="Stability AI Upscale Fast",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.02}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, image: torch.Tensor) -> IO.NodeOutput:
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityTextToAudio(IO.ComfyNode):
|
||||
"""Generates high-quality music and sound effects from text descriptions."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityTextToAudio",
|
||||
display_name="Stability AI Text To Audio",
|
||||
category="partner/audio/Stability AI",
|
||||
essentials_category="Audio",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["stable-audio-2.5"],
|
||||
),
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=190,
|
||||
min=1,
|
||||
max=190,
|
||||
step=1,
|
||||
tooltip="Controls the duration in seconds of the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=8,
|
||||
min=4,
|
||||
max=8,
|
||||
step=1,
|
||||
tooltip="Controls the number of sampling steps.",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Audio.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.2}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=10000)
|
||||
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
|
||||
response_model=StabilityAudioResponse,
|
||||
data=payload,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
|
||||
|
||||
class StabilityAudioToAudio(IO.ComfyNode):
|
||||
"""Transforms existing audio samples into new high-quality compositions using text instructions."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityAudioToAudio",
|
||||
display_name="Stability AI Audio To Audio",
|
||||
category="partner/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["stable-audio-2.5"],
|
||||
),
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=190,
|
||||
min=1,
|
||||
max=190,
|
||||
step=1,
|
||||
tooltip="Controls the duration in seconds of the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=8,
|
||||
min=4,
|
||||
max=8,
|
||||
step=1,
|
||||
tooltip="Controls the number of sampling steps.",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"strength",
|
||||
default=1,
|
||||
min=0.01,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Parameter controls how much influence the audio parameter has on the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Audio.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.2}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=10000)
|
||||
validate_audio_duration(audio, 6, 190)
|
||||
payload = StabilityAudioToAudioRequest(
|
||||
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
|
||||
)
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
|
||||
response_model=StabilityAudioResponse,
|
||||
data=payload,
|
||||
content_type="multipart/form-data",
|
||||
files={"audio": audio_input_to_mp3(audio)},
|
||||
)
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
|
||||
|
||||
class StabilityAudioInpaint(IO.ComfyNode):
|
||||
"""Transforms part of existing audio sample using text instructions."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityAudioInpaint",
|
||||
display_name="Stability AI Audio Inpaint",
|
||||
category="partner/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["stable-audio-2.5"],
|
||||
),
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=190,
|
||||
min=1,
|
||||
max=190,
|
||||
step=1,
|
||||
tooltip="Controls the duration in seconds of the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=8,
|
||||
min=4,
|
||||
max=8,
|
||||
step=1,
|
||||
tooltip="Controls the number of sampling steps.",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"mask_start",
|
||||
default=30,
|
||||
min=0,
|
||||
max=190,
|
||||
step=1,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"mask_end",
|
||||
default=190,
|
||||
min=0,
|
||||
max=190,
|
||||
step=1,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Audio.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.2}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
audio: Input.Audio,
|
||||
duration: int,
|
||||
seed: int,
|
||||
steps: int,
|
||||
mask_start: int,
|
||||
mask_end: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=10000)
|
||||
if mask_end <= mask_start:
|
||||
raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})")
|
||||
validate_audio_duration(audio, 6, 190)
|
||||
|
||||
payload = StabilityAudioInpaintRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
steps=steps,
|
||||
mask_start=mask_start,
|
||||
mask_end=mask_end,
|
||||
)
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
|
||||
response_model=StabilityAudioResponse,
|
||||
data=payload,
|
||||
content_type="multipart/form-data",
|
||||
files={"audio": audio_input_to_mp3(audio)},
|
||||
)
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
|
||||
|
||||
class StabilityExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
StabilityStableImageUltraNode,
|
||||
StabilityStableImageSD_3_5Node,
|
||||
StabilityUpscaleConservativeNode,
|
||||
StabilityUpscaleCreativeNode,
|
||||
StabilityUpscaleFastNode,
|
||||
StabilityTextToAudio,
|
||||
StabilityAudioToAudio,
|
||||
StabilityAudioInpaint,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> StabilityExtension:
|
||||
return StabilityExtension()
|
||||
@ -26,6 +26,7 @@ from .conversions import (
|
||||
text_filepath_to_base64_string,
|
||||
text_filepath_to_data_uri,
|
||||
trim_video,
|
||||
upscale_image_tensor_to_min_pixels,
|
||||
upscale_video_to_min_pixels,
|
||||
video_to_base64_string,
|
||||
)
|
||||
@ -99,6 +100,7 @@ __all__ = [
|
||||
"text_filepath_to_base64_string",
|
||||
"text_filepath_to_data_uri",
|
||||
"trim_video",
|
||||
"upscale_image_tensor_to_min_pixels",
|
||||
"upscale_video_to_min_pixels",
|
||||
"video_to_base64_string",
|
||||
# Validation utilities
|
||||
|
||||
@ -448,6 +448,15 @@ def _compute_upscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[in
|
||||
return new_w, new_h
|
||||
|
||||
|
||||
def upscale_image_tensor_to_min_pixels(image: torch.Tensor, total_pixels: int) -> torch.Tensor:
|
||||
samples = image.movedim(-1, 1)
|
||||
dims = _compute_upscale_dims(samples.shape[3], samples.shape[2], int(total_pixels))
|
||||
if dims is None:
|
||||
return image
|
||||
new_w, new_h = dims
|
||||
return common_upscale(samples, new_w, new_h, "lanczos", "disabled").movedim(1, -1)
|
||||
|
||||
|
||||
def upscale_video_to_min_pixels(video: Input.Video, min_pixels: int) -> Input.Video:
|
||||
"""Upscale a video to meet at least ``min_pixels`` (w * h), preserving aspect ratio.
|
||||
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
import asyncio
|
||||
import bisect
|
||||
import gc
|
||||
import itertools
|
||||
import psutil
|
||||
import time
|
||||
import torch
|
||||
from typing import Sequence, Mapping, Dict
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import nodes
|
||||
@ -116,10 +117,6 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
# Include expected_outputs in cache key for nodes that opt in via LAZY_OUTPUTS
|
||||
if hasattr(class_def, 'LAZY_OUTPUTS') and class_def.LAZY_OUTPUTS:
|
||||
expected = get_expected_outputs_for_node(dynprompt, node_id)
|
||||
signature.append(("expected_outputs", tuple(sorted(expected))))
|
||||
inputs = node["inputs"]
|
||||
for key in sorted(inputs.keys()):
|
||||
if is_link(inputs[key]):
|
||||
@ -532,6 +529,38 @@ class RAMPressureCache(LRUCache):
|
||||
if psutil.virtual_memory().available >= target:
|
||||
return
|
||||
|
||||
def remove_cache_key(key):
|
||||
del self.cache[key]
|
||||
self.used_generation.pop(key, None)
|
||||
self.timestamps.pop(key, None)
|
||||
self.children.pop(key, None)
|
||||
|
||||
def has_old_model_patcher(outputs):
|
||||
if outputs is None:
|
||||
return False
|
||||
for output in outputs:
|
||||
if isinstance(output, (list, tuple)):
|
||||
if has_old_model_patcher(output):
|
||||
return True
|
||||
elif isinstance(output, ModelPatcher):
|
||||
return True
|
||||
return False
|
||||
|
||||
old_modelpatcher_keys = []
|
||||
for key, cache_entry in self.cache.items():
|
||||
if self.used_generation[key] == self.generation:
|
||||
continue
|
||||
if has_old_model_patcher(cache_entry.outputs):
|
||||
old_modelpatcher_keys.append(key)
|
||||
|
||||
for key in old_modelpatcher_keys:
|
||||
remove_cache_key(key)
|
||||
|
||||
if old_modelpatcher_keys:
|
||||
gc.collect()
|
||||
if psutil.virtual_memory().available >= target:
|
||||
return
|
||||
|
||||
clean_list = []
|
||||
|
||||
for key, cache_entry in self.cache.items():
|
||||
@ -549,19 +578,17 @@ class RAMPressureCache(LRUCache):
|
||||
scan_list_for_ram_usage(output)
|
||||
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
|
||||
ram_usage += output.numel() * output.element_size()
|
||||
elif isinstance(output, ModelPatcher) and self.used_generation[key] != self.generation:
|
||||
#old ModelPatchers are the first to go
|
||||
ram_usage = 1e30
|
||||
scan_list_for_ram_usage(cache_entry.outputs)
|
||||
|
||||
oom_score *= ram_usage
|
||||
#In the case where we have no information on the node ram usage at all,
|
||||
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], ram_usage, key))
|
||||
|
||||
while psutil.virtual_memory().available < target and clean_list:
|
||||
_, _, key = clean_list.pop()
|
||||
del self.cache[key]
|
||||
self.used_generation.pop(key, None)
|
||||
self.timestamps.pop(key, None)
|
||||
self.children.pop(key, None)
|
||||
to_free = target - psutil.virtual_memory().available
|
||||
while to_free > 0 and clean_list:
|
||||
_, _, ram_usage, key = clean_list.pop()
|
||||
remove_cache_key(key)
|
||||
to_free -= ram_usage
|
||||
|
||||
gc.collect()
|
||||
|
||||
@ -18,18 +18,6 @@ class NodeInputError(Exception):
|
||||
class NodeNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
|
||||
"""Get the set of output indices that are connected downstream.
|
||||
Returns outputs that MIGHT be used.
|
||||
Outputs NOT in this set are DEFINITELY not used and safe to skip
|
||||
(see Schema.lazy_outputs for the one expansion-related limitation).
|
||||
|
||||
Includes input links and consumers registered via add_output_consumer.
|
||||
"""
|
||||
return dynprompt.get_expected_outputs_map().get(node_id, frozenset())
|
||||
|
||||
|
||||
class DynamicPrompt:
|
||||
def __init__(self, original_prompt):
|
||||
# The original prompt provided by the user
|
||||
@ -38,9 +26,6 @@ class DynamicPrompt:
|
||||
self.ephemeral_prompt = {}
|
||||
self.ephemeral_parents = {}
|
||||
self.ephemeral_display = {}
|
||||
# Output sockets consumed outside of input links (subgraph expansions)
|
||||
self._external_output_consumers = {}
|
||||
self._expected_outputs_map = None
|
||||
|
||||
def get_node(self, node_id):
|
||||
if node_id in self.ephemeral_prompt:
|
||||
@ -56,7 +41,6 @@ class DynamicPrompt:
|
||||
self.ephemeral_prompt[node_id] = node_info
|
||||
self.ephemeral_parents[node_id] = parent_id
|
||||
self.ephemeral_display[node_id] = display_id
|
||||
self._expected_outputs_map = None
|
||||
|
||||
def get_real_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_parents:
|
||||
@ -74,29 +58,6 @@ class DynamicPrompt:
|
||||
def all_node_ids(self):
|
||||
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
||||
|
||||
def add_output_consumer(self, node_id, socket):
|
||||
"""Record an output socket consumed outside of input links, e.g. a subgraph
|
||||
expansion mapping its parent's output to this node's output."""
|
||||
self._external_output_consumers.setdefault(node_id, set()).add(socket)
|
||||
self._expected_outputs_map = None
|
||||
|
||||
def _build_expected_outputs_map(self):
|
||||
result = {}
|
||||
for node_id in self.all_node_ids():
|
||||
node_data = self.get_node(node_id)
|
||||
for value in node_data.get("inputs", {}).values():
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
result.setdefault(from_node_id, set()).add(from_socket)
|
||||
for node_id, sockets in self._external_output_consumers.items():
|
||||
result.setdefault(node_id, set()).update(sockets)
|
||||
self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()}
|
||||
|
||||
def get_expected_outputs_map(self):
|
||||
if self._expected_outputs_map is None:
|
||||
self._build_expected_outputs_map()
|
||||
return self._expected_outputs_map
|
||||
|
||||
def get_original_prompt(self):
|
||||
return self.original_prompt
|
||||
|
||||
|
||||
@ -1,45 +1,23 @@
|
||||
import contextvars
|
||||
from typing import NamedTuple, FrozenSet
|
||||
from typing import Optional, NamedTuple
|
||||
|
||||
class ExecutionContext(NamedTuple):
|
||||
"""
|
||||
Context information about the currently executing node.
|
||||
|
||||
Attributes:
|
||||
prompt_id: The ID of the current prompt execution
|
||||
node_id: The ID of the currently executing node
|
||||
list_index: The index in a list being processed (for operations on batches/lists)
|
||||
expected_outputs: Set of output indices that might be used downstream.
|
||||
Outputs NOT in this set are definitely unused (safe to skip).
|
||||
None means the information is not available.
|
||||
"""
|
||||
prompt_id: str
|
||||
node_id: str
|
||||
list_index: int | None
|
||||
expected_outputs: FrozenSet[int] | None = None
|
||||
list_index: Optional[int]
|
||||
|
||||
current_executing_context: contextvars.ContextVar[ExecutionContext | None] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
|
||||
def get_executing_context() -> ExecutionContext | None:
|
||||
def get_executing_context() -> Optional[ExecutionContext]:
|
||||
return current_executing_context.get(None)
|
||||
|
||||
|
||||
def is_output_needed(output_index: int) -> bool:
|
||||
"""Check if an output at the given index is connected downstream.
|
||||
|
||||
Returns True if the output might be used (should be computed).
|
||||
Returns False if the output is definitely not connected (safe to skip).
|
||||
|
||||
Only meaningful for LAZY_OUTPUTS nodes; for all others expected_outputs is
|
||||
None and this always returns True (skipping without the flag would not be
|
||||
reflected in the cache key).
|
||||
"""
|
||||
ctx = get_executing_context()
|
||||
if ctx is None or ctx.expected_outputs is None:
|
||||
return True
|
||||
return output_index in ctx.expected_outputs
|
||||
|
||||
|
||||
class CurrentNodeContext:
|
||||
"""
|
||||
Context manager for setting the current executing node context.
|
||||
@ -47,22 +25,15 @@ class CurrentNodeContext:
|
||||
Sets the current_executing_context on enter and resets it on exit.
|
||||
|
||||
Example:
|
||||
with CurrentNodeContext(prompt_id="abc", node_id="123", list_index=0):
|
||||
with CurrentNodeContext(node_id="123", list_index=0):
|
||||
# Code that should run with the current node context set
|
||||
process_image()
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
prompt_id: str,
|
||||
node_id: str,
|
||||
list_index: int | None = None,
|
||||
expected_outputs: FrozenSet[int] | None = None,
|
||||
):
|
||||
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
|
||||
self.context = ExecutionContext(
|
||||
prompt_id=prompt_id,
|
||||
node_id=node_id,
|
||||
list_index=list_index,
|
||||
expected_outputs=expected_outputs,
|
||||
prompt_id= prompt_id,
|
||||
node_id= node_id,
|
||||
list_index= list_index
|
||||
)
|
||||
self.token = None
|
||||
|
||||
|
||||
44
execution.py
44
execution.py
@ -35,7 +35,6 @@ from comfy_execution.graph import (
|
||||
ExecutionBlocker,
|
||||
ExecutionList,
|
||||
get_input_info,
|
||||
get_expected_outputs_for_node,
|
||||
)
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
@ -238,18 +237,7 @@ async def resolve_map_node_over_list_results(results):
|
||||
raise exc
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
|
||||
async def _async_map_node_over_list(
|
||||
prompt_id,
|
||||
unique_id,
|
||||
obj,
|
||||
input_data_all,
|
||||
func,
|
||||
allow_interrupt=False,
|
||||
execution_block_cb=None,
|
||||
pre_execute_cb=None,
|
||||
v3_data=None,
|
||||
expected_outputs=None,
|
||||
):
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
|
||||
# check if node wants the lists
|
||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||
|
||||
@ -299,12 +287,10 @@ async def _async_map_node_over_list(
|
||||
else:
|
||||
f = getattr(obj, func)
|
||||
if inspect.iscoroutinefunction(f):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args, expected_outputs):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index, expected_outputs):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index):
|
||||
return await f(**args)
|
||||
task = asyncio.create_task(
|
||||
async_wrapper(f, prompt_id, unique_id, index, args=inputs, expected_outputs=expected_outputs)
|
||||
)
|
||||
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
|
||||
# Give the task a chance to execute without yielding
|
||||
await asyncio.sleep(0)
|
||||
if task.done():
|
||||
@ -313,7 +299,7 @@ async def _async_map_node_over_list(
|
||||
else:
|
||||
results.append(task)
|
||||
else:
|
||||
with CurrentNodeContext(prompt_id, unique_id, index, expected_outputs):
|
||||
with CurrentNodeContext(prompt_id, unique_id, index):
|
||||
result = f(**inputs)
|
||||
results.append(result)
|
||||
else:
|
||||
@ -351,17 +337,8 @@ def merge_result_data(results, obj):
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
async def get_output_data(
|
||||
prompt_id,
|
||||
unique_id,
|
||||
obj,
|
||||
input_data_all,
|
||||
execution_block_cb=None,
|
||||
pre_execute_cb=None,
|
||||
v3_data=None,
|
||||
expected_outputs=None,
|
||||
):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||
if has_pending_task:
|
||||
return return_values, {}, False, has_pending_task
|
||||
@ -561,12 +538,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
|
||||
if getattr(class_def, "LAZY_OUTPUTS", False):
|
||||
expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id)
|
||||
else:
|
||||
expected_outputs = None
|
||||
try:
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
finally:
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
if args.verbose == "DEBUG":
|
||||
@ -623,7 +596,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
if is_link(node_outputs[i]):
|
||||
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
|
||||
new_output_links.append((from_node_id, from_socket))
|
||||
dynprompt.add_output_consumer(from_node_id, from_socket)
|
||||
cached_outputs.append((True, node_outputs))
|
||||
new_node_ids = set(new_node_ids)
|
||||
for cache in caches.all:
|
||||
|
||||
@ -264,6 +264,59 @@ def annotated_filepath(name: str) -> tuple[str, str | None]:
|
||||
return name, base_dir
|
||||
|
||||
|
||||
# Content types a browser may execute or render inline. File endpoints that
|
||||
# serve user-controlled content must force these to download (and ideally set
|
||||
# Content-Disposition: attachment) to avoid stored XSS. Centralised here so the
|
||||
# /view and /userdata handlers can't drift apart. mimetypes.guess_type may
|
||||
# return either the text/* or application/* spelling depending on platform, so
|
||||
# both are listed.
|
||||
DANGEROUS_CONTENT_TYPES = {
|
||||
'text/html', 'text/html-sandboxed', 'application/xhtml+xml',
|
||||
'text/javascript', 'application/javascript', 'application/x-javascript',
|
||||
'application/ecmascript', 'text/css',
|
||||
'image/svg+xml', 'application/xml', 'text/xml',
|
||||
# message/rfc822 (.mht/.mhtml) can carry script in some browsers.
|
||||
'message/rfc822',
|
||||
}
|
||||
|
||||
|
||||
def is_dangerous_content_type(content_type: str | None) -> bool:
|
||||
"""Return True if a browser may execute or render `content_type` inline.
|
||||
|
||||
Normalises before matching so the check can't be slipped past with a
|
||||
charset/boundary parameter (``text/html; charset=utf-8``) or casing
|
||||
(``TEXT/HTML``). Any XML dialect (``*+xml`` or ``*/xml``) is treated as
|
||||
dangerous because XML can carry inline script via stylesheet/entity tricks,
|
||||
which also covers the ``application/{xslt,rss,atom,rdf}+xml`` family without
|
||||
enumerating each one. Endpoints serving user-controlled content should route
|
||||
a dangerous type to ``application/octet-stream`` + ``Content-Disposition:
|
||||
attachment`` + ``X-Content-Type-Options: nosniff``.
|
||||
"""
|
||||
if not content_type:
|
||||
return False
|
||||
normalized = content_type.split(';', 1)[0].strip().lower()
|
||||
if normalized in DANGEROUS_CONTENT_TYPES:
|
||||
return True
|
||||
return normalized.endswith('+xml') or normalized.endswith('/xml')
|
||||
|
||||
|
||||
def is_within_directory(directory: str, target: str) -> bool:
|
||||
"""Return True if `target` resolves to a path inside `directory`.
|
||||
|
||||
Uses realpath on both operands so that a symlink placed inside `directory`
|
||||
that points elsewhere cannot escape the containment check at open time.
|
||||
"""
|
||||
try:
|
||||
directory = os.path.realpath(directory)
|
||||
target = os.path.realpath(target)
|
||||
return os.path.commonpath((directory, target)) == directory
|
||||
except ValueError:
|
||||
# ValueError is raised by realpath() on a path with an embedded null
|
||||
# byte, and by commonpath() on Windows when the paths are on different
|
||||
# drives. In either case the target is not safely within the directory.
|
||||
return False
|
||||
|
||||
|
||||
def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
|
||||
name, base_dir = annotated_filepath(name)
|
||||
|
||||
@ -273,7 +326,12 @@ def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
|
||||
else:
|
||||
base_dir = get_input_directory() # fallback path
|
||||
|
||||
return os.path.join(base_dir, name)
|
||||
filepath = os.path.abspath(os.path.join(base_dir, name))
|
||||
# Prevent path traversal: the resolved path must stay within base_dir.
|
||||
# repr() the name in the message so a crafted value can't inject log lines.
|
||||
if not is_within_directory(base_dir, filepath):
|
||||
raise ValueError("Invalid file path: {!r}".format(name))
|
||||
return filepath
|
||||
|
||||
|
||||
def exists_annotated_filepath(name) -> bool:
|
||||
@ -282,7 +340,10 @@ def exists_annotated_filepath(name) -> bool:
|
||||
if base_dir is None:
|
||||
base_dir = get_input_directory() # fallback path
|
||||
|
||||
filepath = os.path.join(base_dir, name)
|
||||
filepath = os.path.abspath(os.path.join(base_dir, name))
|
||||
# Treat traversal attempts as non-existent rather than probing the filesystem.
|
||||
if not is_within_directory(base_dir, filepath):
|
||||
return False
|
||||
return os.path.exists(filepath)
|
||||
|
||||
|
||||
|
||||
2
main.py
2
main.py
@ -314,7 +314,7 @@ def prompt_worker(q, server_instance):
|
||||
cache_ram = 0
|
||||
cache_ram_inactive = 0
|
||||
if not args.cache_classic and not args.cache_none and args.cache_lru <= 0:
|
||||
cache_ram = min(10.0, max(2.0, comfy.model_management.total_ram * 0.10 / 1024.0))
|
||||
cache_ram = min(10.0, max(1.5, comfy.model_management.total_ram * 0.05 / 1024.0))
|
||||
cache_ram_inactive = min(96.0, comfy.model_management.total_ram / 1024.0)
|
||||
if len(args.cache_ram) > 0:
|
||||
cache_ram = args.cache_ram[0]
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.45.20
|
||||
comfyui-workflow-templates==0.11.1
|
||||
comfyui-workflow-templates==0.11.2
|
||||
comfyui-embedded-docs==0.5.6
|
||||
torch
|
||||
torchsde
|
||||
|
||||
26
server.py
26
server.py
@ -127,6 +127,7 @@ def create_cors_middleware(allowed_origin: str):
|
||||
|
||||
return cors_middleware
|
||||
|
||||
|
||||
def is_loopback(host):
|
||||
if host is None:
|
||||
return False
|
||||
@ -616,15 +617,30 @@ class PromptServer():
|
||||
or 'application/octet-stream'
|
||||
)
|
||||
|
||||
# For security, force certain mimetypes to download instead of display
|
||||
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
|
||||
content_type = 'application/octet-stream' # Forces download
|
||||
# For security, force renderable/active types (HTML, JS,
|
||||
# CSS, SVG, XML — anything that can carry inline <script>
|
||||
# and execute in the page origin) to download instead of
|
||||
# displaying inline, preventing stored XSS. The
|
||||
# attachment disposition is the load-bearing guard: a
|
||||
# bare filename= hint does not force a download per
|
||||
# RFC 6266, so we only attach it on the dangerous branch
|
||||
# to avoid breaking inline display of legitimate images.
|
||||
# Escape backslash/quote per RFC 6266 quoted-string so a
|
||||
# filename containing a double quote (which passes the
|
||||
# ".."/leading-slash filter above) can't break out of the
|
||||
# header's quoted-string and malform the disposition.
|
||||
safe_filename = filename.replace("\\", "\\\\").replace('"', '\\"')
|
||||
disposition = f"filename=\"{safe_filename}\""
|
||||
if folder_paths.is_dangerous_content_type(content_type):
|
||||
content_type = 'application/octet-stream'
|
||||
disposition = f"attachment; filename=\"{safe_filename}\""
|
||||
|
||||
return web.FileResponse(
|
||||
file,
|
||||
headers={
|
||||
"Content-Disposition": f"filename=\"{filename}\"",
|
||||
"Content-Type": content_type
|
||||
"Content-Disposition": disposition,
|
||||
"Content-Type": content_type,
|
||||
"X-Content-Type-Options": "nosniff"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import contextlib
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
@ -9,6 +11,40 @@ import requests
|
||||
from helpers import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
def test_download_svg_forced_to_attachment(http: requests.Session, api_base: str):
|
||||
"""GHSA-779p-m5rp-r4h4 CISA-5 (sibling route): an uploaded SVG must never be
|
||||
served inline from GET /api/assets/{id}/content, or an inline <script> runs
|
||||
in the app origin (stored XSS). Even with disposition=inline requested, a
|
||||
dangerous content type must be forced to application/octet-stream +
|
||||
Content-Disposition: attachment + nosniff. Regression guard for the stale
|
||||
inline blocklist that previously omitted image/svg+xml and ignored the
|
||||
centralized folder_paths.is_dangerous_content_type check.
|
||||
"""
|
||||
svg = b'<svg xmlns="http://www.w3.org/2000/svg"><script>alert(1)</script></svg>'
|
||||
files = {"file": ("evil.svg", svg, "image/svg+xml")}
|
||||
form_data = {
|
||||
"tags": json.dumps(["models", "checkpoints", "unit-tests", "svgxss"]),
|
||||
"name": "evil.svg",
|
||||
}
|
||||
up = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
|
||||
body = up.json()
|
||||
assert up.status_code in (200, 201), body
|
||||
aid = body["id"]
|
||||
try:
|
||||
r = http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline", timeout=120)
|
||||
r.content
|
||||
assert r.status_code == 200
|
||||
ct = r.headers.get("Content-Type", "").lower()
|
||||
cd = r.headers.get("Content-Disposition", "").lower()
|
||||
assert "svg" not in ct, f"SVG served with a renderable content type: {ct!r}"
|
||||
assert ct.startswith("application/octet-stream"), f"expected octet-stream, got {ct!r}"
|
||||
assert "attachment" in cd, f"inline disposition not overridden to attachment: {cd!r}"
|
||||
assert r.headers.get("X-Content-Type-Options", "").lower() == "nosniff"
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
|
||||
|
||||
|
||||
def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
|
||||
@ -53,8 +53,11 @@ def test_annotated_filepath():
|
||||
|
||||
def test_get_annotated_filepath():
|
||||
default_dir = "/default/dir"
|
||||
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
|
||||
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
|
||||
# get_annotated_filepath now normalizes with os.path.abspath (part of the
|
||||
# GHSA-779p traversal hardening), so compare against the normalized form —
|
||||
# on Windows abspath also prepends the current drive letter.
|
||||
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.abspath(os.path.join(default_dir, "test.txt"))
|
||||
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.abspath(os.path.join(folder_paths.get_output_directory(), "test.txt"))
|
||||
|
||||
def test_add_model_folder_path_append(clear_folder_paths):
|
||||
folder_paths.add_model_folder_path("test_folder", "/default/path", is_default=True)
|
||||
|
||||
@ -1,361 +0,0 @@
|
||||
"""Unit tests for the expected_outputs feature.
|
||||
|
||||
This feature allows nodes to know at runtime which outputs are connected downstream,
|
||||
enabling them to skip computing outputs that aren't needed.
|
||||
"""
|
||||
|
||||
from comfy_api.latest import IO
|
||||
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
|
||||
from comfy_execution.utils import (
|
||||
CurrentNodeContext,
|
||||
ExecutionContext,
|
||||
get_executing_context,
|
||||
is_output_needed,
|
||||
)
|
||||
|
||||
|
||||
class TestGetExpectedOutputsForNode:
|
||||
"""Tests for get_expected_outputs_for_node() function."""
|
||||
|
||||
def test_single_output_connected(self):
|
||||
"""Test node with single output connected to one downstream node."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
def test_multiple_outputs_partial_connected(self):
|
||||
"""Test node with multiple outputs, only some connected."""
|
||||
prompt = {
|
||||
"1": {"class_type": "MultiOutputNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
|
||||
# Output 1 is not connected
|
||||
"3": {"class_type": "ConsumerC", "inputs": {"input": ["1", 2]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0, 2})
|
||||
assert 1 not in expected # Output 1 is definitely unused
|
||||
|
||||
def test_no_outputs_connected(self):
|
||||
"""Test node with no outputs connected."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "OtherNode", "inputs": {}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset()
|
||||
|
||||
def test_same_output_connected_multiple_times(self):
|
||||
"""Test same output connected to multiple downstream nodes."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
|
||||
"3": {"class_type": "ConsumerB", "inputs": {"input": ["1", 0]}},
|
||||
"4": {"class_type": "ConsumerC", "inputs": {"input": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
def test_node_not_in_prompt(self):
|
||||
"""Test getting expected outputs for a node not in the prompt."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "999")
|
||||
assert expected == frozenset()
|
||||
|
||||
def test_chained_nodes(self):
|
||||
"""Test expected outputs in a chain of nodes."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "MiddleNode", "inputs": {"input": ["1", 0]}},
|
||||
"3": {"class_type": "EndNode", "inputs": {"input": ["2", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
|
||||
# Node 1's output 0 is connected to node 2
|
||||
expected_1 = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected_1 == frozenset({0})
|
||||
|
||||
# Node 2's output 0 is connected to node 3
|
||||
expected_2 = get_expected_outputs_for_node(dynprompt, "2")
|
||||
assert expected_2 == frozenset({0})
|
||||
|
||||
# Node 3 has no downstream connections
|
||||
expected_3 = get_expected_outputs_for_node(dynprompt, "3")
|
||||
assert expected_3 == frozenset()
|
||||
|
||||
def test_complex_graph(self):
|
||||
"""Test expected outputs in a complex graph with multiple connections."""
|
||||
prompt = {
|
||||
"1": {"class_type": "MultiOutputNode", "inputs": {}},
|
||||
"2": {"class_type": "ProcessorA", "inputs": {"image": ["1", 0], "mask": ["1", 1]}},
|
||||
"3": {"class_type": "ProcessorB", "inputs": {"data": ["1", 2]}},
|
||||
"4": {"class_type": "Combiner", "inputs": {"a": ["2", 0], "b": ["3", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
|
||||
# Node 1 has outputs 0, 1, 2 all connected
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0, 1, 2})
|
||||
|
||||
def test_constant_inputs_ignored(self):
|
||||
"""Test that constant (non-link) inputs don't affect expected outputs."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {
|
||||
"class_type": "ConsumerNode",
|
||||
"inputs": {
|
||||
"image": ["1", 0],
|
||||
"value": 42,
|
||||
"name": "test",
|
||||
},
|
||||
},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
def test_ephemeral_node_invalidates_cache(self):
|
||||
"""Test that adding ephemeral nodes updates expected outputs."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
|
||||
# Initially only output 0 is connected
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
# Add an ephemeral node that connects to output 1
|
||||
dynprompt.add_ephemeral_node(
|
||||
"eph_1",
|
||||
{"class_type": "EphemeralNode", "inputs": {"data": ["1", 1]}},
|
||||
parent_id="2",
|
||||
display_id="2",
|
||||
)
|
||||
|
||||
# Now both outputs 0 and 1 should be expected
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0, 1})
|
||||
|
||||
|
||||
class TestExternalOutputConsumers:
|
||||
"""Tests for DynamicPrompt.add_output_consumer() — out-of-band consumers
|
||||
(subgraph expansion output mappings) that have no input link in the prompt."""
|
||||
|
||||
def test_external_consumer_only(self):
|
||||
"""A socket consumed only externally must appear in expected outputs."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset()
|
||||
|
||||
dynprompt.add_output_consumer("1", 1)
|
||||
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({1})
|
||||
|
||||
def test_external_consumer_merges_with_links(self):
|
||||
"""External consumers merge with input-link consumers."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
dynprompt.add_output_consumer("1", 2)
|
||||
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({0, 2})
|
||||
|
||||
def test_external_consumer_invalidates_cached_map(self):
|
||||
"""Registering after the map was built must invalidate the cache."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
# Build (and cache) the map first
|
||||
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({0})
|
||||
|
||||
dynprompt.add_output_consumer("1", 1)
|
||||
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({0, 1})
|
||||
|
||||
|
||||
class TestExecutionContext:
|
||||
"""Tests for ExecutionContext with expected_outputs field."""
|
||||
|
||||
def test_context_with_expected_outputs(self):
|
||||
"""Test creating ExecutionContext with expected_outputs."""
|
||||
ctx = ExecutionContext(
|
||||
prompt_id="prompt-123", node_id="node-456", list_index=0, expected_outputs=frozenset({0, 2})
|
||||
)
|
||||
assert ctx.prompt_id == "prompt-123"
|
||||
assert ctx.node_id == "node-456"
|
||||
assert ctx.list_index == 0
|
||||
assert ctx.expected_outputs == frozenset({0, 2})
|
||||
|
||||
def test_context_without_expected_outputs(self):
|
||||
"""Test ExecutionContext defaults to None for expected_outputs."""
|
||||
ctx = ExecutionContext(prompt_id="prompt-123", node_id="node-456", list_index=0)
|
||||
assert ctx.expected_outputs is None
|
||||
|
||||
def test_context_empty_expected_outputs(self):
|
||||
"""Test ExecutionContext with empty expected_outputs set."""
|
||||
ctx = ExecutionContext(
|
||||
prompt_id="prompt-123", node_id="node-456", list_index=None, expected_outputs=frozenset()
|
||||
)
|
||||
assert ctx.expected_outputs == frozenset()
|
||||
assert len(ctx.expected_outputs) == 0
|
||||
|
||||
|
||||
class TestCurrentNodeContext:
|
||||
"""Tests for CurrentNodeContext context manager with expected_outputs."""
|
||||
|
||||
def test_context_manager_with_expected_outputs(self):
|
||||
"""Test CurrentNodeContext sets and resets context correctly."""
|
||||
assert get_executing_context() is None
|
||||
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 1})):
|
||||
ctx = get_executing_context()
|
||||
assert ctx is not None
|
||||
assert ctx.prompt_id == "prompt-1"
|
||||
assert ctx.node_id == "node-1"
|
||||
assert ctx.list_index == 0
|
||||
assert ctx.expected_outputs == frozenset({0, 1})
|
||||
|
||||
assert get_executing_context() is None
|
||||
|
||||
def test_context_manager_without_expected_outputs(self):
|
||||
"""Test CurrentNodeContext works without expected_outputs (backwards compatible)."""
|
||||
with CurrentNodeContext("prompt-1", "node-1"):
|
||||
ctx = get_executing_context()
|
||||
assert ctx is not None
|
||||
assert ctx.expected_outputs is None
|
||||
|
||||
def test_nested_context_managers(self):
|
||||
"""Test nested CurrentNodeContext managers."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0})):
|
||||
ctx1 = get_executing_context()
|
||||
assert ctx1.expected_outputs == frozenset({0})
|
||||
|
||||
with CurrentNodeContext("prompt-1", "node-2", 0, frozenset({1, 2})):
|
||||
ctx2 = get_executing_context()
|
||||
assert ctx2.expected_outputs == frozenset({1, 2})
|
||||
assert ctx2.node_id == "node-2"
|
||||
|
||||
# After inner context exits, should be back to outer context
|
||||
ctx1_again = get_executing_context()
|
||||
assert ctx1_again.expected_outputs == frozenset({0})
|
||||
assert ctx1_again.node_id == "node-1"
|
||||
|
||||
def test_output_check_pattern(self):
|
||||
"""Test the typical pattern nodes will use to check expected outputs."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
|
||||
ctx = get_executing_context()
|
||||
|
||||
# Typical usage pattern
|
||||
if ctx and ctx.expected_outputs is not None:
|
||||
should_compute_0 = 0 in ctx.expected_outputs
|
||||
should_compute_1 = 1 in ctx.expected_outputs
|
||||
should_compute_2 = 2 in ctx.expected_outputs
|
||||
else:
|
||||
# Fallback when info not available
|
||||
should_compute_0 = should_compute_1 = should_compute_2 = True
|
||||
|
||||
assert should_compute_0 is True
|
||||
assert should_compute_1 is False # Not in expected_outputs
|
||||
assert should_compute_2 is True
|
||||
|
||||
|
||||
class TestSchemaLazyOutputs:
|
||||
"""Tests for lazy_outputs in V3 Schema."""
|
||||
|
||||
def test_schema_lazy_outputs_default(self):
|
||||
"""Test that lazy_outputs defaults to False."""
|
||||
schema = IO.Schema(
|
||||
node_id="TestNode",
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
assert schema.lazy_outputs is False
|
||||
|
||||
def test_schema_lazy_outputs_true(self):
|
||||
"""Test setting lazy_outputs to True."""
|
||||
schema = IO.Schema(
|
||||
node_id="TestNode",
|
||||
lazy_outputs=True,
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
assert schema.lazy_outputs is True
|
||||
|
||||
def test_v3_node_lazy_outputs_property(self):
|
||||
"""Test that LAZY_OUTPUTS property works on V3 nodes."""
|
||||
|
||||
class TestNodeWithLazyOutputs(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TestNodeWithLazyOutputs",
|
||||
lazy_outputs=True,
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls):
|
||||
return IO.NodeOutput(1.0)
|
||||
|
||||
assert TestNodeWithLazyOutputs.LAZY_OUTPUTS is True
|
||||
|
||||
def test_v3_node_lazy_outputs_default(self):
|
||||
"""Test that LAZY_OUTPUTS defaults to False on V3 nodes."""
|
||||
|
||||
class TestNodeWithoutLazyOutputs(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TestNodeWithoutLazyOutputs",
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls):
|
||||
return IO.NodeOutput(1.0)
|
||||
|
||||
assert TestNodeWithoutLazyOutputs.LAZY_OUTPUTS is False
|
||||
|
||||
|
||||
class TestIsOutputNeeded:
|
||||
"""Tests for is_output_needed() helper function."""
|
||||
|
||||
def test_output_needed_when_in_expected(self):
|
||||
"""Test that output is needed when in expected_outputs."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
|
||||
assert is_output_needed(0) is True
|
||||
assert is_output_needed(2) is True
|
||||
|
||||
def test_output_not_needed_when_not_in_expected(self):
|
||||
"""Test that output is not needed when not in expected_outputs."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
|
||||
assert is_output_needed(1) is False
|
||||
assert is_output_needed(3) is False
|
||||
|
||||
def test_output_needed_when_no_context(self):
|
||||
"""Test that output is needed when no context."""
|
||||
assert get_executing_context() is None
|
||||
assert is_output_needed(0) is True
|
||||
assert is_output_needed(1) is True
|
||||
|
||||
def test_output_needed_when_expected_outputs_is_none(self):
|
||||
"""Test that output is needed when expected_outputs is None."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, None):
|
||||
assert is_output_needed(0) is True
|
||||
assert is_output_needed(1) is True
|
||||
0
tests-unit/security_test/__init__.py
Normal file
0
tests-unit/security_test/__init__.py
Normal file
192
tests-unit/security_test/test_ghsa_779p_02_preview_traversal.py
Normal file
192
tests-unit/security_test/test_ghsa_779p_02_preview_traversal.py
Normal file
@ -0,0 +1,192 @@
|
||||
"""CI unit tests for FIX #2 of GHSA-779p-m5rp-r4h4.
|
||||
|
||||
Path traversal / hardening in app/model_manager.py get_model_preview
|
||||
(route /experiment/models/preview/{folder}/{path_index}/{filename:.*}).
|
||||
|
||||
Reference: https://github.com/Comfy-Org/ComfyUI/security/advisories/GHSA-779p-m5rp-r4h4
|
||||
"""
|
||||
import pytest
|
||||
import yarl
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from aiohttp import web
|
||||
from unittest.mock import patch
|
||||
from app.model_manager import ModelFileManager
|
||||
|
||||
pytestmark = (
|
||||
pytest.mark.asyncio
|
||||
) # This applies the asyncio mark to all test functions in the module
|
||||
|
||||
@pytest.fixture
|
||||
def model_manager():
|
||||
return ModelFileManager()
|
||||
|
||||
@pytest.fixture
|
||||
def app(model_manager):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
model_manager.add_routes(routes)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
|
||||
async def test_legit_preview_returns_200(aiohttp_client, app, tmp_path):
|
||||
"""Sanity: a real preview PNG inside the model folder is served as webp 200."""
|
||||
img = Image.new('RGB', (16, 16), color=(255, 0, 128))
|
||||
img.save(tmp_path / "test_model.png", format='PNG')
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/0/test_model.png')
|
||||
|
||||
assert response.status == 200
|
||||
assert response.content_type == 'image/webp'
|
||||
|
||||
img_bytes = BytesIO(await response.read())
|
||||
served = Image.open(img_bytes)
|
||||
assert served.format
|
||||
assert served.format.lower() == 'webp'
|
||||
served.close()
|
||||
|
||||
|
||||
async def test_non_integer_path_index_returns_400(aiohttp_client, app, tmp_path):
|
||||
"""A non-integer path_index segment must be rejected with 400."""
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/abc/test_model.png')
|
||||
|
||||
assert response.status == 400
|
||||
|
||||
|
||||
async def test_out_of_range_path_index_returns_404(aiohttp_client, app, tmp_path):
|
||||
"""A path_index beyond the configured folder list must return 404."""
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/99/test_model.png')
|
||||
|
||||
assert response.status == 404
|
||||
|
||||
|
||||
async def test_empty_filename_returns_400(aiohttp_client, app, tmp_path):
|
||||
"""The "{filename:.*}" capture also matches the empty string (trailing
|
||||
slash). It would resolve to the folder itself and must be rejected with 400."""
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/0/')
|
||||
|
||||
assert response.status == 400
|
||||
|
||||
|
||||
async def test_path_traversal_in_filename_returns_403(aiohttp_client, app, tmp_path):
|
||||
"""Path traversal in {filename} must be rejected with 403 and must NOT read
|
||||
a file outside the configured model directory.
|
||||
|
||||
GOTCHA: aiohttp/yarl collapses literal ``../`` dot-segments out of the URL
|
||||
path before it reaches the handler, which would make this test vacuously
|
||||
pass (the request would hit a different/non-existent route). We percent-encode
|
||||
the dots and slashes (``%2e%2e%2f``) and send the URL with
|
||||
``yarl.URL(..., encoded=True)`` so the bytes survive client-side normalization
|
||||
untouched; aiohttp's router then percent-decodes them into ``match_info``,
|
||||
delivering the literal ``../`` traversal to the handler's ``{filename:.*}``
|
||||
capture.
|
||||
|
||||
Without the fix the handler computes
|
||||
``os.path.normpath(os.path.join(folder, "../../../../etc/hosts"))``, which
|
||||
escapes ``tmp_path`` and would be passed straight to get_model_previews ->
|
||||
Image.open, serving bytes from outside the model dir (200/served bytes). The
|
||||
is_within_directory() containment check is the load-bearing fix that turns
|
||||
that escape into a 403.
|
||||
"""
|
||||
# Sanity-anchor: a legit preview exists inside tmp_path, so a 200 path is
|
||||
# genuinely reachable — proving the 403 below is the containment check
|
||||
# firing, not an unrelated 404.
|
||||
img = Image.new('RGB', (16, 16), color=(255, 0, 128))
|
||||
img.save(tmp_path / "test_model.png", format='PNG')
|
||||
|
||||
# Percent-encoded "../../../../etc/hosts" so yarl does not collapse the
|
||||
# dot-segments before the request leaves the client.
|
||||
encoded_traversal = '%2e%2e%2f' * 4 + 'etc%2fhosts'
|
||||
raw_path = '/experiment/models/preview/test_folder/0/' + encoded_traversal
|
||||
url = yarl.URL(raw_path, encoded=True)
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get(url)
|
||||
|
||||
# Confirm the traversal actually reached the handler intact: a 200 here
|
||||
# would mean either normalization stripped the ``../`` (vacuous pass) or
|
||||
# the containment check failed open and served outside-dir bytes.
|
||||
assert response.status == 403, (
|
||||
f"expected 403 from is_within_directory() containment check, "
|
||||
f"got {response.status}; traversal may have been normalized away "
|
||||
f"or the fix failed open"
|
||||
)
|
||||
body = await response.read()
|
||||
assert body == b"", "403 response must not carry any file bytes"
|
||||
|
||||
|
||||
async def test_symlink_companion_preview_returns_403(aiohttp_client, app, tmp_path):
|
||||
"""A companion preview file is selected by a glob inside get_model_previews
|
||||
and then opened. If that companion is a symlink whose path is in-dir but
|
||||
whose target escapes the model folder, it must be rejected with 403 — not
|
||||
served. The requested path itself stays in-dir (so the first containment
|
||||
check passes); the load-bearing fix is the SECOND is_within_directory check
|
||||
on the file actually opened.
|
||||
"""
|
||||
model_dir = tmp_path / "models"
|
||||
model_dir.mkdir()
|
||||
secret_dir = tmp_path / "secret"
|
||||
secret_dir.mkdir()
|
||||
# A real image OUTSIDE the model dir — valid, so without the fix Image.open
|
||||
# would succeed and its bytes would be served (200).
|
||||
secret = secret_dir / "secret.png"
|
||||
Image.new('RGB', (8, 8), color=(0, 0, 0)).save(secret, format='PNG')
|
||||
# Companion preview, in-dir by name but a symlink escaping the model dir.
|
||||
# (No real model file is needed — get_model_previews globs companions by
|
||||
# basename, and omitting a .safetensors avoids the metadata-header read.)
|
||||
companion = model_dir / "model.preview.png"
|
||||
try:
|
||||
companion.symlink_to(secret)
|
||||
except (OSError, NotImplementedError):
|
||||
pytest.skip("symlinks not supported on this platform/filesystem")
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(model_dir)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/0/model.safetensors')
|
||||
|
||||
assert response.status == 403, (
|
||||
f"expected 403 — the globbed companion preview is a symlink resolving "
|
||||
f"outside the model dir and must not be served; got {response.status}"
|
||||
)
|
||||
assert await response.read() == b""
|
||||
|
||||
|
||||
async def test_null_byte_in_filename_no_500(aiohttp_client, app, tmp_path):
|
||||
"""A NUL byte in the filename must yield a clean client rejection, not a 500
|
||||
from an uncaught ValueError in is_within_directory's realpath() call."""
|
||||
raw_path = '/experiment/models/preview/test_folder/0/' + 'a%00b'
|
||||
url = yarl.URL(raw_path, encoded=True)
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get(url)
|
||||
|
||||
assert response.status != 500, (
|
||||
f"NUL byte produced a 500 (uncaught ValueError); expected a clean "
|
||||
f"4xx rejection, got {response.status}"
|
||||
)
|
||||
assert 400 <= response.status < 500
|
||||
@ -0,0 +1,165 @@
|
||||
"""Security tests for GHSA-779p-m5rp-r4h4 — FIX #3.
|
||||
|
||||
Path traversal in folder_paths.get_annotated_filepath / exists_annotated_filepath,
|
||||
plus the shared is_within_directory() containment helper.
|
||||
|
||||
These are pure-function tests (no running server). The input/output/temp
|
||||
directories are pointed at tmp_path via the folder_paths setters, so a crafted
|
||||
name containing `../`, an absolute path, or a symlink that escapes the base
|
||||
directory must be rejected.
|
||||
|
||||
Reference: https://github.com/Comfy-Org/ComfyUI/security/advisories/GHSA-779p-m5rp-r4h4
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
import folder_paths
|
||||
from comfy.options import enable_args_parsing
|
||||
enable_args_parsing()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sandbox(tmp_path):
|
||||
"""Point folder_paths' input/output/temp dirs at a real temp sandbox.
|
||||
|
||||
Yields the realpath'd base, input, output and temp directories. The original
|
||||
directory values are restored afterward so tests stay isolated.
|
||||
"""
|
||||
base = os.path.realpath(str(tmp_path))
|
||||
input_dir = os.path.join(base, "input")
|
||||
output_dir = os.path.join(base, "output")
|
||||
temp_dir = os.path.join(base, "temp")
|
||||
for d in (input_dir, output_dir, temp_dir):
|
||||
os.makedirs(d, exist_ok=True)
|
||||
|
||||
orig_input = folder_paths.get_input_directory()
|
||||
orig_output = folder_paths.get_output_directory()
|
||||
orig_temp = folder_paths.get_temp_directory()
|
||||
|
||||
folder_paths.set_input_directory(input_dir)
|
||||
folder_paths.set_output_directory(output_dir)
|
||||
folder_paths.set_temp_directory(temp_dir)
|
||||
|
||||
yield {
|
||||
"base": base,
|
||||
"input": input_dir,
|
||||
"output": output_dir,
|
||||
"temp": temp_dir,
|
||||
}
|
||||
|
||||
folder_paths.set_input_directory(orig_input)
|
||||
folder_paths.set_output_directory(orig_output)
|
||||
folder_paths.set_temp_directory(orig_temp)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_within_directory() — the shared containment helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_is_within_directory_legit_child(sandbox):
|
||||
base = sandbox["input"]
|
||||
child = os.path.join(base, "sub", "image.png")
|
||||
assert folder_paths.is_within_directory(base, child) is True
|
||||
|
||||
|
||||
def test_is_within_directory_dotdot_escape(sandbox):
|
||||
base = sandbox["input"]
|
||||
escape = os.path.join(base, "..", "..", "etc", "passwd")
|
||||
assert folder_paths.is_within_directory(base, escape) is False
|
||||
|
||||
|
||||
def test_is_within_directory_symlink_escape(sandbox):
|
||||
"""A symlink created INSIDE base that points OUTSIDE base must not pass.
|
||||
|
||||
This is the key new hardening: is_within_directory realpath()s both operands,
|
||||
so a symlink planted in the base directory can't be used to read files
|
||||
elsewhere. We create a real on-disk symlink and a real secret target to
|
||||
verify the check actually resolves the link.
|
||||
"""
|
||||
base = sandbox["input"]
|
||||
|
||||
# A directory living outside the base, holding a secret file.
|
||||
outside = os.path.join(sandbox["base"], "outside_secret_dir")
|
||||
os.makedirs(outside, exist_ok=True)
|
||||
secret = os.path.join(outside, "secret.txt")
|
||||
with open(secret, "w") as f:
|
||||
f.write("top secret")
|
||||
|
||||
# Plant a symlink inside base that points at the outside directory.
|
||||
# symlink creation can require elevated privileges / Developer Mode on
|
||||
# Windows, so skip cleanly where it isn't available (same guard as the
|
||||
# sibling test in test_ghsa_779p_02_preview_traversal.py).
|
||||
link = os.path.join(base, "escape_link")
|
||||
try:
|
||||
os.symlink(outside, link)
|
||||
except (OSError, NotImplementedError):
|
||||
pytest.skip("symlinks not supported on this platform/filesystem")
|
||||
|
||||
# Accessing the secret "through" the in-base symlink must be rejected.
|
||||
target_via_link = os.path.join(link, "secret.txt")
|
||||
assert folder_paths.is_within_directory(base, target_via_link) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_annotated_filepath()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_annotated_filepath_legit_name(sandbox):
|
||||
result = folder_paths.get_annotated_filepath("image.png")
|
||||
assert result == os.path.join(sandbox["input"], "image.png")
|
||||
assert folder_paths.is_within_directory(sandbox["input"], result)
|
||||
|
||||
|
||||
def test_get_annotated_filepath_input_annotation(sandbox):
|
||||
result = folder_paths.get_annotated_filepath("image.png [input]")
|
||||
assert result == os.path.join(sandbox["input"], "image.png")
|
||||
|
||||
|
||||
def test_get_annotated_filepath_output_annotation(sandbox):
|
||||
result = folder_paths.get_annotated_filepath("image.png [output]")
|
||||
assert result == os.path.join(sandbox["output"], "image.png")
|
||||
|
||||
|
||||
def test_get_annotated_filepath_temp_annotation(sandbox):
|
||||
result = folder_paths.get_annotated_filepath("image.png [temp]")
|
||||
assert result == os.path.join(sandbox["temp"], "image.png")
|
||||
|
||||
|
||||
def test_get_annotated_filepath_dotdot_raises(sandbox):
|
||||
with pytest.raises(ValueError):
|
||||
folder_paths.get_annotated_filepath("../etc/passwd")
|
||||
|
||||
|
||||
def test_get_annotated_filepath_dotdot_with_annotation_raises(sandbox):
|
||||
with pytest.raises(ValueError):
|
||||
folder_paths.get_annotated_filepath("../../etc/passwd [output]")
|
||||
|
||||
|
||||
def test_get_annotated_filepath_absolute_escape_raises(sandbox):
|
||||
with pytest.raises(ValueError):
|
||||
folder_paths.get_annotated_filepath("/etc/passwd")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# exists_annotated_filepath()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_exists_annotated_filepath_existing_legit_file(sandbox):
|
||||
real = os.path.join(sandbox["input"], "real.png")
|
||||
with open(real, "w") as f:
|
||||
f.write("data")
|
||||
assert folder_paths.exists_annotated_filepath("real.png") is True
|
||||
|
||||
|
||||
def test_exists_annotated_filepath_traversal_returns_false(sandbox):
|
||||
"""A traversal name must return False without raising and without probing
|
||||
outside the base directory (must never reach os.path.exists for the escape).
|
||||
"""
|
||||
# /etc/passwd exists on POSIX; the function must still report False because
|
||||
# the resolved path escapes the input directory.
|
||||
assert folder_paths.exists_annotated_filepath("../../../../../../etc/passwd") is False
|
||||
|
||||
|
||||
def test_exists_annotated_filepath_absolute_returns_false(sandbox):
|
||||
assert folder_paths.exists_annotated_filepath("/etc/passwd") is False
|
||||
147
tests-unit/security_test/test_ghsa_779p_04_userdata_xss.py
Normal file
147
tests-unit/security_test/test_ghsa_779p_04_userdata_xss.py
Normal file
@ -0,0 +1,147 @@
|
||||
"""
|
||||
CI unit tests for FIX #4 of GHSA-779p-m5rp-r4h4.
|
||||
|
||||
Stored-XSS hardening on GET /userdata/{file} in app/user_manager.py.
|
||||
|
||||
User data files are arbitrary user-supplied content and must never render
|
||||
inline in the app origin. The getuserdata handler:
|
||||
- forces Content-Type to application/octet-stream for any type in
|
||||
folder_paths.DANGEROUS_CONTENT_TYPES (text/html, image/svg+xml,
|
||||
text/javascript, ...),
|
||||
- sets X-Content-Type-Options: nosniff,
|
||||
- sets Content-Disposition: attachment.
|
||||
|
||||
These tests pre-create files in tmp_path and GET them back, asserting the
|
||||
secure response headers. They mirror the aiohttp_client pattern in
|
||||
tests-unit/prompt_server_test/user_manager_test.py.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
from aiohttp import web
|
||||
from app.user_manager import UserManager
|
||||
|
||||
pytestmark = (
|
||||
pytest.mark.asyncio
|
||||
) # This applies the asyncio mark to all test functions in the module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_manager(tmp_path):
|
||||
um = UserManager()
|
||||
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
|
||||
tmp_path, file
|
||||
) if file else tmp_path
|
||||
return um
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(user_manager):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
user_manager.add_routes(routes)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
|
||||
async def test_html_served_as_octet_stream(aiohttp_client, app, tmp_path):
|
||||
(tmp_path / "evil.html").write_text(
|
||||
"<script>console.log('xss-marker-ghsa-779p')</script>"
|
||||
)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata/evil.html")
|
||||
|
||||
assert resp.status == 200
|
||||
ct = resp.headers.get("Content-Type", "")
|
||||
# The load-bearing assertion: a .html file must NOT be served as text/html.
|
||||
assert "text/html" not in ct.lower(), (
|
||||
f"Content-Type {ct!r} would let a browser render/execute the file (stored XSS)."
|
||||
)
|
||||
assert ct == "application/octet-stream"
|
||||
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
|
||||
assert "attachment" in resp.headers.get("Content-Disposition", "")
|
||||
|
||||
|
||||
async def test_svg_served_as_octet_stream(aiohttp_client, app, tmp_path):
|
||||
(tmp_path / "evil.svg").write_text(
|
||||
'<?xml version="1.0"?>'
|
||||
'<svg xmlns="http://www.w3.org/2000/svg">'
|
||||
'<script>console.log("xss-marker-ghsa-779p")</script>'
|
||||
"</svg>"
|
||||
)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata/evil.svg")
|
||||
|
||||
assert resp.status == 200
|
||||
ct = resp.headers.get("Content-Type", "")
|
||||
# SVG can carry inline <script>; it must not be served as image/svg+xml.
|
||||
assert "svg" not in ct.lower(), (
|
||||
f"Content-Type {ct!r} would let a browser render the SVG and execute embedded scripts."
|
||||
)
|
||||
assert ct == "application/octet-stream"
|
||||
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
|
||||
assert "attachment" in resp.headers.get("Content-Disposition", "")
|
||||
|
||||
|
||||
async def test_js_served_as_octet_stream(aiohttp_client, app, tmp_path):
|
||||
(tmp_path / "evil.js").write_text("alert('xss-marker-ghsa-779p')")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata/evil.js")
|
||||
|
||||
assert resp.status == 200
|
||||
ct = resp.headers.get("Content-Type", "").lower()
|
||||
# Must not be served as any executable JavaScript content type.
|
||||
assert "javascript" not in ct, (
|
||||
f"Content-Type {ct!r} is an executable JS type."
|
||||
)
|
||||
assert "ecmascript" not in ct, (
|
||||
f"Content-Type {ct!r} is an executable JS type."
|
||||
)
|
||||
assert ct == "application/octet-stream"
|
||||
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
|
||||
assert "attachment" in resp.headers.get("Content-Disposition", "")
|
||||
|
||||
|
||||
async def test_xml_dialect_served_as_octet_stream(aiohttp_client, app, tmp_path):
|
||||
"""An XML dialect outside the original blocklist (.xslt -> application/xslt+xml)
|
||||
must still be forced to download. This pins the normalised *+xml family rule
|
||||
in folder_paths.is_dangerous_content_type(); a plain set-membership test would
|
||||
have served this inline."""
|
||||
(tmp_path / "evil.xslt").write_text(
|
||||
'<?xml version="1.0"?>'
|
||||
'<xsl:stylesheet version="1.0" '
|
||||
'xmlns:xsl="http://www.w3.org/1999/XSL/Transform">'
|
||||
"<!-- xss-marker-ghsa-779p -->"
|
||||
"</xsl:stylesheet>"
|
||||
)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata/evil.xslt")
|
||||
|
||||
assert resp.status == 200
|
||||
ct = resp.headers.get("Content-Type", "")
|
||||
assert ct == "application/octet-stream", (
|
||||
f"Content-Type {ct!r}: an *+xml dialect must be forced to octet-stream "
|
||||
f"(it can carry inline script via stylesheet/entity tricks)."
|
||||
)
|
||||
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
|
||||
assert "attachment" in resp.headers.get("Content-Disposition", "")
|
||||
|
||||
|
||||
async def test_benign_txt_still_served(aiohttp_client, app, tmp_path):
|
||||
(tmp_path / "note.txt").write_text("just a harmless note")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata/note.txt")
|
||||
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == "just a harmless note"
|
||||
ct = resp.headers.get("Content-Type", "")
|
||||
# text/plain is not in the dangerous set, so it is acceptable here. The
|
||||
# defence-in-depth headers must still be present regardless.
|
||||
assert "text/plain" in ct.lower()
|
||||
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
|
||||
assert "attachment" in resp.headers.get("Content-Disposition", "")
|
||||
@ -0,0 +1,138 @@
|
||||
"""CI unit guard for FIX #5 of GHSA-779p-m5rp-r4h4 — the /view forced-download set.
|
||||
|
||||
Vuln #5 was stored XSS via SVG upload: the /view endpoint's Content-Type
|
||||
blocklist covered text/html, text/javascript, etc. but was missing
|
||||
image/svg+xml, so an uploaded SVG carrying an inline <script> was served as
|
||||
image/svg+xml and executed in the page origin when rendered.
|
||||
|
||||
The /view forced-download decision lives in the view_image closure registered by
|
||||
server.PromptServer.add_routes (server.py ~line 596), which calls
|
||||
`folder_paths.is_dangerous_content_type(content_type)` — a normalising check that
|
||||
strips charset/boundary parameters and casing and folds in the whole */xml and
|
||||
*+xml dialect family — rather than a bypassable raw
|
||||
`content_type in folder_paths.DANGEROUS_CONTENT_TYPES` membership test. On a match
|
||||
it rewrites the response to application/octet-stream with a
|
||||
Content-Disposition: attachment header. server.py cannot be imported in a unit
|
||||
test (importing it spins up the full PromptServer/aiohttp app and its global side
|
||||
effects), so these tests pin the underlying dangerous-content data
|
||||
(folder_paths.DANGEROUS_CONTENT_TYPES) and the normalising is_dangerous_content_type()
|
||||
helper that the closure actually calls.
|
||||
|
||||
The end-to-end /view assertion (upload an SVG, GET /view, confirm the response
|
||||
is not served as image/svg+xml) lives in the live POC at
|
||||
.security/pocs/test_security_ghsa_779p.py::TestViewSvgContentType, which
|
||||
requires a running server. This file is the fast, server-free CI guard on the
|
||||
set contents so the blocklist can't silently regress.
|
||||
"""
|
||||
|
||||
import folder_paths
|
||||
|
||||
|
||||
# Active/renderable content types that must be forced to download. Each of these
|
||||
# can carry an inline <script> (or otherwise execute) in the page origin if a
|
||||
# browser renders it. image/svg+xml is the original missing item that caused
|
||||
# vuln #5.
|
||||
DANGEROUS = [
|
||||
'image/svg+xml',
|
||||
'application/xml',
|
||||
'text/xml',
|
||||
'text/html',
|
||||
'text/html-sandboxed',
|
||||
'application/xhtml+xml',
|
||||
'text/javascript',
|
||||
'application/javascript',
|
||||
'application/x-javascript',
|
||||
'application/ecmascript',
|
||||
'text/css',
|
||||
]
|
||||
|
||||
# Benign image types that browsers display inline and that must keep rendering;
|
||||
# forcing these to download would break legitimate previews.
|
||||
BENIGN_INLINE_IMAGES = [
|
||||
'image/png',
|
||||
'image/jpeg',
|
||||
'image/webp',
|
||||
'image/gif',
|
||||
]
|
||||
|
||||
|
||||
def test_dangerous_content_types_is_a_set():
|
||||
assert isinstance(folder_paths.DANGEROUS_CONTENT_TYPES, set)
|
||||
|
||||
|
||||
def test_svg_is_in_the_blocklist():
|
||||
"""The specific item whose absence caused vuln #5."""
|
||||
assert 'image/svg+xml' in folder_paths.DANGEROUS_CONTENT_TYPES, (
|
||||
"image/svg+xml missing from DANGEROUS_CONTENT_TYPES — this is exactly "
|
||||
"the regression that reopens GHSA-779p-m5rp-r4h4 vuln #5 (stored XSS "
|
||||
"via SVG upload on /view)."
|
||||
)
|
||||
|
||||
|
||||
def test_all_dangerous_types_present():
|
||||
missing = [ct for ct in DANGEROUS if ct not in folder_paths.DANGEROUS_CONTENT_TYPES]
|
||||
assert not missing, (
|
||||
f"DANGEROUS_CONTENT_TYPES is missing required active/renderable types: "
|
||||
f"{missing}. The /view closure only forces a download for content types "
|
||||
f"in this set; anything missing here is served inline and can execute."
|
||||
)
|
||||
|
||||
|
||||
def test_benign_inline_image_types_absent():
|
||||
leaked = [ct for ct in BENIGN_INLINE_IMAGES if ct in folder_paths.DANGEROUS_CONTENT_TYPES]
|
||||
assert not leaked, (
|
||||
f"Benign inline-displayable image types found in DANGEROUS_CONTENT_TYPES: "
|
||||
f"{leaked}. Forcing these to download would break legitimate image "
|
||||
f"previews in /view — they must keep rendering inline."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_dangerous_content_type() — the normalising check the /view and /userdata
|
||||
# handlers now call instead of a raw `in DANGEROUS_CONTENT_TYPES` membership
|
||||
# test. An exact-string membership test was bypassable with a charset parameter
|
||||
# or odd casing, and missed the wider XML dialect family; these tests pin the
|
||||
# normalisation so that bypass can't reopen.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_function_matches_plain_dangerous_types():
|
||||
for ct in DANGEROUS:
|
||||
assert folder_paths.is_dangerous_content_type(ct) is True, ct
|
||||
|
||||
|
||||
def test_function_strips_parameters_and_casing():
|
||||
"""A charset/boundary parameter or casing must not slip a type past the check.
|
||||
|
||||
This is the bypass surfaced by review: the /view blake3 branch can serve an
|
||||
attacker-controlled, unvalidated asset mime_type like 'text/html; charset=utf-8',
|
||||
which an exact-string set test missed.
|
||||
"""
|
||||
for ct in (
|
||||
'text/html; charset=utf-8',
|
||||
'TEXT/HTML',
|
||||
'Text/HTML; charset=UTF-8',
|
||||
'image/svg+xml; charset=utf-8',
|
||||
' text/html ',
|
||||
):
|
||||
assert folder_paths.is_dangerous_content_type(ct) is True, ct
|
||||
|
||||
|
||||
def test_function_covers_xml_dialect_family():
|
||||
"""Any *+xml / */xml dialect is dangerous without enumerating each one."""
|
||||
for ct in (
|
||||
'application/xslt+xml',
|
||||
'application/rss+xml',
|
||||
'application/atom+xml',
|
||||
'application/rdf+xml',
|
||||
'application/mathml+xml',
|
||||
'message/rfc822',
|
||||
):
|
||||
assert folder_paths.is_dangerous_content_type(ct) is True, ct
|
||||
|
||||
|
||||
def test_function_allows_benign_and_empty():
|
||||
for ct in BENIGN_INLINE_IMAGES + ['application/octet-stream', 'text/plain']:
|
||||
assert folder_paths.is_dangerous_content_type(ct) is False, ct
|
||||
# None / empty (mimetypes.guess_type miss) must not be treated as dangerous.
|
||||
assert folder_paths.is_dangerous_content_type(None) is False
|
||||
assert folder_paths.is_dangerous_content_type('') is False
|
||||
@ -573,144 +573,6 @@ class TestExecution:
|
||||
else:
|
||||
assert result.did_run(test_node), "The execution should have been re-run"
|
||||
|
||||
def test_expected_outputs_all_connected(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that expected_outputs contains all connected outputs."""
|
||||
g = builder
|
||||
# Create a node with 3 outputs, all connected
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
|
||||
|
||||
# Connect all 3 outputs to preview nodes
|
||||
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
|
||||
output1 = g.node("PreviewImage", images=expected_outputs_node.out(1))
|
||||
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# All outputs should be white (255) since all are connected
|
||||
images0 = result.get_images(output0)
|
||||
images1 = result.get_images(output1)
|
||||
images2 = result.get_images(output2)
|
||||
|
||||
assert len(images0) == 1, "Should have 1 image for output0"
|
||||
assert len(images1) == 1, "Should have 1 image for output1"
|
||||
assert len(images2) == 1, "Should have 1 image for output2"
|
||||
|
||||
# White pixels = 255, meaning output was in expected_outputs
|
||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)"
|
||||
assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)"
|
||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)"
|
||||
|
||||
def test_expected_outputs_partial_connected(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that expected_outputs only contains connected outputs."""
|
||||
g = builder
|
||||
# Create a node with 3 outputs, only some connected
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
|
||||
|
||||
# Only connect outputs 0 and 2, leave output 1 disconnected
|
||||
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
|
||||
# output1 is intentionally not connected
|
||||
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Connected outputs should be white (255)
|
||||
images0 = result.get_images(output0)
|
||||
images2 = result.get_images(output2)
|
||||
|
||||
assert len(images0) == 1, "Should have 1 image for output0"
|
||||
assert len(images2) == 1, "Should have 1 image for output2"
|
||||
|
||||
# White = expected, output 1 is not connected so we can't verify it directly but outputs 0 and 2 should be white
|
||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)"
|
||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)"
|
||||
|
||||
def test_expected_outputs_single_connected(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that expected_outputs works with single connected output."""
|
||||
g = builder
|
||||
# Create a node with 3 outputs, only one connected
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
|
||||
|
||||
# Only connect output 1
|
||||
output1 = g.node("PreviewImage", images=expected_outputs_node.out(1))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
images1 = result.get_images(output1)
|
||||
assert len(images1) == 1, "Should have 1 image for output1"
|
||||
|
||||
# Output 1 should be white (connected), others are not visible in this test
|
||||
assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)"
|
||||
|
||||
def test_expected_outputs_cache_invalidation(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
"""Test that cache invalidates when output connections change."""
|
||||
g = builder
|
||||
# Use unique dimensions to avoid cache collision with other expected_outputs tests
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=32, width=32)
|
||||
|
||||
# First run: only connect output 0
|
||||
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
|
||||
|
||||
result1 = client.run(g)
|
||||
assert result1.did_run(expected_outputs_node), "First run should execute the node"
|
||||
|
||||
# Second run: same connections, should be cached
|
||||
result2 = client.run(g)
|
||||
if server["should_cache_results"]:
|
||||
assert not result2.did_run(expected_outputs_node), "Second run should be cached"
|
||||
|
||||
# Third run: add connection to output 2
|
||||
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
|
||||
|
||||
result3 = client.run(g)
|
||||
# Because LAZY_OUTPUTS=True, changing connections should invalidate cache
|
||||
if server["should_cache_results"]:
|
||||
assert result3.did_run(expected_outputs_node), "Adding output connection should invalidate cache"
|
||||
|
||||
# Verify both outputs are now white
|
||||
images0 = result3.get_images(output0)
|
||||
images2 = result3.get_images(output2)
|
||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white"
|
||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white"
|
||||
|
||||
def test_expected_outputs_expansion_output_mapping(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""A socket consumed only via an expansion's parent-output mapping must still
|
||||
be in the inner LAZY_OUTPUTS node's expected_outputs (white, not black)."""
|
||||
g = builder
|
||||
expander = g.node("TestExpectedOutputsExpansion", height=80, width=80)
|
||||
output = g.node("PreviewImage", images=expander.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have 1 image"
|
||||
assert numpy.array(images[0]).min() == 255, (
|
||||
"Inner node skipped an output that is consumed via the expansion's "
|
||||
"parent-output mapping (expected white, got black)"
|
||||
)
|
||||
|
||||
def test_expected_outputs_requires_opt_in(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
"""Nodes without LAZY_OUTPUTS must see expected_outputs=None: their cache key
|
||||
ignores topology, so a skipped output would be served stale after rewiring."""
|
||||
g = builder
|
||||
node = g.node("TestExpectedOutputsNotOptedIn", height=96, width=96)
|
||||
output0 = g.node("PreviewImage", images=node.out(0))
|
||||
|
||||
# Only output 0 connected: correct gating -> node sees None, computes all
|
||||
result1 = client.run(g)
|
||||
assert numpy.array(result1.get_images(output0)[0]).min() == 255
|
||||
|
||||
# Connect output 1: key unchanged -> cache hit must still serve correct data
|
||||
output1 = g.node("PreviewImage", images=node.out(1))
|
||||
result2 = client.run(g)
|
||||
|
||||
if server["should_cache_results"]:
|
||||
assert not result2.did_run(node), "Node should be a cache hit (key ignores topology)"
|
||||
images1 = result2.get_images(output1)
|
||||
assert len(images1) == 1, "Should have 1 image for output1"
|
||||
assert numpy.array(images1[0]).min() == 255, (
|
||||
"Non-opted-in node observed expected_outputs and skipped output 1; "
|
||||
"the stale skipped value was then served from cache"
|
||||
)
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
|
||||
@ -6,7 +6,6 @@ from .tools import VariantSupport
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
from comfy_execution.utils import get_executing_context, is_output_needed
|
||||
|
||||
class TestLazyMixImages:
|
||||
@classmethod
|
||||
@ -483,106 +482,6 @@ class TestOutputNodeWithSocketOutput:
|
||||
result = image * value
|
||||
return (result,)
|
||||
|
||||
|
||||
class TestExpectedOutputs:
|
||||
"""Test node for the expected_outputs feature.
|
||||
|
||||
This node has 3 IMAGE outputs that encode which outputs were expected:
|
||||
- White image (255) if the output was in expected_outputs
|
||||
- Black image (0) if the output was NOT in expected_outputs
|
||||
|
||||
This allows integration tests to verify which outputs were expected by checking pixel values.
|
||||
"""
|
||||
LAZY_OUTPUTS = True # Opt into cache invalidation on output connection changes
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"height": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
"width": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
|
||||
RETURN_NAMES = ("output0", "output1", "output2")
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def execute(self, height, width):
|
||||
# Return white image if expected, black if not
|
||||
# This allows tests to verify which outputs were expected via pixel values
|
||||
white = torch.ones(1, height, width, 3)
|
||||
black = torch.zeros(1, height, width, 3)
|
||||
|
||||
return (
|
||||
white if is_output_needed(0) else black,
|
||||
white if is_output_needed(1) else black,
|
||||
white if is_output_needed(2) else black,
|
||||
)
|
||||
|
||||
|
||||
class TestExpectedOutputsExpansion:
|
||||
"""Expands into an inner LAZY_OUTPUTS node whose output 1 is consumed ONLY via
|
||||
the parent-output mapping (no input link anywhere). If that mapping is not part
|
||||
of the expected-outputs map, the inner node wrongly skips it -> black not white.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"height": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
"width": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def execute(self, height, width):
|
||||
g = GraphBuilder()
|
||||
inner = g.node("TestExpectedOutputs", height=height, width=width)
|
||||
return {"result": (inner.out(1),), "expand": g.finalize()}
|
||||
|
||||
|
||||
class TestExpectedOutputsNotOptedIn:
|
||||
"""Reads expected_outputs WITHOUT declaring LAZY_OUTPUTS; the executor must pass
|
||||
None (such nodes have no cache-key protection against output rewiring). Outputs
|
||||
are white when the node correctly sees None, otherwise they encode membership.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"height": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
"width": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "IMAGE")
|
||||
RETURN_NAMES = ("output0", "output1")
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def execute(self, height, width):
|
||||
# Raw context access (not is_output_needed): must distinguish None from a set
|
||||
ctx = get_executing_context()
|
||||
expected = ctx.expected_outputs if ctx is not None else None
|
||||
|
||||
white = torch.ones(1, height, width, 3)
|
||||
black = torch.zeros(1, height, width, 3)
|
||||
|
||||
if expected is None:
|
||||
return (white, white.clone())
|
||||
return (
|
||||
white if 0 in expected else black,
|
||||
white if 1 in expected else black,
|
||||
)
|
||||
|
||||
|
||||
TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestLazyMixImages": TestLazyMixImages,
|
||||
"TestVariadicAverage": TestVariadicAverage,
|
||||
@ -599,9 +498,6 @@ TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestSleep": TestSleep,
|
||||
"TestParallelSleep": TestParallelSleep,
|
||||
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
|
||||
"TestExpectedOutputs": TestExpectedOutputs,
|
||||
"TestExpectedOutputsExpansion": TestExpectedOutputsExpansion,
|
||||
"TestExpectedOutputsNotOptedIn": TestExpectedOutputsNotOptedIn,
|
||||
}
|
||||
|
||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@ -620,7 +516,4 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestSleep": "Test Sleep",
|
||||
"TestParallelSleep": "Test Parallel Sleep",
|
||||
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
|
||||
"TestExpectedOutputs": "Test Expected Outputs",
|
||||
"TestExpectedOutputsExpansion": "Test Expected Outputs Expansion",
|
||||
"TestExpectedOutputsNotOptedIn": "Test Expected Outputs Not Opted In",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user