mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-20 16:36:41 +08:00
Compare commits
8 Commits
synap5e/fe
...
feature/cu
| Author | SHA1 | Date | |
|---|---|---|---|
| 8f82b16993 | |||
| 72fe66a18b | |||
| 07ff14ae02 | |||
| ba1c039a04 | |||
| 6220400ad5 | |||
| af55a2308f | |||
| 3a649984f2 | |||
| a145651cc0 |
@ -38,7 +38,7 @@
|
||||
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
|
||||
- ComfyUI natively supports the latest open-source state of the art models.
|
||||
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
|
||||
- It is available on Windows, Linux, and macOS, locally with our [desktop application](https://www.comfy.org/download), our [portable install](#installing) or on our [cloud](https://www.comfy.org/cloud).
|
||||
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
|
||||
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
|
||||
- It integrates seamlessly into production pipelines with our API endpoints.
|
||||
|
||||
|
||||
@ -39,7 +39,6 @@ from app.assets.services import (
|
||||
update_asset_metadata,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.services.path_utils import compute_paths_for_response
|
||||
from app.assets.services.tagging import list_tag_histogram
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
@ -161,16 +160,9 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
|
||||
preview_url = None
|
||||
else:
|
||||
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
|
||||
if result.ref.file_path:
|
||||
paths = compute_paths_for_response(result.ref.file_path)
|
||||
file_path, display_name = paths if paths else (None, None)
|
||||
else:
|
||||
file_path, display_name = None, None
|
||||
return schemas_out.Asset(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
file_path=file_path,
|
||||
display_name=display_name,
|
||||
asset_hash=result.asset.hash if result.asset else None,
|
||||
size=int(result.asset.size_bytes) if result.asset else None,
|
||||
mime_type=result.asset.mime_type if result.asset else None,
|
||||
|
||||
@ -10,8 +10,6 @@ class Asset(BaseModel):
|
||||
|
||||
id: str
|
||||
name: str
|
||||
file_path: str | None = None
|
||||
display_name: str | None = None
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
|
||||
@ -8,8 +8,6 @@ from app.assets.helpers import normalize_tags
|
||||
|
||||
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
|
||||
|
||||
RootCategory = Literal["input", "output", "temp", "models"]
|
||||
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build list of (folder_name, base_paths[]) for all model locations.
|
||||
@ -67,109 +65,35 @@ def validate_path_within_base(candidate: str, base: str) -> None:
|
||||
raise ValueError("destination escapes base directory")
|
||||
|
||||
|
||||
def compute_paths_for_response(
|
||||
file_path: str,
|
||||
) -> tuple[str, str | None] | None:
|
||||
"""Compute (file_path, display_name) for an Asset response.
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
`file_path` is a logical locator under the asset namespace: `<root>/<rel>`
|
||||
for input/output/temp assets and `<root>/<bucket>/<rel>` for model assets.
|
||||
`display_name` is the path below that root or model bucket, suitable for UI
|
||||
labels. Returns None when the absolute path is not under a known asset root.
|
||||
For non-model paths, returns None.
|
||||
"""
|
||||
try:
|
||||
root, bucket, rel = get_asset_root_bucket_and_filepath(file_path)
|
||||
root_category, rel_path = get_asset_category_and_relative_path(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
display_name = rel or None
|
||||
if bucket is None:
|
||||
response_file_path = f"{root}/{rel}" if rel else root
|
||||
else:
|
||||
response_file_path = f"{root}/{bucket}/{rel}" if rel else f"{root}/{bucket}"
|
||||
return response_file_path, display_name
|
||||
p = Path(rel_path)
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
|
||||
def compute_display_name(file_path: str) -> str | None:
|
||||
"""Return the asset's `display_name`, or None for unknown paths."""
|
||||
result = compute_paths_for_response(file_path)
|
||||
return result[1] if result else None
|
||||
|
||||
|
||||
def compute_file_path(file_path: str) -> str | None:
|
||||
"""Return the asset's logical `file_path`, or None for unknown paths."""
|
||||
result = compute_paths_for_response(file_path)
|
||||
return result[0] if result else None
|
||||
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the path relative to the asset root or model category, using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
/.../input/sub/image.png -> "sub/image.png"
|
||||
|
||||
For unknown paths, returns None.
|
||||
"""
|
||||
return compute_display_name(file_path)
|
||||
|
||||
|
||||
def get_asset_root_bucket_and_filepath(
|
||||
file_path: str,
|
||||
) -> tuple[RootCategory, str | None, str]:
|
||||
"""Decompose an absolute path into (root, bucket, path-under-bucket).
|
||||
|
||||
`bucket` is only set for model assets. The returned relative path always
|
||||
uses `/` separators and is empty when the path is exactly the matched root.
|
||||
|
||||
Raises:
|
||||
ValueError: path does not belong to any known root.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _check_is_within(child: str, parent: str) -> bool:
|
||||
return Path(child).is_relative_to(parent)
|
||||
|
||||
def _compute_relative(child: str, parent: str) -> str:
|
||||
# Normalize relative path, stripping any leading ".." components
|
||||
# by anchoring to root (os.sep) then computing relpath back from it.
|
||||
rel = os.path.relpath(
|
||||
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
|
||||
)
|
||||
return "" if rel == "." else rel.replace(os.sep, "/")
|
||||
|
||||
for root_tag, getter in (
|
||||
("input", folder_paths.get_input_directory),
|
||||
("output", folder_paths.get_output_directory),
|
||||
("temp", folder_paths.get_temp_directory),
|
||||
):
|
||||
base = os.path.abspath(getter())
|
||||
if _check_is_within(fp_abs, base):
|
||||
return root_tag, None, _compute_relative(fp_abs, base)
|
||||
|
||||
# models: check deepest matching base to avoid ambiguity.
|
||||
best: tuple[int, str, str] | None = None
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _check_is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
return "models", bucket, rel_inside
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not within input, output, temp, or configured model bases: {file_path}"
|
||||
)
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_asset_category_and_relative_path(
|
||||
file_path: str,
|
||||
) -> tuple[RootCategory, str]:
|
||||
) -> tuple[Literal["input", "output", "temp", "models"], str]:
|
||||
"""Determine which root category a file path belongs to.
|
||||
|
||||
Categories:
|
||||
|
||||
@ -44,14 +44,7 @@ class BackgroundRemovalModel():
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
H, W = image.shape[1], image.shape[2]
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
|
||||
|
||||
if pixel_values.shape[0] > 1:
|
||||
out = torch.cat([
|
||||
self.model(pixel_values=pixel_values[i:i+1])
|
||||
for i in range(pixel_values.shape[0])
|
||||
], dim=0)
|
||||
else:
|
||||
out = self.model(pixel_values=pixel_values)
|
||||
out = self.model(pixel_values=pixel_values)
|
||||
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
|
||||
|
||||
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
|
||||
@ -141,7 +141,8 @@ manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", he
|
||||
vram_group = parser.add_mutually_exclusive_group()
|
||||
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||
vram_group.add_argument("--lowvram", action="store_true", help="Doesn't do anything if dynamic vram is enabled. If dynamic vram isn't being used this option makes the text encoders run on the CPU.")
|
||||
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
|
||||
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
|
||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||
|
||||
|
||||
@ -1691,13 +1691,6 @@ class HiDreamO1(BaseModel):
|
||||
if text_input_ids is None or noise is None:
|
||||
return out
|
||||
|
||||
# handle area conds
|
||||
area = kwargs.get("area", None)
|
||||
if area is not None:
|
||||
crop_h = min(noise.shape[-2] - area[2], area[0])
|
||||
crop_w = min(noise.shape[-1] - area[3], area[1])
|
||||
noise = torch.empty((noise.shape[0], 3, crop_h, crop_w), dtype=noise.dtype, device=noise.device)
|
||||
|
||||
conds = build_extra_conds(
|
||||
text_input_ids, noise,
|
||||
ref_images=kwargs.get("reference_latents", None),
|
||||
|
||||
@ -1493,30 +1493,27 @@ class ModelPatcher:
|
||||
self.unpatch_hooks()
|
||||
self.clear_cached_hook_weights()
|
||||
|
||||
def model_state_dict_for_saving(self, model=None, prefix=""):
|
||||
if model is None:
|
||||
model = self.model
|
||||
|
||||
original_state_dict = model.state_dict()
|
||||
output_state_dict = {}
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
original_state_dict = self.model.diffusion_model.state_dict()
|
||||
unet_state_dict = {}
|
||||
keys = list(original_state_dict)
|
||||
while len(keys) > 0:
|
||||
k = keys.pop(0)
|
||||
v = original_state_dict[k]
|
||||
op_keys = k.rsplit('.', 1)
|
||||
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
|
||||
output_state_dict[k] = v
|
||||
unet_state_dict[k] = v
|
||||
continue
|
||||
try:
|
||||
op = comfy.utils.get_attr(model, op_keys[0])
|
||||
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
|
||||
except:
|
||||
output_state_dict[k] = v
|
||||
unet_state_dict[k] = v
|
||||
continue
|
||||
if not op or not hasattr(op, "comfy_cast_weights") or \
|
||||
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
|
||||
output_state_dict[k] = v
|
||||
unet_state_dict[k] = v
|
||||
continue
|
||||
key = prefix + k
|
||||
key = "diffusion_model." + k
|
||||
weight = comfy.utils.get_attr(self.model, key)
|
||||
if isinstance(weight, QuantizedTensor) and k in original_state_dict:
|
||||
qt_state_dict = weight.state_dict(k)
|
||||
@ -1524,14 +1521,10 @@ class ModelPatcher:
|
||||
for group_key in (x for x in qt_state_dict if x in original_state_dict):
|
||||
if group_key in keys:
|
||||
keys.remove(group_key)
|
||||
output_state_dict.pop(group_key, "")
|
||||
output_state_dict[group_key] = LazyCastingParamPiece(caster, prefix + group_key, original_state_dict[group_key])
|
||||
unet_state_dict.pop(group_key, "")
|
||||
unet_state_dict[group_key] = LazyCastingParamPiece(caster, "diffusion_model." + group_key, original_state_dict[group_key])
|
||||
continue
|
||||
output_state_dict[k] = LazyCastingParam(self, key, weight)
|
||||
return output_state_dict
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
unet_state_dict = self.model_state_dict_for_saving(self.model.diffusion_model, "diffusion_model.")
|
||||
unet_state_dict[k] = LazyCastingParam(self, key, weight)
|
||||
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
|
||||
def __del__(self):
|
||||
|
||||
@ -1376,7 +1376,6 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
|
||||
if not fp8_compute:
|
||||
disabled.add("float8_e4m3fn")
|
||||
disabled.add("float8_e5m2")
|
||||
logging.info("Native ops: {} {}".format(", ".join(QUANT_ALGOS.keys() - disabled), ", emulated ops: {}".format(", ".join(disabled)) if len(disabled) > 0 else ""))
|
||||
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)
|
||||
|
||||
if (
|
||||
|
||||
15
comfy/sd.py
15
comfy/sd.py
@ -79,7 +79,7 @@ import comfy.latent_formats
|
||||
|
||||
import comfy.ldm.flux.redux
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None):
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
@ -91,8 +91,6 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||
if lora_metadata:
|
||||
new_modelpatcher.set_attachments("lora_metadata", lora_metadata)
|
||||
else:
|
||||
k = ()
|
||||
new_modelpatcher = None
|
||||
@ -100,8 +98,6 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_
|
||||
if clip is not None:
|
||||
new_clip = clip.clone()
|
||||
k1 = new_clip.add_patches(loaded, strength_clip)
|
||||
if lora_metadata:
|
||||
new_clip.patcher.set_attachments("lora_metadata", lora_metadata)
|
||||
else:
|
||||
k1 = ()
|
||||
new_clip = None
|
||||
@ -423,13 +419,6 @@ class CLIP:
|
||||
sd_clip[k] = sd_tokenizer[k]
|
||||
return sd_clip
|
||||
|
||||
def state_dict_for_saving(self):
|
||||
sd_clip = self.patcher.model_state_dict_for_saving()
|
||||
sd_tokenizer = self.tokenizer.state_dict()
|
||||
for k in sd_tokenizer:
|
||||
sd_clip[k] = sd_tokenizer[k]
|
||||
return sd_clip
|
||||
|
||||
def load_model(self, tokens={}):
|
||||
memory_used = 0
|
||||
if hasattr(self.cond_stage_model, "memory_estimation_function"):
|
||||
@ -1915,7 +1904,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
||||
load_models = [model]
|
||||
if clip is not None:
|
||||
load_models.append(clip.load_model())
|
||||
clip_sd = clip.state_dict_for_saving()
|
||||
clip_sd = clip.get_sd()
|
||||
vae_sd = None
|
||||
if vae is not None:
|
||||
vae_sd = vae.get_sd()
|
||||
|
||||
@ -760,7 +760,7 @@ class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
|
||||
image = kwargs.get("image", None)
|
||||
if image is not None and len(images) == 0:
|
||||
images = [image[i:i + 1] for i in range(image.shape[0])]
|
||||
images = [image]
|
||||
|
||||
skip_template = False
|
||||
if text.startswith('<|im_start|>'):
|
||||
@ -771,16 +771,13 @@ class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
if skip_template:
|
||||
llama_text = text
|
||||
else:
|
||||
if llama_template is not None:
|
||||
template = llama_template
|
||||
elif len(images) == 0:
|
||||
template = self.llama_template
|
||||
if llama_template is None:
|
||||
if len(images) > 0:
|
||||
llama_text = self.llama_template_images.format(text)
|
||||
else:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
template = self.llama_template_images
|
||||
if len(images) > 1:
|
||||
vision_block = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
template = template.replace(vision_block, vision_block * len(images), 1)
|
||||
llama_text = template.format(text)
|
||||
llama_text = llama_template.format(text)
|
||||
if not thinking:
|
||||
llama_text += "<think>\n</think>\n"
|
||||
|
||||
|
||||
@ -1,101 +0,0 @@
|
||||
"""Pydantic models for BytePlus ModelArk Responses API.
|
||||
|
||||
See: https://docs.byteplus.com/en/docs/ModelArk/1585128 (request)
|
||||
https://docs.byteplus.com/en/docs/ModelArk/1783703 (response)
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BytePlusInputText(BaseModel):
|
||||
type: Literal["input_text"] = "input_text"
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class BytePlusInputImage(BaseModel):
|
||||
type: Literal["input_image"] = "input_image"
|
||||
image_url: str = Field(..., description="Image URL or `data:image/...;base64,...` payload")
|
||||
detail: str = Field("auto", description="One of high, low, auto")
|
||||
|
||||
|
||||
class BytePlusInputVideo(BaseModel):
|
||||
type: Literal["input_video"] = "input_video"
|
||||
video_url: str = Field(..., description="Video URL or `data:video/...;base64,...` payload")
|
||||
fps: float | None = Field(None, ge=0.2, le=5.0)
|
||||
|
||||
|
||||
BytePlusMessageContent = BytePlusInputText | BytePlusInputImage | BytePlusInputVideo
|
||||
|
||||
|
||||
class BytePlusInputMessage(BaseModel):
|
||||
type: Literal["message"] = "message"
|
||||
role: str = Field(..., description="One of user, system, assistant, developer")
|
||||
content: list[BytePlusMessageContent] = Field(...)
|
||||
|
||||
|
||||
class BytePlusResponseCreateRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: list[BytePlusInputMessage] = Field(...)
|
||||
instructions: str | None = Field(None)
|
||||
max_output_tokens: int | None = Field(None, ge=1)
|
||||
temperature: float | None = Field(None, ge=0.0, le=2.0)
|
||||
store: bool | None = Field(False)
|
||||
stream: bool | None = Field(False)
|
||||
|
||||
|
||||
class BytePlusOutputText(BaseModel):
|
||||
type: Literal["output_text"] = "output_text"
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class BytePlusOutputRefusal(BaseModel):
|
||||
type: Literal["refusal"] = "refusal"
|
||||
refusal: str = Field(...)
|
||||
|
||||
|
||||
class BytePlusOutputContent(BaseModel):
|
||||
type: str = Field(...)
|
||||
text: str | None = Field(None)
|
||||
refusal: str | None = Field(None)
|
||||
|
||||
|
||||
class BytePlusOutputMessage(BaseModel):
|
||||
type: str = Field(...)
|
||||
id: str | None = Field(None)
|
||||
role: str | None = Field(None)
|
||||
status: str | None = Field(None)
|
||||
content: list[BytePlusOutputContent] | None = Field(None)
|
||||
|
||||
|
||||
class BytePlusInputTokensDetails(BaseModel):
|
||||
cached_tokens: int | None = Field(None)
|
||||
|
||||
|
||||
class BytePlusOutputTokensDetails(BaseModel):
|
||||
reasoning_tokens: int | None = Field(None)
|
||||
|
||||
|
||||
class BytePlusResponseUsage(BaseModel):
|
||||
input_tokens: int | None = Field(None)
|
||||
output_tokens: int | None = Field(None)
|
||||
total_tokens: int | None = Field(None)
|
||||
input_tokens_details: BytePlusInputTokensDetails | None = Field(None)
|
||||
output_tokens_details: BytePlusOutputTokensDetails | None = Field(None)
|
||||
|
||||
|
||||
class BytePlusResponseError(BaseModel):
|
||||
code: str = Field(...)
|
||||
message: str = Field(...)
|
||||
|
||||
|
||||
class BytePlusResponseObject(BaseModel):
|
||||
id: str | None = Field(None)
|
||||
object: str | None = Field(None)
|
||||
created_at: int | None = Field(None)
|
||||
model: str | None = Field(None)
|
||||
status: str | None = Field(None)
|
||||
error: BytePlusResponseError | None = Field(None)
|
||||
output: list[BytePlusOutputMessage] | None = Field(None)
|
||||
usage: BytePlusResponseUsage | None = Field(None)
|
||||
@ -49,7 +49,7 @@ def _claude_model_inputs():
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Controls randomness. 0.0 is deterministic, 1.0 is most random. Ignored for Opus 4.7.",
|
||||
tooltip="Controls randomness. 0.0 is deterministic, 1.0 is most random.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
@ -208,7 +208,7 @@ class ClaudeNode(IO.ComfyNode):
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_label = model["model"]
|
||||
max_tokens = model["max_tokens"]
|
||||
temperature = None if model_label == "Opus 4.7" else model["temperature"]
|
||||
temperature = model["temperature"]
|
||||
|
||||
image_tensors: list[Input.Image] = [t for t in (images or {}).values() if t is not None]
|
||||
if sum(get_number_of_images(t) for t in image_tensors) > CLAUDE_MAX_IMAGES:
|
||||
|
||||
@ -1,271 +0,0 @@
|
||||
"""API Nodes for ByteDance Seed LLM via the BytePlus ModelArk Responses API.
|
||||
|
||||
See: https://docs.byteplus.com/en/docs/ModelArk/1585128
|
||||
"""
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bytedance_llm import (
|
||||
BytePlusInputImage,
|
||||
BytePlusInputMessage,
|
||||
BytePlusInputText,
|
||||
BytePlusInputVideo,
|
||||
BytePlusMessageContent,
|
||||
BytePlusResponseCreateRequest,
|
||||
BytePlusResponseObject,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
get_number_of_images,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
BYTEPLUS_RESPONSES_ENDPOINT = "/proxy/byteplus/api/v3/responses"
|
||||
SEED_MAX_IMAGES = 20
|
||||
SEED_MAX_VIDEOS = 4
|
||||
|
||||
SEED_MODELS: dict[str, str] = {
|
||||
"Seed 2.0 Pro": "seed-2-0-pro-260328",
|
||||
"Seed 2.0 Lite": "seed-2-0-lite-260228",
|
||||
"Seed 2.0 Mini": "seed-2-0-mini-260215",
|
||||
}
|
||||
|
||||
# USD per 1M tokens: (input, cache_hit_input, output)
|
||||
_SEED_PRICES_PER_MILLION: dict[str, tuple[float, float, float]] = {
|
||||
"seed-2-0-pro-260328": (0.50, 0.10, 3.00),
|
||||
"seed-2-0-lite-260228": (0.25, 0.05, 2.00),
|
||||
"seed-2-0-mini-260215": (0.10, 0.02, 0.40),
|
||||
}
|
||||
|
||||
|
||||
def _seed_model_inputs(max_images: int = SEED_MAX_IMAGES, max_videos: int = SEED_MAX_VIDEOS):
|
||||
return [
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, max_images + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional image(s) to use as context for the model. Up to {max_images} images.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"videos",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Video.Input("video"),
|
||||
names=[f"video_{i}" for i in range(1, max_videos + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional video(s) to use as context for the model. Up to {max_videos} videos.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.01,
|
||||
tooltip="Controls randomness. 0.0 is deterministic, higher values are more random.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _calculate_price(model_id: str, response: BytePlusResponseObject) -> float | None:
|
||||
"""Compute approximate USD price from response usage."""
|
||||
if not response.usage:
|
||||
return None
|
||||
rates = _SEED_PRICES_PER_MILLION.get(model_id)
|
||||
if rates is None:
|
||||
return None
|
||||
input_rate, cache_hit_rate, output_rate = rates
|
||||
input_tokens = response.usage.input_tokens or 0
|
||||
output_tokens = response.usage.output_tokens or 0
|
||||
cached = 0
|
||||
if response.usage.input_tokens_details:
|
||||
cached = response.usage.input_tokens_details.cached_tokens or 0
|
||||
fresh_input = max(0, input_tokens - cached)
|
||||
total = fresh_input * input_rate + cached * cache_hit_rate + output_tokens * output_rate
|
||||
return total / 1_000_000.0
|
||||
|
||||
|
||||
def _get_text_from_response(response: BytePlusResponseObject) -> str:
|
||||
"""Extract concatenated text from all assistant message output_text blocks."""
|
||||
if not response.output:
|
||||
return ""
|
||||
chunks: list[str] = []
|
||||
for item in response.output:
|
||||
if item.type != "message" or not item.content:
|
||||
continue
|
||||
for block in item.content:
|
||||
if block.type == "output_text" and block.text:
|
||||
chunks.append(block.text)
|
||||
elif block.type == "refusal" and block.refusal:
|
||||
raise ValueError(f"Model refused to respond: {block.refusal}")
|
||||
return "\n".join(chunks)
|
||||
|
||||
|
||||
async def _build_image_content_blocks(
|
||||
cls: type[IO.ComfyNode],
|
||||
image_tensors: list[Input.Image],
|
||||
) -> list[BytePlusInputImage]:
|
||||
urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image_tensors,
|
||||
max_images=SEED_MAX_IMAGES,
|
||||
wait_label="Uploading reference images",
|
||||
)
|
||||
return [BytePlusInputImage(image_url=url) for url in urls]
|
||||
|
||||
|
||||
async def _build_video_content_blocks(
|
||||
cls: type[IO.ComfyNode],
|
||||
videos: list[Input.Video],
|
||||
) -> list[BytePlusInputVideo]:
|
||||
blocks: list[BytePlusInputVideo] = []
|
||||
total = len(videos)
|
||||
for idx, video in enumerate(videos):
|
||||
label = "Uploading reference video"
|
||||
if total > 1:
|
||||
label = f"{label} ({idx + 1}/{total})"
|
||||
url = await upload_video_to_comfyapi(cls, video, wait_label=label)
|
||||
blocks.append(BytePlusInputVideo(video_url=url))
|
||||
return blocks
|
||||
|
||||
|
||||
class ByteDanceSeedNode(IO.ComfyNode):
|
||||
"""Generate text responses from a ByteDance Seed 2.0 model."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedNode",
|
||||
display_name="ByteDance Seed",
|
||||
category="api node/text/ByteDance",
|
||||
essentials_category="Text Generation",
|
||||
description="Generate text responses with ByteDance's Seed 2.0 models. "
|
||||
"Provide a text prompt and optionally one or more images or videos for multimodal context.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text input to the model.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[IO.DynamicCombo.Option(label, _seed_model_inputs()) for label in SEED_MODELS],
|
||||
tooltip="The Seed model used to generate the response.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Foundational instructions that dictate the model's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.String.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$contains($m, "mini") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.00025, 0.0009],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "lite") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0003, 0.002],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "pro") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0005, 0.003],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: {"type":"text", "text":"Token-based"}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_label = model["model"]
|
||||
temperature = model["temperature"]
|
||||
model_id = SEED_MODELS[model_label]
|
||||
|
||||
image_tensors: list[Input.Image] = [t for t in (model.get("images") or {}).values() if t is not None]
|
||||
if sum(get_number_of_images(t) for t in image_tensors) > SEED_MAX_IMAGES:
|
||||
raise ValueError(f"Up to {SEED_MAX_IMAGES} images are supported per request.")
|
||||
|
||||
video_inputs: list[Input.Video] = [v for v in (model.get("videos") or {}).values() if v is not None]
|
||||
if len(video_inputs) > SEED_MAX_VIDEOS:
|
||||
raise ValueError(f"Up to {SEED_MAX_VIDEOS} videos are supported per request.")
|
||||
|
||||
content: list[BytePlusMessageContent] = []
|
||||
if image_tensors:
|
||||
content.extend(await _build_image_content_blocks(cls, image_tensors))
|
||||
if video_inputs:
|
||||
content.extend(await _build_video_content_blocks(cls, video_inputs))
|
||||
content.append(BytePlusInputText(text=prompt))
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_RESPONSES_ENDPOINT, method="POST"),
|
||||
response_model=BytePlusResponseObject,
|
||||
data=BytePlusResponseCreateRequest(
|
||||
model=model_id,
|
||||
input=[BytePlusInputMessage(role="user", content=content)],
|
||||
instructions=system_prompt or None,
|
||||
temperature=temperature,
|
||||
store=False,
|
||||
stream=False,
|
||||
),
|
||||
price_extractor=lambda r: _calculate_price(model_id, r),
|
||||
)
|
||||
if response.error:
|
||||
raise ValueError(f"Seed API error ({response.error.code}): {response.error.message}")
|
||||
result = _get_text_from_response(response)
|
||||
if not result:
|
||||
raise ValueError("Empty response from Seed model.")
|
||||
return IO.NodeOutput(result)
|
||||
|
||||
|
||||
class ByteDanceLLMExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [ByteDanceSeedNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ByteDanceLLMExtension:
|
||||
return ByteDanceLLMExtension()
|
||||
@ -14,49 +14,6 @@ from typing_extensions import override
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
ICLoRAParameters = io.Custom("IC_LORA_PARAMETERS")
|
||||
|
||||
|
||||
class GetICLoRAParameters(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="GetICLoRAParameters",
|
||||
display_name="Get IC-LoRA Parameters",
|
||||
description="Extracts IC-LoRA parameters from the safetensors metadata of a LoRA-loaded "
|
||||
"model and outputs them for LTXVAddGuide (eg. reference_downscale_factor).",
|
||||
category="conditioning/video_models",
|
||||
search_aliases=["ic-lora", "ic lora", "iclora", "downscale factor", "reference downscale"],
|
||||
inputs=[
|
||||
io.Model.Input(
|
||||
"iclora_model",
|
||||
tooltip="Direct output from a LoRA Loader for the specific IC-LoRA "
|
||||
"from which to extract the metadata.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
ICLoRAParameters.Output(
|
||||
"iclora_parameters",
|
||||
tooltip="IC-LoRA parameters extracted from the LoRA metadata "
|
||||
"(eg. reference_downscale_factor). Connect to LTXVAddGuide "
|
||||
"if the LoRA requires special handling of the guides.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, iclora_model) -> io.NodeOutput:
|
||||
metadata = iclora_model.get_attachment("lora_metadata")
|
||||
factor = 1
|
||||
if metadata:
|
||||
try:
|
||||
factor = max(1, round(float(metadata.get("reference_downscale_factor", 1))))
|
||||
except (TypeError, ValueError):
|
||||
factor = 1
|
||||
parameters = {"reference_downscale_factor": factor}
|
||||
return io.NodeOutput(parameters)
|
||||
|
||||
|
||||
class EmptyLTXVLatentVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -175,7 +132,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0, attention_mask=None):
|
||||
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
|
||||
"""Append a guide_attention_entry to both positive and negative conditioning.
|
||||
|
||||
Each entry tracks one guide reference for per-reference attention control.
|
||||
@ -184,10 +141,9 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s
|
||||
new_entry = {
|
||||
"pre_filter_count": pre_filter_count,
|
||||
"strength": strength,
|
||||
"pixel_mask": attention_mask.unsqueeze(0).unsqueeze(0) if attention_mask is not None else None, # reshape to (1, 1, F, H, W)
|
||||
"pixel_mask": None,
|
||||
"latent_shape": latent_shape,
|
||||
}
|
||||
|
||||
results = []
|
||||
for cond in (positive, negative):
|
||||
# Read existing entries from this specific conditioning
|
||||
@ -197,7 +153,8 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s
|
||||
if found is not None:
|
||||
existing = found
|
||||
break
|
||||
# Shallow copy only and append (pixel_mask is never mutated).
|
||||
# Shallow copy and append (no deepcopy needed — entries contain
|
||||
# only scalars and None for pixel_mask at this call site).
|
||||
entries = [*existing, new_entry]
|
||||
results.append(node_helpers.conditioning_set_values(
|
||||
cond, {"guide_attention_entries": entries}
|
||||
@ -263,20 +220,6 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
|
||||
),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Mask.Input(
|
||||
"attention_mask",
|
||||
optional=True,
|
||||
tooltip="Optional pixel-space spatial mask. Controls per-region "
|
||||
"conditioning influence via self-attention, multiplied by strength.",
|
||||
),
|
||||
ICLoRAParameters.Input(
|
||||
"iclora_parameters",
|
||||
optional=True,
|
||||
tooltip="Optional IC-LoRA parameters from a Get IC-LoRA Parameters node. "
|
||||
"Used for adjusting guide processing as required by certain IC-LoRAs "
|
||||
"(eg. those with a reference_downscale_factor > 1). "
|
||||
"When chained, each LTXVAddGuide uses only the parameters connected to it.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
@ -286,41 +229,14 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def encode(cls, vae, latent_width, latent_height, images, scale_factors, latent_downscale_factor=1):
|
||||
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
|
||||
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
|
||||
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
|
||||
target_width = int(latent_width * width_scale_factor / latent_downscale_factor)
|
||||
target_height = int(latent_height * height_scale_factor / latent_downscale_factor)
|
||||
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), target_width, target_height, "bilinear", crop="center").movedim(1, -1)
|
||||
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1)
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
return encode_pixels, t
|
||||
|
||||
@classmethod
|
||||
def dilate_latent(cls, guide_latent, latent_downscale_factor):
|
||||
if latent_downscale_factor <= 1:
|
||||
return guide_latent, None
|
||||
scale = int(latent_downscale_factor)
|
||||
dilated_shape = guide_latent.shape[:3] + (guide_latent.shape[3] * scale, guide_latent.shape[4] * scale)
|
||||
dilated = torch.zeros(dilated_shape, device=guide_latent.device, dtype=guide_latent.dtype)
|
||||
dilated[..., ::scale, ::scale] = guide_latent
|
||||
dilated_mask = torch.full(
|
||||
(dilated.shape[0], 1, dilated.shape[2], dilated.shape[3], dilated.shape[4]),
|
||||
-1.0, device=guide_latent.device, dtype=guide_latent.dtype,
|
||||
)
|
||||
dilated_mask[..., ::scale, ::scale] = 1.0
|
||||
return dilated, dilated_mask
|
||||
|
||||
@classmethod
|
||||
def get_reference_downscale_factor(cls, iclora_parameters):
|
||||
if not iclora_parameters:
|
||||
return 1
|
||||
try:
|
||||
factor = max(1, round(float(iclora_parameters.get("reference_downscale_factor", 1))))
|
||||
except (TypeError, ValueError):
|
||||
factor = 1
|
||||
return factor
|
||||
|
||||
@classmethod
|
||||
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
|
||||
time_scale_factor, _, _ = scale_factors
|
||||
@ -416,21 +332,13 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return latent_image, noise_mask
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, attention_mask=None, iclora_parameters=None) -> io.NodeOutput:
|
||||
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput:
|
||||
scale_factors = vae.downscale_index_formula
|
||||
latent_image = latent["samples"]
|
||||
noise_mask = get_noise_mask(latent)
|
||||
|
||||
_, _, latent_length, latent_height, latent_width = latent_image.shape
|
||||
|
||||
latent_downscale_factor = cls.get_reference_downscale_factor(iclora_parameters)
|
||||
if latent_downscale_factor > 1:
|
||||
if latent_width % latent_downscale_factor != 0 or latent_height % latent_downscale_factor != 0:
|
||||
raise ValueError(
|
||||
f"Latent spatial size {latent_width}x{latent_height} must be divisible by "
|
||||
f"reference_downscale_factor {latent_downscale_factor} from the IC-LoRA parameters."
|
||||
)
|
||||
|
||||
# For mid-video multi-frame guides, prepend+strip a throwaway first frame so the VAE's "first latent = 1 pixel frame" asymmetry lands on the discarded slot
|
||||
time_scale_factor = scale_factors[0]
|
||||
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
|
||||
@ -443,17 +351,12 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
if not causal_fix:
|
||||
image = torch.cat([image[:1], image], dim=0)
|
||||
|
||||
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors, latent_downscale_factor)
|
||||
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors)
|
||||
|
||||
if not causal_fix:
|
||||
t = t[:, :, 1:, :, :]
|
||||
image = image[1:]
|
||||
|
||||
guide_latent_shape = list(t.shape[2:]) # pre-dilation [F, H, W] for spatial-mask downsampling
|
||||
guide_mask = None
|
||||
if latent_downscale_factor > 1:
|
||||
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
|
||||
|
||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
@ -466,16 +369,14 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
t,
|
||||
strength,
|
||||
scale_factors,
|
||||
guide_mask=guide_mask,
|
||||
latent_downscale_factor=latent_downscale_factor,
|
||||
causal_fix=causal_fix,
|
||||
)
|
||||
|
||||
# Track this guide for per-reference attention control.
|
||||
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
|
||||
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
|
||||
positive, negative = _append_guide_attention_entry(
|
||||
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
@ -893,7 +794,6 @@ class LtxvExtension(ComfyExtension):
|
||||
ModelSamplingLTXV,
|
||||
LTXVConditioning,
|
||||
LTXVScheduler,
|
||||
GetICLoRAParameters,
|
||||
LTXVAddGuide,
|
||||
LTXVPreprocess,
|
||||
LTXVCropGuides,
|
||||
|
||||
@ -330,7 +330,7 @@ class FeatherMask(IO.ComfyNode):
|
||||
|
||||
for x in range(right):
|
||||
feather_rate = (x + 1) / right
|
||||
output[:, :, -(x + 1)] *= feather_rate
|
||||
output[:, :, -x] *= feather_rate
|
||||
|
||||
for y in range(top):
|
||||
feather_rate = (y + 1) / top
|
||||
@ -338,7 +338,7 @@ class FeatherMask(IO.ComfyNode):
|
||||
|
||||
for y in range(bottom):
|
||||
feather_rate = (y + 1) / bottom
|
||||
output[:, -(y + 1), :] *= feather_rate
|
||||
output[:, -y, :] *= feather_rate
|
||||
|
||||
return IO.NodeOutput(output)
|
||||
|
||||
|
||||
@ -276,8 +276,8 @@ class CLIPSave:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
clip.load_model()
|
||||
clip_sd = clip.state_dict_for_saving()
|
||||
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
|
||||
clip_sd = clip.get_sd()
|
||||
|
||||
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
|
||||
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
|
||||
|
||||
@ -568,7 +568,7 @@ def batch_latents(latents: list[dict[str, torch.Tensor]]) -> dict[str, torch.Ten
|
||||
class BatchImagesNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Image.Input("image"), prefix="image", min=1, max=50)
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Image.Input("image"), prefix="image", min=2, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchImagesNode",
|
||||
display_name="Batch Images",
|
||||
@ -590,7 +590,7 @@ class BatchImagesNode(io.ComfyNode):
|
||||
class BatchMasksNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=1, max=50)
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchMasksNode",
|
||||
search_aliases=["combine masks", "stack masks", "merge masks"],
|
||||
@ -611,7 +611,7 @@ class BatchMasksNode(io.ComfyNode):
|
||||
class BatchLatentsNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=1, max=50)
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchLatentsNode",
|
||||
search_aliases=["combine latents", "stack latents", "merge latents"],
|
||||
|
||||
@ -626,7 +626,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
|
||||
if comfy.model_management.is_oom(ex):
|
||||
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
||||
logging.info("Memory summary:\n{}".format(comfy.model_management.debug_memory_summary()))
|
||||
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
||||
logging.error("Got an OOM, unloading all loaded models.")
|
||||
comfy.model_management.unload_all_models()
|
||||
elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type:
|
||||
|
||||
9
main.py
9
main.py
@ -27,6 +27,7 @@ from utils.mime_types import init_mime_types
|
||||
import faulthandler
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_api import feature_flags
|
||||
@ -148,6 +149,14 @@ def execute_prestartup_script():
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
||||
from nodes import record_node_startup_error
|
||||
record_node_startup_error(
|
||||
module_path=os.path.dirname(script_path),
|
||||
source="custom_nodes",
|
||||
phase="prestartup",
|
||||
error=e,
|
||||
tb=traceback.format_exc(),
|
||||
)
|
||||
return False
|
||||
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
|
||||
91
nodes.py
91
nodes.py
@ -700,19 +700,17 @@ class LoraLoader:
|
||||
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
||||
lora = None
|
||||
lora_metadata = None
|
||||
if self.loaded_lora is not None:
|
||||
if self.loaded_lora[0] == lora_path:
|
||||
lora = self.loaded_lora[1]
|
||||
lora_metadata = self.loaded_lora[2] if len(self.loaded_lora) > 2 else None
|
||||
else:
|
||||
self.loaded_lora = None
|
||||
|
||||
if lora is None:
|
||||
lora, lora_metadata = comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True)
|
||||
self.loaded_lora = (lora_path, lora, lora_metadata)
|
||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
self.loaded_lora = (lora_path, lora)
|
||||
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=lora_metadata)
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
||||
return (model_lora, clip_lora)
|
||||
|
||||
class LoraLoaderModelOnly(LoraLoader):
|
||||
@ -2156,6 +2154,71 @@ EXTENSION_WEB_DIRS = {}
|
||||
# Dictionary of successfully loaded module names and associated directories.
|
||||
LOADED_MODULE_DIRS = {}
|
||||
|
||||
# Dictionary of custom node startup errors, keyed by "<source>:<module_name>"
|
||||
# so that name collisions across custom_nodes / comfy_extras / comfy_api_nodes
|
||||
# do not overwrite each other. Each value contains: source, module_name,
|
||||
# module_path, error, traceback, phase.
|
||||
#
|
||||
# `source` is the same string as the internal `module_parent` used at load
|
||||
# time (e.g. "custom_nodes", "comfy_extras", "comfy_api_nodes"). It is
|
||||
# intentionally a free-form string rather than a fixed enum so the contract
|
||||
# survives node-source layouts evolving (e.g. comfy_api_nodes eventually
|
||||
# moving out of core). Consumers should treat any new value as a new bucket
|
||||
# rather than rejecting it.
|
||||
NODE_STARTUP_ERRORS: dict[str, dict] = {}
|
||||
|
||||
|
||||
def _read_pyproject_metadata(module_path: str) -> dict | None:
|
||||
"""Best-effort extraction of node-pack identity from pyproject.toml.
|
||||
|
||||
Returns a dict with the Comfy Registry-style identity (pack_id,
|
||||
display_name, publisher_id, version, repository) when the module
|
||||
directory contains a pyproject.toml. Returns None when no toml is
|
||||
present or parsing fails for any reason — startup-error tracking
|
||||
must never itself raise.
|
||||
"""
|
||||
if not module_path or not os.path.isdir(module_path):
|
||||
return None
|
||||
toml_path = os.path.join(module_path, "pyproject.toml")
|
||||
if not os.path.isfile(toml_path):
|
||||
return None
|
||||
try:
|
||||
from comfy_config import config_parser
|
||||
|
||||
cfg = config_parser.extract_node_configuration(module_path)
|
||||
if cfg is None:
|
||||
return None
|
||||
meta = {
|
||||
"pack_id": cfg.project.name or None,
|
||||
"display_name": cfg.tool_comfy.display_name or None,
|
||||
"publisher_id": cfg.tool_comfy.publisher_id or None,
|
||||
"version": cfg.project.version or None,
|
||||
"repository": cfg.project.urls.repository or None,
|
||||
}
|
||||
# Drop empty fields so the API payload stays compact.
|
||||
return {k: v for k, v in meta.items() if v}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def record_node_startup_error(
|
||||
*, module_path: str, source: str, phase: str, error: BaseException, tb: str
|
||||
) -> None:
|
||||
"""Record a startup error for a node module so it can be exposed via the API."""
|
||||
module_name = get_module_name(module_path)
|
||||
entry = {
|
||||
"source": source,
|
||||
"module_name": module_name,
|
||||
"module_path": module_path,
|
||||
"error": str(error),
|
||||
"traceback": tb,
|
||||
"phase": phase,
|
||||
}
|
||||
pyproject = _read_pyproject_metadata(module_path)
|
||||
if pyproject:
|
||||
entry["pyproject"] = pyproject
|
||||
NODE_STARTUP_ERRORS[f"{source}:{module_name}"] = entry
|
||||
|
||||
|
||||
def get_module_name(module_path: str) -> str:
|
||||
"""
|
||||
@ -2265,14 +2328,30 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
|
||||
return True
|
||||
except Exception as e:
|
||||
tb = traceback.format_exc()
|
||||
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
|
||||
record_node_startup_error(
|
||||
module_path=module_path,
|
||||
source=module_parent,
|
||||
phase="entrypoint",
|
||||
error=e,
|
||||
tb=tb,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or comfy_entrypoint (need one).")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.warning(traceback.format_exc())
|
||||
tb = traceback.format_exc()
|
||||
logging.warning(tb)
|
||||
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
|
||||
record_node_startup_error(
|
||||
module_path=module_path,
|
||||
source=module_parent,
|
||||
phase="import",
|
||||
error=e,
|
||||
tb=tb,
|
||||
)
|
||||
return False
|
||||
|
||||
async def init_external_custom_nodes():
|
||||
|
||||
12
openapi.yaml
12
openapi.yaml
@ -6329,16 +6329,6 @@ components:
|
||||
name:
|
||||
type: string
|
||||
description: Name of the asset file
|
||||
file_path:
|
||||
type: string
|
||||
nullable: true
|
||||
x-runtime: [cloud, local]
|
||||
description: "Logical asset locator under the namespace root. Not a unique reference key; use `id` for identity."
|
||||
display_name:
|
||||
type: string
|
||||
nullable: true
|
||||
x-runtime: [cloud, local]
|
||||
description: "Human-facing display label for the asset. Not a unique reference key; use `id` for identity."
|
||||
hash:
|
||||
type: string
|
||||
nullable: true
|
||||
@ -8119,4 +8109,4 @@ components:
|
||||
items:
|
||||
$ref: "#/components/schemas/TaskEntry"
|
||||
pagination:
|
||||
$ref: "#/components/schemas/PaginationInfo"
|
||||
$ref: "#/components/schemas/PaginationInfo"
|
||||
20
server.py
20
server.py
@ -765,6 +765,26 @@ class PromptServer():
|
||||
out[node_class] = node_info(node_class)
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/node_startup_errors")
|
||||
async def get_node_startup_errors(request):
|
||||
# Group errors by source so the frontend/Manager can render them
|
||||
# in distinct sections. `source` is the same string as the
|
||||
# module_parent used at load time (e.g. "custom_nodes",
|
||||
# "comfy_extras", "comfy_api_nodes") and is left as a free-form
|
||||
# string so the contract survives node-source layouts evolving.
|
||||
# The response only contains source buckets that actually had a
|
||||
# failure; consumers should not assume any particular set of keys
|
||||
# is always present.
|
||||
#
|
||||
# `module_path` is stripped because the absolute on-disk path is
|
||||
# internal detail that the frontend has no use for.
|
||||
grouped: dict[str, dict[str, dict]] = {}
|
||||
for entry in nodes.NODE_STARTUP_ERRORS.values():
|
||||
source = entry.get("source", "custom_nodes")
|
||||
public_entry = {k: v for k, v in entry.items() if k != "module_path"}
|
||||
grouped.setdefault(source, {})[entry["module_name"]] = public_entry
|
||||
return web.json_response(grouped)
|
||||
|
||||
@routes.get("/api/jobs")
|
||||
async def get_jobs(request):
|
||||
"""List all jobs with filtering, sorting, and pagination.
|
||||
|
||||
@ -6,11 +6,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.services.path_utils import (
|
||||
compute_display_name,
|
||||
compute_file_path,
|
||||
get_asset_category_and_relative_path,
|
||||
)
|
||||
from app.assets.services.path_utils import get_asset_category_and_relative_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -83,27 +79,3 @@ class TestGetAssetCategoryAndRelativePath:
|
||||
def test_unknown_path_raises(self, fake_dirs):
|
||||
with pytest.raises(ValueError, match="not within"):
|
||||
get_asset_category_and_relative_path("/some/random/path.png")
|
||||
|
||||
|
||||
class TestResponsePaths:
|
||||
def test_input_file_path_and_display_name_include_subfolder(self, fake_dirs):
|
||||
sub = fake_dirs["input"] / "some" / "folder"
|
||||
sub.mkdir(parents=True)
|
||||
f = sub / "image.png"
|
||||
f.touch()
|
||||
|
||||
assert compute_file_path(str(f)) == "input/some/folder/image.png"
|
||||
assert compute_display_name(str(f)) == "some/folder/image.png"
|
||||
|
||||
def test_model_file_path_includes_bucket_display_name_drops_it(self, fake_dirs):
|
||||
sub = fake_dirs["models"] / "flux"
|
||||
sub.mkdir()
|
||||
f = sub / "model.safetensors"
|
||||
f.touch()
|
||||
|
||||
assert compute_file_path(str(f)) == "models/checkpoints/flux/model.safetensors"
|
||||
assert compute_display_name(str(f)) == "flux/model.safetensors"
|
||||
|
||||
def test_unknown_path_returns_none(self, fake_dirs):
|
||||
assert compute_file_path("/some/random/path.png") is None
|
||||
assert compute_display_name("/some/random/path.png") is None
|
||||
|
||||
@ -5,8 +5,6 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
import requests
|
||||
import pytest
|
||||
|
||||
from helpers import get_asset_filename
|
||||
|
||||
|
||||
def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes):
|
||||
name = "dup_a.safetensors"
|
||||
@ -65,14 +63,6 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_
|
||||
assert r2.status_code == 200, b2 # fast path returns 200 with created_new == False
|
||||
assert b2["created_new"] is False
|
||||
assert b2["asset_hash"] == h
|
||||
assert b2.get("file_path") is None
|
||||
assert b2.get("display_name") is None
|
||||
|
||||
rg = http.get(f"{api_base}/api/assets/{b2['id']}", timeout=120)
|
||||
detail = rg.json()
|
||||
assert rg.status_code == 200, detail
|
||||
assert detail.get("file_path") is None
|
||||
assert detail.get("display_name") is None
|
||||
|
||||
|
||||
def test_upload_fastpath_with_known_hash_and_file(
|
||||
@ -117,54 +107,6 @@ def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base
|
||||
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tags", "extension", "expected_prefix", "expected_display_prefix"),
|
||||
[
|
||||
(["input", "unit-tests"], ".png", "input", ""),
|
||||
(["models", "checkpoints", "unit-tests"], ".safetensors", "models/checkpoints", ""),
|
||||
],
|
||||
)
|
||||
def test_upload_response_includes_file_path_and_display_name(
|
||||
tags: list[str],
|
||||
extension: str,
|
||||
expected_prefix: str,
|
||||
expected_display_prefix: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
scope = f"response-paths-{uuid.uuid4().hex[:6]}"
|
||||
scoped_tags = [*tags, scope]
|
||||
name = f"asset_response_path{extension}"
|
||||
|
||||
created = asset_factory(name, scoped_tags, {}, make_asset_bytes(name, 1024))
|
||||
stored_filename = get_asset_filename(created["asset_hash"], extension)
|
||||
expected_suffix = f"unit-tests/{scope}/{stored_filename}"
|
||||
expected_file_path = f"{expected_prefix}/{expected_suffix}"
|
||||
expected_display_name = f"{expected_display_prefix}{expected_suffix}"
|
||||
|
||||
assert created["file_path"] == expected_file_path
|
||||
assert created["display_name"] == expected_display_name
|
||||
|
||||
detail_r = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
|
||||
detail = detail_r.json()
|
||||
assert detail_r.status_code == 200, detail
|
||||
assert detail["file_path"] == expected_file_path
|
||||
assert detail["display_name"] == expected_display_name
|
||||
|
||||
list_r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "limit": "50"},
|
||||
timeout=120,
|
||||
)
|
||||
listed = list_r.json()
|
||||
assert list_r.status_code == 200, listed
|
||||
match = next(a for a in listed["assets"] if a["id"] == created["id"])
|
||||
assert match["file_path"] == expected_file_path
|
||||
assert match["display_name"] == expected_display_name
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_concurrent_upload_identical_bytes_different_names(
|
||||
root: str,
|
||||
|
||||
146
tests-unit/node_startup_errors_test.py
Normal file
146
tests-unit/node_startup_errors_test.py
Normal file
@ -0,0 +1,146 @@
|
||||
"""Tests for the custom node startup error tracking introduced for
|
||||
Comfy-Org/ComfyUI-Launcher#303.
|
||||
|
||||
Covers:
|
||||
- load_custom_node populates NODE_STARTUP_ERRORS with the correct source
|
||||
for each module_parent (custom_nodes / comfy_extras / comfy_api_nodes).
|
||||
- Composite keying prevents collisions between modules with the same name
|
||||
in different sources.
|
||||
- record_node_startup_error stores the expected fields.
|
||||
- pyproject.toml metadata is attached when present and omitted when absent.
|
||||
"""
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
|
||||
import nodes
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_startup_errors():
|
||||
nodes.NODE_STARTUP_ERRORS.clear()
|
||||
yield
|
||||
nodes.NODE_STARTUP_ERRORS.clear()
|
||||
|
||||
|
||||
def _write_broken_module(tmp_path, name: str) -> str:
|
||||
path = tmp_path / f"{name}.py"
|
||||
path.write_text(textwrap.dedent("""\
|
||||
# Deliberately broken module to exercise startup-error tracking.
|
||||
raise RuntimeError("boom from " + __name__)
|
||||
"""))
|
||||
return str(path)
|
||||
|
||||
|
||||
def test_record_node_startup_error_fields(tmp_path):
|
||||
err = ValueError("kaboom")
|
||||
nodes.record_node_startup_error(
|
||||
module_path=str(tmp_path / "my_pack"),
|
||||
source="custom_nodes",
|
||||
phase="import",
|
||||
error=err,
|
||||
tb="traceback-text",
|
||||
)
|
||||
assert "custom_nodes:my_pack" in nodes.NODE_STARTUP_ERRORS
|
||||
entry = nodes.NODE_STARTUP_ERRORS["custom_nodes:my_pack"]
|
||||
assert entry["source"] == "custom_nodes"
|
||||
assert entry["module_name"] == "my_pack"
|
||||
assert entry["phase"] == "import"
|
||||
assert entry["error"] == "kaboom"
|
||||
assert entry["traceback"] == "traceback-text"
|
||||
assert entry["module_path"].endswith("my_pack")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"module_parent",
|
||||
["custom_nodes", "comfy_extras", "comfy_api_nodes"],
|
||||
)
|
||||
async def test_load_custom_node_records_source(tmp_path, module_parent):
|
||||
# `source` in the entry should be the same string as `module_parent`.
|
||||
module_path = _write_broken_module(tmp_path, "broken_pack")
|
||||
|
||||
success = await nodes.load_custom_node(module_path, module_parent=module_parent)
|
||||
assert success is False
|
||||
|
||||
key = f"{module_parent}:broken_pack"
|
||||
assert key in nodes.NODE_STARTUP_ERRORS, nodes.NODE_STARTUP_ERRORS
|
||||
entry = nodes.NODE_STARTUP_ERRORS[key]
|
||||
assert entry["source"] == module_parent
|
||||
assert entry["module_name"] == "broken_pack"
|
||||
assert entry["phase"] == "import"
|
||||
assert "boom from" in entry["error"]
|
||||
assert "RuntimeError" in entry["traceback"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_node_collision_across_sources(tmp_path):
|
||||
# Same module name registered as both a custom node and a comfy_extra;
|
||||
# composite keying should keep both entries.
|
||||
cn_dir = tmp_path / "cn"
|
||||
extras_dir = tmp_path / "extras"
|
||||
cn_dir.mkdir()
|
||||
extras_dir.mkdir()
|
||||
cn_path = _write_broken_module(cn_dir, "nodes_audio")
|
||||
extras_path = _write_broken_module(extras_dir, "nodes_audio")
|
||||
|
||||
assert await nodes.load_custom_node(cn_path, module_parent="custom_nodes") is False
|
||||
assert await nodes.load_custom_node(extras_path, module_parent="comfy_extras") is False
|
||||
|
||||
assert "custom_nodes:nodes_audio" in nodes.NODE_STARTUP_ERRORS
|
||||
assert "comfy_extras:nodes_audio" in nodes.NODE_STARTUP_ERRORS
|
||||
assert (
|
||||
nodes.NODE_STARTUP_ERRORS["custom_nodes:nodes_audio"]["module_path"]
|
||||
!= nodes.NODE_STARTUP_ERRORS["comfy_extras:nodes_audio"]["module_path"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_node_attaches_pyproject_metadata(tmp_path):
|
||||
pack_dir = tmp_path / "MyCoolPack"
|
||||
pack_dir.mkdir()
|
||||
(pack_dir / "__init__.py").write_text("raise RuntimeError('boom')\n")
|
||||
(pack_dir / "pyproject.toml").write_text(textwrap.dedent("""\
|
||||
[project]
|
||||
name = "comfyui-mycoolpack"
|
||||
version = "1.2.3"
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://github.com/example/comfyui-mycoolpack"
|
||||
|
||||
[tool.comfy]
|
||||
PublisherId = "example"
|
||||
DisplayName = "My Cool Pack"
|
||||
"""))
|
||||
|
||||
success = await nodes.load_custom_node(str(pack_dir), module_parent="custom_nodes")
|
||||
assert success is False
|
||||
|
||||
entry = nodes.NODE_STARTUP_ERRORS["custom_nodes:MyCoolPack"]
|
||||
assert "pyproject" in entry, entry
|
||||
py = entry["pyproject"]
|
||||
assert py["pack_id"] == "comfyui-mycoolpack"
|
||||
assert py["display_name"] == "My Cool Pack"
|
||||
assert py["publisher_id"] == "example"
|
||||
assert py["version"] == "1.2.3"
|
||||
assert py["repository"] == "https://github.com/example/comfyui-mycoolpack"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_node_no_pyproject_skips_metadata(tmp_path):
|
||||
# Single-file extras-style module: no pyproject.toml exists alongside it,
|
||||
# so the entry must not contain a 'pyproject' key.
|
||||
module_path = _write_broken_module(tmp_path, "lonely")
|
||||
assert await nodes.load_custom_node(module_path, module_parent="comfy_extras") is False
|
||||
entry = nodes.NODE_STARTUP_ERRORS["comfy_extras:lonely"]
|
||||
assert "pyproject" not in entry
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_node_arbitrary_module_parent_passes_through(tmp_path):
|
||||
# `source` is a free-form string — an unknown module_parent (e.g. a future
|
||||
# node-source bucket) should be recorded as-is, not coerced or rejected.
|
||||
module_path = _write_broken_module(tmp_path, "future_pack")
|
||||
assert await nodes.load_custom_node(module_path, module_parent="future_source") is False
|
||||
entry = nodes.NODE_STARTUP_ERRORS["future_source:future_pack"]
|
||||
assert entry["source"] == "future_source"
|
||||
Reference in New Issue
Block a user