mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-05 03:56:47 +08:00
Compare commits
1 Commits
ListInput
...
feat/creat
| Author | SHA1 | Date | |
|---|---|---|---|
| b82c8c7eb3 |
130
AGENTS.md
130
AGENTS.md
@ -8,13 +8,10 @@
|
||||
directly required.
|
||||
- Prefer practical fixes over broad architecture work. Add abstractions only
|
||||
when they remove real repeated logic or match an existing ComfyUI pattern.
|
||||
- Prefer fewer dependencies. Do not add new dependencies to ComfyUI unless they
|
||||
are absolutely necessary.
|
||||
- Delete obsolete code aggressively when newer infrastructure makes it useless.
|
||||
Remove dead fallbacks, migration paths, unused options, debug prints, and
|
||||
compatibility branches that are no longer needed. Do not leave dead branches,
|
||||
unreachable code, or functions that are never called. If code is not
|
||||
necessary for the current behavior, remove it.
|
||||
unreachable code, or functions that are never called.
|
||||
- Revert or disable problematic behavior quickly when it breaks users. It is
|
||||
better to remove a broken feature path than keep a complicated partial fix.
|
||||
- Preserve existing APIs, node names, model-loading behavior, file layout, and
|
||||
@ -88,14 +85,6 @@
|
||||
not change a shared method to return extra values, alternate shapes, or
|
||||
sentinel wrappers for one implementation unless the shared interface is
|
||||
explicitly updated.
|
||||
- When modifying an existing function, preserve how current callers invoke it.
|
||||
Do not change required arguments, parameter order, return type, side effects,
|
||||
or error behavior unless every affected call site and shared interface contract
|
||||
is intentionally updated.
|
||||
- Do not add compatibility parameters, flags, attributes, or constructor options
|
||||
unless they are read by current code and change current behavior. Remove
|
||||
pass-through or stored-but-unused values instead of preserving upstream or
|
||||
deprecated API baggage.
|
||||
- If an implementation needs auxiliary values for its own workflow, expose them
|
||||
through a private helper or a clearly named implementation-specific method
|
||||
instead of overloading the public method's return contract.
|
||||
@ -113,11 +102,6 @@
|
||||
- Do not add freeze, unfreeze, or trainability toggles to model classes. ComfyUI
|
||||
models are always treated as frozen for inference, so explicit freeze
|
||||
functionality is redundant and should not be added.
|
||||
- Remove training-only behavior such as dropout from inference model code, but
|
||||
preserve checkpoint and state-dict compatibility when doing so. If deleting a
|
||||
module would change state-dict keys, module ordering, or checkpoint loading
|
||||
behavior, replace it with a no-op such as `nn.Identity` instead of removing the
|
||||
slot outright.
|
||||
|
||||
## Python Style
|
||||
|
||||
@ -127,11 +111,6 @@
|
||||
- Do not add unnecessary `try`/`except` blocks. Use them for optional dependency,
|
||||
platform, or backend capability detection only when the program has a useful
|
||||
fallback. Prefer specific exception types when changing new code.
|
||||
- Remove any workarounds for PyTorch versions that ComfyUI no longer officially
|
||||
supports. Deprecated workarounds include catching an exception and rerunning
|
||||
the same op with the input cast to float. If a workaround does not have a
|
||||
comment naming the exact PyTorch version or versions that still need it,
|
||||
remove it.
|
||||
- Let unsupported model formats, invalid quantization metadata, and bad states
|
||||
fail with clear errors instead of silently producing lower quality output.
|
||||
- Match the existing local style in the file you edit. This codebase tolerates
|
||||
@ -150,101 +129,8 @@
|
||||
adding parallel code paths. Use `comfy.quant_ops`, `comfy.model_management`,
|
||||
`comfy.memory_management`, `comfy.pinned_memory`, `comfy_aimdo`, and
|
||||
`comfy-kitchen` helpers where they already solve the problem.
|
||||
- Use optimized comfy-kitchen ops in places where they improve performance
|
||||
without changing the expected dtype, device, memory, or interface behavior.
|
||||
- All models should use the optimized attention function selected by ComfyUI.
|
||||
Treat optimized backend functions, dispatch helpers, and capability-selected
|
||||
callables as opaque. Higher-level code must not inspect function identity,
|
||||
names, modules, or implementation details to decide behavior.
|
||||
- Apply the same opacity rule to similar patterns beyond attention: callers
|
||||
should depend on the documented interface and result contract, not on which
|
||||
backend implementation was selected underneath.
|
||||
- Do not use custom inference ops that only duplicate an existing op while
|
||||
upcasting to float32, such as custom RMSNorm variants. Use the generic ComfyUI
|
||||
ops and/or native torch ops instead.
|
||||
- If a model class `__init__` has an `operations` parameter, assume
|
||||
`operations` is never `None`. Do not add fallback branches or default torch
|
||||
ops for a missing `operations` object.
|
||||
- Do not add unnecessary parameters to model, model block, or model ops related
|
||||
classes. Constructor and forward signatures should carry only values that are
|
||||
actually needed by that object for inference.
|
||||
- Reuse existing model classes, blocks, ops, and helper modules when appropriate.
|
||||
Before implementing a new version of a model component, search the existing
|
||||
model code for a class or helper that already provides the behavior.
|
||||
- Model detection code that inspects linear weight shapes should only use the
|
||||
first dimension. The second dimension may be half the original size for
|
||||
NVFP4 or other 4-bit quantized models.
|
||||
- Avoid adding `einops` usage in core inference code. Use native torch tensor
|
||||
ops such as `reshape`, `view`, `permute`, `transpose`, `flatten`, `unflatten`,
|
||||
`unsqueeze`, and `squeeze` instead.
|
||||
- Do not use tensors as general-purpose Python data structures. Keep metadata,
|
||||
bookkeeping, counters, flags, shape math, padding math, index planning, memory
|
||||
estimates, and control-flow decisions in plain Python values unless the data
|
||||
must participate directly in tensor computation. Do not create tensors for
|
||||
structural metadata that is only used for Python-side control flow. Sequence
|
||||
lengths, cumulative offsets, split indices, window counts, slice boundaries,
|
||||
and repeat counts should be kept as Python ints/lists from the point they are
|
||||
computed. Do not build them as CPU/GPU tensors and then cast, move, validate,
|
||||
or convert them back to Python for `split`, `tensor_split`, indexing plans,
|
||||
loops, or cache keys. Avoid creating temporary tensors just to use tensor
|
||||
methods for scalar or structural calculations.
|
||||
- Avoid unnecessary casts and transfers. Preserve the intended compute dtype,
|
||||
storage dtype, bias dtype, and original tensor shape metadata.
|
||||
- Keep model-native latent layout handling inside the model or latent-format
|
||||
owner, not in helper nodes. Do not collapse, expand, pack, or unpack latent
|
||||
dimensions in nodes or other caller-side adapters just to satisfy a model
|
||||
forward; the model path should consume and return the native latent shape for
|
||||
that model family.
|
||||
- Assume inputs to the main model forward are already in the compute dtype by
|
||||
default, except integer inputs such as some model timestep tensors. Do not add
|
||||
defensive or convenience casts in model code; it is better for invalid dtype
|
||||
plumbing to error clearly than to hide it with unnecessary casts.
|
||||
- Raw model parameters that are not owned by an op and may be initialized in a
|
||||
dtype different from the compute dtype should be cast at use in forward or
|
||||
inference code with `comfy.ops.cast_to_input` or
|
||||
`comfy.model_management.cast_to` to avoid dtype mismatches.
|
||||
- Model code should not care what dtype it is initialized in, and model
|
||||
`__init__` methods should not contain workarounds for specific dtypes. Dtype
|
||||
workaround code, such as making a model work with fp16 compute, belongs in the
|
||||
execution or model-management layer that owns compute policy.
|
||||
- Model code should not perform unnecessary device-to-CPU or CPU-to-device
|
||||
transfers. New allocations must be created on the correct device and dtype;
|
||||
never allocate on CPU and then move to GPU, or allocate in one dtype and then
|
||||
convert to another.
|
||||
- Model code itself should not perform memory management. Loading, unloading,
|
||||
offloading, device movement, VRAM policy, cache lifetime, and cleanup belong
|
||||
in the relevant model-management and execution layers, not inside model
|
||||
implementations.
|
||||
- Do not add global, module-level, class-level, singleton, or model-owned stores
|
||||
for tensors or other large memory that persist across executions. Temporary
|
||||
caches must be scoped to a single execution or forward/encode/decode call:
|
||||
allocate them in the owning top-level call, pass them explicitly through the
|
||||
call stack, and let them be discarded when that call returns.
|
||||
- Follow the Wan VAE temporal cache pattern for temporary caches: create a local
|
||||
cache such as `feat_map` for the encode/decode operation, pass it into the
|
||||
blocks that need it, and do not retain it on the model or in global state.
|
||||
- In model init code, prefer `torch.empty` for parameter/buffer placeholders
|
||||
that are populated from the model state dict instead of zero-initializing with
|
||||
`torch.zeros` or similar. If an allocation is not loaded from the state dict
|
||||
and is useless for inference, do not include it.
|
||||
- `nn.Parameter` tensors that are stored in and populated from the model state
|
||||
dict should be initialized with `torch.empty`, not with zero, random, or
|
||||
otherwise meaningful initialization.
|
||||
- Model initialization should describe module structure, not fabricate
|
||||
checkpoint-owned tensor contents. Parameters and buffers that are loaded from
|
||||
the state dict must not be manually initialized, reassigned, or filled with
|
||||
fallback values unless that value is actually used when no checkpoint key
|
||||
exists.
|
||||
- When slicing large tensors, copy the slice if the sliced tensor's lifetime
|
||||
exceeds the current function scope. Do not keep a long-lived view into a large
|
||||
backing tensor when a smaller copy would release memory sooner.
|
||||
- Use fused or compound torch operations such as `addcmul` when they naturally
|
||||
match the math. Reducing Python and torch dispatch overhead is a valid
|
||||
optimization when it does not obscure the code or change dtype/device
|
||||
behavior.
|
||||
- Avoid caches that persist across different executions as much as possible.
|
||||
Persistent caches are acceptable only when they use a very minimal amount of
|
||||
memory and have a clear ownership and invalidation story.
|
||||
- When optimizing, favor small measurable changes: fewer allocations, fewer
|
||||
device transfers, less peak memory, better batching, or use of a faster
|
||||
existing backend op.
|
||||
@ -255,20 +141,6 @@
|
||||
`CATEGORY`, and registration through the local mapping used by that file.
|
||||
- Keep node changes backward compatible by default. Add inputs with sensible
|
||||
defaults and avoid changing output types unless the request requires it.
|
||||
- Model implementations should add the minimal number of ComfyUI nodes required
|
||||
to run the model. Reuse existing nodes as much as possible; adapting the model
|
||||
to work with existing nodes is strongly preferred over creating new nodes.
|
||||
- Nodes should output only values they own. Do not add pass-through outputs for
|
||||
workflow convenience unless the node is explicitly an output node. Existing
|
||||
models, latents, conditioning, or other inputs should flow directly to the
|
||||
next consumer instead of being re-emitted unchanged.
|
||||
- Nodes should expose only inputs they actually read to produce current
|
||||
behavior. Do not add placeholder, pass-through, compatibility, or
|
||||
workflow-shaping inputs that are ignored or could flow directly to another
|
||||
node.
|
||||
- Node-level code must not patch model code directly. Any node behavior that
|
||||
modifies, wraps, hooks, or changes model behavior must go through the model
|
||||
patcher class instead of reaching into model internals.
|
||||
- The official mascot of ComfyUI is a very cute anime girl with massive fennec
|
||||
ears, a big fluffy tail, long blonde wavy hair, and blue eyes. Feel free to
|
||||
use her in ComfyUI materials, UI text, examples, tests, generated assets, or
|
||||
|
||||
@ -306,15 +306,12 @@ async def download_asset_content(request: web.Request) -> web.Response:
|
||||
404, "FILE_NOT_FOUND", "Underlying file not found on disk."
|
||||
)
|
||||
|
||||
# User-controlled asset content must never render inline in the app origin
|
||||
# (stored XSS via SVG/HTML/XML). Force dangerous types to download and
|
||||
# override any requested inline disposition. Centralised through
|
||||
# folder_paths.is_dangerous_content_type so this can't drift from /view and
|
||||
# /userdata (the previous inline set here omitted image/svg+xml and missed
|
||||
# the charset/casing/+xml-dialect bypasses).
|
||||
if folder_paths.is_dangerous_content_type(content_type):
|
||||
_DANGEROUS_MIME_TYPES = {
|
||||
"text/html", "text/html-sandboxed", "application/xhtml+xml",
|
||||
"text/javascript", "text/css",
|
||||
}
|
||||
if content_type in _DANGEROUS_MIME_TYPES:
|
||||
content_type = "application/octet-stream"
|
||||
disposition = "attachment"
|
||||
|
||||
safe_name = (filename or "").replace("\r", "").replace("\n", "")
|
||||
encoded = urllib.parse.quote(safe_name)
|
||||
|
||||
@ -50,45 +50,21 @@ class ModelFileManager:
|
||||
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
||||
async def get_model_preview(request):
|
||||
folder_name = request.match_info.get("folder", None)
|
||||
path_index = int(request.match_info.get("path_index", None))
|
||||
filename = request.match_info.get("filename", None)
|
||||
|
||||
if folder_name not in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
|
||||
# The "{filename:.*}" capture also matches the empty string, which
|
||||
# would resolve to the folder itself; reject it explicitly.
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
try:
|
||||
path_index = int(request.match_info.get("path_index", None))
|
||||
except (TypeError, ValueError):
|
||||
return web.Response(status=400)
|
||||
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
if path_index < 0 or path_index >= len(folders[0]):
|
||||
return web.Response(status=404)
|
||||
folder = folders[0][path_index]
|
||||
full_filename = os.path.normpath(os.path.join(folder, filename))
|
||||
|
||||
# Prevent path traversal: the requested file must stay within the
|
||||
# configured model folder. `filename` is an unrestricted ".*" capture,
|
||||
# so values like "../../../../etc/passwd" would otherwise escape it.
|
||||
if not folder_paths.is_within_directory(folder, full_filename):
|
||||
return web.Response(status=403)
|
||||
full_filename = os.path.join(folder, filename)
|
||||
|
||||
previews = self.get_model_previews(full_filename)
|
||||
default_preview = previews[0] if len(previews) > 0 else None
|
||||
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
||||
return web.Response(status=404)
|
||||
|
||||
# The preview is selected by a glob inside get_model_previews, so a
|
||||
# companion file (e.g. "model.preview.png") could itself be a symlink
|
||||
# resolving outside the model folder. Re-validate the file actually
|
||||
# opened: is_within_directory realpaths it, catching symlink escape.
|
||||
if isinstance(default_preview, str) and not folder_paths.is_within_directory(folder, default_preview):
|
||||
return web.Response(status=403)
|
||||
|
||||
try:
|
||||
with Image.open(default_preview) as img:
|
||||
img_bytes = BytesIO()
|
||||
|
||||
@ -6,7 +6,6 @@ import glob
|
||||
import shutil
|
||||
import logging
|
||||
import tempfile
|
||||
import mimetypes
|
||||
from aiohttp import web
|
||||
from urllib import parse
|
||||
from comfy.cli_args import args
|
||||
@ -337,20 +336,7 @@ class UserManager():
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
# User data files are arbitrary user-supplied content and are never
|
||||
# meant to render inline. Disable MIME sniffing and force a download
|
||||
# so uploaded markup/scripts can't execute in the app origin (stored
|
||||
# XSS). Content-Disposition: attachment is the load-bearing guard;
|
||||
# the content-type override and nosniff are defence in depth.
|
||||
content_type = mimetypes.guess_type(path)[0] or 'application/octet-stream'
|
||||
if folder_paths.is_dangerous_content_type(content_type):
|
||||
content_type = 'application/octet-stream'
|
||||
|
||||
return web.FileResponse(path, headers={
|
||||
"Content-Type": content_type,
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Content-Disposition": "attachment",
|
||||
})
|
||||
return web.FileResponse(path)
|
||||
|
||||
@routes.post("/userdata/{file}")
|
||||
async def post_userdata(request):
|
||||
|
||||
@ -543,24 +543,18 @@ class SDTokenizer:
|
||||
def _try_get_embedding(self, embedding_name:str):
|
||||
'''
|
||||
Takes a potential embedding name and tries to retrieve it.
|
||||
Returns a Tuple consisting of the embedding, the cleaned embedding name, and any leftover string, embedding can be None.
|
||||
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
||||
'''
|
||||
split_embed = embedding_name.split()
|
||||
embedding_name = split_embed[0]
|
||||
leftover = ' '.join(split_embed[1:])
|
||||
|
||||
match = re.search(r'[<\[]', embedding_name)
|
||||
if match is not None:
|
||||
leftover = embedding_name[match.start():] + (" " + leftover if leftover else "")
|
||||
embedding_name = embedding_name[:match.start()]
|
||||
|
||||
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||
if embed is None:
|
||||
stripped = embedding_name.strip(',')
|
||||
if len(stripped) < len(embedding_name):
|
||||
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||
return (embed, embedding_name, "{} {}".format(embedding_name[len(stripped):], leftover))
|
||||
return (embed, embedding_name, leftover)
|
||||
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
|
||||
return (embed, leftover)
|
||||
|
||||
def pad_tokens(self, tokens, amount):
|
||||
if self.pad_left:
|
||||
@ -591,7 +585,7 @@ class SDTokenizer:
|
||||
tokens = []
|
||||
for weighted_segment, weight in parsed_weights:
|
||||
to_tokenize = unescape_important(weighted_segment)
|
||||
split = re.split(r'(?<=\s){}'.format(re.escape(self.embedding_identifier)), to_tokenize)
|
||||
split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize)
|
||||
to_tokenize = [split[0]]
|
||||
for i in range(1, len(split)):
|
||||
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
|
||||
@ -601,7 +595,7 @@ class SDTokenizer:
|
||||
# if we find an embedding, deal with the embedding
|
||||
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
|
||||
embedding_name = word[len(self.embedding_identifier):].strip('\n')
|
||||
embed, embedding_name, leftover = self._try_get_embedding(embedding_name)
|
||||
embed, leftover = self._try_get_embedding(embedding_name)
|
||||
if embed is None:
|
||||
logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
|
||||
else:
|
||||
|
||||
@ -167,7 +167,7 @@ class Qwen3VLTokenizer(sd1_clip.SD1Tokenizer):
|
||||
embed_count = 0
|
||||
for r in tokens[key_name]:
|
||||
for i in range(len(r)):
|
||||
if isinstance(r[i][0], (int, float)) and r[i][0] == 151655: # <|image_pad|>
|
||||
if r[i][0] == 151655: # <|image_pad|>
|
||||
if len(images) > embed_count:
|
||||
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
|
||||
@ -1261,158 +1261,6 @@ class DynamicSlot(ComfyTypeI):
|
||||
out_dict[input_type][finalized_id] = value
|
||||
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||
|
||||
@comfytype(io_type="COMFY_DYNAMICGROUP_V3")
|
||||
class DynamicGroup(ComfyTypeI):
|
||||
"""A repeatable group of widget inputs (e.g. lora_name + strength stacked into N rows).
|
||||
|
||||
At execution time the node receives a ``list[dict]`` where each element is a row.
|
||||
|
||||
Example::
|
||||
|
||||
io.DynamicGroup.Input(
|
||||
"loras",
|
||||
template=[
|
||||
io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras")),
|
||||
io.Float.Input("strength", default=1.0, min=-100, max=100, step=0.01),
|
||||
],
|
||||
min=0,
|
||||
max=50,
|
||||
)
|
||||
# execute receives: loras: list[dict] = [{"lora_name": "x.safetensors", "strength": 1.0}, ...]
|
||||
"""
|
||||
|
||||
Type = list[dict[str, Any]]
|
||||
_MaxRows = 100
|
||||
|
||||
class Input(DynamicInput):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
template: list["Input"],
|
||||
min: int = 0,
|
||||
max: int = 50,
|
||||
display_name: str = None,
|
||||
optional: bool = False,
|
||||
tooltip: str = None,
|
||||
lazy: bool = None,
|
||||
extra_dict=None,
|
||||
group_name: str = "Group",
|
||||
):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
# Validate template entries: only WidgetInput subclasses, no nesting
|
||||
assert len(template) > 0, "DynamicGroup template must have at least one field."
|
||||
for t in template:
|
||||
assert isinstance(t, WidgetInput), (
|
||||
f"DynamicGroup template field '{t.id}' must be a WidgetInput subclass "
|
||||
f"(Combo, Float, Int, String, Boolean, Color). Got {type(t).__name__}."
|
||||
)
|
||||
assert not isinstance(t, DynamicInput), (
|
||||
f"DynamicGroup template field '{t.id}' must not be a DynamicInput. "
|
||||
"Nesting dynamic inputs inside DynamicGroup is not supported."
|
||||
)
|
||||
# Enforce unique field ids within template
|
||||
field_ids = [t.id for t in template]
|
||||
assert len(field_ids) == len(set(field_ids)), (
|
||||
f"DynamicGroup template field ids must be unique within a row. Got: {field_ids}"
|
||||
)
|
||||
# Reject "." in group id and template field ids: slot_id encoding uses "." as a
|
||||
# delimiter (<group_id>.<row>.<field_id>), so any "." in these names would cause
|
||||
# path.split(".") to produce the wrong number of segments during decoding.
|
||||
assert "." not in id, (
|
||||
f"DynamicGroup id must not contain '.'. Got: '{id}'"
|
||||
)
|
||||
for t in template:
|
||||
assert "." not in t.id, (
|
||||
f"DynamicGroup template field id must not contain '.'. Got: '{t.id}'"
|
||||
)
|
||||
assert min >= 0, "DynamicGroup min must be >= 0."
|
||||
assert max >= 1, "DynamicGroup max must be >= 1."
|
||||
assert max <= DynamicGroup._MaxRows, f"DynamicGroup max must be <= {DynamicGroup._MaxRows}."
|
||||
assert min <= max, "DynamicGroup min must be <= max."
|
||||
self.template = template
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.group_name = group_name
|
||||
|
||||
def get_all(self) -> list["Input"]:
|
||||
return [self] + list(self.template)
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"template": create_input_dict_v1(self.template),
|
||||
"min": self.min,
|
||||
"max": self.max,
|
||||
"group_name": self.group_name,
|
||||
})
|
||||
|
||||
def validate(self):
|
||||
for t in self.template:
|
||||
t.validate()
|
||||
|
||||
@staticmethod
|
||||
def _expand_schema_for_dynamic(
|
||||
out_dict: dict[str, Any],
|
||||
live_inputs: dict[str, Any],
|
||||
value: tuple[str, dict[str, Any]],
|
||||
input_type: str,
|
||||
curr_prefix: list[str] | None,
|
||||
):
|
||||
info = value[1]
|
||||
min_rows: int = info.get("min", 0)
|
||||
max_rows: int = info.get("max", DynamicGroup._MaxRows)
|
||||
template: dict[str, Any] = info.get("template", {})
|
||||
|
||||
# Collect all template field specs across required/optional sections
|
||||
field_specs: list[tuple[str, tuple[str, dict[str, Any]], bool]] = []
|
||||
for field_required_key in ("required", "optional"):
|
||||
section = template.get(field_required_key, {})
|
||||
is_required_field = field_required_key == "required"
|
||||
for field_id, field_value in section.items():
|
||||
field_specs.append((field_id, field_value, is_required_field))
|
||||
|
||||
# Determine how many rows are currently present by scanning live_inputs
|
||||
finalized_prefix = finalize_prefix(curr_prefix)
|
||||
present_rows = 0
|
||||
for live_key in live_inputs:
|
||||
# Keys look like "<prefix>.<row>.<field_id>"
|
||||
if live_key.startswith(finalized_prefix + "."):
|
||||
remainder = live_key[len(finalized_prefix) + 1:]
|
||||
parts = remainder.split(".", 1)
|
||||
if len(parts) >= 1:
|
||||
try:
|
||||
row_idx = int(parts[0])
|
||||
present_rows = max(present_rows, row_idx + 1)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if present_rows > max_rows:
|
||||
raise ValueError(
|
||||
f"DynamicGroup input '{finalized_prefix}' received {present_rows} rows but max is {max_rows}."
|
||||
)
|
||||
row_count = max(min_rows, present_rows)
|
||||
|
||||
for row in range(row_count):
|
||||
for field_id, field_value, is_required_field in field_specs:
|
||||
slot_id = f"{finalized_prefix}.{row}.{field_id}"
|
||||
# The first `min_rows` rows are required if the field itself is required
|
||||
if row < min_rows and is_required_field:
|
||||
out_dict["required"][slot_id] = field_value
|
||||
else:
|
||||
out_dict["optional"][slot_id] = field_value
|
||||
# Register into dynamic_paths so build_nested_inputs places value at the right path
|
||||
out_dict["dynamic_paths"][slot_id] = slot_id
|
||||
|
||||
# Track the list root path so build_nested_inputs can convert the index dict to a list
|
||||
out_dict.setdefault("list_paths", set()).add(finalized_prefix)
|
||||
|
||||
# Handle the empty case (0 rows) – emit an empty-list default for the parent.
|
||||
# This must only fire when there are genuinely no rows; otherwise the parent
|
||||
# path would clobber the per-row dict built from the slot ids above.
|
||||
if row_count == 0:
|
||||
out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix
|
||||
out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_LIST
|
||||
|
||||
|
||||
@comfytype(io_type="IMAGECOMPARE")
|
||||
class ImageCompare(ComfyTypeI):
|
||||
Type = dict
|
||||
@ -1570,8 +1418,6 @@ def setup_dynamic_input_funcs():
|
||||
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
|
||||
# DynamicSlot.Input
|
||||
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
|
||||
# DynamicGroup.Input
|
||||
register_dynamic_input_func(DynamicGroup.io_type, DynamicGroup._expand_schema_for_dynamic)
|
||||
|
||||
if len(DYNAMIC_INPUT_LOOKUP) == 0:
|
||||
setup_dynamic_input_funcs()
|
||||
@ -1583,8 +1429,6 @@ class V3Data(TypedDict):
|
||||
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
|
||||
dynamic_paths_default_value: dict[str, Any]
|
||||
'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.'
|
||||
list_paths: set[str]
|
||||
'Set of top-level keys whose index-keyed dict values should be converted to a sorted list[dict] after build_nested_inputs runs.'
|
||||
create_dynamic_tuple: bool
|
||||
'When True, the value of the dynamic input will be in the format (value, path_key).'
|
||||
|
||||
@ -1926,7 +1770,6 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
||||
"optional": {},
|
||||
"dynamic_paths": {},
|
||||
"dynamic_paths_default_value": {},
|
||||
"list_paths": set(),
|
||||
}
|
||||
d = d.copy()
|
||||
# ignore hidden for parsing
|
||||
@ -1942,10 +1785,6 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
||||
dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None)
|
||||
if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0:
|
||||
v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value
|
||||
# list_paths: keys whose nested dict should be post-converted to a sorted list[dict]
|
||||
list_paths = out_dict.pop("list_paths", None)
|
||||
if list_paths:
|
||||
v3_data["list_paths"] = list_paths
|
||||
return out_dict, hidden, v3_data
|
||||
|
||||
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
|
||||
@ -1981,12 +1820,10 @@ def add_to_dict_v1(i: Input, d: dict):
|
||||
|
||||
class DynamicPathsDefaultValue:
|
||||
EMPTY_DICT = "empty_dict"
|
||||
EMPTY_LIST = "empty_list"
|
||||
|
||||
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
paths = v3_data.get("dynamic_paths", None)
|
||||
default_value_dict = v3_data.get("dynamic_paths_default_value", {})
|
||||
list_paths: set[str] = v3_data.get("list_paths", set()) or set()
|
||||
if paths is None:
|
||||
return values
|
||||
values = values.copy()
|
||||
@ -2009,8 +1846,6 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
default_option = default_value_dict.get(key, None)
|
||||
if default_option == DynamicPathsDefaultValue.EMPTY_DICT:
|
||||
value = {}
|
||||
elif default_option == DynamicPathsDefaultValue.EMPTY_LIST:
|
||||
value = []
|
||||
if create_tuple:
|
||||
value = (value, key)
|
||||
current[p] = value
|
||||
@ -2018,34 +1853,6 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
current = current.setdefault(p, {})
|
||||
|
||||
values.update(result)
|
||||
|
||||
# Post-pass: convert index-keyed dicts to sorted lists for io.DynamicGroup fields
|
||||
for list_path in list_paths:
|
||||
parts = list_path.split(".")
|
||||
# Navigate to the parent container, then convert the leaf
|
||||
container = values
|
||||
for part in parts[:-1]:
|
||||
if not isinstance(container, dict) or part not in container:
|
||||
container = None
|
||||
break
|
||||
container = container[part]
|
||||
if container is None:
|
||||
continue
|
||||
leaf_key = parts[-1]
|
||||
leaf = container.get(leaf_key, None)
|
||||
if isinstance(leaf, dict):
|
||||
try:
|
||||
sorted_rows = [leaf[k] for k in sorted(leaf.keys(), key=int)]
|
||||
container[leaf_key] = sorted_rows
|
||||
except (ValueError, TypeError):
|
||||
# Keys are not all integers; leave as-is
|
||||
pass
|
||||
elif isinstance(leaf, list):
|
||||
# Already a list (e.g. the EMPTY_LIST default was applied above)
|
||||
pass
|
||||
elif leaf is None:
|
||||
container[leaf_key] = []
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@ -2610,9 +2417,7 @@ __all__ = [
|
||||
# Dynamic Types
|
||||
"MatchType",
|
||||
"DynamicCombo",
|
||||
"DynamicSlot",
|
||||
"Autogrow",
|
||||
"DynamicGroup",
|
||||
# Other classes
|
||||
"HiddenHolder",
|
||||
"Hidden",
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -316,36 +316,3 @@ VIDEO_TASKS_EXECUTION_TIME = {
|
||||
"1080p": 150,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SeedAudioConfig(BaseModel):
|
||||
format: str = Field(default="mp3")
|
||||
sample_rate: int = Field(default=24000)
|
||||
speech_rate: int = Field(default=0)
|
||||
loudness_rate: int = Field(default=0)
|
||||
pitch_rate: int = Field(default=0)
|
||||
|
||||
|
||||
class SeedAudioReference(BaseModel):
|
||||
speaker: str | None = Field(default=None)
|
||||
audio_data: str | None = Field(default=None)
|
||||
audio_url: str | None = Field(default=None)
|
||||
image_data: str | None = Field(default=None)
|
||||
image_url: str | None = Field(default=None)
|
||||
|
||||
|
||||
class SeedAudioRequest(BaseModel):
|
||||
model: str = Field(default="seed-audio-1.0")
|
||||
text_prompt: str = Field(...)
|
||||
references: list[SeedAudioReference] | None = Field(default=None)
|
||||
audio_config: SeedAudioConfig = Field(default_factory=SeedAudioConfig)
|
||||
watermark: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SeedAudioResponse(BaseModel):
|
||||
audio: str | None = Field(default=None)
|
||||
url: str | None = Field(default=None)
|
||||
duration: float | None = Field(default=None)
|
||||
original_duration: float | None = Field(default=None)
|
||||
code: int | None = Field(default=None)
|
||||
message: str | None = Field(default=None)
|
||||
|
||||
@ -33,6 +33,53 @@ class IdeogramColorPalette(
|
||||
)
|
||||
|
||||
|
||||
class ImageRequest(BaseModel):
|
||||
aspect_ratio: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.",
|
||||
)
|
||||
color_palette: Optional[Dict[str, Any]] = Field(
|
||||
None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.'
|
||||
)
|
||||
magic_prompt_option: Optional[str] = Field(
|
||||
None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')."
|
||||
)
|
||||
model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')")
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None,
|
||||
description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.',
|
||||
)
|
||||
num_images: Optional[int] = Field(
|
||||
1,
|
||||
description='Optional. Number of images to generate (1-8). Defaults to 1.',
|
||||
ge=1,
|
||||
le=8,
|
||||
)
|
||||
prompt: str = Field(
|
||||
..., description='Required. The prompt to use to generate the image.'
|
||||
)
|
||||
resolution: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.",
|
||||
)
|
||||
seed: Optional[int] = Field(
|
||||
None,
|
||||
description='Optional. A number between 0 and 2147483647.',
|
||||
ge=0,
|
||||
le=2147483647,
|
||||
)
|
||||
style_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.",
|
||||
)
|
||||
|
||||
|
||||
class IdeogramGenerateRequest(BaseModel):
|
||||
image_request: ImageRequest = Field(
|
||||
..., description='The image generation request parameters.'
|
||||
)
|
||||
|
||||
|
||||
class Datum(BaseModel):
|
||||
is_image_safe: Optional[bool] = Field(
|
||||
None, description='Indicates whether the image is considered safe.'
|
||||
@ -66,6 +113,20 @@ class StyleCode(RootModel[str]):
|
||||
root: str = Field(..., pattern='^[0-9A-Fa-f]{8}$')
|
||||
|
||||
|
||||
class Datum1(BaseModel):
|
||||
is_image_safe: Optional[bool] = None
|
||||
prompt: Optional[str] = None
|
||||
resolution: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
style_type: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class IdeogramV3IdeogramResponse(BaseModel):
|
||||
created: Optional[datetime] = None
|
||||
data: Optional[List[Datum1]] = None
|
||||
|
||||
|
||||
class RenderingSpeed1(str, Enum):
|
||||
TURBO = 'TURBO'
|
||||
DEFAULT = 'DEFAULT'
|
||||
|
||||
147
comfy_api_nodes/apis/stability.py
Normal file
147
comfy_api_nodes/apis/stability.py
Normal file
@ -0,0 +1,147 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, confloat
|
||||
|
||||
|
||||
class StabilityFormat(str, Enum):
|
||||
png = 'png'
|
||||
jpeg = 'jpeg'
|
||||
webp = 'webp'
|
||||
|
||||
|
||||
class StabilityAspectRatio(str, Enum):
|
||||
ratio_1_1 = "1:1"
|
||||
ratio_16_9 = "16:9"
|
||||
ratio_9_16 = "9:16"
|
||||
ratio_3_2 = "3:2"
|
||||
ratio_2_3 = "2:3"
|
||||
ratio_5_4 = "5:4"
|
||||
ratio_4_5 = "4:5"
|
||||
ratio_21_9 = "21:9"
|
||||
ratio_9_21 = "9:21"
|
||||
|
||||
|
||||
def get_stability_style_presets(include_none=True):
|
||||
presets = []
|
||||
if include_none:
|
||||
presets.append("None")
|
||||
return presets + [x.value for x in StabilityStylePreset]
|
||||
|
||||
|
||||
class StabilityStylePreset(str, Enum):
|
||||
_3d_model = "3d-model"
|
||||
analog_film = "analog-film"
|
||||
anime = "anime"
|
||||
cinematic = "cinematic"
|
||||
comic_book = "comic-book"
|
||||
digital_art = "digital-art"
|
||||
enhance = "enhance"
|
||||
fantasy_art = "fantasy-art"
|
||||
isometric = "isometric"
|
||||
line_art = "line-art"
|
||||
low_poly = "low-poly"
|
||||
modeling_compound = "modeling-compound"
|
||||
neon_punk = "neon-punk"
|
||||
origami = "origami"
|
||||
photographic = "photographic"
|
||||
pixel_art = "pixel-art"
|
||||
tile_texture = "tile-texture"
|
||||
|
||||
|
||||
class Stability_SD3_5_Model(str, Enum):
|
||||
sd3_5_large = "sd3.5-large"
|
||||
# sd3_5_large_turbo = "sd3.5-large-turbo"
|
||||
sd3_5_medium = "sd3.5-medium"
|
||||
|
||||
|
||||
class Stability_SD3_5_GenerationMode(str, Enum):
|
||||
text_to_image = "text-to-image"
|
||||
image_to_image = "image-to-image"
|
||||
|
||||
|
||||
class StabilityStable3_5Request(BaseModel):
|
||||
model: str = Field(...)
|
||||
mode: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
aspect_ratio: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
style_preset: Optional[str] = Field(None)
|
||||
cfg_scale: float = Field(...)
|
||||
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
|
||||
|
||||
|
||||
class StabilityUpscaleConservativeRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
creativity: Optional[confloat(ge=0.2, le=0.5)] = Field(None)
|
||||
|
||||
|
||||
class StabilityUpscaleCreativeRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
creativity: Optional[confloat(ge=0.1, le=0.5)] = Field(None)
|
||||
style_preset: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class StabilityStableUltraRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
aspect_ratio: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
style_preset: Optional[str] = Field(None)
|
||||
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
|
||||
|
||||
|
||||
class StabilityStableUltraResponse(BaseModel):
|
||||
image: Optional[str] = Field(None)
|
||||
finish_reason: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class StabilityResultsGetResponse(BaseModel):
|
||||
image: Optional[str] = Field(None)
|
||||
finish_reason: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
id: Optional[str] = Field(None)
|
||||
name: Optional[str] = Field(None)
|
||||
errors: Optional[list[str]] = Field(None)
|
||||
status: Optional[str] = Field(None)
|
||||
result: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class StabilityAsyncResponse(BaseModel):
|
||||
id: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class StabilityTextToAudioRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
duration: int = Field(190, ge=1, le=190)
|
||||
seed: int = Field(0, ge=0, le=4294967294)
|
||||
steps: int = Field(8, ge=4, le=8)
|
||||
output_format: str = Field("wav")
|
||||
|
||||
|
||||
class StabilityAudioToAudioRequest(StabilityTextToAudioRequest):
|
||||
strength: float = Field(0.01, ge=0.01, le=1.0)
|
||||
|
||||
|
||||
class StabilityAudioInpaintRequest(StabilityTextToAudioRequest):
|
||||
mask_start: int = Field(30, ge=0, le=190)
|
||||
mask_end: int = Field(190, ge=0, le=190)
|
||||
|
||||
|
||||
class StabilityAudioResponse(BaseModel):
|
||||
audio: Optional[str] = Field(None)
|
||||
@ -1,4 +1,3 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import math
|
||||
@ -21,10 +20,6 @@ from comfy_api_nodes.apis.bytedance import (
|
||||
GetAssetResponse,
|
||||
Image2VideoTaskCreationRequest,
|
||||
ImageTaskCreationResponse,
|
||||
SeedAudioConfig,
|
||||
SeedAudioReference,
|
||||
SeedAudioRequest,
|
||||
SeedAudioResponse,
|
||||
Seedance2TaskCreationRequest,
|
||||
SeedanceCreateAssetRequest,
|
||||
SeedanceCreateAssetResponse,
|
||||
@ -48,8 +43,6 @@ from comfy_api_nodes.apis.bytedance import (
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
audio_bytes_to_audio_input,
|
||||
audio_input_to_mp3,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
downscale_image_tensor_by_max_side,
|
||||
@ -58,14 +51,11 @@ from comfy_api_nodes.util import (
|
||||
image_tensor_pair_to_batch,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
upscale_image_tensor_to_min_pixels,
|
||||
upscale_video_to_min_pixels,
|
||||
validate_audio_duration,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
@ -2484,311 +2474,6 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode):
|
||||
return IO.NodeOutput(asset_id, resolved_group)
|
||||
|
||||
|
||||
MODE_TEXT = "text only"
|
||||
MODE_AUDIO = "audio reference"
|
||||
MODE_IMAGE = "image reference"
|
||||
MODE_SPEAKER = "preset voice"
|
||||
|
||||
# (speaker_id, display_label) for built-in TTS 2.0 voices; resolvable ids are account-scoped.
|
||||
SEED_AUDIO_PRESET_VOICES: list[tuple[str, str]] = [
|
||||
("zh_female_vv_uranus_bigtts", "Vivi (Female, multilingual)"),
|
||||
("zh_female_xiaohe_uranus_bigtts", "Mindy (Female, multilingual)"),
|
||||
("en_female_stokie_uranus_bigtts", "Stokie (Female, English)"),
|
||||
("en_female_dacey_uranus_bigtts", "Dacey (Female, English)"),
|
||||
("en_male_tim_uranus_bigtts", "Tim (Male, English)"),
|
||||
("zh_male_m191_uranus_bigtts", "Kian (Male, multilingual)"),
|
||||
("zh_male_taocheng_uranus_bigtts", "Cedric (Male, multilingual)"),
|
||||
("zh_male_sophie_uranus_bigtts", "Sophie (Female, multilingual)"),
|
||||
("zh_female_yingyujiaoxue_uranus_bigtts", "Jean (Female, multilingual)"),
|
||||
("zh_male_dayi_uranus_bigtts", "Magnus (Male, multilingual)"),
|
||||
("zh_female_mizai_uranus_bigtts", "Mabel (Female, multilingual)"),
|
||||
("zh_female_jitangnv_uranus_bigtts", "Nadia (Female, multilingual)"),
|
||||
("zh_female_meilinvyou_uranus_bigtts", "Opal (Female, multilingual)"),
|
||||
("zh_female_liuchangnv_uranus_bigtts", "Pearl (Female, multilingual)"),
|
||||
("zh_male_ruyayichen_uranus_bigtts", "Quentin (Male, multilingual)"),
|
||||
("zh_female_vivo_uranus_bigtts", "Vienna (Female, multilingual)"),
|
||||
("zh_female_xiaoai_uranus_bigtts", "Alina (Female, multilingual)"),
|
||||
("zh_female_cancan_uranus_bigtts", "Corinne (Female, multilingual)"),
|
||||
("zh_female_tianmeixiaoyuan_uranus_bigtts", "Esther (Female, multilingual)"),
|
||||
("zh_female_tianmeitaozi_uranus_bigtts", "Freya (Female, multilingual)"),
|
||||
("zh_female_shuangkuaisisi_uranus_bigtts", "Gigi (Female, multilingual)"),
|
||||
("zh_female_peiqi_uranus_bigtts", "Holly (Female, multilingual)"),
|
||||
("zh_female_xiaoxue_uranus_bigtts", "Lyla (Female, multilingual)"),
|
||||
("zh_female_yuanqi_uranus_bigtts", "Daisy (Female, multilingual)"),
|
||||
("zh_female_kefunvsheng_uranus_bigtts", "Tracy (Female, multilingual)"),
|
||||
("zh_male_shaonianzixin_uranus_bigtts", "Jess (Male, multilingual)"),
|
||||
("zh_female_linjianvhai_uranus_bigtts", "Pinky (Female, multilingual)"),
|
||||
("zh_female_kiwi_uranus_bigtts", "Sweety (Female, multilingual)"),
|
||||
("zh_female_sajiaoxuemei_uranus_bigtts", "Sandy (Female, multilingual)"),
|
||||
("de_male_seven_uranus_bigtts", "Sven (Male, German)"),
|
||||
("jp_female_minimi_uranus_bigtts", "Minimi (Female, Japanese)"),
|
||||
("fr_male_usseau_uranus_bigtts", "Usseau (Male, French)"),
|
||||
("es_male_felipe_uranus_bigtts", "Felipe (Male, Spanish)"),
|
||||
("id_male_han_uranus_bigtts", "Han (Male, Indonesian)"),
|
||||
("pt_male_martins_uranus_bigtts", "Martins (Male, Portuguese)"),
|
||||
("it_male_enzo_uranus_bigtts", "Enzo (Male, Italian)"),
|
||||
("kr_male_shane_uranus_bigtts", "Shane (Male, Korean)"),
|
||||
("zh_male_liufei_uranus_bigtts", "Felix (Male, Chinese)"),
|
||||
("zh_female_qingxinnvsheng_uranus_bigtts", "Celeste (Female, Chinese)"),
|
||||
("zh_male_sunwukong_uranus_bigtts", "Monkey King (Male, Chinese)"),
|
||||
]
|
||||
SEED_AUDIO_VOICE_OPTIONS = [label for _, label in SEED_AUDIO_PRESET_VOICES]
|
||||
SEED_AUDIO_VOICE_MAP = {label: speaker_id for speaker_id, label in SEED_AUDIO_PRESET_VOICES}
|
||||
|
||||
_AUDIO_TAG_RE = re.compile(r"@Audio(\d+)", re.IGNORECASE)
|
||||
|
||||
|
||||
def max_audio_tag(prompt: str) -> int:
|
||||
"""Highest N referenced as @AudioN in the prompt (0 if none)."""
|
||||
nums = [int(m) for m in _AUDIO_TAG_RE.findall(prompt or "")]
|
||||
return max(nums) if nums else 0
|
||||
|
||||
|
||||
def connected_audio_indices(reference_mode: dict) -> list[int]:
|
||||
"""Indices (1-based) of connected reference_audio sockets, in order."""
|
||||
return [
|
||||
i
|
||||
for i in range(1, 3 + 1)
|
||||
if reference_mode.get(f"reference_audio_{i}") is not None
|
||||
]
|
||||
|
||||
|
||||
def validate_seed_audio_inputs(
|
||||
text_prompt: str,
|
||||
mode: str,
|
||||
audio_indices: list[int],
|
||||
has_image: bool,
|
||||
preset_voice: str | None = None,
|
||||
) -> None:
|
||||
validate_string(text_prompt, field_name="text_prompt", min_length=1, max_length=3000)
|
||||
max_tag = max_audio_tag(text_prompt)
|
||||
|
||||
if mode == MODE_TEXT:
|
||||
if max_tag:
|
||||
raise ValueError(
|
||||
f"The prompt references @Audio{max_tag}, but reference mode is '{MODE_TEXT}'. "
|
||||
f"Switch to '{MODE_AUDIO}' and connect the reference clip(s)."
|
||||
)
|
||||
elif mode == MODE_AUDIO:
|
||||
if not audio_indices:
|
||||
raise ValueError(
|
||||
f"Reference mode '{MODE_AUDIO}' requires at least one reference_audio input "
|
||||
f"(or switch to '{MODE_TEXT}')."
|
||||
)
|
||||
if audio_indices != list(range(1, len(audio_indices) + 1)):
|
||||
raise ValueError(
|
||||
"Connect reference_audio inputs in order without gaps: reference_audio_1, then _2, then _3."
|
||||
)
|
||||
if max_tag > len(audio_indices):
|
||||
raise ValueError(
|
||||
f"The prompt references @Audio{max_tag}, but only {len(audio_indices)} "
|
||||
f"reference audio(s) are connected."
|
||||
)
|
||||
elif mode == MODE_IMAGE:
|
||||
if not has_image:
|
||||
raise ValueError(f"Reference mode '{MODE_IMAGE}' requires a reference_image input.")
|
||||
if max_tag:
|
||||
raise ValueError(
|
||||
f"@AudioN tags are not used in '{MODE_IMAGE}' mode; the prompt should contain "
|
||||
f"only the text to synthesize."
|
||||
)
|
||||
elif mode == MODE_SPEAKER:
|
||||
if not preset_voice or preset_voice not in SEED_AUDIO_VOICE_MAP:
|
||||
raise ValueError(f"Reference mode '{MODE_SPEAKER}' requires selecting a preset voice.")
|
||||
if max_tag > 1:
|
||||
raise ValueError(
|
||||
f"'{MODE_SPEAKER}' mode uses a single voice, so @Audio{max_tag} is out of range. "
|
||||
f"Remove the @AudioN tags — the whole prompt is read in the selected voice."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown reference mode: {mode!r}")
|
||||
|
||||
|
||||
class ByteDanceSeedAudioNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedAudio",
|
||||
display_name="ByteDance Seed Audio 1.0",
|
||||
category="partner/audio/ByteDance",
|
||||
description=(
|
||||
"Generate speech, music, sound effects and multi-speaker dialogue from a single prompt "
|
||||
"with ByteDance Seed Audio 1.0. Describe the voice(s), emotion, ambience, background music "
|
||||
"and sound effects in the prompt, and include the lines to speak. Optionally pick a built-in "
|
||||
"preset voice, clone voices from up to 3 reference clips (tagged @Audio1-3 in the prompt), "
|
||||
"or derive a voice from a character image. Up to 2 minutes of audio per run."
|
||||
),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"text_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip=(
|
||||
"Describe the voice(s), emotion, pacing, ambience, background music and sound "
|
||||
"effects, and include the lines to speak (name characters inline for dialogue). "
|
||||
"In 'audio reference' mode, refer to connected clips by order as @Audio1, @Audio2, "
|
||||
"@Audio3. Maximum 3000 characters."
|
||||
),
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"reference_mode",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(MODE_TEXT, []),
|
||||
IO.DynamicCombo.Option(
|
||||
MODE_AUDIO,
|
||||
[
|
||||
IO.Audio.Input(
|
||||
"reference_audio_1",
|
||||
optional=True,
|
||||
tooltip="Reference clip for voice cloning, tagged @Audio1 in the prompt. "
|
||||
"Up to 30s.",
|
||||
),
|
||||
IO.Audio.Input(
|
||||
"reference_audio_2",
|
||||
optional=True,
|
||||
tooltip="Reference clip tagged @Audio2 in the prompt. Up to 30s.",
|
||||
),
|
||||
IO.Audio.Input(
|
||||
"reference_audio_3",
|
||||
optional=True,
|
||||
tooltip="Reference clip tagged @Audio3 in the prompt. Up to 30s.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
MODE_IMAGE,
|
||||
[
|
||||
IO.Image.Input(
|
||||
"reference_image",
|
||||
optional=True,
|
||||
tooltip="A single character image; the model derives a voice from it. "
|
||||
"Cannot be combined with reference audio.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
MODE_SPEAKER,
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"preset_voice",
|
||||
options=SEED_AUDIO_VOICE_OPTIONS,
|
||||
default=SEED_AUDIO_VOICE_OPTIONS[0],
|
||||
tooltip="A built-in TTS 2.0 voice that reads the prompt. No reference "
|
||||
"clip needed, and @AudioN tags are not used in this mode.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip=(
|
||||
"How to condition the voice: 'text only' (describe everything in the prompt), "
|
||||
"'audio reference' (clone up to 3 voices, tagged @Audio1-3), 'image reference' "
|
||||
"(derive a voice from one character image), or 'preset voice' (pick a built-in "
|
||||
"named voice that reads the prompt)."
|
||||
),
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"sample_rate",
|
||||
options=["8000", "16000", "24000", "32000", "44100", "48000"],
|
||||
default="24000",
|
||||
tooltip="Output sample rate in Hz.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"speech_rate",
|
||||
default=0,
|
||||
min=-50,
|
||||
max=100,
|
||||
tooltip="Speaking speed. 0 = normal, 100 = 2.0x, -50 = 0.5x.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"loudness_rate",
|
||||
default=0,
|
||||
min=-50,
|
||||
max=100,
|
||||
tooltip="Loudness. 0 = normal, 100 = 2.0x, -50 = 0.5x.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"pitch_rate",
|
||||
default=0,
|
||||
min=-12,
|
||||
max=12,
|
||||
tooltip="Pitch shift in semitones (-12 to 12).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.2145, "format":{"suffix":"/minute","approximate":true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
text_prompt: str,
|
||||
reference_mode: dict,
|
||||
sample_rate: str,
|
||||
speech_rate: int,
|
||||
loudness_rate: int,
|
||||
pitch_rate: int,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
mode = reference_mode["reference_mode"]
|
||||
audio_indices = connected_audio_indices(reference_mode)
|
||||
image = reference_mode.get("reference_image")
|
||||
preset_voice = reference_mode.get("preset_voice")
|
||||
validate_seed_audio_inputs(text_prompt, mode, audio_indices, image is not None, preset_voice)
|
||||
|
||||
references: list[SeedAudioReference] | None = None
|
||||
if mode == MODE_AUDIO:
|
||||
references = []
|
||||
for i in audio_indices:
|
||||
clip = reference_mode[f"reference_audio_{i}"]
|
||||
validate_audio_duration(clip, max_duration=30.0)
|
||||
mp3_bytes = audio_input_to_mp3(clip).getvalue()
|
||||
references.append(SeedAudioReference(audio_data=base64.b64encode(mp3_bytes).decode("utf-8")))
|
||||
elif mode == MODE_IMAGE:
|
||||
image = upscale_image_tensor_to_min_pixels(image, 160_000)
|
||||
references = [SeedAudioReference(image_data=tensor_to_base64_string(image, mime_type="image/png"))]
|
||||
elif mode == MODE_SPEAKER:
|
||||
references = [SeedAudioReference(speaker=SEED_AUDIO_VOICE_MAP[preset_voice])]
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/byteplus/api/v3/tts/create", method="POST"),
|
||||
response_model=SeedAudioResponse,
|
||||
data=SeedAudioRequest(
|
||||
text_prompt=text_prompt,
|
||||
references=references,
|
||||
audio_config=SeedAudioConfig(
|
||||
sample_rate=int(sample_rate),
|
||||
speech_rate=speech_rate,
|
||||
loudness_rate=loudness_rate,
|
||||
pitch_rate=pitch_rate,
|
||||
),
|
||||
),
|
||||
)
|
||||
if not response.audio:
|
||||
raise Exception(
|
||||
f"Seed Audio returned no audio (code={response.code}): {response.message}"
|
||||
)
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response.audio)))
|
||||
|
||||
|
||||
class ByteDanceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -2805,7 +2490,6 @@ class ByteDanceExtension(ComfyExtension):
|
||||
ByteDance2ReferenceNode,
|
||||
ByteDanceCreateImageAsset,
|
||||
ByteDanceCreateVideoAsset,
|
||||
ByteDanceSeedAudioNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -5,7 +5,9 @@ from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from comfy_api_nodes.apis.ideogram import (
|
||||
IdeogramGenerateRequest,
|
||||
IdeogramGenerateResponse,
|
||||
ImageRequest,
|
||||
IdeogramV3Request,
|
||||
IdeogramV3EditRequest,
|
||||
IdeogramV4Request,
|
||||
@ -19,6 +21,101 @@ from comfy_api_nodes.util import (
|
||||
validate_string,
|
||||
)
|
||||
|
||||
V1_V1_RES_MAP = {
|
||||
"Auto":"AUTO",
|
||||
"512 x 1536":"RESOLUTION_512_1536",
|
||||
"576 x 1408":"RESOLUTION_576_1408",
|
||||
"576 x 1472":"RESOLUTION_576_1472",
|
||||
"576 x 1536":"RESOLUTION_576_1536",
|
||||
"640 x 1024":"RESOLUTION_640_1024",
|
||||
"640 x 1344":"RESOLUTION_640_1344",
|
||||
"640 x 1408":"RESOLUTION_640_1408",
|
||||
"640 x 1472":"RESOLUTION_640_1472",
|
||||
"640 x 1536":"RESOLUTION_640_1536",
|
||||
"704 x 1152":"RESOLUTION_704_1152",
|
||||
"704 x 1216":"RESOLUTION_704_1216",
|
||||
"704 x 1280":"RESOLUTION_704_1280",
|
||||
"704 x 1344":"RESOLUTION_704_1344",
|
||||
"704 x 1408":"RESOLUTION_704_1408",
|
||||
"704 x 1472":"RESOLUTION_704_1472",
|
||||
"720 x 1280":"RESOLUTION_720_1280",
|
||||
"736 x 1312":"RESOLUTION_736_1312",
|
||||
"768 x 1024":"RESOLUTION_768_1024",
|
||||
"768 x 1088":"RESOLUTION_768_1088",
|
||||
"768 x 1152":"RESOLUTION_768_1152",
|
||||
"768 x 1216":"RESOLUTION_768_1216",
|
||||
"768 x 1232":"RESOLUTION_768_1232",
|
||||
"768 x 1280":"RESOLUTION_768_1280",
|
||||
"768 x 1344":"RESOLUTION_768_1344",
|
||||
"832 x 960":"RESOLUTION_832_960",
|
||||
"832 x 1024":"RESOLUTION_832_1024",
|
||||
"832 x 1088":"RESOLUTION_832_1088",
|
||||
"832 x 1152":"RESOLUTION_832_1152",
|
||||
"832 x 1216":"RESOLUTION_832_1216",
|
||||
"832 x 1248":"RESOLUTION_832_1248",
|
||||
"864 x 1152":"RESOLUTION_864_1152",
|
||||
"896 x 960":"RESOLUTION_896_960",
|
||||
"896 x 1024":"RESOLUTION_896_1024",
|
||||
"896 x 1088":"RESOLUTION_896_1088",
|
||||
"896 x 1120":"RESOLUTION_896_1120",
|
||||
"896 x 1152":"RESOLUTION_896_1152",
|
||||
"960 x 832":"RESOLUTION_960_832",
|
||||
"960 x 896":"RESOLUTION_960_896",
|
||||
"960 x 1024":"RESOLUTION_960_1024",
|
||||
"960 x 1088":"RESOLUTION_960_1088",
|
||||
"1024 x 640":"RESOLUTION_1024_640",
|
||||
"1024 x 768":"RESOLUTION_1024_768",
|
||||
"1024 x 832":"RESOLUTION_1024_832",
|
||||
"1024 x 896":"RESOLUTION_1024_896",
|
||||
"1024 x 960":"RESOLUTION_1024_960",
|
||||
"1024 x 1024":"RESOLUTION_1024_1024",
|
||||
"1088 x 768":"RESOLUTION_1088_768",
|
||||
"1088 x 832":"RESOLUTION_1088_832",
|
||||
"1088 x 896":"RESOLUTION_1088_896",
|
||||
"1088 x 960":"RESOLUTION_1088_960",
|
||||
"1120 x 896":"RESOLUTION_1120_896",
|
||||
"1152 x 704":"RESOLUTION_1152_704",
|
||||
"1152 x 768":"RESOLUTION_1152_768",
|
||||
"1152 x 832":"RESOLUTION_1152_832",
|
||||
"1152 x 864":"RESOLUTION_1152_864",
|
||||
"1152 x 896":"RESOLUTION_1152_896",
|
||||
"1216 x 704":"RESOLUTION_1216_704",
|
||||
"1216 x 768":"RESOLUTION_1216_768",
|
||||
"1216 x 832":"RESOLUTION_1216_832",
|
||||
"1232 x 768":"RESOLUTION_1232_768",
|
||||
"1248 x 832":"RESOLUTION_1248_832",
|
||||
"1280 x 704":"RESOLUTION_1280_704",
|
||||
"1280 x 720":"RESOLUTION_1280_720",
|
||||
"1280 x 768":"RESOLUTION_1280_768",
|
||||
"1280 x 800":"RESOLUTION_1280_800",
|
||||
"1312 x 736":"RESOLUTION_1312_736",
|
||||
"1344 x 640":"RESOLUTION_1344_640",
|
||||
"1344 x 704":"RESOLUTION_1344_704",
|
||||
"1344 x 768":"RESOLUTION_1344_768",
|
||||
"1408 x 576":"RESOLUTION_1408_576",
|
||||
"1408 x 640":"RESOLUTION_1408_640",
|
||||
"1408 x 704":"RESOLUTION_1408_704",
|
||||
"1472 x 576":"RESOLUTION_1472_576",
|
||||
"1472 x 640":"RESOLUTION_1472_640",
|
||||
"1472 x 704":"RESOLUTION_1472_704",
|
||||
"1536 x 512":"RESOLUTION_1536_512",
|
||||
"1536 x 576":"RESOLUTION_1536_576",
|
||||
"1536 x 640":"RESOLUTION_1536_640",
|
||||
}
|
||||
|
||||
V1_V2_RATIO_MAP = {
|
||||
"1:1":"ASPECT_1_1",
|
||||
"4:3":"ASPECT_4_3",
|
||||
"3:4":"ASPECT_3_4",
|
||||
"16:9":"ASPECT_16_9",
|
||||
"9:16":"ASPECT_9_16",
|
||||
"2:1":"ASPECT_2_1",
|
||||
"1:2":"ASPECT_1_2",
|
||||
"3:2":"ASPECT_3_2",
|
||||
"2:3":"ASPECT_2_3",
|
||||
"4:5":"ASPECT_4_5",
|
||||
"5:4":"ASPECT_5_4",
|
||||
}
|
||||
|
||||
V3_RATIO_MAP = {
|
||||
"1:3":"1x3",
|
||||
@ -132,6 +229,298 @@ async def download_and_process_images(image_urls):
|
||||
return stacked_tensors
|
||||
|
||||
|
||||
class IdeogramV1(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV1",
|
||||
display_name="Ideogram V1",
|
||||
category="partner/image/Ideogram",
|
||||
description="Generates images using the Ideogram V1 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"turbo",
|
||||
default=False,
|
||||
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V1_V2_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Description of what to exclude from the image",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]),
|
||||
expr="""
|
||||
(
|
||||
$n := widgets.num_images;
|
||||
$base := (widgets.turbo = true) ? 0.0286 : 0.0858;
|
||||
{"type":"usd","usd": $round($base * $n, 2)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
turbo=False,
|
||||
aspect_ratio="1:1",
|
||||
magic_prompt_option="AUTO",
|
||||
seed=0,
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
):
|
||||
# Determine the model based on turbo setting
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
model = "V_1_TURBO" if turbo else "V_1"
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=IdeogramGenerateRequest(
|
||||
image_request=ImageRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
num_images=num_images,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
|
||||
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
)
|
||||
),
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
class IdeogramV2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV2",
|
||||
display_name="Ideogram V2",
|
||||
category="partner/image/Ideogram",
|
||||
description="Generates images using the Ideogram V2 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"turbo",
|
||||
default=False,
|
||||
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V1_V2_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=list(V1_V1_RES_MAP.keys()),
|
||||
default="Auto",
|
||||
tooltip="The resolution for image generation. "
|
||||
"If not set to AUTO, this overrides the aspect_ratio setting.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"style_type",
|
||||
options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
||||
default="NONE",
|
||||
tooltip="Style type for generation (V2 only)",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Description of what to exclude from the image",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
#"color_palette": (
|
||||
# IO.STRING,
|
||||
# {
|
||||
# "multiline": False,
|
||||
# "default": "",
|
||||
# "tooltip": "Color palette preset name or hex colors with weights",
|
||||
# },
|
||||
#),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]),
|
||||
expr="""
|
||||
(
|
||||
$n := widgets.num_images;
|
||||
$base := (widgets.turbo = true) ? 0.0715 : 0.1144;
|
||||
{"type":"usd","usd": $round($base * $n, 2)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
turbo=False,
|
||||
aspect_ratio="1:1",
|
||||
resolution="Auto",
|
||||
magic_prompt_option="AUTO",
|
||||
seed=0,
|
||||
style_type="NONE",
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
color_palette="",
|
||||
):
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
resolution = V1_V1_RES_MAP.get(resolution, None)
|
||||
# Determine the model based on turbo setting
|
||||
model = "V_2_TURBO" if turbo else "V_2"
|
||||
|
||||
# Handle resolution vs aspect_ratio logic
|
||||
# If resolution is not AUTO, it overrides aspect_ratio
|
||||
final_resolution = None
|
||||
final_aspect_ratio = None
|
||||
|
||||
if resolution != "AUTO":
|
||||
final_resolution = resolution
|
||||
else:
|
||||
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=IdeogramGenerateRequest(
|
||||
image_request=ImageRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
num_images=num_images,
|
||||
seed=seed,
|
||||
aspect_ratio=final_aspect_ratio,
|
||||
resolution=final_resolution,
|
||||
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
|
||||
style_type=style_type if style_type != "NONE" else None,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
color_palette=color_palette if color_palette else None,
|
||||
)
|
||||
),
|
||||
max_retries=1,
|
||||
)
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
class IdeogramV3(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -528,6 +917,8 @@ class IdeogramExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
IdeogramV1,
|
||||
IdeogramV2,
|
||||
IdeogramV3,
|
||||
IdeogramV4,
|
||||
]
|
||||
|
||||
932
comfy_api_nodes/nodes_stability.py
Normal file
932
comfy_api_nodes/nodes_stability.py
Normal file
@ -0,0 +1,932 @@
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, Input, IO
|
||||
from comfy_api_nodes.apis.stability import (
|
||||
StabilityUpscaleConservativeRequest,
|
||||
StabilityUpscaleCreativeRequest,
|
||||
StabilityAsyncResponse,
|
||||
StabilityResultsGetResponse,
|
||||
StabilityStable3_5Request,
|
||||
StabilityStableUltraRequest,
|
||||
StabilityStableUltraResponse,
|
||||
StabilityAspectRatio,
|
||||
Stability_SD3_5_Model,
|
||||
Stability_SD3_5_GenerationMode,
|
||||
get_stability_style_presets,
|
||||
StabilityTextToAudioRequest,
|
||||
StabilityAudioToAudioRequest,
|
||||
StabilityAudioInpaintRequest,
|
||||
StabilityAudioResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
validate_audio_duration,
|
||||
validate_string,
|
||||
audio_input_to_mp3,
|
||||
bytesio_to_image_tensor,
|
||||
tensor_to_bytesio,
|
||||
audio_bytes_to_audio_input,
|
||||
sync_op,
|
||||
poll_op,
|
||||
ApiEndpoint,
|
||||
)
|
||||
|
||||
import torch
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class StabilityPollStatus(str, Enum):
|
||||
finished = "finished"
|
||||
in_progress = "in_progress"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
def get_async_dummy_status(x: StabilityResultsGetResponse):
|
||||
if x.name is not None or x.errors is not None:
|
||||
return StabilityPollStatus.failed
|
||||
elif x.finish_reason is not None:
|
||||
return StabilityPollStatus.finished
|
||||
return StabilityPollStatus.in_progress
|
||||
|
||||
|
||||
class StabilityStableImageUltraNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityStableImageUltraNode",
|
||||
display_name="Stability AI Stable Image Ultra",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
|
||||
"elements, colors, and subjects will lead to better results. " +
|
||||
"To control the weight of a given word use the format `(word:weight)`," +
|
||||
"where `word` is the word you'd like to control the weight of and `weight`" +
|
||||
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
|
||||
"would convey a sky that was blue and green, but more green than blue.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=StabilityAspectRatio,
|
||||
default=StabilityAspectRatio.ratio_1_1,
|
||||
tooltip="Aspect ratio of generated image.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"style_preset",
|
||||
options=get_stability_style_presets(),
|
||||
tooltip="Optional desired style of generated image.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"image_denoise",
|
||||
default=0.5,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.08}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
style_preset: str,
|
||||
seed: int,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
negative_prompt: str = "",
|
||||
image_denoise: Optional[float] = 0.5,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
if image is not None:
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
|
||||
else:
|
||||
image_denoise = None
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
if style_preset == "None":
|
||||
style_preset = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
data=StabilityStableUltraRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
seed=seed,
|
||||
strength=image_denoise,
|
||||
style_preset=style_preset,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityStableImageSD_3_5Node",
|
||||
display_name="Stability AI Stable Diffusion 3.5 Image",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Stability_SD3_5_Model,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=StabilityAspectRatio,
|
||||
default=StabilityAspectRatio.ratio_1_1,
|
||||
tooltip="Aspect ratio of generated image.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"style_preset",
|
||||
options=get_stability_style_presets(),
|
||||
tooltip="Optional desired style of generated image.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"cfg_scale",
|
||||
default=4.0,
|
||||
min=1.0,
|
||||
max=10.0,
|
||||
step=0.1,
|
||||
tooltip="How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"image_denoise",
|
||||
default=0.5,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$contains(widgets.model,"large")
|
||||
? {"type":"usd","usd":0.065}
|
||||
: {"type":"usd","usd":0.035}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
style_preset: str,
|
||||
seed: int,
|
||||
cfg_scale: float,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
negative_prompt: str = "",
|
||||
image_denoise: Optional[float] = 0.5,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
mode = Stability_SD3_5_GenerationMode.text_to_image
|
||||
if image is not None:
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
|
||||
mode = Stability_SD3_5_GenerationMode.image_to_image
|
||||
aspect_ratio = None
|
||||
else:
|
||||
image_denoise = None
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
if style_preset == "None":
|
||||
style_preset = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
data=StabilityStable3_5Request(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
seed=seed,
|
||||
strength=image_denoise,
|
||||
style_preset=style_preset,
|
||||
cfg_scale=cfg_scale,
|
||||
model=model,
|
||||
mode=mode,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
||||
"""
|
||||
Upscale image with minimal alterations to 4K resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleConservativeNode",
|
||||
display_name="Stability AI Upscale Conservative",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"creativity",
|
||||
default=0.35,
|
||||
min=0.2,
|
||||
max=0.5,
|
||||
step=0.01,
|
||||
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.4}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
creativity: float,
|
||||
seed: int,
|
||||
negative_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
data=StabilityUpscaleConservativeRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
creativity=round(creativity,2),
|
||||
seed=seed,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
||||
"""
|
||||
Upscale image with minimal alterations to 4K resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleCreativeNode",
|
||||
display_name="Stability AI Upscale Creative",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"creativity",
|
||||
default=0.3,
|
||||
min=0.1,
|
||||
max=0.5,
|
||||
step=0.01,
|
||||
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"style_preset",
|
||||
options=get_stability_style_presets(),
|
||||
tooltip="Optional desired style of generated image.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.6}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
creativity: float,
|
||||
style_preset: str,
|
||||
seed: int,
|
||||
negative_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
if style_preset == "None":
|
||||
style_preset = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
|
||||
response_model=StabilityAsyncResponse,
|
||||
data=StabilityUpscaleCreativeRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
creativity=round(creativity,2),
|
||||
style_preset=style_preset,
|
||||
seed=seed,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
|
||||
response_model=StabilityResultsGetResponse,
|
||||
poll_interval=3,
|
||||
status_extractor=lambda x: get_async_dummy_status(x),
|
||||
)
|
||||
|
||||
if response_poll.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_poll.result)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityUpscaleFastNode(IO.ComfyNode):
|
||||
"""
|
||||
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleFastNode",
|
||||
display_name="Stability AI Upscale Fast",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.02}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, image: torch.Tensor) -> IO.NodeOutput:
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityTextToAudio(IO.ComfyNode):
|
||||
"""Generates high-quality music and sound effects from text descriptions."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityTextToAudio",
|
||||
display_name="Stability AI Text To Audio",
|
||||
category="partner/audio/Stability AI",
|
||||
essentials_category="Audio",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["stable-audio-2.5"],
|
||||
),
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=190,
|
||||
min=1,
|
||||
max=190,
|
||||
step=1,
|
||||
tooltip="Controls the duration in seconds of the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=8,
|
||||
min=4,
|
||||
max=8,
|
||||
step=1,
|
||||
tooltip="Controls the number of sampling steps.",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Audio.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.2}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=10000)
|
||||
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
|
||||
response_model=StabilityAudioResponse,
|
||||
data=payload,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
|
||||
|
||||
class StabilityAudioToAudio(IO.ComfyNode):
|
||||
"""Transforms existing audio samples into new high-quality compositions using text instructions."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityAudioToAudio",
|
||||
display_name="Stability AI Audio To Audio",
|
||||
category="partner/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["stable-audio-2.5"],
|
||||
),
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=190,
|
||||
min=1,
|
||||
max=190,
|
||||
step=1,
|
||||
tooltip="Controls the duration in seconds of the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=8,
|
||||
min=4,
|
||||
max=8,
|
||||
step=1,
|
||||
tooltip="Controls the number of sampling steps.",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"strength",
|
||||
default=1,
|
||||
min=0.01,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Parameter controls how much influence the audio parameter has on the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Audio.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.2}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=10000)
|
||||
validate_audio_duration(audio, 6, 190)
|
||||
payload = StabilityAudioToAudioRequest(
|
||||
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
|
||||
)
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
|
||||
response_model=StabilityAudioResponse,
|
||||
data=payload,
|
||||
content_type="multipart/form-data",
|
||||
files={"audio": audio_input_to_mp3(audio)},
|
||||
)
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
|
||||
|
||||
class StabilityAudioInpaint(IO.ComfyNode):
|
||||
"""Transforms part of existing audio sample using text instructions."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityAudioInpaint",
|
||||
display_name="Stability AI Audio Inpaint",
|
||||
category="partner/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["stable-audio-2.5"],
|
||||
),
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=190,
|
||||
min=1,
|
||||
max=190,
|
||||
step=1,
|
||||
tooltip="Controls the duration in seconds of the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=8,
|
||||
min=4,
|
||||
max=8,
|
||||
step=1,
|
||||
tooltip="Controls the number of sampling steps.",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"mask_start",
|
||||
default=30,
|
||||
min=0,
|
||||
max=190,
|
||||
step=1,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"mask_end",
|
||||
default=190,
|
||||
min=0,
|
||||
max=190,
|
||||
step=1,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Audio.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.2}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
audio: Input.Audio,
|
||||
duration: int,
|
||||
seed: int,
|
||||
steps: int,
|
||||
mask_start: int,
|
||||
mask_end: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=10000)
|
||||
if mask_end <= mask_start:
|
||||
raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})")
|
||||
validate_audio_duration(audio, 6, 190)
|
||||
|
||||
payload = StabilityAudioInpaintRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
steps=steps,
|
||||
mask_start=mask_start,
|
||||
mask_end=mask_end,
|
||||
)
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
|
||||
response_model=StabilityAudioResponse,
|
||||
data=payload,
|
||||
content_type="multipart/form-data",
|
||||
files={"audio": audio_input_to_mp3(audio)},
|
||||
)
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
|
||||
|
||||
class StabilityExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
StabilityStableImageUltraNode,
|
||||
StabilityStableImageSD_3_5Node,
|
||||
StabilityUpscaleConservativeNode,
|
||||
StabilityUpscaleCreativeNode,
|
||||
StabilityUpscaleFastNode,
|
||||
StabilityTextToAudio,
|
||||
StabilityAudioToAudio,
|
||||
StabilityAudioInpaint,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> StabilityExtension:
|
||||
return StabilityExtension()
|
||||
@ -26,7 +26,6 @@ from .conversions import (
|
||||
text_filepath_to_base64_string,
|
||||
text_filepath_to_data_uri,
|
||||
trim_video,
|
||||
upscale_image_tensor_to_min_pixels,
|
||||
upscale_video_to_min_pixels,
|
||||
video_to_base64_string,
|
||||
)
|
||||
@ -100,7 +99,6 @@ __all__ = [
|
||||
"text_filepath_to_base64_string",
|
||||
"text_filepath_to_data_uri",
|
||||
"trim_video",
|
||||
"upscale_image_tensor_to_min_pixels",
|
||||
"upscale_video_to_min_pixels",
|
||||
"video_to_base64_string",
|
||||
# Validation utilities
|
||||
|
||||
@ -448,15 +448,6 @@ def _compute_upscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[in
|
||||
return new_w, new_h
|
||||
|
||||
|
||||
def upscale_image_tensor_to_min_pixels(image: torch.Tensor, total_pixels: int) -> torch.Tensor:
|
||||
samples = image.movedim(-1, 1)
|
||||
dims = _compute_upscale_dims(samples.shape[3], samples.shape[2], int(total_pixels))
|
||||
if dims is None:
|
||||
return image
|
||||
new_w, new_h = dims
|
||||
return common_upscale(samples, new_w, new_h, "lanczos", "disabled").movedim(1, -1)
|
||||
|
||||
|
||||
def upscale_video_to_min_pixels(video: Input.Video, min_pixels: int) -> Input.Video:
|
||||
"""Upscale a video to meet at least ``min_pixels`` (w * h), preserving aspect ratio.
|
||||
|
||||
|
||||
@ -166,6 +166,32 @@ def boxes_to_regions(boxes, width: int, height: int) -> list:
|
||||
return regions
|
||||
|
||||
|
||||
def normalize_incoming_boxes(bboxes) -> list:
|
||||
if isinstance(bboxes, dict):
|
||||
frame = [bboxes]
|
||||
elif not isinstance(bboxes, list) or not bboxes:
|
||||
frame = []
|
||||
elif isinstance(bboxes[0], dict):
|
||||
frame = bboxes
|
||||
else:
|
||||
frame = bboxes[0] if isinstance(bboxes[0], list) else []
|
||||
boxes = []
|
||||
for box in frame:
|
||||
if not isinstance(box, dict):
|
||||
continue
|
||||
norm = {
|
||||
"x": box.get("x", 0),
|
||||
"y": box.get("y", 0),
|
||||
"width": box.get("width", 0),
|
||||
"height": box.get("height", 0),
|
||||
}
|
||||
meta = box.get("metadata")
|
||||
if isinstance(meta, dict):
|
||||
norm["metadata"] = meta
|
||||
boxes.append(norm)
|
||||
return boxes
|
||||
|
||||
|
||||
def _norm_bbox(region: dict) -> list[int]:
|
||||
def grid(value: float) -> int:
|
||||
return max(0, min(1000, round(value * 1000)))
|
||||
@ -199,6 +225,8 @@ def build_elements(regions: list) -> list:
|
||||
|
||||
|
||||
class CreateBoundingBoxes(io.ComfyNode):
|
||||
_last_incoming: dict = {}
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
editor_state = io.BoundingBoxes.Input(
|
||||
@ -217,6 +245,12 @@ class CreateBoundingBoxes(io.ComfyNode):
|
||||
optional=True,
|
||||
tooltip="Optional image used as background in the canvas and preview.",
|
||||
),
|
||||
io.BoundingBox.Input(
|
||||
"bboxes",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
tooltip="Bounding boxes from an upstream node. A new upstream value seeds the canvas; edits you make on the canvas take priority and are kept until the upstream value changes again.",
|
||||
),
|
||||
io.Int.Input("width", default=1024, min=64, max=16384, step=16,
|
||||
tooltip="Width of the canvas and the pixel grid for the bounding boxes."),
|
||||
io.Int.Input("height", default=1024, min=64, max=16384, step=16,
|
||||
@ -228,18 +262,33 @@ class CreateBoundingBoxes(io.ComfyNode):
|
||||
io.BoundingBox.Output(display_name="bboxes"),
|
||||
io.Array.Output(display_name="elements"),
|
||||
],
|
||||
hidden=[io.Hidden.unique_id],
|
||||
is_output_node=True,
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, editor_state=None, background=None) -> io.NodeOutput:
|
||||
regions = boxes_to_regions(editor_state, width, height)
|
||||
def execute(cls, width, height, editor_state=None, background=None, bboxes=None) -> io.NodeOutput:
|
||||
incoming = normalize_incoming_boxes(bboxes)
|
||||
node_id = cls.hidden.unique_id
|
||||
if incoming:
|
||||
changed = cls._last_incoming.get(node_id) != incoming
|
||||
if changed:
|
||||
cls._last_incoming[node_id] = incoming
|
||||
else:
|
||||
changed = False
|
||||
cls._last_incoming.pop(node_id, None)
|
||||
source = incoming if changed else (editor_state or incoming)
|
||||
regions = boxes_to_regions(source, width, height)
|
||||
preview = render_preview(regions, width, height, _bg_from_image(background))
|
||||
ui = {"dims": [width, height]}
|
||||
if incoming:
|
||||
ui["input_bboxes"] = incoming
|
||||
return io.NodeOutput(
|
||||
preview,
|
||||
fractions_to_bbox_frame(regions, width, height),
|
||||
build_elements(regions),
|
||||
ui={"dims": [width, height]},
|
||||
ui=ui,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -264,59 +264,6 @@ def annotated_filepath(name: str) -> tuple[str, str | None]:
|
||||
return name, base_dir
|
||||
|
||||
|
||||
# Content types a browser may execute or render inline. File endpoints that
|
||||
# serve user-controlled content must force these to download (and ideally set
|
||||
# Content-Disposition: attachment) to avoid stored XSS. Centralised here so the
|
||||
# /view and /userdata handlers can't drift apart. mimetypes.guess_type may
|
||||
# return either the text/* or application/* spelling depending on platform, so
|
||||
# both are listed.
|
||||
DANGEROUS_CONTENT_TYPES = {
|
||||
'text/html', 'text/html-sandboxed', 'application/xhtml+xml',
|
||||
'text/javascript', 'application/javascript', 'application/x-javascript',
|
||||
'application/ecmascript', 'text/css',
|
||||
'image/svg+xml', 'application/xml', 'text/xml',
|
||||
# message/rfc822 (.mht/.mhtml) can carry script in some browsers.
|
||||
'message/rfc822',
|
||||
}
|
||||
|
||||
|
||||
def is_dangerous_content_type(content_type: str | None) -> bool:
|
||||
"""Return True if a browser may execute or render `content_type` inline.
|
||||
|
||||
Normalises before matching so the check can't be slipped past with a
|
||||
charset/boundary parameter (``text/html; charset=utf-8``) or casing
|
||||
(``TEXT/HTML``). Any XML dialect (``*+xml`` or ``*/xml``) is treated as
|
||||
dangerous because XML can carry inline script via stylesheet/entity tricks,
|
||||
which also covers the ``application/{xslt,rss,atom,rdf}+xml`` family without
|
||||
enumerating each one. Endpoints serving user-controlled content should route
|
||||
a dangerous type to ``application/octet-stream`` + ``Content-Disposition:
|
||||
attachment`` + ``X-Content-Type-Options: nosniff``.
|
||||
"""
|
||||
if not content_type:
|
||||
return False
|
||||
normalized = content_type.split(';', 1)[0].strip().lower()
|
||||
if normalized in DANGEROUS_CONTENT_TYPES:
|
||||
return True
|
||||
return normalized.endswith('+xml') or normalized.endswith('/xml')
|
||||
|
||||
|
||||
def is_within_directory(directory: str, target: str) -> bool:
|
||||
"""Return True if `target` resolves to a path inside `directory`.
|
||||
|
||||
Uses realpath on both operands so that a symlink placed inside `directory`
|
||||
that points elsewhere cannot escape the containment check at open time.
|
||||
"""
|
||||
try:
|
||||
directory = os.path.realpath(directory)
|
||||
target = os.path.realpath(target)
|
||||
return os.path.commonpath((directory, target)) == directory
|
||||
except ValueError:
|
||||
# ValueError is raised by realpath() on a path with an embedded null
|
||||
# byte, and by commonpath() on Windows when the paths are on different
|
||||
# drives. In either case the target is not safely within the directory.
|
||||
return False
|
||||
|
||||
|
||||
def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
|
||||
name, base_dir = annotated_filepath(name)
|
||||
|
||||
@ -326,12 +273,7 @@ def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
|
||||
else:
|
||||
base_dir = get_input_directory() # fallback path
|
||||
|
||||
filepath = os.path.abspath(os.path.join(base_dir, name))
|
||||
# Prevent path traversal: the resolved path must stay within base_dir.
|
||||
# repr() the name in the message so a crafted value can't inject log lines.
|
||||
if not is_within_directory(base_dir, filepath):
|
||||
raise ValueError("Invalid file path: {!r}".format(name))
|
||||
return filepath
|
||||
return os.path.join(base_dir, name)
|
||||
|
||||
|
||||
def exists_annotated_filepath(name) -> bool:
|
||||
@ -340,10 +282,7 @@ def exists_annotated_filepath(name) -> bool:
|
||||
if base_dir is None:
|
||||
base_dir = get_input_directory() # fallback path
|
||||
|
||||
filepath = os.path.abspath(os.path.join(base_dir, name))
|
||||
# Treat traversal attempts as non-existent rather than probing the filesystem.
|
||||
if not is_within_directory(base_dir, filepath):
|
||||
return False
|
||||
filepath = os.path.join(base_dir, name)
|
||||
return os.path.exists(filepath)
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.45.20
|
||||
comfyui-workflow-templates==0.11.2
|
||||
comfyui-workflow-templates==0.11.1
|
||||
comfyui-embedded-docs==0.5.6
|
||||
torch
|
||||
torchsde
|
||||
|
||||
26
server.py
26
server.py
@ -127,7 +127,6 @@ def create_cors_middleware(allowed_origin: str):
|
||||
|
||||
return cors_middleware
|
||||
|
||||
|
||||
def is_loopback(host):
|
||||
if host is None:
|
||||
return False
|
||||
@ -617,30 +616,15 @@ class PromptServer():
|
||||
or 'application/octet-stream'
|
||||
)
|
||||
|
||||
# For security, force renderable/active types (HTML, JS,
|
||||
# CSS, SVG, XML — anything that can carry inline <script>
|
||||
# and execute in the page origin) to download instead of
|
||||
# displaying inline, preventing stored XSS. The
|
||||
# attachment disposition is the load-bearing guard: a
|
||||
# bare filename= hint does not force a download per
|
||||
# RFC 6266, so we only attach it on the dangerous branch
|
||||
# to avoid breaking inline display of legitimate images.
|
||||
# Escape backslash/quote per RFC 6266 quoted-string so a
|
||||
# filename containing a double quote (which passes the
|
||||
# ".."/leading-slash filter above) can't break out of the
|
||||
# header's quoted-string and malform the disposition.
|
||||
safe_filename = filename.replace("\\", "\\\\").replace('"', '\\"')
|
||||
disposition = f"filename=\"{safe_filename}\""
|
||||
if folder_paths.is_dangerous_content_type(content_type):
|
||||
content_type = 'application/octet-stream'
|
||||
disposition = f"attachment; filename=\"{safe_filename}\""
|
||||
# For security, force certain mimetypes to download instead of display
|
||||
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
|
||||
content_type = 'application/octet-stream' # Forces download
|
||||
|
||||
return web.FileResponse(
|
||||
file,
|
||||
headers={
|
||||
"Content-Disposition": disposition,
|
||||
"Content-Type": content_type,
|
||||
"X-Content-Type-Options": "nosniff"
|
||||
"Content-Disposition": f"filename=\"{filename}\"",
|
||||
"Content-Type": content_type
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
import contextlib
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
@ -11,40 +9,6 @@ import requests
|
||||
from helpers import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
def test_download_svg_forced_to_attachment(http: requests.Session, api_base: str):
|
||||
"""GHSA-779p-m5rp-r4h4 CISA-5 (sibling route): an uploaded SVG must never be
|
||||
served inline from GET /api/assets/{id}/content, or an inline <script> runs
|
||||
in the app origin (stored XSS). Even with disposition=inline requested, a
|
||||
dangerous content type must be forced to application/octet-stream +
|
||||
Content-Disposition: attachment + nosniff. Regression guard for the stale
|
||||
inline blocklist that previously omitted image/svg+xml and ignored the
|
||||
centralized folder_paths.is_dangerous_content_type check.
|
||||
"""
|
||||
svg = b'<svg xmlns="http://www.w3.org/2000/svg"><script>alert(1)</script></svg>'
|
||||
files = {"file": ("evil.svg", svg, "image/svg+xml")}
|
||||
form_data = {
|
||||
"tags": json.dumps(["models", "checkpoints", "unit-tests", "svgxss"]),
|
||||
"name": "evil.svg",
|
||||
}
|
||||
up = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
|
||||
body = up.json()
|
||||
assert up.status_code in (200, 201), body
|
||||
aid = body["id"]
|
||||
try:
|
||||
r = http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline", timeout=120)
|
||||
r.content
|
||||
assert r.status_code == 200
|
||||
ct = r.headers.get("Content-Type", "").lower()
|
||||
cd = r.headers.get("Content-Disposition", "").lower()
|
||||
assert "svg" not in ct, f"SVG served with a renderable content type: {ct!r}"
|
||||
assert ct.startswith("application/octet-stream"), f"expected octet-stream, got {ct!r}"
|
||||
assert "attachment" in cd, f"inline disposition not overridden to attachment: {cd!r}"
|
||||
assert r.headers.get("X-Content-Type-Options", "").lower() == "nosniff"
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
|
||||
|
||||
|
||||
def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
|
||||
@ -1,204 +0,0 @@
|
||||
"""Unit tests for io.DynamicGroup: expansion/reconstruction (0-row and N-row cases)."""
|
||||
import sys
|
||||
import types
|
||||
import pytest
|
||||
|
||||
# Stub torch (type-hint only in _io.py; real torch not available in unit-test env)
|
||||
if "torch" not in sys.modules:
|
||||
_torch_stub = types.ModuleType("torch")
|
||||
_torch_stub.Tensor = object # type: ignore[attr-defined]
|
||||
sys.modules["torch"] = _torch_stub
|
||||
|
||||
from comfy_api.latest._io import ( # noqa: E402
|
||||
DynamicGroup,
|
||||
Float,
|
||||
Int,
|
||||
String,
|
||||
Boolean,
|
||||
get_finalized_class_inputs,
|
||||
build_nested_inputs,
|
||||
create_input_dict_v1,
|
||||
setup_dynamic_input_funcs,
|
||||
)
|
||||
|
||||
# Make sure dynamic input funcs are registered (may already be done at import time)
|
||||
setup_dynamic_input_funcs()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_class_inputs(group_input: DynamicGroup.Input) -> dict:
|
||||
"""Wrap a DynamicGroup.Input into the required/optional dict structure."""
|
||||
return create_input_dict_v1([group_input])
|
||||
|
||||
|
||||
def _run(group_input: DynamicGroup.Input, live_values: dict) -> dict:
|
||||
"""End-to-end helper: expand schema + reconstruct values.
|
||||
|
||||
Mirrors the production split in execution.py:
|
||||
1. get_finalized_class_inputs (schema expansion, line 162)
|
||||
2. build_nested_inputs (value reconstruction, line 281)
|
||||
|
||||
The two steps are separate in production because the engine resolves
|
||||
linked node outputs between them, but in tests we supply values directly.
|
||||
"""
|
||||
class_inputs = _make_class_inputs(group_input)
|
||||
_, _, v3_data = get_finalized_class_inputs(class_inputs, live_values)
|
||||
return build_nested_inputs(dict(live_values), v3_data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDynamicGroupInputConstruction:
|
||||
def test_basic_construction(self):
|
||||
inp = DynamicGroup.Input(
|
||||
"loras",
|
||||
template=[
|
||||
Float.Input("strength", default=1.0),
|
||||
String.Input("name"),
|
||||
],
|
||||
min=0,
|
||||
max=10,
|
||||
)
|
||||
assert inp.id == "loras"
|
||||
assert inp.min == 0
|
||||
assert inp.max == 10
|
||||
assert len(inp.template) == 2
|
||||
|
||||
def test_get_all_includes_self_and_template(self):
|
||||
inp = DynamicGroup.Input(
|
||||
"items",
|
||||
template=[Float.Input("value")],
|
||||
)
|
||||
all_inputs = inp.get_all()
|
||||
assert all_inputs[0] is inp
|
||||
assert all_inputs[1].id == "value"
|
||||
|
||||
def test_as_dict_has_template_min_max(self):
|
||||
inp = DynamicGroup.Input(
|
||||
"items",
|
||||
template=[Float.Input("val", default=0.5)],
|
||||
min=1,
|
||||
max=5,
|
||||
)
|
||||
d = inp.as_dict()
|
||||
assert "template" in d
|
||||
assert d["min"] == 1
|
||||
assert d["max"] == 5
|
||||
|
||||
def test_duplicate_field_ids_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input(
|
||||
"bad",
|
||||
template=[Float.Input("x"), Float.Input("x")],
|
||||
)
|
||||
|
||||
def test_empty_template_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input("bad", template=[])
|
||||
|
||||
def test_min_gt_max_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input("bad", template=[Float.Input("x")], min=5, max=3)
|
||||
|
||||
def test_max_exceeds_limit_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input("bad", template=[Float.Input("x")], max=101)
|
||||
|
||||
def test_dynamic_input_in_template_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input(
|
||||
"bad",
|
||||
template=[DynamicGroup.Input("nested", template=[Float.Input("x")])],
|
||||
)
|
||||
|
||||
def test_validate_calls_through(self):
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", min=-1.0, max=1.0)])
|
||||
inp.validate() # should not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 0-row case
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestZeroRows:
|
||||
def test_empty_live_inputs_produces_empty_list(self):
|
||||
"""With min=0 and no live values, the result should be an empty list."""
|
||||
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
|
||||
assert _run(inp, {}).get("loras") == []
|
||||
|
||||
def test_min_zero_with_values(self):
|
||||
"""min=0 but 2 rows of live data."""
|
||||
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
|
||||
result = _run(inp, {"loras.0.strength": 0.8, "loras.1.strength": 0.5})
|
||||
assert result["loras"] == [{"strength": 0.8}, {"strength": 0.5}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# N-row case
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNRows:
|
||||
def test_two_rows_two_fields(self):
|
||||
"""Two rows with two fields each produce a list[dict]."""
|
||||
inp = DynamicGroup.Input(
|
||||
"loras",
|
||||
template=[String.Input("lora_name"), Float.Input("strength", default=1.0)],
|
||||
min=0, max=50,
|
||||
)
|
||||
result = _run(inp, {
|
||||
"loras.0.lora_name": "model_a.safetensors", "loras.0.strength": 0.9,
|
||||
"loras.1.lora_name": "model_b.safetensors", "loras.1.strength": 0.4,
|
||||
})
|
||||
assert result["loras"] == [
|
||||
{"lora_name": "model_a.safetensors", "strength": 0.9},
|
||||
{"lora_name": "model_b.safetensors", "strength": 0.4},
|
||||
]
|
||||
|
||||
def test_rows_are_sorted_by_index(self):
|
||||
"""Rows must be in ascending index order even if dict iteration is unordered."""
|
||||
inp = DynamicGroup.Input("items", template=[Int.Input("v", default=0)], min=0, max=10)
|
||||
result = _run(inp, {"items.0.v": 10, "items.2.v": 30, "items.1.v": 20})
|
||||
assert [row["v"] for row in result["items"]] == [10, 20, 30]
|
||||
|
||||
def test_min_rows_schema_slots(self):
|
||||
"""With min=2 and no live data, 2 slots must appear in the expanded schema."""
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
out, _, _ = get_finalized_class_inputs(_make_class_inputs(inp), {})
|
||||
all_slots = {**out.get("required", {}), **out.get("optional", {})}
|
||||
assert "items.0.val" in all_slots
|
||||
assert "items.1.val" in all_slots
|
||||
|
||||
def test_min_rows_reconstructs_when_no_values(self):
|
||||
"""min=2 with NO live values must still yield a 2-element list,
|
||||
not collapse to [] (regression: parent-path clobber)."""
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
result = _run(inp, {})
|
||||
assert len(result["items"]) == 2
|
||||
assert all("val" in row for row in result["items"])
|
||||
|
||||
def test_min_rows_reconstructs_with_partial_values(self):
|
||||
"""min=2 with only the first row's value present still yields 2 rows."""
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
result = _run(inp, {"items.0.val": 0.7})
|
||||
assert len(result["items"]) == 2
|
||||
assert result["items"][0]["val"] == 0.7
|
||||
assert result["items"][1]["val"] is None
|
||||
|
||||
def test_list_paths_in_v3_data(self):
|
||||
"""list_paths must contain the group id so build_nested_inputs knows to convert."""
|
||||
inp = DynamicGroup.Input("things", template=[Boolean.Input("flag")], min=0, max=5)
|
||||
_, _, v3_data = get_finalized_class_inputs(_make_class_inputs(inp), {})
|
||||
assert "things" in v3_data.get("list_paths", set())
|
||||
|
||||
def test_no_leftover_flat_keys(self):
|
||||
"""Flat keys must be consumed; only the reconstructed list remains."""
|
||||
inp = DynamicGroup.Input("rows", template=[Float.Input("x", default=0.0)], min=0, max=5)
|
||||
result = _run(inp, {"rows.0.x": 1.0, "rows.1.x": 2.0})
|
||||
assert "rows.0.x" not in result
|
||||
assert "rows.1.x" not in result
|
||||
assert isinstance(result["rows"], list)
|
||||
@ -53,11 +53,8 @@ def test_annotated_filepath():
|
||||
|
||||
def test_get_annotated_filepath():
|
||||
default_dir = "/default/dir"
|
||||
# get_annotated_filepath now normalizes with os.path.abspath (part of the
|
||||
# GHSA-779p traversal hardening), so compare against the normalized form —
|
||||
# on Windows abspath also prepends the current drive letter.
|
||||
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.abspath(os.path.join(default_dir, "test.txt"))
|
||||
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.abspath(os.path.join(folder_paths.get_output_directory(), "test.txt"))
|
||||
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
|
||||
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
|
||||
|
||||
def test_add_model_folder_path_append(clear_folder_paths):
|
||||
folder_paths.add_model_folder_path("test_folder", "/default/path", is_default=True)
|
||||
|
||||
@ -1,192 +0,0 @@
|
||||
"""CI unit tests for FIX #2 of GHSA-779p-m5rp-r4h4.
|
||||
|
||||
Path traversal / hardening in app/model_manager.py get_model_preview
|
||||
(route /experiment/models/preview/{folder}/{path_index}/{filename:.*}).
|
||||
|
||||
Reference: https://github.com/Comfy-Org/ComfyUI/security/advisories/GHSA-779p-m5rp-r4h4
|
||||
"""
|
||||
import pytest
|
||||
import yarl
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from aiohttp import web
|
||||
from unittest.mock import patch
|
||||
from app.model_manager import ModelFileManager
|
||||
|
||||
pytestmark = (
|
||||
pytest.mark.asyncio
|
||||
) # This applies the asyncio mark to all test functions in the module
|
||||
|
||||
@pytest.fixture
|
||||
def model_manager():
|
||||
return ModelFileManager()
|
||||
|
||||
@pytest.fixture
|
||||
def app(model_manager):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
model_manager.add_routes(routes)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
|
||||
async def test_legit_preview_returns_200(aiohttp_client, app, tmp_path):
|
||||
"""Sanity: a real preview PNG inside the model folder is served as webp 200."""
|
||||
img = Image.new('RGB', (16, 16), color=(255, 0, 128))
|
||||
img.save(tmp_path / "test_model.png", format='PNG')
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/0/test_model.png')
|
||||
|
||||
assert response.status == 200
|
||||
assert response.content_type == 'image/webp'
|
||||
|
||||
img_bytes = BytesIO(await response.read())
|
||||
served = Image.open(img_bytes)
|
||||
assert served.format
|
||||
assert served.format.lower() == 'webp'
|
||||
served.close()
|
||||
|
||||
|
||||
async def test_non_integer_path_index_returns_400(aiohttp_client, app, tmp_path):
|
||||
"""A non-integer path_index segment must be rejected with 400."""
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/abc/test_model.png')
|
||||
|
||||
assert response.status == 400
|
||||
|
||||
|
||||
async def test_out_of_range_path_index_returns_404(aiohttp_client, app, tmp_path):
|
||||
"""A path_index beyond the configured folder list must return 404."""
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/99/test_model.png')
|
||||
|
||||
assert response.status == 404
|
||||
|
||||
|
||||
async def test_empty_filename_returns_400(aiohttp_client, app, tmp_path):
|
||||
"""The "{filename:.*}" capture also matches the empty string (trailing
|
||||
slash). It would resolve to the folder itself and must be rejected with 400."""
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/0/')
|
||||
|
||||
assert response.status == 400
|
||||
|
||||
|
||||
async def test_path_traversal_in_filename_returns_403(aiohttp_client, app, tmp_path):
|
||||
"""Path traversal in {filename} must be rejected with 403 and must NOT read
|
||||
a file outside the configured model directory.
|
||||
|
||||
GOTCHA: aiohttp/yarl collapses literal ``../`` dot-segments out of the URL
|
||||
path before it reaches the handler, which would make this test vacuously
|
||||
pass (the request would hit a different/non-existent route). We percent-encode
|
||||
the dots and slashes (``%2e%2e%2f``) and send the URL with
|
||||
``yarl.URL(..., encoded=True)`` so the bytes survive client-side normalization
|
||||
untouched; aiohttp's router then percent-decodes them into ``match_info``,
|
||||
delivering the literal ``../`` traversal to the handler's ``{filename:.*}``
|
||||
capture.
|
||||
|
||||
Without the fix the handler computes
|
||||
``os.path.normpath(os.path.join(folder, "../../../../etc/hosts"))``, which
|
||||
escapes ``tmp_path`` and would be passed straight to get_model_previews ->
|
||||
Image.open, serving bytes from outside the model dir (200/served bytes). The
|
||||
is_within_directory() containment check is the load-bearing fix that turns
|
||||
that escape into a 403.
|
||||
"""
|
||||
# Sanity-anchor: a legit preview exists inside tmp_path, so a 200 path is
|
||||
# genuinely reachable — proving the 403 below is the containment check
|
||||
# firing, not an unrelated 404.
|
||||
img = Image.new('RGB', (16, 16), color=(255, 0, 128))
|
||||
img.save(tmp_path / "test_model.png", format='PNG')
|
||||
|
||||
# Percent-encoded "../../../../etc/hosts" so yarl does not collapse the
|
||||
# dot-segments before the request leaves the client.
|
||||
encoded_traversal = '%2e%2e%2f' * 4 + 'etc%2fhosts'
|
||||
raw_path = '/experiment/models/preview/test_folder/0/' + encoded_traversal
|
||||
url = yarl.URL(raw_path, encoded=True)
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get(url)
|
||||
|
||||
# Confirm the traversal actually reached the handler intact: a 200 here
|
||||
# would mean either normalization stripped the ``../`` (vacuous pass) or
|
||||
# the containment check failed open and served outside-dir bytes.
|
||||
assert response.status == 403, (
|
||||
f"expected 403 from is_within_directory() containment check, "
|
||||
f"got {response.status}; traversal may have been normalized away "
|
||||
f"or the fix failed open"
|
||||
)
|
||||
body = await response.read()
|
||||
assert body == b"", "403 response must not carry any file bytes"
|
||||
|
||||
|
||||
async def test_symlink_companion_preview_returns_403(aiohttp_client, app, tmp_path):
|
||||
"""A companion preview file is selected by a glob inside get_model_previews
|
||||
and then opened. If that companion is a symlink whose path is in-dir but
|
||||
whose target escapes the model folder, it must be rejected with 403 — not
|
||||
served. The requested path itself stays in-dir (so the first containment
|
||||
check passes); the load-bearing fix is the SECOND is_within_directory check
|
||||
on the file actually opened.
|
||||
"""
|
||||
model_dir = tmp_path / "models"
|
||||
model_dir.mkdir()
|
||||
secret_dir = tmp_path / "secret"
|
||||
secret_dir.mkdir()
|
||||
# A real image OUTSIDE the model dir — valid, so without the fix Image.open
|
||||
# would succeed and its bytes would be served (200).
|
||||
secret = secret_dir / "secret.png"
|
||||
Image.new('RGB', (8, 8), color=(0, 0, 0)).save(secret, format='PNG')
|
||||
# Companion preview, in-dir by name but a symlink escaping the model dir.
|
||||
# (No real model file is needed — get_model_previews globs companions by
|
||||
# basename, and omitting a .safetensors avoids the metadata-header read.)
|
||||
companion = model_dir / "model.preview.png"
|
||||
try:
|
||||
companion.symlink_to(secret)
|
||||
except (OSError, NotImplementedError):
|
||||
pytest.skip("symlinks not supported on this platform/filesystem")
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(model_dir)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/0/model.safetensors')
|
||||
|
||||
assert response.status == 403, (
|
||||
f"expected 403 — the globbed companion preview is a symlink resolving "
|
||||
f"outside the model dir and must not be served; got {response.status}"
|
||||
)
|
||||
assert await response.read() == b""
|
||||
|
||||
|
||||
async def test_null_byte_in_filename_no_500(aiohttp_client, app, tmp_path):
|
||||
"""A NUL byte in the filename must yield a clean client rejection, not a 500
|
||||
from an uncaught ValueError in is_within_directory's realpath() call."""
|
||||
raw_path = '/experiment/models/preview/test_folder/0/' + 'a%00b'
|
||||
url = yarl.URL(raw_path, encoded=True)
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get(url)
|
||||
|
||||
assert response.status != 500, (
|
||||
f"NUL byte produced a 500 (uncaught ValueError); expected a clean "
|
||||
f"4xx rejection, got {response.status}"
|
||||
)
|
||||
assert 400 <= response.status < 500
|
||||
@ -1,165 +0,0 @@
|
||||
"""Security tests for GHSA-779p-m5rp-r4h4 — FIX #3.
|
||||
|
||||
Path traversal in folder_paths.get_annotated_filepath / exists_annotated_filepath,
|
||||
plus the shared is_within_directory() containment helper.
|
||||
|
||||
These are pure-function tests (no running server). The input/output/temp
|
||||
directories are pointed at tmp_path via the folder_paths setters, so a crafted
|
||||
name containing `../`, an absolute path, or a symlink that escapes the base
|
||||
directory must be rejected.
|
||||
|
||||
Reference: https://github.com/Comfy-Org/ComfyUI/security/advisories/GHSA-779p-m5rp-r4h4
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
import folder_paths
|
||||
from comfy.options import enable_args_parsing
|
||||
enable_args_parsing()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sandbox(tmp_path):
|
||||
"""Point folder_paths' input/output/temp dirs at a real temp sandbox.
|
||||
|
||||
Yields the realpath'd base, input, output and temp directories. The original
|
||||
directory values are restored afterward so tests stay isolated.
|
||||
"""
|
||||
base = os.path.realpath(str(tmp_path))
|
||||
input_dir = os.path.join(base, "input")
|
||||
output_dir = os.path.join(base, "output")
|
||||
temp_dir = os.path.join(base, "temp")
|
||||
for d in (input_dir, output_dir, temp_dir):
|
||||
os.makedirs(d, exist_ok=True)
|
||||
|
||||
orig_input = folder_paths.get_input_directory()
|
||||
orig_output = folder_paths.get_output_directory()
|
||||
orig_temp = folder_paths.get_temp_directory()
|
||||
|
||||
folder_paths.set_input_directory(input_dir)
|
||||
folder_paths.set_output_directory(output_dir)
|
||||
folder_paths.set_temp_directory(temp_dir)
|
||||
|
||||
yield {
|
||||
"base": base,
|
||||
"input": input_dir,
|
||||
"output": output_dir,
|
||||
"temp": temp_dir,
|
||||
}
|
||||
|
||||
folder_paths.set_input_directory(orig_input)
|
||||
folder_paths.set_output_directory(orig_output)
|
||||
folder_paths.set_temp_directory(orig_temp)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_within_directory() — the shared containment helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_is_within_directory_legit_child(sandbox):
|
||||
base = sandbox["input"]
|
||||
child = os.path.join(base, "sub", "image.png")
|
||||
assert folder_paths.is_within_directory(base, child) is True
|
||||
|
||||
|
||||
def test_is_within_directory_dotdot_escape(sandbox):
|
||||
base = sandbox["input"]
|
||||
escape = os.path.join(base, "..", "..", "etc", "passwd")
|
||||
assert folder_paths.is_within_directory(base, escape) is False
|
||||
|
||||
|
||||
def test_is_within_directory_symlink_escape(sandbox):
|
||||
"""A symlink created INSIDE base that points OUTSIDE base must not pass.
|
||||
|
||||
This is the key new hardening: is_within_directory realpath()s both operands,
|
||||
so a symlink planted in the base directory can't be used to read files
|
||||
elsewhere. We create a real on-disk symlink and a real secret target to
|
||||
verify the check actually resolves the link.
|
||||
"""
|
||||
base = sandbox["input"]
|
||||
|
||||
# A directory living outside the base, holding a secret file.
|
||||
outside = os.path.join(sandbox["base"], "outside_secret_dir")
|
||||
os.makedirs(outside, exist_ok=True)
|
||||
secret = os.path.join(outside, "secret.txt")
|
||||
with open(secret, "w") as f:
|
||||
f.write("top secret")
|
||||
|
||||
# Plant a symlink inside base that points at the outside directory.
|
||||
# symlink creation can require elevated privileges / Developer Mode on
|
||||
# Windows, so skip cleanly where it isn't available (same guard as the
|
||||
# sibling test in test_ghsa_779p_02_preview_traversal.py).
|
||||
link = os.path.join(base, "escape_link")
|
||||
try:
|
||||
os.symlink(outside, link)
|
||||
except (OSError, NotImplementedError):
|
||||
pytest.skip("symlinks not supported on this platform/filesystem")
|
||||
|
||||
# Accessing the secret "through" the in-base symlink must be rejected.
|
||||
target_via_link = os.path.join(link, "secret.txt")
|
||||
assert folder_paths.is_within_directory(base, target_via_link) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_annotated_filepath()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_annotated_filepath_legit_name(sandbox):
|
||||
result = folder_paths.get_annotated_filepath("image.png")
|
||||
assert result == os.path.join(sandbox["input"], "image.png")
|
||||
assert folder_paths.is_within_directory(sandbox["input"], result)
|
||||
|
||||
|
||||
def test_get_annotated_filepath_input_annotation(sandbox):
|
||||
result = folder_paths.get_annotated_filepath("image.png [input]")
|
||||
assert result == os.path.join(sandbox["input"], "image.png")
|
||||
|
||||
|
||||
def test_get_annotated_filepath_output_annotation(sandbox):
|
||||
result = folder_paths.get_annotated_filepath("image.png [output]")
|
||||
assert result == os.path.join(sandbox["output"], "image.png")
|
||||
|
||||
|
||||
def test_get_annotated_filepath_temp_annotation(sandbox):
|
||||
result = folder_paths.get_annotated_filepath("image.png [temp]")
|
||||
assert result == os.path.join(sandbox["temp"], "image.png")
|
||||
|
||||
|
||||
def test_get_annotated_filepath_dotdot_raises(sandbox):
|
||||
with pytest.raises(ValueError):
|
||||
folder_paths.get_annotated_filepath("../etc/passwd")
|
||||
|
||||
|
||||
def test_get_annotated_filepath_dotdot_with_annotation_raises(sandbox):
|
||||
with pytest.raises(ValueError):
|
||||
folder_paths.get_annotated_filepath("../../etc/passwd [output]")
|
||||
|
||||
|
||||
def test_get_annotated_filepath_absolute_escape_raises(sandbox):
|
||||
with pytest.raises(ValueError):
|
||||
folder_paths.get_annotated_filepath("/etc/passwd")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# exists_annotated_filepath()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_exists_annotated_filepath_existing_legit_file(sandbox):
|
||||
real = os.path.join(sandbox["input"], "real.png")
|
||||
with open(real, "w") as f:
|
||||
f.write("data")
|
||||
assert folder_paths.exists_annotated_filepath("real.png") is True
|
||||
|
||||
|
||||
def test_exists_annotated_filepath_traversal_returns_false(sandbox):
|
||||
"""A traversal name must return False without raising and without probing
|
||||
outside the base directory (must never reach os.path.exists for the escape).
|
||||
"""
|
||||
# /etc/passwd exists on POSIX; the function must still report False because
|
||||
# the resolved path escapes the input directory.
|
||||
assert folder_paths.exists_annotated_filepath("../../../../../../etc/passwd") is False
|
||||
|
||||
|
||||
def test_exists_annotated_filepath_absolute_returns_false(sandbox):
|
||||
assert folder_paths.exists_annotated_filepath("/etc/passwd") is False
|
||||
@ -1,147 +0,0 @@
|
||||
"""
|
||||
CI unit tests for FIX #4 of GHSA-779p-m5rp-r4h4.
|
||||
|
||||
Stored-XSS hardening on GET /userdata/{file} in app/user_manager.py.
|
||||
|
||||
User data files are arbitrary user-supplied content and must never render
|
||||
inline in the app origin. The getuserdata handler:
|
||||
- forces Content-Type to application/octet-stream for any type in
|
||||
folder_paths.DANGEROUS_CONTENT_TYPES (text/html, image/svg+xml,
|
||||
text/javascript, ...),
|
||||
- sets X-Content-Type-Options: nosniff,
|
||||
- sets Content-Disposition: attachment.
|
||||
|
||||
These tests pre-create files in tmp_path and GET them back, asserting the
|
||||
secure response headers. They mirror the aiohttp_client pattern in
|
||||
tests-unit/prompt_server_test/user_manager_test.py.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
from aiohttp import web
|
||||
from app.user_manager import UserManager
|
||||
|
||||
pytestmark = (
|
||||
pytest.mark.asyncio
|
||||
) # This applies the asyncio mark to all test functions in the module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_manager(tmp_path):
|
||||
um = UserManager()
|
||||
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
|
||||
tmp_path, file
|
||||
) if file else tmp_path
|
||||
return um
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(user_manager):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
user_manager.add_routes(routes)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
|
||||
async def test_html_served_as_octet_stream(aiohttp_client, app, tmp_path):
|
||||
(tmp_path / "evil.html").write_text(
|
||||
"<script>console.log('xss-marker-ghsa-779p')</script>"
|
||||
)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata/evil.html")
|
||||
|
||||
assert resp.status == 200
|
||||
ct = resp.headers.get("Content-Type", "")
|
||||
# The load-bearing assertion: a .html file must NOT be served as text/html.
|
||||
assert "text/html" not in ct.lower(), (
|
||||
f"Content-Type {ct!r} would let a browser render/execute the file (stored XSS)."
|
||||
)
|
||||
assert ct == "application/octet-stream"
|
||||
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
|
||||
assert "attachment" in resp.headers.get("Content-Disposition", "")
|
||||
|
||||
|
||||
async def test_svg_served_as_octet_stream(aiohttp_client, app, tmp_path):
|
||||
(tmp_path / "evil.svg").write_text(
|
||||
'<?xml version="1.0"?>'
|
||||
'<svg xmlns="http://www.w3.org/2000/svg">'
|
||||
'<script>console.log("xss-marker-ghsa-779p")</script>'
|
||||
"</svg>"
|
||||
)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata/evil.svg")
|
||||
|
||||
assert resp.status == 200
|
||||
ct = resp.headers.get("Content-Type", "")
|
||||
# SVG can carry inline <script>; it must not be served as image/svg+xml.
|
||||
assert "svg" not in ct.lower(), (
|
||||
f"Content-Type {ct!r} would let a browser render the SVG and execute embedded scripts."
|
||||
)
|
||||
assert ct == "application/octet-stream"
|
||||
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
|
||||
assert "attachment" in resp.headers.get("Content-Disposition", "")
|
||||
|
||||
|
||||
async def test_js_served_as_octet_stream(aiohttp_client, app, tmp_path):
|
||||
(tmp_path / "evil.js").write_text("alert('xss-marker-ghsa-779p')")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata/evil.js")
|
||||
|
||||
assert resp.status == 200
|
||||
ct = resp.headers.get("Content-Type", "").lower()
|
||||
# Must not be served as any executable JavaScript content type.
|
||||
assert "javascript" not in ct, (
|
||||
f"Content-Type {ct!r} is an executable JS type."
|
||||
)
|
||||
assert "ecmascript" not in ct, (
|
||||
f"Content-Type {ct!r} is an executable JS type."
|
||||
)
|
||||
assert ct == "application/octet-stream"
|
||||
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
|
||||
assert "attachment" in resp.headers.get("Content-Disposition", "")
|
||||
|
||||
|
||||
async def test_xml_dialect_served_as_octet_stream(aiohttp_client, app, tmp_path):
|
||||
"""An XML dialect outside the original blocklist (.xslt -> application/xslt+xml)
|
||||
must still be forced to download. This pins the normalised *+xml family rule
|
||||
in folder_paths.is_dangerous_content_type(); a plain set-membership test would
|
||||
have served this inline."""
|
||||
(tmp_path / "evil.xslt").write_text(
|
||||
'<?xml version="1.0"?>'
|
||||
'<xsl:stylesheet version="1.0" '
|
||||
'xmlns:xsl="http://www.w3.org/1999/XSL/Transform">'
|
||||
"<!-- xss-marker-ghsa-779p -->"
|
||||
"</xsl:stylesheet>"
|
||||
)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata/evil.xslt")
|
||||
|
||||
assert resp.status == 200
|
||||
ct = resp.headers.get("Content-Type", "")
|
||||
assert ct == "application/octet-stream", (
|
||||
f"Content-Type {ct!r}: an *+xml dialect must be forced to octet-stream "
|
||||
f"(it can carry inline script via stylesheet/entity tricks)."
|
||||
)
|
||||
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
|
||||
assert "attachment" in resp.headers.get("Content-Disposition", "")
|
||||
|
||||
|
||||
async def test_benign_txt_still_served(aiohttp_client, app, tmp_path):
|
||||
(tmp_path / "note.txt").write_text("just a harmless note")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata/note.txt")
|
||||
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == "just a harmless note"
|
||||
ct = resp.headers.get("Content-Type", "")
|
||||
# text/plain is not in the dangerous set, so it is acceptable here. The
|
||||
# defence-in-depth headers must still be present regardless.
|
||||
assert "text/plain" in ct.lower()
|
||||
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
|
||||
assert "attachment" in resp.headers.get("Content-Disposition", "")
|
||||
@ -1,138 +0,0 @@
|
||||
"""CI unit guard for FIX #5 of GHSA-779p-m5rp-r4h4 — the /view forced-download set.
|
||||
|
||||
Vuln #5 was stored XSS via SVG upload: the /view endpoint's Content-Type
|
||||
blocklist covered text/html, text/javascript, etc. but was missing
|
||||
image/svg+xml, so an uploaded SVG carrying an inline <script> was served as
|
||||
image/svg+xml and executed in the page origin when rendered.
|
||||
|
||||
The /view forced-download decision lives in the view_image closure registered by
|
||||
server.PromptServer.add_routes (server.py ~line 596), which calls
|
||||
`folder_paths.is_dangerous_content_type(content_type)` — a normalising check that
|
||||
strips charset/boundary parameters and casing and folds in the whole */xml and
|
||||
*+xml dialect family — rather than a bypassable raw
|
||||
`content_type in folder_paths.DANGEROUS_CONTENT_TYPES` membership test. On a match
|
||||
it rewrites the response to application/octet-stream with a
|
||||
Content-Disposition: attachment header. server.py cannot be imported in a unit
|
||||
test (importing it spins up the full PromptServer/aiohttp app and its global side
|
||||
effects), so these tests pin the underlying dangerous-content data
|
||||
(folder_paths.DANGEROUS_CONTENT_TYPES) and the normalising is_dangerous_content_type()
|
||||
helper that the closure actually calls.
|
||||
|
||||
The end-to-end /view assertion (upload an SVG, GET /view, confirm the response
|
||||
is not served as image/svg+xml) lives in the live POC at
|
||||
.security/pocs/test_security_ghsa_779p.py::TestViewSvgContentType, which
|
||||
requires a running server. This file is the fast, server-free CI guard on the
|
||||
set contents so the blocklist can't silently regress.
|
||||
"""
|
||||
|
||||
import folder_paths
|
||||
|
||||
|
||||
# Active/renderable content types that must be forced to download. Each of these
|
||||
# can carry an inline <script> (or otherwise execute) in the page origin if a
|
||||
# browser renders it. image/svg+xml is the original missing item that caused
|
||||
# vuln #5.
|
||||
DANGEROUS = [
|
||||
'image/svg+xml',
|
||||
'application/xml',
|
||||
'text/xml',
|
||||
'text/html',
|
||||
'text/html-sandboxed',
|
||||
'application/xhtml+xml',
|
||||
'text/javascript',
|
||||
'application/javascript',
|
||||
'application/x-javascript',
|
||||
'application/ecmascript',
|
||||
'text/css',
|
||||
]
|
||||
|
||||
# Benign image types that browsers display inline and that must keep rendering;
|
||||
# forcing these to download would break legitimate previews.
|
||||
BENIGN_INLINE_IMAGES = [
|
||||
'image/png',
|
||||
'image/jpeg',
|
||||
'image/webp',
|
||||
'image/gif',
|
||||
]
|
||||
|
||||
|
||||
def test_dangerous_content_types_is_a_set():
|
||||
assert isinstance(folder_paths.DANGEROUS_CONTENT_TYPES, set)
|
||||
|
||||
|
||||
def test_svg_is_in_the_blocklist():
|
||||
"""The specific item whose absence caused vuln #5."""
|
||||
assert 'image/svg+xml' in folder_paths.DANGEROUS_CONTENT_TYPES, (
|
||||
"image/svg+xml missing from DANGEROUS_CONTENT_TYPES — this is exactly "
|
||||
"the regression that reopens GHSA-779p-m5rp-r4h4 vuln #5 (stored XSS "
|
||||
"via SVG upload on /view)."
|
||||
)
|
||||
|
||||
|
||||
def test_all_dangerous_types_present():
|
||||
missing = [ct for ct in DANGEROUS if ct not in folder_paths.DANGEROUS_CONTENT_TYPES]
|
||||
assert not missing, (
|
||||
f"DANGEROUS_CONTENT_TYPES is missing required active/renderable types: "
|
||||
f"{missing}. The /view closure only forces a download for content types "
|
||||
f"in this set; anything missing here is served inline and can execute."
|
||||
)
|
||||
|
||||
|
||||
def test_benign_inline_image_types_absent():
|
||||
leaked = [ct for ct in BENIGN_INLINE_IMAGES if ct in folder_paths.DANGEROUS_CONTENT_TYPES]
|
||||
assert not leaked, (
|
||||
f"Benign inline-displayable image types found in DANGEROUS_CONTENT_TYPES: "
|
||||
f"{leaked}. Forcing these to download would break legitimate image "
|
||||
f"previews in /view — they must keep rendering inline."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_dangerous_content_type() — the normalising check the /view and /userdata
|
||||
# handlers now call instead of a raw `in DANGEROUS_CONTENT_TYPES` membership
|
||||
# test. An exact-string membership test was bypassable with a charset parameter
|
||||
# or odd casing, and missed the wider XML dialect family; these tests pin the
|
||||
# normalisation so that bypass can't reopen.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_function_matches_plain_dangerous_types():
|
||||
for ct in DANGEROUS:
|
||||
assert folder_paths.is_dangerous_content_type(ct) is True, ct
|
||||
|
||||
|
||||
def test_function_strips_parameters_and_casing():
|
||||
"""A charset/boundary parameter or casing must not slip a type past the check.
|
||||
|
||||
This is the bypass surfaced by review: the /view blake3 branch can serve an
|
||||
attacker-controlled, unvalidated asset mime_type like 'text/html; charset=utf-8',
|
||||
which an exact-string set test missed.
|
||||
"""
|
||||
for ct in (
|
||||
'text/html; charset=utf-8',
|
||||
'TEXT/HTML',
|
||||
'Text/HTML; charset=UTF-8',
|
||||
'image/svg+xml; charset=utf-8',
|
||||
' text/html ',
|
||||
):
|
||||
assert folder_paths.is_dangerous_content_type(ct) is True, ct
|
||||
|
||||
|
||||
def test_function_covers_xml_dialect_family():
|
||||
"""Any *+xml / */xml dialect is dangerous without enumerating each one."""
|
||||
for ct in (
|
||||
'application/xslt+xml',
|
||||
'application/rss+xml',
|
||||
'application/atom+xml',
|
||||
'application/rdf+xml',
|
||||
'application/mathml+xml',
|
||||
'message/rfc822',
|
||||
):
|
||||
assert folder_paths.is_dangerous_content_type(ct) is True, ct
|
||||
|
||||
|
||||
def test_function_allows_benign_and_empty():
|
||||
for ct in BENIGN_INLINE_IMAGES + ['application/octet-stream', 'text/plain']:
|
||||
assert folder_paths.is_dangerous_content_type(ct) is False, ct
|
||||
# None / empty (mimetypes.guess_type miss) must not be treated as dangerous.
|
||||
assert folder_paths.is_dangerous_content_type(None) is False
|
||||
assert folder_paths.is_dangerous_content_type('') is False
|
||||
Reference in New Issue
Block a user