mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-07 12:16:00 +08:00
Compare commits
28 Commits
v0.12.0
...
jk/node-re
| Author | SHA1 | Date | |
|---|---|---|---|
| 739ed21714 | |||
| a2d4c0f98b | |||
| d5b3da823d | |||
| 8bbd8f7d65 | |||
| d6b217a7f8 | |||
| 04f89c75d1 | |||
| 588bc6b257 | |||
| c9dbe13c0c | |||
| 7024486e37 | |||
| a50c32d63f | |||
| 6125b80979 | |||
| c8fcbd66ee | |||
| 26dd7eb421 | |||
| e77b34dfea | |||
| ef73070ea4 | |||
| d30c609f5a | |||
| 5087f1d497 | |||
| a31681564d | |||
| 855849c658 | |||
| fe2511468d | |||
| 3be0175166 | |||
| b8315e66cb | |||
| ab1050bec3 | |||
| fb23935c11 | |||
| 85fc35e8fa | |||
| 223364743c | |||
| affe881354 | |||
| f5030e26fd |
38
app/node_replace_manager.py
Normal file
38
app/node_replace_manager.py
Normal file
@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.latest._node_replace import NodeReplace
|
||||
|
||||
|
||||
class NodeReplaceManager:
|
||||
"""Manages node replacement registrations."""
|
||||
|
||||
def __init__(self):
|
||||
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||
|
||||
def register(self, node_replace: NodeReplace):
|
||||
"""Register a node replacement mapping."""
|
||||
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||
|
||||
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||
"""Get replacements for an old node ID."""
|
||||
return self._replacements.get(old_node_id)
|
||||
|
||||
def has_replacement(self, old_node_id: str) -> bool:
|
||||
"""Check if a replacement exists for an old node ID."""
|
||||
return old_node_id in self._replacements
|
||||
|
||||
def as_dict(self):
|
||||
"""Serialize all replacements to dict."""
|
||||
return {
|
||||
k: [v.as_dict() for v in v_list]
|
||||
for k, v_list in self._replacements.items()
|
||||
}
|
||||
|
||||
def add_routes(self, routes):
|
||||
@routes.get("/node_replacements")
|
||||
async def get_node_replacements(request):
|
||||
return web.json_response(self.as_dict())
|
||||
@ -183,7 +183,7 @@ class AceStepAttention(nn.Module):
|
||||
else:
|
||||
attn_bias = window_bias
|
||||
|
||||
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True)
|
||||
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
@ -1035,8 +1035,7 @@ class AceStepConditionGenerationModel(nn.Module):
|
||||
audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847)
|
||||
lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype)
|
||||
else:
|
||||
assert False
|
||||
# TODO ?
|
||||
lm_hints_5Hz, indices = self.tokenizer.tokenize(refer_audio_acoustic_hidden_states_packed)
|
||||
|
||||
lm_hints = self.detokenizer(lm_hints_5Hz)
|
||||
|
||||
|
||||
@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if kwargs.get("low_precision_attention", True) is False:
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
exception_fallback = False
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
|
||||
@ -332,6 +332,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["{}".format(key_lora)] = k
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, comfy.model_base.ACEStep15):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model.decoder."):-len(".weight")]
|
||||
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@ -1548,6 +1548,7 @@ class ACEStep15(BaseModel):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
device = kwargs["device"]
|
||||
noise = kwargs["noise"]
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
@ -1571,15 +1572,19 @@ class ACEStep15(BaseModel):
|
||||
1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
|
||||
-8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
|
||||
-5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
|
||||
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750)
|
||||
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, noise.shape[2])
|
||||
pass_audio_codes = True
|
||||
else:
|
||||
refer_audio = refer_audio[-1]
|
||||
refer_audio = refer_audio[-1][:, :, :noise.shape[2]]
|
||||
pass_audio_codes = False
|
||||
|
||||
if pass_audio_codes:
|
||||
audio_codes = kwargs.get("audio_codes", None)
|
||||
if audio_codes is not None:
|
||||
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
|
||||
refer_audio = refer_audio[:, :, :750]
|
||||
|
||||
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
|
||||
|
||||
audio_codes = kwargs.get("audio_codes", None)
|
||||
if audio_codes is not None:
|
||||
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
|
||||
|
||||
return out
|
||||
|
||||
class Omnigen2(BaseModel):
|
||||
|
||||
@ -1724,11 +1724,9 @@ def soft_empty_cache(force=False):
|
||||
elif is_mlu():
|
||||
torch.mlu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
@ -1400,7 +1400,7 @@ class ModelPatcher:
|
||||
continue
|
||||
key = "diffusion_model." + k
|
||||
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
|
||||
return self.model.state_dict_for_saving(unet_state_dict)
|
||||
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
|
||||
def __del__(self):
|
||||
self.unpin_all_weights()
|
||||
|
||||
@ -54,6 +54,8 @@ try:
|
||||
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
if q.nelement() < 1024 * 128: # arbitrary number, for small inputs cudnn attention seems slower
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
else:
|
||||
|
||||
15
comfy/sd.py
15
comfy/sd.py
@ -554,6 +554,8 @@ class VAE:
|
||||
elif "decoder.layers.1.layers.0.beta" in sd:
|
||||
config = {}
|
||||
param_key = None
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
if "decoder.layers.2.layers.1.weight_v" in sd:
|
||||
param_key = "decoder.layers.2.layers.1.weight_v"
|
||||
if "decoder.layers.2.layers.1.parametrizations.weight.original1" in sd:
|
||||
@ -562,6 +564,8 @@ class VAE:
|
||||
if sd[param_key].shape[-1] == 12:
|
||||
config["strides"] = [2, 4, 4, 6, 10]
|
||||
self.audio_sample_rate = 48000
|
||||
self.upscale_ratio = 1920
|
||||
self.downscale_ratio = 1920
|
||||
|
||||
self.first_stage_model = AudioOobleckVAE(**config)
|
||||
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
|
||||
@ -569,8 +573,6 @@ class VAE:
|
||||
self.latent_channels = 64
|
||||
self.output_channels = 2
|
||||
self.pad_channel_value = "replicate"
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
self.latent_dim = 1
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
@ -870,7 +872,7 @@ class VAE:
|
||||
/ 3.0)
|
||||
return output
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
||||
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
||||
if samples.ndim == 3:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
else:
|
||||
@ -1442,7 +1444,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
|
||||
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
|
||||
elif clip_type == CLIPType.ACE:
|
||||
clip_target.clip = comfy.text_encoders.ace15.te(**llama_detect(clip_data))
|
||||
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
|
||||
if TEModel.QWEN3_4B in te_models:
|
||||
model_type = "qwen3_4b"
|
||||
else:
|
||||
model_type = "qwen3_2b"
|
||||
clip_target.clip = comfy.text_encoders.ace15.te(lm_model=model_type, **llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
|
||||
@ -1625,8 +1625,16 @@ class ACEStep15(supported_models_base.BASE):
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**hunyuan_detect))
|
||||
detect_2b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref))
|
||||
detect_4b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||
if "dtype_llama" in detect_2b:
|
||||
detect = detect_2b
|
||||
detect["lm_model"] = "qwen3_2b"
|
||||
elif "dtype_llama" in detect_4b:
|
||||
detect = detect_4b
|
||||
detect["lm_model"] = "qwen3_4b"
|
||||
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
|
||||
@ -3,6 +3,7 @@ import comfy.text_encoders.llama
|
||||
from comfy import sd1_clip
|
||||
import torch
|
||||
import math
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def sample_manual_loop_no_classes(
|
||||
@ -18,6 +19,7 @@ def sample_manual_loop_no_classes(
|
||||
min_tokens: int = 1,
|
||||
max_new_tokens: int = 2048,
|
||||
audio_start_id: int = 151669, # The cutoff ID for audio codes
|
||||
audio_end_id: int = 215669,
|
||||
eos_token_id: int = 151645,
|
||||
):
|
||||
device = model.execution_device
|
||||
@ -42,6 +44,8 @@ def sample_manual_loop_no_classes(
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
|
||||
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
|
||||
|
||||
for step in range(max_new_tokens):
|
||||
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
|
||||
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
|
||||
@ -54,8 +58,10 @@ def sample_manual_loop_no_classes(
|
||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
||||
eos_score = cfg_logits[:, eos_token_id].clone()
|
||||
|
||||
remove_logit_value = torch.finfo(cfg_logits.dtype).min
|
||||
# Only generate audio tokens
|
||||
cfg_logits[:, :audio_start_id] = float('-inf')
|
||||
cfg_logits[:, :audio_start_id] = remove_logit_value
|
||||
cfg_logits[:, audio_end_id:] = remove_logit_value
|
||||
|
||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
||||
cfg_logits[:, eos_token_id] = eos_score
|
||||
@ -63,7 +69,7 @@ def sample_manual_loop_no_classes(
|
||||
if top_k is not None and top_k > 0:
|
||||
top_k_vals, _ = torch.topk(cfg_logits, top_k)
|
||||
min_val = top_k_vals[..., -1, None]
|
||||
cfg_logits[cfg_logits < min_val] = float('-inf')
|
||||
cfg_logits[cfg_logits < min_val] = remove_logit_value
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
|
||||
@ -72,7 +78,7 @@ def sample_manual_loop_no_classes(
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
cfg_logits[indices_to_remove] = float('-inf')
|
||||
cfg_logits[indices_to_remove] = remove_logit_value
|
||||
|
||||
if temperature > 0:
|
||||
cfg_logits = cfg_logits / temperature
|
||||
@ -90,13 +96,12 @@ def sample_manual_loop_no_classes(
|
||||
attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
|
||||
|
||||
output_audio_codes.append(token - audio_start_id)
|
||||
progress_bar.update_absolute(step)
|
||||
|
||||
return output_audio_codes
|
||||
|
||||
|
||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0):
|
||||
cfg_scale = 2.0
|
||||
|
||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
|
||||
positive = [[token for token, _ in inner_list] for inner_list in positive]
|
||||
negative = [[token for token, _ in inner_list] for inner_list in negative]
|
||||
positive = positive[0]
|
||||
@ -113,7 +118,7 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
|
||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||
|
||||
paddings = [pos_pad, neg_pad]
|
||||
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
|
||||
|
||||
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
@ -130,6 +135,12 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
language = kwargs.get("language", "en")
|
||||
seed = kwargs.get("seed", 0)
|
||||
|
||||
generate_audio_codes = kwargs.get("generate_audio_codes", True)
|
||||
cfg_scale = kwargs.get("cfg_scale", 2.0)
|
||||
temperature = kwargs.get("temperature", 0.85)
|
||||
top_p = kwargs.get("top_p", 0.9)
|
||||
top_k = kwargs.get("top_k", 0.0)
|
||||
|
||||
duration = math.ceil(duration)
|
||||
meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature)
|
||||
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n<think>\n{}\n</think>\n\n<|im_end|>\n"
|
||||
@ -140,7 +151,14 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
||||
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs)
|
||||
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
|
||||
out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed}
|
||||
out["lm_metadata"] = {"min_tokens": duration * 5,
|
||||
"seed": seed,
|
||||
"generate_audio_codes": generate_audio_codes,
|
||||
"cfg_scale": cfg_scale,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
@ -157,14 +175,34 @@ class Qwen3_2B_ACE15(sd1_clip.SDClipModel):
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class Qwen3_4B_ACE15(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class ACE15TEModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, dtype_llama=None, model_options={}):
|
||||
def __init__(self, device="cpu", dtype=None, dtype_llama=None, lm_model=None, model_options={}):
|
||||
super().__init__()
|
||||
if dtype_llama is None:
|
||||
dtype_llama = dtype
|
||||
|
||||
model = None
|
||||
self.constant = 0.4375
|
||||
if lm_model == "qwen3_4b":
|
||||
model = Qwen3_4B_ACE15
|
||||
self.constant = 0.5625
|
||||
elif lm_model == "qwen3_2b":
|
||||
model = Qwen3_2B_ACE15
|
||||
|
||||
self.lm_model = lm_model
|
||||
self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options)
|
||||
self.qwen3_2b = Qwen3_2B_ACE15(device=device, dtype=dtype_llama, model_options=model_options)
|
||||
if model is not None:
|
||||
setattr(self, self.lm_model, model(device=device, dtype=dtype_llama, model_options=model_options))
|
||||
|
||||
self.dtypes = set([dtype, dtype_llama])
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
@ -176,18 +214,26 @@ class ACE15TEModel(torch.nn.Module):
|
||||
self.qwen3_06b.set_clip_options({"layer": [0]})
|
||||
lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics)
|
||||
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
audio_codes = generate_audio_codes(self.qwen3_2b, token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"])
|
||||
out = {"conditioning_lyrics": lyrics_embeds[:, 0]}
|
||||
|
||||
return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]}
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
if lm_metadata["generate_audio_codes"]:
|
||||
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
|
||||
out["audio_codes"] = [audio_codes]
|
||||
|
||||
return base_out, None, out
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.qwen3_06b.set_clip_options(options)
|
||||
self.qwen3_2b.set_clip_options(options)
|
||||
lm_model = getattr(self, self.lm_model, None)
|
||||
if lm_model is not None:
|
||||
lm_model.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.qwen3_06b.reset_clip_options()
|
||||
self.qwen3_2b.reset_clip_options()
|
||||
lm_model = getattr(self, self.lm_model, None)
|
||||
if lm_model is not None:
|
||||
lm_model.reset_clip_options()
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
@ -195,11 +241,11 @@ class ACE15TEModel(torch.nn.Module):
|
||||
if shape[0] == 1024:
|
||||
return self.qwen3_06b.load_sd(sd)
|
||||
else:
|
||||
return self.qwen3_2b.load_sd(sd)
|
||||
return getattr(self, self.lm_model).load_sd(sd)
|
||||
|
||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
constant = 0.4375
|
||||
constant = self.constant
|
||||
if comfy.model_management.should_use_bf16(device):
|
||||
constant *= 0.5
|
||||
|
||||
@ -208,11 +254,11 @@ class ACE15TEModel(torch.nn.Module):
|
||||
num_tokens += lm_metadata['min_tokens']
|
||||
return num_tokens * constant * 1024 * 1024
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"):
|
||||
class ACE15TEModel_(ACE15TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype_llama=dtype_llama, dtype=dtype, model_options=model_options)
|
||||
super().__init__(device=device, dtype_llama=dtype_llama, lm_model=lm_model, dtype=dtype, model_options=model_options)
|
||||
return ACE15TEModel_
|
||||
|
||||
@ -6,6 +6,7 @@ import math
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.clip_model
|
||||
|
||||
@ -149,6 +150,29 @@ class Qwen3_2B_ACE15_lm_Config:
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_4B_ACE15_lm_Config:
|
||||
vocab_size: int = 217204
|
||||
hidden_size: int = 2560
|
||||
intermediate_size: int = 9728
|
||||
num_hidden_layers: int = 36
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 40960
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_4BConfig:
|
||||
vocab_size: int = 151936
|
||||
@ -627,10 +651,10 @@ class Llama2_(nn.Module):
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4)
|
||||
|
||||
if seq_len > 1:
|
||||
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
@ -738,6 +762,21 @@ class BaseLlama:
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
return self.model(input_ids, *args, **kwargs)
|
||||
|
||||
class BaseQwen3:
|
||||
def logits(self, x):
|
||||
input = x[:, -1:]
|
||||
module = self.model.embed_tokens
|
||||
|
||||
offload_stream = None
|
||||
if module.comfy_cast_weights:
|
||||
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
|
||||
else:
|
||||
weight = self.model.embed_tokens.weight.to(x)
|
||||
|
||||
x = torch.nn.functional.linear(input, weight, None)
|
||||
|
||||
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
||||
return x
|
||||
|
||||
class Llama2(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
@ -766,7 +805,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_06B(BaseLlama, torch.nn.Module):
|
||||
class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_06BConfig(**config_dict)
|
||||
@ -775,7 +814,7 @@ class Qwen3_06B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module):
|
||||
class Qwen3_06B_ACE15(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_06B_ACE15_Config(**config_dict)
|
||||
@ -784,7 +823,7 @@ class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module):
|
||||
class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_2B_ACE15_lm_Config(**config_dict)
|
||||
@ -793,10 +832,7 @@ class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
def logits(self, x):
|
||||
return torch.nn.functional.linear(x[:, -1:], self.model.embed_tokens.weight.to(x), None)
|
||||
|
||||
class Qwen3_4B(BaseLlama, torch.nn.Module):
|
||||
class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_4BConfig(**config_dict)
|
||||
@ -805,7 +841,16 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_8B(BaseLlama, torch.nn.Module):
|
||||
class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_4B_ACE15_lm_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_8BConfig(**config_dict)
|
||||
|
||||
@ -82,14 +82,12 @@ _TYPES = {
|
||||
def load_safetensors(ckpt):
|
||||
f = open(ckpt, "rb")
|
||||
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||
mv = memoryview(mapping)
|
||||
|
||||
header_size = struct.unpack("<Q", mapping[:8])[0]
|
||||
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
|
||||
|
||||
with warnings.catch_warnings():
|
||||
#We are working with read-only RAM by design
|
||||
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
||||
data_area = torch.frombuffer(mapping, dtype=torch.uint8)[8 + header_size:]
|
||||
mv = mv[8 + header_size:]
|
||||
|
||||
sd = {}
|
||||
for name, info in header.items():
|
||||
@ -97,7 +95,13 @@ def load_safetensors(ckpt):
|
||||
continue
|
||||
|
||||
start, end = info["data_offsets"]
|
||||
sd[name] = data_area[start:end].view(_TYPES[info["dtype"]]).view(info["shape"])
|
||||
if start == end:
|
||||
sd[name] = torch.empty(info["shape"], dtype =_TYPES[info["dtype"]])
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
#We are working with read-only RAM by design
|
||||
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
||||
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
||||
|
||||
return sd, header.get("__metadata__", {}),
|
||||
|
||||
|
||||
@ -7,9 +7,10 @@ from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
from . import _node_replace_public as node_replace
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||
from PIL import Image
|
||||
@ -21,6 +22,14 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
VERSION = "latest"
|
||||
STABLE = False
|
||||
|
||||
class NodeReplacement(ProxiedSingleton):
|
||||
async def register(self, node_replace: 'node_replace.NodeReplace') -> None:
|
||||
"""Register a node replacement mapping."""
|
||||
from server import PromptServer
|
||||
PromptServer.instance.node_replace_manager.register(node_replace)
|
||||
|
||||
node_replacement: NodeReplacement
|
||||
|
||||
class Execution(ProxiedSingleton):
|
||||
async def set_progress(
|
||||
self,
|
||||
@ -105,6 +114,7 @@ class Types:
|
||||
VideoComponents = VideoComponents
|
||||
MESH = MESH
|
||||
VOXEL = VOXEL
|
||||
File3D = File3D
|
||||
|
||||
ComfyAPI = ComfyAPI_latest
|
||||
|
||||
@ -130,4 +140,5 @@ __all__ = [
|
||||
"IO",
|
||||
"ui",
|
||||
"UI",
|
||||
"node_replace",
|
||||
]
|
||||
|
||||
@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL, SVG as _SVG
|
||||
from ._util import MESH, VOXEL, SVG as _SVG, File3D
|
||||
|
||||
|
||||
class FolderType(str, Enum):
|
||||
@ -667,6 +667,49 @@ class Voxel(ComfyTypeIO):
|
||||
class Mesh(ComfyTypeIO):
|
||||
Type = MESH
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D")
|
||||
class File3DAny(ComfyTypeIO):
|
||||
"""General 3D file type - accepts any supported 3D format."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_GLB")
|
||||
class File3DGLB(ComfyTypeIO):
|
||||
"""GLB format 3D file - binary glTF, best for web and cross-platform."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_GLTF")
|
||||
class File3DGLTF(ComfyTypeIO):
|
||||
"""GLTF format 3D file - JSON-based glTF with external resources."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_FBX")
|
||||
class File3DFBX(ComfyTypeIO):
|
||||
"""FBX format 3D file - best for game engines and animation."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_OBJ")
|
||||
class File3DOBJ(ComfyTypeIO):
|
||||
"""OBJ format 3D file - simple geometry format."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_STL")
|
||||
class File3DSTL(ComfyTypeIO):
|
||||
"""STL format 3D file - best for 3D printing."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_USDZ")
|
||||
class File3DUSDZ(ComfyTypeIO):
|
||||
"""USDZ format 3D file - Apple AR format."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="HOOKS")
|
||||
class Hooks(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
@ -2037,6 +2080,13 @@ __all__ = [
|
||||
"LossMap",
|
||||
"Voxel",
|
||||
"Mesh",
|
||||
"File3DAny",
|
||||
"File3DGLB",
|
||||
"File3DGLTF",
|
||||
"File3DFBX",
|
||||
"File3DOBJ",
|
||||
"File3DSTL",
|
||||
"File3DUSDZ",
|
||||
"Hooks",
|
||||
"HookKeyframes",
|
||||
"TimestepsRange",
|
||||
|
||||
94
comfy_api/latest/_node_replace.py
Normal file
94
comfy_api/latest/_node_replace.py
Normal file
@ -0,0 +1,94 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class NodeReplace:
|
||||
"""
|
||||
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
|
||||
|
||||
Also supports assigning specific values to the input widgets of the new node.
|
||||
"""
|
||||
def __init__(self,
|
||||
new_node_id: str,
|
||||
old_node_id: str,
|
||||
old_widget_ids: list[str] | None=None,
|
||||
input_mapping: list[InputMap] | None=None,
|
||||
output_mapping: list[OutputMap] | None=None,
|
||||
):
|
||||
self.new_node_id = new_node_id
|
||||
self.old_node_id = old_node_id
|
||||
self.old_widget_ids = old_widget_ids
|
||||
self.input_mapping = input_mapping
|
||||
self.output_mapping = output_mapping
|
||||
|
||||
def as_dict(self):
|
||||
"""Create serializable representation of the node replacement."""
|
||||
return {
|
||||
"new_node_id": self.new_node_id,
|
||||
"old_node_id": self.old_node_id,
|
||||
"old_widget_ids": self.old_widget_ids,
|
||||
"input_mapping": [m.as_dict() for m in self.input_mapping] if self.input_mapping else None,
|
||||
"output_mapping": [m.as_dict() for m in self.output_mapping] if self.output_mapping else None,
|
||||
}
|
||||
|
||||
|
||||
class InputMap:
|
||||
"""
|
||||
Map inputs of node replacement.
|
||||
|
||||
Use InputMap.OldId or InputMap.SetValue for mapping purposes.
|
||||
"""
|
||||
class _Assign:
|
||||
def __init__(self, assign_type: str):
|
||||
self.assign_type = assign_type
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"assign_type": self.assign_type,
|
||||
}
|
||||
|
||||
class OldId(_Assign):
|
||||
"""Connect the input of the old node with given id to new node when replacing."""
|
||||
def __init__(self, old_id: str):
|
||||
super().__init__("old_id")
|
||||
self.old_id = old_id
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | {
|
||||
"old_id": self.old_id,
|
||||
}
|
||||
|
||||
class SetValue(_Assign):
|
||||
"""Use the given value for the input of the new node when replacing; assumes input is a widget."""
|
||||
def __init__(self, value: Any):
|
||||
super().__init__("set_value")
|
||||
self.value = value
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | {
|
||||
"value": self.value,
|
||||
}
|
||||
|
||||
def __init__(self, new_id: str, assign: OldId | SetValue):
|
||||
self.new_id = new_id
|
||||
self.assign = assign
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"new_id": self.new_id,
|
||||
"assign": self.assign.as_dict(),
|
||||
}
|
||||
|
||||
|
||||
class OutputMap:
|
||||
"""Map outputs of node replacement via indexes, as that's how outputs are stored."""
|
||||
def __init__(self, new_idx: int, old_idx: int):
|
||||
self.new_idx = new_idx
|
||||
self.old_idx = old_idx
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"new_idx": self.new_idx,
|
||||
"old_idx": self.old_idx,
|
||||
}
|
||||
1
comfy_api/latest/_node_replace_public.py
Normal file
1
comfy_api/latest/_node_replace_public.py
Normal file
@ -0,0 +1 @@
|
||||
from ._node_replace import * # noqa: F403
|
||||
@ -1,5 +1,5 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
from .geometry_types import VOXEL, MESH
|
||||
from .geometry_types import VOXEL, MESH, File3D
|
||||
from .image_types import SVG
|
||||
|
||||
__all__ = [
|
||||
@ -9,5 +9,6 @@ __all__ = [
|
||||
"VideoComponents",
|
||||
"VOXEL",
|
||||
"MESH",
|
||||
"File3D",
|
||||
"SVG",
|
||||
]
|
||||
|
||||
@ -1,3 +1,8 @@
|
||||
import shutil
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import IO
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@ -10,3 +15,75 @@ class MESH:
|
||||
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor):
|
||||
self.vertices = vertices
|
||||
self.faces = faces
|
||||
|
||||
|
||||
class File3D:
|
||||
"""Class representing a 3D file from a file path or binary stream.
|
||||
|
||||
Supports both disk-backed (file path) and memory-backed (BytesIO) storage.
|
||||
"""
|
||||
|
||||
def __init__(self, source: str | IO[bytes], file_format: str = ""):
|
||||
self._source = source
|
||||
self._format = file_format or self._infer_format()
|
||||
|
||||
def _infer_format(self) -> str:
|
||||
if isinstance(self._source, str):
|
||||
return Path(self._source).suffix.lstrip(".").lower()
|
||||
return ""
|
||||
|
||||
@property
|
||||
def format(self) -> str:
|
||||
return self._format
|
||||
|
||||
@format.setter
|
||||
def format(self, value: str) -> None:
|
||||
self._format = value.lstrip(".").lower() if value else ""
|
||||
|
||||
@property
|
||||
def is_disk_backed(self) -> bool:
|
||||
return isinstance(self._source, str)
|
||||
|
||||
def get_source(self) -> str | IO[bytes]:
|
||||
if isinstance(self._source, str):
|
||||
return self._source
|
||||
if hasattr(self._source, "seek"):
|
||||
self._source.seek(0)
|
||||
return self._source
|
||||
|
||||
def get_data(self) -> BytesIO:
|
||||
if isinstance(self._source, str):
|
||||
with open(self._source, "rb") as f:
|
||||
result = BytesIO(f.read())
|
||||
return result
|
||||
if hasattr(self._source, "seek"):
|
||||
self._source.seek(0)
|
||||
if isinstance(self._source, BytesIO):
|
||||
return self._source
|
||||
return BytesIO(self._source.read())
|
||||
|
||||
def save_to(self, path: str) -> str:
|
||||
dest = Path(path)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if isinstance(self._source, str):
|
||||
if Path(self._source).resolve() != dest.resolve():
|
||||
shutil.copy2(self._source, dest)
|
||||
else:
|
||||
if hasattr(self._source, "seek"):
|
||||
self._source.seek(0)
|
||||
with open(dest, "wb") as f:
|
||||
f.write(self._source.read())
|
||||
return str(dest)
|
||||
|
||||
def get_bytes(self) -> bytes:
|
||||
if isinstance(self._source, str):
|
||||
return Path(self._source).read_bytes()
|
||||
if hasattr(self._source, "seek"):
|
||||
self._source.seek(0)
|
||||
return self._source.read()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if isinstance(self._source, str):
|
||||
return f"File3D(source={self._source!r}, format={self._format!r})"
|
||||
return f"File3D(<stream>, format={self._format!r})"
|
||||
|
||||
@ -6,7 +6,7 @@ from comfy_api.latest import (
|
||||
)
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
|
||||
from comfy_api.latest import io, ui, IO, UI, ComfyExtension, node_replace #noqa: F401
|
||||
|
||||
|
||||
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
||||
@ -46,4 +46,5 @@ __all__ = [
|
||||
"IO",
|
||||
"ui",
|
||||
"UI",
|
||||
"node_replace",
|
||||
]
|
||||
|
||||
@ -109,14 +109,19 @@ class MeshyTextureRequest(BaseModel):
|
||||
|
||||
class MeshyModelsUrls(BaseModel):
|
||||
glb: str = Field("")
|
||||
fbx: str = Field("")
|
||||
usdz: str = Field("")
|
||||
obj: str = Field("")
|
||||
|
||||
|
||||
class MeshyRiggedModelsUrls(BaseModel):
|
||||
rigged_character_glb_url: str = Field("")
|
||||
rigged_character_fbx_url: str = Field("")
|
||||
|
||||
|
||||
class MeshyAnimatedModelsUrls(BaseModel):
|
||||
animation_glb_url: str = Field("")
|
||||
animation_fbx_url: str = Field("")
|
||||
|
||||
|
||||
class MeshyResultTextureUrls(BaseModel):
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
import os
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
@ -14,7 +12,7 @@ from comfy_api_nodes.apis.hunyuan3d import (
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
downscale_image_tensor_by_max_side,
|
||||
poll_op,
|
||||
sync_op,
|
||||
@ -22,14 +20,13 @@ from comfy_api_nodes.util import (
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
)
|
||||
from folder_paths import get_output_directory
|
||||
|
||||
|
||||
def get_glb_obj_from_response(response_objs: list[ResultFile3D]) -> ResultFile3D:
|
||||
def get_file_from_response(response_objs: list[ResultFile3D], file_type: str) -> ResultFile3D | None:
|
||||
for i in response_objs:
|
||||
if i.Type.lower() == "glb":
|
||||
if i.Type.lower() == file_type.lower():
|
||||
return i
|
||||
raise ValueError("No GLB file found in response. Please report this to the developers.")
|
||||
return None
|
||||
|
||||
|
||||
class TencentTextToModelNode(IO.ComfyNode):
|
||||
@ -74,7 +71,9 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -124,19 +123,20 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
task_id = response.JobId
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
data=To3DProTaskQueryRequest(JobId=task_id),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
model_file = f"hunyuan_model_{response.JobId}.glb"
|
||||
await download_url_to_bytesio(
|
||||
get_glb_obj_from_response(result.ResultFile3Ds).Url,
|
||||
os.path.join(get_output_directory(), model_file),
|
||||
glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
|
||||
obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
|
||||
file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
|
||||
return IO.NodeOutput(
|
||||
file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
|
||||
)
|
||||
return IO.NodeOutput(model_file)
|
||||
|
||||
|
||||
class TencentImageToModelNode(IO.ComfyNode):
|
||||
@ -184,7 +184,9 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -269,19 +271,20 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
task_id = response.JobId
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
data=To3DProTaskQueryRequest(JobId=task_id),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
model_file = f"hunyuan_model_{response.JobId}.glb"
|
||||
await download_url_to_bytesio(
|
||||
get_glb_obj_from_response(result.ResultFile3Ds).Url,
|
||||
os.path.join(get_output_directory(), model_file),
|
||||
glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
|
||||
obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
|
||||
file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
|
||||
return IO.NodeOutput(
|
||||
file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
|
||||
)
|
||||
return IO.NodeOutput(model_file)
|
||||
|
||||
|
||||
class TencentHunyuan3DExtension(ComfyExtension):
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
import os
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
@ -20,13 +18,12 @@ from comfy_api_nodes.apis.meshy import (
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
from folder_paths import get_output_directory
|
||||
|
||||
|
||||
class MeshyTextToModelNode(IO.ComfyNode):
|
||||
@ -79,8 +76,10 @@ class MeshyTextToModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -122,16 +121,20 @@ class MeshyTextToModelNode(IO.ComfyNode):
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyRefineNode(IO.ComfyNode):
|
||||
@ -167,8 +170,10 @@ class MeshyRefineNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -210,16 +215,20 @@ class MeshyRefineNode(IO.ComfyNode):
|
||||
ai_model=model,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyImageToModelNode(IO.ComfyNode):
|
||||
@ -303,8 +312,10 @@ class MeshyImageToModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -368,16 +379,20 @@ class MeshyImageToModelNode(IO.ComfyNode):
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{task_id}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyMultiImageToModelNode(IO.ComfyNode):
|
||||
@ -464,8 +479,10 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -531,16 +548,20 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{task_id}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyRigModelNode(IO.ComfyNode):
|
||||
@ -571,8 +592,10 @@ class MeshyRigModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MESHY_RIGGED_TASK_ID").Output(display_name="rig_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -606,18 +629,20 @@ class MeshyRigModelNode(IO.ComfyNode):
|
||||
texture_image_url=texture_image_url,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{task_id}"),
|
||||
response_model=MeshyRiggedResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(
|
||||
result.result.rigged_character_glb_url, os.path.join(get_output_directory(), model_file)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.result.rigged_character_glb_url, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.result.rigged_character_fbx_url, "fbx", task_id=task_id),
|
||||
)
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
|
||||
|
||||
class MeshyAnimateModelNode(IO.ComfyNode):
|
||||
@ -640,7 +665,9 @@ class MeshyAnimateModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -669,16 +696,19 @@ class MeshyAnimateModelNode(IO.ComfyNode):
|
||||
action_id=action_id,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{task_id}"),
|
||||
response_model=MeshyAnimationResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.result.animation_glb_url, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
await download_url_to_file_3d(result.result.animation_glb_url, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.result.animation_fbx_url, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyTextureNode(IO.ComfyNode):
|
||||
@ -715,8 +745,10 @@ class MeshyTextureNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -760,16 +792,20 @@ class MeshyTextureNode(IO.ComfyNode):
|
||||
image_style_url=image_style_url,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{task_id}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyExtension(ComfyExtension):
|
||||
|
||||
@ -10,7 +10,6 @@ import folder_paths as comfy_paths
|
||||
import os
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
from io import BytesIO
|
||||
from typing_extensions import override
|
||||
from PIL import Image
|
||||
@ -28,8 +27,9 @@ from comfy_api_nodes.util import (
|
||||
poll_op,
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
)
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
|
||||
|
||||
COMMON_PARAMETERS = [
|
||||
@ -177,7 +177,7 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
|
||||
return "DONE"
|
||||
return "Generating"
|
||||
|
||||
def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]:
|
||||
def extract_progress(response: Rodin3DCheckStatusResponse) -> int | None:
|
||||
if not response.jobs:
|
||||
return None
|
||||
completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done)
|
||||
@ -207,17 +207,25 @@ async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3D
|
||||
)
|
||||
|
||||
|
||||
async def download_files(url_list, task_uuid: str):
|
||||
async def download_files(url_list, task_uuid: str) -> tuple[str | None, Types.File3D | None]:
|
||||
result_folder_name = f"Rodin3D_{task_uuid}"
|
||||
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
model_file_path = None
|
||||
file_3d = None
|
||||
|
||||
for i in url_list.list:
|
||||
file_path = os.path.join(save_path, i.name)
|
||||
if file_path.endswith(".glb"):
|
||||
if i.name.lower().endswith(".glb"):
|
||||
model_file_path = os.path.join(result_folder_name, i.name)
|
||||
await download_url_to_bytesio(i.url, file_path)
|
||||
return model_file_path
|
||||
file_3d = await download_url_to_file_3d(i.url, "glb")
|
||||
# Save to disk for backward compatibility
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_3d.get_bytes())
|
||||
else:
|
||||
await download_url_to_bytesio(i.url, file_path)
|
||||
|
||||
return model_file_path, file_3d
|
||||
|
||||
|
||||
class Rodin3D_Regular(IO.ComfyNode):
|
||||
@ -234,7 +242,10 @@ class Rodin3D_Regular(IO.ComfyNode):
|
||||
IO.Image.Input("Images"),
|
||||
*COMMON_PARAMETERS,
|
||||
],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
@ -271,9 +282,9 @@ class Rodin3D_Regular(IO.ComfyNode):
|
||||
)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
model_path, file_3d = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
class Rodin3D_Detail(IO.ComfyNode):
|
||||
@ -290,7 +301,10 @@ class Rodin3D_Detail(IO.ComfyNode):
|
||||
IO.Image.Input("Images"),
|
||||
*COMMON_PARAMETERS,
|
||||
],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
@ -327,9 +341,9 @@ class Rodin3D_Detail(IO.ComfyNode):
|
||||
)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
model_path, file_3d = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
class Rodin3D_Smooth(IO.ComfyNode):
|
||||
@ -346,7 +360,10 @@ class Rodin3D_Smooth(IO.ComfyNode):
|
||||
IO.Image.Input("Images"),
|
||||
*COMMON_PARAMETERS,
|
||||
],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
@ -382,9 +399,9 @@ class Rodin3D_Smooth(IO.ComfyNode):
|
||||
)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
model_path, file_3d = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
class Rodin3D_Sketch(IO.ComfyNode):
|
||||
@ -408,7 +425,10 @@ class Rodin3D_Sketch(IO.ComfyNode):
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
@ -441,9 +461,9 @@ class Rodin3D_Sketch(IO.ComfyNode):
|
||||
)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
model_path, file_3d = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
class Rodin3D_Gen2(IO.ComfyNode):
|
||||
@ -475,7 +495,10 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
),
|
||||
IO.Boolean.Input("TAPose", default=False),
|
||||
],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
@ -511,9 +534,9 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
model_path, file_3d = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
class Rodin3DExtension(ComfyExtension):
|
||||
|
||||
@ -1,10 +1,6 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.tripo import (
|
||||
TripoAnimateRetargetRequest,
|
||||
TripoAnimateRigRequest,
|
||||
@ -26,12 +22,11 @@ from comfy_api_nodes.apis.tripo import (
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_as_bytesio,
|
||||
download_url_to_file_3d,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
)
|
||||
from folder_paths import get_output_directory
|
||||
|
||||
|
||||
def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
||||
@ -45,7 +40,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
||||
async def poll_until_finished(
|
||||
node_cls: type[IO.ComfyNode],
|
||||
response: TripoTaskResponse,
|
||||
average_duration: Optional[int] = None,
|
||||
average_duration: int | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
"""Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
if response.code != 0:
|
||||
@ -69,12 +64,8 @@ async def poll_until_finished(
|
||||
)
|
||||
if response_poll.data.status == TripoTaskStatus.SUCCESS:
|
||||
url = get_model_url_from_response(response_poll)
|
||||
bytesio = await download_url_as_bytesio(url)
|
||||
# Save the downloaded model file
|
||||
model_file = f"tripo_model_{task_id}.glb"
|
||||
with open(os.path.join(get_output_directory(), model_file), "wb") as f:
|
||||
f.write(bytesio.getvalue())
|
||||
return IO.NodeOutput(model_file, task_id)
|
||||
file_glb = await download_url_to_file_3d(url, "glb", task_id=task_id)
|
||||
return IO.NodeOutput(f"{task_id}.glb", task_id, file_glb)
|
||||
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
|
||||
|
||||
|
||||
@ -107,8 +98,9 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -155,18 +147,18 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
negative_prompt: str | None = None,
|
||||
model_version=None,
|
||||
style: Optional[str] = None,
|
||||
texture: Optional[bool] = None,
|
||||
pbr: Optional[bool] = None,
|
||||
image_seed: Optional[int] = None,
|
||||
model_seed: Optional[int] = None,
|
||||
texture_seed: Optional[int] = None,
|
||||
texture_quality: Optional[str] = None,
|
||||
geometry_quality: Optional[str] = None,
|
||||
face_limit: Optional[int] = None,
|
||||
quad: Optional[bool] = None,
|
||||
style: str | None = None,
|
||||
texture: bool | None = None,
|
||||
pbr: bool | None = None,
|
||||
image_seed: int | None = None,
|
||||
model_seed: int | None = None,
|
||||
texture_seed: int | None = None,
|
||||
texture_quality: str | None = None,
|
||||
geometry_quality: str | None = None,
|
||||
face_limit: int | None = None,
|
||||
quad: bool | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
style_enum = None if style == "None" else style
|
||||
if not prompt:
|
||||
@ -232,8 +224,9 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -279,19 +272,19 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
model_version: Optional[str] = None,
|
||||
style: Optional[str] = None,
|
||||
texture: Optional[bool] = None,
|
||||
pbr: Optional[bool] = None,
|
||||
model_seed: Optional[int] = None,
|
||||
image: Input.Image,
|
||||
model_version: str | None = None,
|
||||
style: str | None = None,
|
||||
texture: bool | None = None,
|
||||
pbr: bool | None = None,
|
||||
model_seed: int | None = None,
|
||||
orientation=None,
|
||||
texture_seed: Optional[int] = None,
|
||||
texture_quality: Optional[str] = None,
|
||||
geometry_quality: Optional[str] = None,
|
||||
texture_alignment: Optional[str] = None,
|
||||
face_limit: Optional[int] = None,
|
||||
quad: Optional[bool] = None,
|
||||
texture_seed: int | None = None,
|
||||
texture_quality: str | None = None,
|
||||
geometry_quality: str | None = None,
|
||||
texture_alignment: str | None = None,
|
||||
face_limit: int | None = None,
|
||||
quad: bool | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
style_enum = None if style == "None" else style
|
||||
if image is None:
|
||||
@ -368,8 +361,9 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -411,21 +405,21 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
image_left: Optional[torch.Tensor] = None,
|
||||
image_back: Optional[torch.Tensor] = None,
|
||||
image_right: Optional[torch.Tensor] = None,
|
||||
model_version: Optional[str] = None,
|
||||
orientation: Optional[str] = None,
|
||||
texture: Optional[bool] = None,
|
||||
pbr: Optional[bool] = None,
|
||||
model_seed: Optional[int] = None,
|
||||
texture_seed: Optional[int] = None,
|
||||
texture_quality: Optional[str] = None,
|
||||
geometry_quality: Optional[str] = None,
|
||||
texture_alignment: Optional[str] = None,
|
||||
face_limit: Optional[int] = None,
|
||||
quad: Optional[bool] = None,
|
||||
image: Input.Image,
|
||||
image_left: Input.Image | None = None,
|
||||
image_back: Input.Image | None = None,
|
||||
image_right: Input.Image | None = None,
|
||||
model_version: str | None = None,
|
||||
orientation: str | None = None,
|
||||
texture: bool | None = None,
|
||||
pbr: bool | None = None,
|
||||
model_seed: int | None = None,
|
||||
texture_seed: int | None = None,
|
||||
texture_quality: str | None = None,
|
||||
geometry_quality: str | None = None,
|
||||
texture_alignment: str | None = None,
|
||||
face_limit: int | None = None,
|
||||
quad: bool | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if image is None:
|
||||
raise RuntimeError("front image for multiview is required")
|
||||
@ -487,8 +481,9 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -512,11 +507,11 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model_task_id,
|
||||
texture: Optional[bool] = None,
|
||||
pbr: Optional[bool] = None,
|
||||
texture_seed: Optional[int] = None,
|
||||
texture_quality: Optional[str] = None,
|
||||
texture_alignment: Optional[str] = None,
|
||||
texture: bool | None = None,
|
||||
pbr: bool | None = None,
|
||||
texture_seed: int | None = None,
|
||||
texture_quality: str | None = None,
|
||||
texture_alignment: str | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
@ -547,8 +542,9 @@ class TripoRefineNode(IO.ComfyNode):
|
||||
IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -583,8 +579,9 @@ class TripoRigNode(IO.ComfyNode):
|
||||
category="api node/3d/Tripo",
|
||||
inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("RIG_TASK_ID").Output(display_name="rig task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -642,8 +639,9 @@ class TripoRetargetNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("RETARGET_TASK_ID").Output(display_name="retarget task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
|
||||
@ -28,6 +28,7 @@ from .conversions import (
|
||||
from .download_helpers import (
|
||||
download_url_as_bytesio,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
)
|
||||
@ -69,6 +70,7 @@ __all__ = [
|
||||
# Download helpers
|
||||
"download_url_as_bytesio",
|
||||
"download_url_to_bytesio",
|
||||
"download_url_to_file_3d",
|
||||
"download_url_to_image_tensor",
|
||||
"download_url_to_video_output",
|
||||
# Conversions
|
||||
|
||||
@ -11,7 +11,8 @@ import torch
|
||||
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||
|
||||
from comfy_api.latest import IO as COMFY_IO
|
||||
from comfy_api.latest import InputImpl
|
||||
from comfy_api.latest import InputImpl, Types
|
||||
from folder_paths import get_output_directory
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
@ -261,3 +262,38 @@ def _generate_operation_id(method: str, url: str, attempt: int) -> str:
|
||||
except Exception:
|
||||
slug = "download"
|
||||
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
async def download_url_to_file_3d(
|
||||
url: str,
|
||||
file_format: str,
|
||||
*,
|
||||
task_id: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 5,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> Types.File3D:
|
||||
"""Downloads a 3D model file from a URL into memory as BytesIO.
|
||||
|
||||
If task_id is provided, also writes the file to disk in the output directory
|
||||
for backward compatibility with the old save-to-disk behavior.
|
||||
"""
|
||||
file_format = file_format.lstrip(".").lower()
|
||||
data = BytesIO()
|
||||
await download_url_to_bytesio(
|
||||
url,
|
||||
data,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
cls=cls,
|
||||
)
|
||||
|
||||
if task_id is not None:
|
||||
# This is only for backward compatability with current behavior when every 3D node is output node
|
||||
# All new API nodes should not use "task_id" and instead users should use "SaveGLB" node to save results
|
||||
output_dir = Path(get_output_directory())
|
||||
output_path = output_dir / f"{task_id}.{file_format}"
|
||||
output_path.write_bytes(data.getvalue())
|
||||
data.seek(0)
|
||||
|
||||
return Types.File3D(source=data, file_format=file_format)
|
||||
|
||||
@ -44,13 +44,18 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
|
||||
io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
|
||||
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
|
||||
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
|
||||
io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
|
||||
io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
|
||||
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
||||
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
||||
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed)
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
@ -100,14 +105,15 @@ class EmptyAceStep15LatentAudio(io.ComfyNode):
|
||||
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
class ReferenceTimbreAudio(io.ComfyNode):
|
||||
class ReferenceAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ReferenceTimbreAudio",
|
||||
display_name="Reference Audio",
|
||||
category="advanced/conditioning/audio",
|
||||
is_experimental=True,
|
||||
description="This node sets the reference audio for timbre (for ace step 1.5)",
|
||||
description="This node sets the reference audio for ace step 1.5",
|
||||
inputs=[
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Latent.Input("latent", optional=True),
|
||||
@ -131,7 +137,7 @@ class AceExtension(ComfyExtension):
|
||||
EmptyAceStepLatentAudio,
|
||||
TextEncodeAceStepAudio15,
|
||||
EmptyAceStep15LatentAudio,
|
||||
ReferenceTimbreAudio,
|
||||
ReferenceAudio,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AceExtension:
|
||||
|
||||
@ -618,18 +618,31 @@ class SaveGLB(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveGLB",
|
||||
display_name="Save 3D Model",
|
||||
search_aliases=["export 3d model", "save mesh"],
|
||||
category="3d",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.MultiType.Input(
|
||||
IO.Mesh.Input("mesh"),
|
||||
types=[
|
||||
IO.File3DGLB,
|
||||
IO.File3DGLTF,
|
||||
IO.File3DOBJ,
|
||||
IO.File3DFBX,
|
||||
IO.File3DSTL,
|
||||
IO.File3DUSDZ,
|
||||
IO.File3DAny,
|
||||
],
|
||||
tooltip="Mesh or 3D file to save",
|
||||
),
|
||||
IO.String.Input("filename_prefix", default="mesh/ComfyUI"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, filename_prefix) -> IO.NodeOutput:
|
||||
def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||
results = []
|
||||
|
||||
@ -641,15 +654,27 @@ class SaveGLB(IO.ComfyNode):
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
|
||||
if isinstance(mesh, Types.File3D):
|
||||
# Handle File3D input - save BytesIO data to output folder
|
||||
ext = mesh.format or "glb"
|
||||
f = f"{filename}_{counter:05}_.{ext}"
|
||||
mesh.save_to(os.path.join(full_output_folder, f))
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
"type": "output"
|
||||
})
|
||||
counter += 1
|
||||
else:
|
||||
# Handle Mesh input - save vertices and faces as GLB
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
"type": "output"
|
||||
})
|
||||
counter += 1
|
||||
return IO.NodeOutput(ui={"3d": results})
|
||||
|
||||
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import nodes
|
||||
import folder_paths
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import IO, ComfyExtension, InputImpl, UI
|
||||
from comfy_api.latest import IO, UI, ComfyExtension, InputImpl, Types
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@ -44,6 +45,7 @@ class Load3D(IO.ComfyNode):
|
||||
IO.Image.Output(display_name="normal"),
|
||||
IO.Load3DCamera.Output(display_name="camera_info"),
|
||||
IO.Video.Output(display_name="recording_video"),
|
||||
IO.File3DAny.Output(display_name="model_3d"),
|
||||
],
|
||||
)
|
||||
|
||||
@ -65,7 +67,8 @@ class Load3D(IO.ComfyNode):
|
||||
|
||||
video = InputImpl.VideoFromFile(recording_video_path)
|
||||
|
||||
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video)
|
||||
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
|
||||
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d)
|
||||
|
||||
process = execute # TODO: remove
|
||||
|
||||
@ -81,7 +84,19 @@ class Preview3D(IO.ComfyNode):
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
IO.String.Input("model_file", default="", multiline=False),
|
||||
IO.MultiType.Input(
|
||||
IO.String.Input("model_file", default="", multiline=False),
|
||||
types=[
|
||||
IO.File3DGLB,
|
||||
IO.File3DGLTF,
|
||||
IO.File3DFBX,
|
||||
IO.File3DOBJ,
|
||||
IO.File3DSTL,
|
||||
IO.File3DUSDZ,
|
||||
IO.File3DAny,
|
||||
],
|
||||
tooltip="3D model file or path string",
|
||||
),
|
||||
IO.Load3DCamera.Input("camera_info", optional=True),
|
||||
IO.Image.Input("bg_image", optional=True),
|
||||
],
|
||||
@ -89,10 +104,15 @@ class Preview3D(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_file, **kwargs) -> IO.NodeOutput:
|
||||
def execute(cls, model_file: str | Types.File3D, **kwargs) -> IO.NodeOutput:
|
||||
if isinstance(model_file, Types.File3D):
|
||||
filename = f"preview3d_{uuid.uuid4().hex}.{model_file.format}"
|
||||
model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename))
|
||||
else:
|
||||
filename = model_file
|
||||
camera_info = kwargs.get("camera_info", None)
|
||||
bg_image = kwargs.get("bg_image", None)
|
||||
return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image))
|
||||
return IO.NodeOutput(ui=UI.PreviewUI3D(filename, camera_info, bg_image=bg_image))
|
||||
|
||||
process = execute # TODO: remove
|
||||
|
||||
|
||||
@ -655,6 +655,103 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
|
||||
batched = batch_masks(values)
|
||||
return io.NodeOutput(batched)
|
||||
|
||||
|
||||
from comfy_api.latest import node_replace
|
||||
from server import PromptServer
|
||||
|
||||
def _register(nr: node_replace.NodeReplace):
|
||||
"""Helper to register replacements via PromptServer."""
|
||||
PromptServer.instance.node_replace_manager.register(nr)
|
||||
|
||||
async def register_replacements():
|
||||
"""Register all built-in node replacements."""
|
||||
register_replacements_longeredge()
|
||||
register_replacements_batchimages()
|
||||
register_replacements_upscaleimage()
|
||||
register_replacements_controlnet()
|
||||
register_replacements_load3d()
|
||||
register_replacements_preview3d()
|
||||
register_replacements_svdimg2vid()
|
||||
register_replacements_conditioningavg()
|
||||
|
||||
def register_replacements_longeredge():
|
||||
# No dynamic inputs here
|
||||
_register(node_replace.NodeReplace(
|
||||
new_node_id="ImageScaleToMaxDimension",
|
||||
old_node_id="ResizeImagesByLongerEdge",
|
||||
old_widget_ids=["longer_edge"],
|
||||
input_mapping=[
|
||||
node_replace.InputMap(new_id="image", assign=node_replace.InputMap.OldId("images")),
|
||||
node_replace.InputMap(new_id="largest_size", assign=node_replace.InputMap.OldId("longer_edge")),
|
||||
node_replace.InputMap(new_id="upscale_method", assign=node_replace.InputMap.SetValue("lanczos")),
|
||||
],
|
||||
# just to test the frontend output_mapping code, does nothing really here
|
||||
output_mapping=[node_replace.OutputMap(new_idx=0, old_idx=0)],
|
||||
))
|
||||
|
||||
def register_replacements_batchimages():
|
||||
# BatchImages node uses Autogrow
|
||||
_register(node_replace.NodeReplace(
|
||||
new_node_id="BatchImagesNode",
|
||||
old_node_id="ImageBatch",
|
||||
input_mapping=[
|
||||
node_replace.InputMap(new_id="images.image0", assign=node_replace.InputMap.OldId("image1")),
|
||||
node_replace.InputMap(new_id="images.image1", assign=node_replace.InputMap.OldId("image2")),
|
||||
],
|
||||
))
|
||||
|
||||
def register_replacements_upscaleimage():
|
||||
# ResizeImageMaskNode uses DynamicCombo
|
||||
_register(node_replace.NodeReplace(
|
||||
new_node_id="ResizeImageMaskNode",
|
||||
old_node_id="ImageScaleBy",
|
||||
old_widget_ids=["upscale_method", "scale_by"],
|
||||
input_mapping=[
|
||||
node_replace.InputMap(new_id="input", assign=node_replace.InputMap.OldId("image")),
|
||||
node_replace.InputMap(new_id="resize_type", assign=node_replace.InputMap.SetValue("scale by multiplier")),
|
||||
node_replace.InputMap(new_id="resize_type.multiplier", assign=node_replace.InputMap.OldId("scale_by")),
|
||||
node_replace.InputMap(new_id="scale_method", assign=node_replace.InputMap.OldId("upscale_method")),
|
||||
],
|
||||
))
|
||||
|
||||
def register_replacements_controlnet():
|
||||
# T2IAdapterLoader → ControlNetLoader
|
||||
_register(node_replace.NodeReplace(
|
||||
new_node_id="ControlNetLoader",
|
||||
old_node_id="T2IAdapterLoader",
|
||||
input_mapping=[
|
||||
node_replace.InputMap(new_id="control_net_name", assign=node_replace.InputMap.OldId("t2i_adapter_name")),
|
||||
],
|
||||
))
|
||||
|
||||
def register_replacements_load3d():
|
||||
# Load3DAnimation merged into Load3D
|
||||
_register(node_replace.NodeReplace(
|
||||
new_node_id="Load3D",
|
||||
old_node_id="Load3DAnimation",
|
||||
))
|
||||
|
||||
def register_replacements_preview3d():
|
||||
# Preview3DAnimation merged into Preview3D
|
||||
_register(node_replace.NodeReplace(
|
||||
new_node_id="Preview3D",
|
||||
old_node_id="Preview3DAnimation",
|
||||
))
|
||||
|
||||
def register_replacements_svdimg2vid():
|
||||
# Typo fix: SDV → SVD
|
||||
_register(node_replace.NodeReplace(
|
||||
new_node_id="SVD_img2vid_Conditioning",
|
||||
old_node_id="SDV_img2vid_Conditioning",
|
||||
))
|
||||
|
||||
def register_replacements_conditioningavg():
|
||||
# Typo fix: trailing space in node name
|
||||
_register(node_replace.NodeReplace(
|
||||
new_node_id="ConditioningAverage",
|
||||
old_node_id="ConditioningAverage ",
|
||||
))
|
||||
|
||||
class PostProcessingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@ -672,4 +769,5 @@ class PostProcessingExtension(ComfyExtension):
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> PostProcessingExtension:
|
||||
await register_replacements()
|
||||
return PostProcessingExtension()
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.12.0"
|
||||
__version__ = "0.12.2"
|
||||
|
||||
7
main.py
7
main.py
@ -192,7 +192,10 @@ import comfy_aimdo.control
|
||||
import comfy_aimdo.torch
|
||||
|
||||
if enables_dynamic_vram():
|
||||
if comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
||||
if comfy.model_management.torch_version_numeric < (2, 8):
|
||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
comfy.memory_management.aimdo_allocator = None
|
||||
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
||||
if args.verbose == 'DEBUG':
|
||||
comfy_aimdo.control.set_log_debug()
|
||||
elif args.verbose == 'CRITICAL':
|
||||
@ -208,7 +211,7 @@ if enables_dynamic_vram():
|
||||
comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator()
|
||||
logging.info("DynamicVRAM support detected and enabled")
|
||||
else:
|
||||
logging.info("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
comfy.memory_management.aimdo_allocator = None
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.12.0"
|
||||
version = "0.12.2"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -40,6 +40,7 @@ from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from app.subgraph_manager import SubgraphManager
|
||||
from app.node_replace_manager import NodeReplaceManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from protocol import BinaryEventTypes
|
||||
@ -204,6 +205,7 @@ class PromptServer():
|
||||
self.model_file_manager = ModelFileManager()
|
||||
self.custom_node_manager = CustomNodeManager()
|
||||
self.subgraph_manager = SubgraphManager()
|
||||
self.node_replace_manager = NodeReplaceManager()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = execution.PromptQueue(self)
|
||||
@ -995,6 +997,7 @@ class PromptServer():
|
||||
self.model_file_manager.add_routes(self.routes)
|
||||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.node_replace_manager.add_routes(self.routes)
|
||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||
|
||||
# Prefix every route with /api for easier matching for delegation.
|
||||
|
||||
Reference in New Issue
Block a user