mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-07 12:16:00 +08:00
Compare commits
6 Commits
essentials
...
v0.12.3
| Author | SHA1 | Date | |
|---|---|---|---|
| cb459573c8 | |||
| 35183543e0 | |||
| a246cc02b2 | |||
| a50c32d63f | |||
| 6125b80979 | |||
| c8fcbd66ee |
@ -183,7 +183,7 @@ class AceStepAttention(nn.Module):
|
||||
else:
|
||||
attn_bias = window_bias
|
||||
|
||||
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True)
|
||||
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
@ -1035,8 +1035,7 @@ class AceStepConditionGenerationModel(nn.Module):
|
||||
audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847)
|
||||
lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype)
|
||||
else:
|
||||
assert False
|
||||
# TODO ?
|
||||
lm_hints_5Hz, indices = self.tokenizer.tokenize(refer_audio_acoustic_hidden_states_packed)
|
||||
|
||||
lm_hints = self.detokenizer(lm_hints_5Hz)
|
||||
|
||||
|
||||
@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if kwargs.get("low_precision_attention", True) is False:
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
exception_fallback = False
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
|
||||
@ -1548,6 +1548,7 @@ class ACEStep15(BaseModel):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
device = kwargs["device"]
|
||||
noise = kwargs["noise"]
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
@ -1571,15 +1572,19 @@ class ACEStep15(BaseModel):
|
||||
1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
|
||||
-8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
|
||||
-5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
|
||||
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750)
|
||||
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, noise.shape[2])
|
||||
pass_audio_codes = True
|
||||
else:
|
||||
refer_audio = refer_audio[-1]
|
||||
refer_audio = refer_audio[-1][:, :, :noise.shape[2]]
|
||||
pass_audio_codes = False
|
||||
|
||||
if pass_audio_codes:
|
||||
audio_codes = kwargs.get("audio_codes", None)
|
||||
if audio_codes is not None:
|
||||
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
|
||||
refer_audio = refer_audio[:, :, :750]
|
||||
|
||||
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
|
||||
|
||||
audio_codes = kwargs.get("audio_codes", None)
|
||||
if audio_codes is not None:
|
||||
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
|
||||
|
||||
return out
|
||||
|
||||
class Omnigen2(BaseModel):
|
||||
|
||||
@ -54,6 +54,8 @@ try:
|
||||
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
if q.nelement() < 1024 * 128: # arbitrary number, for small inputs cudnn attention seems slower
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
else:
|
||||
|
||||
@ -976,7 +976,7 @@ class VAE:
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
if dims == 1:
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
|
||||
@ -3,6 +3,7 @@ import comfy.text_encoders.llama
|
||||
from comfy import sd1_clip
|
||||
import torch
|
||||
import math
|
||||
import yaml
|
||||
import comfy.utils
|
||||
|
||||
|
||||
@ -101,9 +102,7 @@ def sample_manual_loop_no_classes(
|
||||
return output_audio_codes
|
||||
|
||||
|
||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0):
|
||||
cfg_scale = 2.0
|
||||
|
||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
|
||||
positive = [[token for token, _ in inner_list] for inner_list in positive]
|
||||
negative = [[token for token, _ in inner_list] for inner_list in negative]
|
||||
positive = positive[0]
|
||||
@ -120,34 +119,80 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
|
||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||
|
||||
paddings = [pos_pad, neg_pad]
|
||||
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
|
||||
|
||||
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer)
|
||||
|
||||
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
|
||||
user_metas = {
|
||||
k: kwargs.pop(k)
|
||||
for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption")
|
||||
if k in kwargs
|
||||
}
|
||||
timesignature = user_metas.get("timesignature")
|
||||
if isinstance(timesignature, str) and timesignature.endswith("/4"):
|
||||
user_metas["timesignature"] = timesignature.rsplit("/", 1)[0]
|
||||
user_metas = {
|
||||
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
|
||||
for k, v in user_metas.items()
|
||||
if v not in {"unspecified", None}
|
||||
}
|
||||
if len(user_metas):
|
||||
meta_yaml = yaml.dump(user_metas, allow_unicode=True, sort_keys=True).strip()
|
||||
else:
|
||||
meta_yaml = ""
|
||||
return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml
|
||||
|
||||
def _metas_to_cap(self, **kwargs) -> str:
|
||||
use_keys = ("bpm", "duration", "keyscale", "timesignature")
|
||||
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
|
||||
duration = user_metas["duration"]
|
||||
if duration == "N/A":
|
||||
user_metas["duration"] = "30 seconds"
|
||||
elif isinstance(duration, (str, int, float)):
|
||||
user_metas["duration"] = f"{math.ceil(float(duration))} seconds"
|
||||
else:
|
||||
raise TypeError("Unexpected type for duration key, must be str, int or float")
|
||||
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
lyrics = kwargs.get("lyrics", "")
|
||||
bpm = kwargs.get("bpm", 120)
|
||||
duration = kwargs.get("duration", 120)
|
||||
keyscale = kwargs.get("keyscale", "C major")
|
||||
timesignature = kwargs.get("timesignature", 2)
|
||||
language = kwargs.get("language", "en")
|
||||
language = kwargs.get("language")
|
||||
seed = kwargs.get("seed", 0)
|
||||
|
||||
generate_audio_codes = kwargs.get("generate_audio_codes", True)
|
||||
cfg_scale = kwargs.get("cfg_scale", 2.0)
|
||||
temperature = kwargs.get("temperature", 0.85)
|
||||
top_p = kwargs.get("top_p", 0.9)
|
||||
top_k = kwargs.get("top_k", 0.0)
|
||||
|
||||
|
||||
duration = math.ceil(duration)
|
||||
meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature)
|
||||
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n<think>\n{}\n</think>\n\n<|im_end|>\n"
|
||||
kwargs["duration"] = duration
|
||||
|
||||
meta_cap = '- bpm: {}\n- timesignature: {}\n- keyscale: {}\n- duration: {}\n'.format(bpm, timesignature, keyscale, duration)
|
||||
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, meta_lm), disable_weights=True)
|
||||
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, ""), disable_weights=True)
|
||||
cot_text = self._metas_to_cot(caption = text, **kwargs)
|
||||
meta_cap = self._metas_to_cap(**kwargs)
|
||||
|
||||
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs)
|
||||
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
|
||||
out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed}
|
||||
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n"
|
||||
|
||||
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True)
|
||||
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "<think>\n</think>"), disable_weights=True)
|
||||
|
||||
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
|
||||
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
|
||||
out["lm_metadata"] = {"min_tokens": duration * 5,
|
||||
"seed": seed,
|
||||
"generate_audio_codes": generate_audio_codes,
|
||||
"cfg_scale": cfg_scale,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
@ -203,10 +248,14 @@ class ACE15TEModel(torch.nn.Module):
|
||||
self.qwen3_06b.set_clip_options({"layer": [0]})
|
||||
lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics)
|
||||
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"])
|
||||
out = {"conditioning_lyrics": lyrics_embeds[:, 0]}
|
||||
|
||||
return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]}
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
if lm_metadata["generate_audio_codes"]:
|
||||
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
|
||||
out["audio_codes"] = [audio_codes]
|
||||
|
||||
return base_out, None, out
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.qwen3_06b.set_clip_options(options)
|
||||
|
||||
@ -44,13 +44,18 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
|
||||
io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
|
||||
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
|
||||
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
|
||||
io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
|
||||
io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
|
||||
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
||||
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
||||
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed)
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
@ -100,14 +105,15 @@ class EmptyAceStep15LatentAudio(io.ComfyNode):
|
||||
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
class ReferenceTimbreAudio(io.ComfyNode):
|
||||
class ReferenceAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ReferenceTimbreAudio",
|
||||
display_name="Reference Audio",
|
||||
category="advanced/conditioning/audio",
|
||||
is_experimental=True,
|
||||
description="This node sets the reference audio for timbre (for ace step 1.5)",
|
||||
description="This node sets the reference audio for ace step 1.5",
|
||||
inputs=[
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Latent.Input("latent", optional=True),
|
||||
@ -131,7 +137,7 @@ class AceExtension(ComfyExtension):
|
||||
EmptyAceStepLatentAudio,
|
||||
TextEncodeAceStepAudio15,
|
||||
EmptyAceStep15LatentAudio,
|
||||
ReferenceTimbreAudio,
|
||||
ReferenceAudio,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AceExtension:
|
||||
|
||||
@ -94,6 +94,19 @@ class VAEEncodeAudio(IO.ComfyNode):
|
||||
encode = execute # TODO: remove
|
||||
|
||||
|
||||
def vae_decode_audio(vae, samples, tile=None, overlap=None):
|
||||
if tile is not None:
|
||||
audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1)
|
||||
else:
|
||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||
|
||||
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
||||
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}
|
||||
|
||||
|
||||
class VAEDecodeAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -111,16 +124,33 @@ class VAEDecodeAudio(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, samples) -> IO.NodeOutput:
|
||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
||||
return IO.NodeOutput({"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]})
|
||||
return IO.NodeOutput(vae_decode_audio(vae, samples))
|
||||
|
||||
decode = execute # TODO: remove
|
||||
|
||||
|
||||
class VAEDecodeAudioTiled(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeAudioTiled",
|
||||
search_aliases=["latent to audio"],
|
||||
display_name="VAE Decode Audio (Tiled)",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Int.Input("tile_size", default=512, min=32, max=8192, step=8),
|
||||
IO.Int.Input("overlap", default=64, min=0, max=1024, step=8),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, samples, tile_size, overlap) -> IO.NodeOutput:
|
||||
return IO.NodeOutput(vae_decode_audio(vae, samples, tile_size, overlap))
|
||||
|
||||
|
||||
class SaveAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -675,6 +705,7 @@ class AudioExtension(ComfyExtension):
|
||||
EmptyLatentAudio,
|
||||
VAEEncodeAudio,
|
||||
VAEDecodeAudio,
|
||||
VAEDecodeAudioTiled,
|
||||
SaveAudio,
|
||||
SaveAudioMP3,
|
||||
SaveAudioOpus,
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.12.2"
|
||||
__version__ = "0.12.3"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.12.2"
|
||||
version = "0.12.3"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
Reference in New Issue
Block a user