mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-12 22:45:39 +08:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4a93a62371 | |||
| 66c18522fb | |||
| e5ae670a40 | |||
| 3fe61cedda | |||
| 2a4328d639 | |||
| d297a749a2 | |||
| 2b7cc7e3b6 | |||
| 4993411fd9 | |||
| 2c7cef4a23 | |||
| 76a7fa96db | |||
| cdcf4119b3 | |||
| dbe70b6821 | |||
| 00fff6019e | |||
| 123a7874a9 | |||
| f719f9c062 | |||
| fe053ba5eb |
36
.github/workflows/release-webhook.yml
vendored
36
.github/workflows/release-webhook.yml
vendored
@ -7,6 +7,8 @@ on:
|
||||
jobs:
|
||||
send-webhook:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
steps:
|
||||
- name: Send release webhook
|
||||
env:
|
||||
@ -106,3 +108,37 @@ jobs:
|
||||
--fail --silent --show-error
|
||||
|
||||
echo "✅ Release webhook sent successfully"
|
||||
|
||||
- name: Send repository dispatch to desktop
|
||||
env:
|
||||
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
RELEASE_TAG: ${{ github.event.release.tag_name }}
|
||||
RELEASE_URL: ${{ github.event.release.html_url }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [ -z "${DISPATCH_TOKEN:-}" ]; then
|
||||
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PAYLOAD="$(jq -n \
|
||||
--arg release_tag "$RELEASE_TAG" \
|
||||
--arg release_url "$RELEASE_URL" \
|
||||
'{
|
||||
event_type: "comfyui_release_published",
|
||||
client_payload: {
|
||||
release_tag: $release_tag,
|
||||
release_url: $release_url
|
||||
}
|
||||
}')"
|
||||
|
||||
curl -fsSL \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
|
||||
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
|
||||
-d "$PAYLOAD"
|
||||
|
||||
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
import math
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
from scipy import integrate
|
||||
import torch
|
||||
from torch import nn
|
||||
import torchsde
|
||||
from tqdm.auto import trange as trange_, tqdm
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from . import utils
|
||||
from . import deis
|
||||
@ -15,34 +14,7 @@ import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
import comfy.memory_management
|
||||
|
||||
|
||||
def trange(*args, **kwargs):
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
return trange_(*args, **kwargs)
|
||||
|
||||
pbar = trange_(*args, **kwargs, smoothing=1.0)
|
||||
pbar._i = 0
|
||||
pbar.set_postfix_str(" Model Initializing ... ")
|
||||
|
||||
_update = pbar.update
|
||||
|
||||
def warmup_update(n=1):
|
||||
pbar._i += 1
|
||||
if pbar._i == 1:
|
||||
pbar.i1_time = time.time()
|
||||
pbar.set_postfix_str(" Model Initialization complete! ")
|
||||
elif pbar._i == 2:
|
||||
#bring forward the effective start time based the the diff between first and second iteration
|
||||
#to attempt to remove load overhead from the final step rate estimate.
|
||||
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
|
||||
pbar.set_postfix_str("")
|
||||
|
||||
_update(n)
|
||||
|
||||
pbar.update = warmup_update
|
||||
return pbar
|
||||
|
||||
from comfy.utils import model_trange as trange
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
@ -195,8 +195,20 @@ class Anima(MiniTrainDIT):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
||||
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids):
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
|
||||
if text_ids is not None:
|
||||
return self.llm_adapter(text_embeds, text_ids)
|
||||
out = self.llm_adapter(text_embeds, text_ids)
|
||||
if t5xxl_weights is not None:
|
||||
out = out * t5xxl_weights
|
||||
|
||||
if out.shape[1] < 512:
|
||||
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
|
||||
return out
|
||||
else:
|
||||
return text_embeds
|
||||
|
||||
def forward(self, x, timesteps, context, **kwargs):
|
||||
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
|
||||
if t5xxl_ids is not None:
|
||||
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
|
||||
return super().forward(x, timesteps, context, **kwargs)
|
||||
|
||||
@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
|
||||
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
|
||||
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
|
||||
try:
|
||||
import comfy.quant_ops
|
||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
q_apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope(xq, xk, freqs_cis)
|
||||
else:
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
def apply_rope1(x, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope1(x, freqs_cis)
|
||||
else:
|
||||
return q_apply_rope1(x, freqs_cis)
|
||||
except:
|
||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
apply_rope = _apply_rope
|
||||
apply_rope1 = _apply_rope1
|
||||
|
||||
@ -1160,12 +1160,16 @@ class Anima(BaseModel):
|
||||
device = kwargs["device"]
|
||||
if cross_attn is not None:
|
||||
if t5xxl_ids is not None:
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
|
||||
if t5xxl_weights is not None:
|
||||
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||
t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
||||
|
||||
if torch.is_inference_mode_enabled(): # if not we are training
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
|
||||
else:
|
||||
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
||||
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||
|
||||
if cross_attn.shape[1] < 512:
|
||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
@ -55,6 +55,11 @@ cpu_state = CPUState.GPU
|
||||
|
||||
total_vram = 0
|
||||
|
||||
|
||||
# Training Related State
|
||||
in_training = False
|
||||
|
||||
|
||||
def get_supported_float8_types():
|
||||
float8_types = []
|
||||
try:
|
||||
@ -1208,8 +1213,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||
if signature is not None:
|
||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, weight._v_tensor)[0]
|
||||
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||
v_tensor = weight._v_tensor
|
||||
else:
|
||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
||||
weight._v_tensor = v_tensor
|
||||
weight._v_signature = signature
|
||||
#Send it over
|
||||
v_tensor.copy_(weight, non_blocking=non_blocking)
|
||||
|
||||
@ -19,7 +19,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import math
|
||||
@ -317,7 +316,7 @@ class ModelPatcher:
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
n.parent = self
|
||||
@ -1526,7 +1525,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
setattr(m, param_key + "_function", weight_function)
|
||||
geometry = weight
|
||||
if not isinstance(weight, QuantizedTensor):
|
||||
model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
|
||||
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
|
||||
weight._model_dtype = model_dtype
|
||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||
return comfy.memory_management.vram_aligned_size(geometry)
|
||||
@ -1543,7 +1542,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
if vbar is not None and not hasattr(m, "_v"):
|
||||
m._v = vbar.alloc(v_weight_size)
|
||||
m._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(m._v, device_to)
|
||||
allocated_size += v_weight_size
|
||||
|
||||
else:
|
||||
@ -1553,12 +1551,11 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
weight.seed_key = key
|
||||
set_dirty(weight, dirty)
|
||||
geometry = weight
|
||||
model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
|
||||
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
|
||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||
weight_size = geometry.numel() * geometry.element_size()
|
||||
if vbar is not None and not hasattr(weight, "_v"):
|
||||
weight._v = vbar.alloc(weight_size)
|
||||
weight._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device_to)
|
||||
weight._model_dtype = model_dtype
|
||||
allocated_size += weight_size
|
||||
vbar.set_watermark_limit(allocated_size)
|
||||
|
||||
25
comfy/ops.py
25
comfy/ops.py
@ -83,14 +83,18 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
if signature is not None:
|
||||
xfer_dest = s._v_tensor
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
if signature is not None:
|
||||
if resident:
|
||||
weight = s._v_weight
|
||||
bias = s._v_bias
|
||||
else:
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||
|
||||
if not resident:
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
cast_dest = None
|
||||
|
||||
xfer_source = [ s.weight, s.bias ]
|
||||
@ -140,9 +144,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
post_cast.copy_(pre_cast)
|
||||
xfer_dest = cast_dest
|
||||
|
||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
if signature is not None:
|
||||
s._v_weight = weight
|
||||
s._v_bias = bias
|
||||
s._v_signature=signature
|
||||
|
||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
@ -169,8 +177,8 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
if orig.dtype == dtype and len(fns) == 0:
|
||||
#The layer actually wants our freshly saved QT
|
||||
x = y
|
||||
else:
|
||||
y = x
|
||||
elif update_weight:
|
||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
|
||||
if update_weight:
|
||||
orig.copy_(y)
|
||||
for f in fns:
|
||||
@ -182,7 +190,6 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
||||
if s.bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
||||
s._v_signature=signature
|
||||
|
||||
#FIXME: weird offload return protocol
|
||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||
|
||||
@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds):
|
||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||
return memory_required, minimum_memory_required
|
||||
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_prepare_sampling,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||
real_model: BaseModel = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
|
||||
if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
|
||||
memory_required = 1e20
|
||||
minimum_memory_required = None
|
||||
else:
|
||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||
memory_required += inference_memory
|
||||
minimum_memory_required += inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
@ -793,8 +793,6 @@ class VAE:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
|
||||
model_management.archive_model_dtypes(self.first_stage_model)
|
||||
|
||||
if device is None:
|
||||
device = model_management.vae_device()
|
||||
self.device = device
|
||||
@ -803,6 +801,7 @@ class VAE:
|
||||
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
||||
self.vae_dtype = dtype
|
||||
self.first_stage_model.to(self.vae_dtype)
|
||||
model_management.archive_model_dtypes(self.first_stage_model)
|
||||
self.output_device = model_management.intermediate_device()
|
||||
|
||||
mp = comfy.model_patcher.CoreModelPatcher
|
||||
|
||||
@ -3,7 +3,6 @@ import comfy.text_encoders.llama
|
||||
from comfy import sd1_clip
|
||||
import torch
|
||||
import math
|
||||
from tqdm.auto import trange
|
||||
import yaml
|
||||
import comfy.utils
|
||||
|
||||
@ -17,6 +16,7 @@ def sample_manual_loop_no_classes(
|
||||
temperature: float = 0.85,
|
||||
top_p: float = 0.9,
|
||||
top_k: int = None,
|
||||
min_p: float = 0.000,
|
||||
seed: int = 1,
|
||||
min_tokens: int = 1,
|
||||
max_new_tokens: int = 2048,
|
||||
@ -52,7 +52,7 @@ def sample_manual_loop_no_classes(
|
||||
|
||||
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
|
||||
|
||||
for step in trange(max_new_tokens, desc="LM sampling"):
|
||||
for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
|
||||
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
|
||||
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
|
||||
past_key_values = outputs[2]
|
||||
@ -81,6 +81,12 @@ def sample_manual_loop_no_classes(
|
||||
min_val = top_k_vals[..., -1, None]
|
||||
cfg_logits[cfg_logits < min_val] = remove_logit_value
|
||||
|
||||
if min_p is not None and min_p > 0:
|
||||
probs = torch.softmax(cfg_logits, dim=-1)
|
||||
p_max = probs.max(dim=-1, keepdim=True).values
|
||||
indices_to_remove = probs < (min_p * p_max)
|
||||
cfg_logits[indices_to_remove] = remove_logit_value
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
@ -111,7 +117,7 @@ def sample_manual_loop_no_classes(
|
||||
return output_audio_codes
|
||||
|
||||
|
||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
|
||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0, min_p=0.000):
|
||||
positive = [[token for token, _ in inner_list] for inner_list in positive]
|
||||
positive = positive[0]
|
||||
|
||||
@ -135,7 +141,7 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
|
||||
paddings = []
|
||||
ids = [positive]
|
||||
|
||||
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
|
||||
|
||||
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
@ -193,6 +199,7 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
temperature = kwargs.get("temperature", 0.85)
|
||||
top_p = kwargs.get("top_p", 0.9)
|
||||
top_k = kwargs.get("top_k", 0.0)
|
||||
min_p = kwargs.get("min_p", 0.000)
|
||||
|
||||
duration = math.ceil(duration)
|
||||
kwargs["duration"] = duration
|
||||
@ -240,6 +247,7 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"min_p": min_p,
|
||||
}
|
||||
return out
|
||||
|
||||
@ -300,7 +308,7 @@ class ACE15TEModel(torch.nn.Module):
|
||||
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
if lm_metadata["generate_audio_codes"]:
|
||||
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["max_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
|
||||
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"], min_p=lm_metadata["min_p"])
|
||||
out["audio_codes"] = [audio_codes]
|
||||
|
||||
return base_out, None, out
|
||||
|
||||
@ -27,6 +27,7 @@ from PIL import Image
|
||||
import logging
|
||||
import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from tqdm.auto import trange
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
import json
|
||||
@ -1155,6 +1156,32 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
||||
|
||||
def model_trange(*args, **kwargs):
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
return trange(*args, **kwargs)
|
||||
|
||||
pbar = trange(*args, **kwargs, smoothing=1.0)
|
||||
pbar._i = 0
|
||||
pbar.set_postfix_str(" Model Initializing ... ")
|
||||
|
||||
_update = pbar.update
|
||||
|
||||
def warmup_update(n=1):
|
||||
pbar._i += 1
|
||||
if pbar._i == 1:
|
||||
pbar.i1_time = time.time()
|
||||
pbar.set_postfix_str(" Model Initialization complete! ")
|
||||
elif pbar._i == 2:
|
||||
#bring forward the effective start time based the the diff between first and second iteration
|
||||
#to attempt to remove load overhead from the final step rate estimate.
|
||||
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
|
||||
pbar.set_postfix_str("")
|
||||
|
||||
_update(n)
|
||||
|
||||
pbar.update = warmup_update
|
||||
return pbar
|
||||
|
||||
PROGRESS_BAR_ENABLED = True
|
||||
def set_progress_bar_enabled(enabled):
|
||||
global PROGRESS_BAR_ENABLED
|
||||
@ -1376,3 +1403,21 @@ def string_to_seed(data):
|
||||
else:
|
||||
crc >>= 1
|
||||
return crc ^ 0xFFFFFFFF
|
||||
|
||||
def deepcopy_list_dict(obj, memo=None):
|
||||
if memo is None:
|
||||
memo = {}
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id]
|
||||
|
||||
if isinstance(obj, dict):
|
||||
res = {deepcopy_list_dict(k, memo): deepcopy_list_dict(v, memo) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
res = [deepcopy_list_dict(i, memo) for i in obj]
|
||||
else:
|
||||
res = obj
|
||||
|
||||
memo[obj_id] = res
|
||||
return res
|
||||
|
||||
@ -21,6 +21,7 @@ from typing import Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
||||
from comfy.patcher_extension import PatcherInjection
|
||||
|
||||
@ -181,18 +182,21 @@ class BypassForwardHook:
|
||||
)
|
||||
return # Already injected
|
||||
|
||||
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
|
||||
device = None
|
||||
# Move adapter weights to compute device (GPU)
|
||||
# Use get_torch_device() instead of module.weight.device because
|
||||
# with offloading, module weights may be on CPU while compute happens on GPU
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
# Get dtype from module weight if available
|
||||
dtype = None
|
||||
if hasattr(self.module, "weight") and self.module.weight is not None:
|
||||
device = self.module.weight.device
|
||||
dtype = self.module.weight.dtype
|
||||
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
|
||||
device = self.module.W_q.device
|
||||
dtype = self.module.W_q.dtype
|
||||
|
||||
if device is not None:
|
||||
self._move_adapter_weights_to_device(device, dtype)
|
||||
# Only use dtype if it's a standard float type, not quantized
|
||||
if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
|
||||
dtype = None
|
||||
|
||||
self._move_adapter_weights_to_device(device, dtype)
|
||||
|
||||
self.original_forward = self.module.forward
|
||||
self.module.forward = self._bypass_forward
|
||||
|
||||
@ -34,6 +34,21 @@ class VideoInput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def as_trimmed(
|
||||
self,
|
||||
start_time: float | None = None,
|
||||
duration: float | None = None,
|
||||
strict_duration: bool = False,
|
||||
) -> VideoInput | None:
|
||||
"""
|
||||
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
|
||||
|
||||
Returns:
|
||||
A new VideoInput, or None if the result would have negative duration
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||
"""
|
||||
Get a streamable source for the video. This allows processing without
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Optional
|
||||
from .._input import AudioInput, VideoInput
|
||||
import av
|
||||
import io
|
||||
import itertools
|
||||
import json
|
||||
import numpy as np
|
||||
import math
|
||||
@ -29,7 +30,6 @@ def container_to_output_format(container_format: str | None) -> str | None:
|
||||
formats = container_format.split(",")
|
||||
return formats[0]
|
||||
|
||||
|
||||
def get_open_write_kwargs(
|
||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||
) -> dict:
|
||||
@ -57,12 +57,14 @@ class VideoFromFile(VideoInput):
|
||||
Class representing video input from a file.
|
||||
"""
|
||||
|
||||
def __init__(self, file: str | io.BytesIO):
|
||||
def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
|
||||
"""
|
||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
self.__start_time = start_time
|
||||
self.__duration = duration
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
@ -96,6 +98,16 @@ class VideoFromFile(VideoInput):
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
raw_duration = self._get_raw_duration()
|
||||
if self.__start_time < 0:
|
||||
duration_from_start = min(raw_duration, -self.__start_time)
|
||||
else:
|
||||
duration_from_start = raw_duration - self.__start_time
|
||||
if self.__duration:
|
||||
return min(self.__duration, duration_from_start)
|
||||
return duration_from_start
|
||||
|
||||
def _get_raw_duration(self) -> float:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
@ -113,9 +125,13 @@ class VideoFromFile(VideoInput):
|
||||
if video_stream and video_stream.average_rate:
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
frame_iterator = (
|
||||
container.decode(video_stream)
|
||||
if video_stream.codec.capabilities & 0x100
|
||||
else container.demux(video_stream)
|
||||
)
|
||||
for packet in frame_iterator:
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
@ -131,36 +147,54 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# 1. Prefer the frames field if available
|
||||
if video_stream.frames and video_stream.frames > 0:
|
||||
# 1. Prefer the frames field if available and usable
|
||||
if (
|
||||
video_stream.frames
|
||||
and video_stream.frames > 0
|
||||
and not self.__start_time
|
||||
and not self.__duration
|
||||
):
|
||||
return int(video_stream.frames)
|
||||
|
||||
# 2. Try to estimate from duration and average_rate using only metadata
|
||||
if container.duration is not None and video_stream.average_rate:
|
||||
duration_seconds = float(container.duration / av.time_base)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
if (
|
||||
getattr(video_stream, "duration", None) is not None
|
||||
and getattr(video_stream, "time_base", None) is not None
|
||||
and video_stream.average_rate
|
||||
):
|
||||
duration_seconds = float(video_stream.duration * video_stream.time_base)
|
||||
raw_duration = float(video_stream.duration * video_stream.time_base)
|
||||
if self.__start_time < 0:
|
||||
duration_from_start = min(raw_duration, -self.__start_time)
|
||||
else:
|
||||
duration_from_start = raw_duration - self.__start_time
|
||||
duration_seconds = min(self.__duration, duration_from_start)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
# 3. Last resort: decode frames and count them (streaming)
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
|
||||
if frame_count == 0:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
||||
if self.__start_time < 0:
|
||||
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||
else:
|
||||
start_time = self.__start_time
|
||||
frame_count = 1
|
||||
start_pts = int(start_time / video_stream.time_base)
|
||||
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
frame_iterator = (
|
||||
container.decode(video_stream)
|
||||
if video_stream.codec.capabilities & 0x100
|
||||
else container.demux(video_stream)
|
||||
)
|
||||
for frame in frame_iterator:
|
||||
if frame.pts >= start_pts:
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
|
||||
for frame in frame_iterator:
|
||||
if frame.pts >= end_pts:
|
||||
break
|
||||
frame_count += 1
|
||||
return frame_count
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
@ -199,9 +233,21 @@ class VideoFromFile(VideoInput):
|
||||
return container.format.name
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
if self.__start_time < 0:
|
||||
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||
else:
|
||||
start_time = self.__start_time
|
||||
# Get video frames
|
||||
frames = []
|
||||
for frame in container.decode(video=0):
|
||||
start_pts = int(start_time / video_stream.time_base)
|
||||
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
for frame in container.decode(video_stream):
|
||||
if frame.pts < start_pts:
|
||||
continue
|
||||
if self.__duration and frame.pts >= end_pts:
|
||||
break
|
||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||
frames.append(img)
|
||||
@ -209,31 +255,44 @@ class VideoFromFile(VideoInput):
|
||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||
|
||||
# Get frame rate
|
||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
|
||||
|
||||
# Get audio if available
|
||||
audio = None
|
||||
try:
|
||||
container.seek(0) # Reset the container to the beginning
|
||||
for stream in container.streams:
|
||||
if stream.type != 'audio':
|
||||
continue
|
||||
assert isinstance(stream, av.AudioStream)
|
||||
audio_frames = []
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
assert isinstance(frame, av.AudioFrame)
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||
})
|
||||
except StopIteration:
|
||||
pass # No audio stream
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
# Use last stream for consistency
|
||||
if len(container.streams.audio):
|
||||
audio_stream = container.streams.audio[-1]
|
||||
audio_frames = []
|
||||
resample = av.audio.resampler.AudioResampler(format='fltp').resample
|
||||
frames = itertools.chain.from_iterable(
|
||||
map(resample, container.decode(audio_stream))
|
||||
)
|
||||
|
||||
has_first_frame = False
|
||||
for frame in frames:
|
||||
offset_seconds = start_time - frame.pts * audio_stream.time_base
|
||||
to_skip = int(offset_seconds * audio_stream.sample_rate)
|
||||
if to_skip < frame.samples:
|
||||
has_first_frame = True
|
||||
break
|
||||
if has_first_frame:
|
||||
audio_frames.append(frame.to_ndarray()[..., to_skip:])
|
||||
|
||||
for frame in frames:
|
||||
if frame.time > start_time + self.__duration:
|
||||
break
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
if self.__duration:
|
||||
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
|
||||
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
|
||||
})
|
||||
|
||||
metadata = container.metadata
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
@ -250,7 +309,7 @@ class VideoFromFile(VideoInput):
|
||||
path: str | io.BytesIO,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
@ -262,15 +321,14 @@ class VideoFromFile(VideoInput):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
if self.__start_time or self.__duration:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path,
|
||||
format=format,
|
||||
codec=codec,
|
||||
metadata=metadata
|
||||
path, format=format, codec=codec, metadata=metadata
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
@ -304,10 +362,21 @@ class VideoFromFile(VideoInput):
|
||||
output_container.mux(packet)
|
||||
|
||||
def _get_first_video_stream(self, container: InputContainer):
|
||||
video_stream = next((s for s in container.streams if s.type == "video"), None)
|
||||
if video_stream is None:
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
return video_stream
|
||||
if len(container.streams.video):
|
||||
return container.streams.video[0]
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def as_trimmed(
|
||||
self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
|
||||
) -> VideoInput | None:
|
||||
trimmed = VideoFromFile(
|
||||
self.get_stream_source(),
|
||||
start_time=start_time + self.__start_time,
|
||||
duration=duration,
|
||||
)
|
||||
if trimmed.get_duration() < duration and strict_duration:
|
||||
return None
|
||||
return trimmed
|
||||
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
@ -322,7 +391,7 @@ class VideoFromComponents(VideoInput):
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
frame_rate=self.__components.frame_rate,
|
||||
)
|
||||
|
||||
def save_to(
|
||||
@ -330,7 +399,7 @@ class VideoFromComponents(VideoInput):
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
@ -357,7 +426,10 @@ class VideoFromComponents(VideoInput):
|
||||
audio_stream: Optional[av.AudioStream] = None
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
@ -372,12 +444,21 @@ class VideoFromComponents(VideoInput):
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
|
||||
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout)
|
||||
frame.sample_rate = audio_sample_rate
|
||||
frame.pts = 0
|
||||
output.mux(audio_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output.mux(audio_stream.encode(None))
|
||||
|
||||
def as_trimmed(
|
||||
self,
|
||||
start_time: float | None = None,
|
||||
duration: float | None = None,
|
||||
strict_duration: bool = True,
|
||||
) -> VideoInput | None:
|
||||
if self.get_duration() < start_time + duration:
|
||||
return None
|
||||
#TODO Consider tracking duration and trimming at time of save?
|
||||
return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)
|
||||
|
||||
@ -30,6 +30,30 @@ from comfy_api_nodes.util import (
|
||||
validate_image_dimensions,
|
||||
)
|
||||
|
||||
_EUR_TO_USD = 1.19
|
||||
|
||||
|
||||
def _tier_price_eur(megapixels: float) -> float:
|
||||
"""Price in EUR for a single Magnific upscaling step based on input megapixels."""
|
||||
if megapixels <= 1.3:
|
||||
return 0.143
|
||||
if megapixels <= 3.0:
|
||||
return 0.286
|
||||
if megapixels <= 6.4:
|
||||
return 0.429
|
||||
return 1.716
|
||||
|
||||
|
||||
def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float:
|
||||
"""Calculate total Magnific upscale price in USD for given input dimensions and scale factor."""
|
||||
num_steps = int(math.log2(scale))
|
||||
total_eur = 0.0
|
||||
pixels = width * height
|
||||
for _ in range(num_steps):
|
||||
total_eur += _tier_price_eur(pixels / 1_000_000)
|
||||
pixels *= 4
|
||||
return round(total_eur * _EUR_TO_USD, 2)
|
||||
|
||||
|
||||
class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -103,11 +127,20 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]),
|
||||
expr="""
|
||||
(
|
||||
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
|
||||
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
|
||||
$ad := widgets.auto_downscale;
|
||||
$mins := $ad
|
||||
? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515}
|
||||
: {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
|
||||
$maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187};
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": $lookup($mins, widgets.scale_factor),
|
||||
"max_usd": $lookup($maxs, widgets.scale_factor),
|
||||
"format": { "approximate": true }
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -168,6 +201,10 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
f"Use a smaller input image or lower scale factor."
|
||||
)
|
||||
|
||||
final_height, final_width = get_image_dimensions(image)
|
||||
actual_scale = int(scale_factor.rstrip("x"))
|
||||
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale)
|
||||
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
|
||||
@ -189,6 +226,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
|
||||
response_model=TaskResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
price_extractor=lambda _: price_usd,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
@ -257,8 +295,14 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
|
||||
expr="""
|
||||
(
|
||||
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
|
||||
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
|
||||
$mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
|
||||
$maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06};
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": $lookup($mins, widgets.scale_factor),
|
||||
"max_usd": $lookup($maxs, widgets.scale_factor),
|
||||
"format": { "approximate": true }
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -321,6 +365,9 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
f"Use a smaller input image or lower scale factor."
|
||||
)
|
||||
|
||||
final_height, final_width = get_image_dimensions(image)
|
||||
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale)
|
||||
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
|
||||
@ -339,6 +386,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
|
||||
response_model=TaskResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
price_extractor=lambda _: price_usd,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
@ -877,8 +925,8 @@ class MagnificExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
# MagnificImageUpscalerCreativeNode,
|
||||
# MagnificImageUpscalerPreciseV2Node,
|
||||
MagnificImageUpscalerCreativeNode,
|
||||
MagnificImageUpscalerPreciseV2Node,
|
||||
MagnificImageStyleTransferNode,
|
||||
MagnificImageRelightNode,
|
||||
MagnificImageSkinEnhancerNode,
|
||||
|
||||
@ -57,6 +57,7 @@ class _RequestConfig:
|
||||
files: dict[str, Any] | list[tuple[str, Any]] | None
|
||||
multipart_parser: Callable | None
|
||||
max_retries: int
|
||||
max_retries_on_rate_limit: int
|
||||
retry_delay: float
|
||||
retry_backoff: float
|
||||
wait_label: str = "Waiting"
|
||||
@ -65,6 +66,7 @@ class _RequestConfig:
|
||||
final_label_on_success: str | None = "Completed"
|
||||
progress_origin_ts: float | None = None
|
||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -78,7 +80,7 @@ class _PollUIState:
|
||||
active_since: float | None = None # start time of current active interval (None if queued)
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
|
||||
@ -103,6 +105,8 @@ async def sync_op(
|
||||
final_label_on_success: str | None = "Completed",
|
||||
progress_origin_ts: float | None = None,
|
||||
monitor_progress: bool = True,
|
||||
max_retries_on_rate_limit: int = 16,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
) -> M:
|
||||
raw = await sync_op_raw(
|
||||
cls,
|
||||
@ -122,6 +126,8 @@ async def sync_op(
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
monitor_progress=monitor_progress,
|
||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||
is_rate_limited=is_rate_limited,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
@ -143,9 +149,9 @@ async def poll_op(
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 160,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
max_retries_per_poll: int = 10,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
retry_backoff_per_poll: float = 1.4,
|
||||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
@ -194,6 +200,8 @@ async def sync_op_raw(
|
||||
final_label_on_success: str | None = "Completed",
|
||||
progress_origin_ts: float | None = None,
|
||||
monitor_progress: bool = True,
|
||||
max_retries_on_rate_limit: int = 16,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
) -> dict[str, Any] | bytes:
|
||||
"""
|
||||
Make a single network request.
|
||||
@ -222,6 +230,8 @@ async def sync_op_raw(
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
price_extractor=price_extractor,
|
||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||
is_rate_limited=is_rate_limited,
|
||||
)
|
||||
return await _request_base(cfg, expect_binary=as_binary)
|
||||
|
||||
@ -240,9 +250,9 @@ async def poll_op_raw(
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 160,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
max_retries_per_poll: int = 10,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
retry_backoff_per_poll: float = 1.4,
|
||||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
@ -506,7 +516,7 @@ def _friendly_http_message(status: int, body: Any) -> str:
|
||||
if status == 409:
|
||||
return "There is a problem with your account. Please contact support@comfy.org."
|
||||
if status == 429:
|
||||
return "Rate Limit Exceeded: Please try again later."
|
||||
return "Rate Limit Exceeded: The server returned 429 after all retry attempts. Please wait and try again."
|
||||
try:
|
||||
if isinstance(body, dict):
|
||||
err = body.get("error")
|
||||
@ -586,6 +596,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
|
||||
attempt = 0
|
||||
delay = cfg.retry_delay
|
||||
rate_limit_attempts = 0
|
||||
rate_limit_delay = cfg.retry_delay
|
||||
operation_succeeded: bool = False
|
||||
final_elapsed_seconds: int | None = None
|
||||
extracted_price: float | None = None
|
||||
@ -653,17 +665,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
payload_headers["Content-Type"] = "application/json"
|
||||
payload_kw["json"] = cfg.data or {}
|
||||
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] request logging failed: %s", _log_e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
)
|
||||
|
||||
req_coro = sess.request(method, url, params=params, **payload_kw)
|
||||
req_task = asyncio.create_task(req_coro)
|
||||
@ -688,41 +697,33 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
body = await resp.json()
|
||||
except (ContentTypeError, json.JSONDecodeError):
|
||||
body = await resp.text()
|
||||
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
|
||||
should_retry = False
|
||||
wait_time = 0.0
|
||||
retry_label = ""
|
||||
is_rl = resp.status == 429 or (
|
||||
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
|
||||
)
|
||||
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
|
||||
rate_limit_attempts += 1
|
||||
wait_time = min(rate_limit_delay, 30.0)
|
||||
rate_limit_delay *= cfg.retry_backoff
|
||||
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
|
||||
should_retry = True
|
||||
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
|
||||
wait_time = delay
|
||||
delay *= cfg.retry_backoff
|
||||
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
|
||||
should_retry = True
|
||||
|
||||
if should_retry:
|
||||
logging.warning(
|
||||
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
|
||||
"HTTP %s %s -> %s. Waiting %.2fs (%s).",
|
||||
method,
|
||||
url,
|
||||
resp.status,
|
||||
delay,
|
||||
attempt,
|
||||
cfg.max_retries,
|
||||
wait_time,
|
||||
retry_label,
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=_friendly_http_message(resp.status, body),
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
delay *= cfg.retry_backoff
|
||||
continue
|
||||
msg = _friendly_http_message(resp.status, body)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
@ -730,10 +731,27 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=msg,
|
||||
error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
await sleep_with_interrupt(
|
||||
wait_time,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
continue
|
||||
msg = _friendly_http_message(resp.status, body)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=msg,
|
||||
)
|
||||
raise Exception(msg)
|
||||
|
||||
if expect_binary:
|
||||
@ -753,17 +771,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
bytes_payload = bytes(buff)
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=bytes_payload,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=bytes_payload,
|
||||
)
|
||||
return bytes_payload
|
||||
else:
|
||||
try:
|
||||
@ -780,45 +795,39 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=response_content_to_log,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=response_content_to_log,
|
||||
)
|
||||
return payload
|
||||
|
||||
except ProcessingInterrupted:
|
||||
logging.debug("Polling was interrupted by user")
|
||||
raise
|
||||
except (ClientError, OSError) as e:
|
||||
if attempt <= cfg.max_retries:
|
||||
if (attempt - rate_limit_attempts) <= cfg.max_retries:
|
||||
logging.warning(
|
||||
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
|
||||
method,
|
||||
url,
|
||||
delay,
|
||||
attempt,
|
||||
attempt - rate_limit_attempts,
|
||||
cfg.max_retries,
|
||||
str(e),
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cfg.node_cls,
|
||||
@ -831,23 +840,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
continue
|
||||
diag = await _diagnose_connectivity()
|
||||
if not diag["internet_accessible"]:
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"LocalNetworkError: {str(e)}",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
@ -855,10 +847,21 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"ApiServerError: {str(e)}",
|
||||
error_message=f"LocalNetworkError: {str(e)}",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"ApiServerError: {str(e)}",
|
||||
)
|
||||
raise ApiServerError(
|
||||
f"The API server at {default_base_url()} is currently unreachable. "
|
||||
f"The service may be experiencing issues."
|
||||
|
||||
@ -167,27 +167,25 @@ async def download_url_to_bytesio(
|
||||
with contextlib.suppress(Exception):
|
||||
dest.seek(0)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=f"[streamed {written} bytes to dest]",
|
||||
)
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=f"[streamed {written} bytes to dest]",
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
except (ClientError, OSError) as e:
|
||||
if attempt <= max_retries:
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(delay, cls, None, None, None)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
|
||||
@ -8,7 +8,6 @@ from typing import Any
|
||||
|
||||
import folder_paths
|
||||
|
||||
# Get the logger instance
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -91,38 +90,41 @@ def log_request_response(
|
||||
Filenames are sanitized and length-limited for cross-platform safety.
|
||||
If we still fail to write, we fall back to appending into api.log.
|
||||
"""
|
||||
log_dir = get_log_directory()
|
||||
filepath = _build_log_filepath(log_dir, operation_id, request_url)
|
||||
|
||||
log_content: list[str] = []
|
||||
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
|
||||
log_content.append(f"Operation ID: {operation_id}")
|
||||
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
|
||||
log_content.append(f"Method: {request_method}")
|
||||
log_content.append(f"URL: {request_url}")
|
||||
if request_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
|
||||
if request_params:
|
||||
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
|
||||
if request_data is not None:
|
||||
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
|
||||
|
||||
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
|
||||
if response_status_code is not None:
|
||||
log_content.append(f"Status Code: {response_status_code}")
|
||||
if response_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
|
||||
if response_content is not None:
|
||||
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
|
||||
if error_message:
|
||||
log_content.append(f"Error:\n{error_message}")
|
||||
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(log_content))
|
||||
logger.debug("API log saved to: %s", filepath)
|
||||
except Exception as e:
|
||||
logger.error("Error writing API log to %s: %s", filepath, str(e))
|
||||
log_dir = get_log_directory()
|
||||
filepath = _build_log_filepath(log_dir, operation_id, request_url)
|
||||
|
||||
log_content: list[str] = []
|
||||
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
|
||||
log_content.append(f"Operation ID: {operation_id}")
|
||||
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
|
||||
log_content.append(f"Method: {request_method}")
|
||||
log_content.append(f"URL: {request_url}")
|
||||
if request_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
|
||||
if request_params:
|
||||
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
|
||||
if request_data is not None:
|
||||
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
|
||||
|
||||
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
|
||||
if response_status_code is not None:
|
||||
log_content.append(f"Status Code: {response_status_code}")
|
||||
if response_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
|
||||
if response_content is not None:
|
||||
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
|
||||
if error_message:
|
||||
log_content.append(f"Error:\n{error_message}")
|
||||
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(log_content))
|
||||
logger.debug("API log saved to: %s", filepath)
|
||||
except Exception as e:
|
||||
logger.error("Error writing API log to %s: %s", filepath, str(e))
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] log_request_response failed: %s", _log_e)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -255,17 +255,14 @@ async def upload_file(
|
||||
monitor_task = asyncio.create_task(_monitor())
|
||||
sess: aiohttp.ClientSession | None = None
|
||||
try:
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_params=None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug("[DEBUG] upload request logging failed: %s", e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_params=None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
)
|
||||
|
||||
sess = aiohttp.ClientSession(timeout=timeout)
|
||||
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
|
||||
@ -311,31 +308,27 @@ async def upload_file(
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
raise Exception(f"Failed to upload (HTTP {resp.status}).")
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content="File uploaded successfully.",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug("[DEBUG] upload response logging failed: %s", e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content="File uploaded successfully.",
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
except (aiohttp.ClientError, OSError) as e:
|
||||
if attempt <= max_retries:
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cls,
|
||||
|
||||
@ -20,10 +20,60 @@ class JobStatus:
|
||||
|
||||
|
||||
# Media types that can be previewed in the frontend
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
|
||||
|
||||
# 3D file extensions for preview fallback (no dedicated media_type exists)
|
||||
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
|
||||
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
|
||||
|
||||
|
||||
def has_3d_extension(filename: str) -> bool:
|
||||
lower = filename.lower()
|
||||
return any(lower.endswith(ext) for ext in THREE_D_EXTENSIONS)
|
||||
|
||||
|
||||
def normalize_output_item(item):
|
||||
"""Normalize a single output list item for the jobs API.
|
||||
|
||||
Returns the normalized item, or None to exclude it.
|
||||
String items with 3D extensions become {filename, type, subfolder} dicts.
|
||||
"""
|
||||
if item is None:
|
||||
return None
|
||||
if isinstance(item, str):
|
||||
if has_3d_extension(item):
|
||||
return {'filename': item, 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
|
||||
return None
|
||||
if isinstance(item, dict):
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
def normalize_outputs(outputs: dict) -> dict:
|
||||
"""Normalize raw node outputs for the jobs API.
|
||||
|
||||
Transforms string 3D filenames into file output dicts and removes
|
||||
None items. All other items (non-3D strings, dicts, etc.) are
|
||||
preserved as-is.
|
||||
"""
|
||||
normalized = {}
|
||||
for node_id, node_outputs in outputs.items():
|
||||
if not isinstance(node_outputs, dict):
|
||||
normalized[node_id] = node_outputs
|
||||
continue
|
||||
normalized_node = {}
|
||||
for media_type, items in node_outputs.items():
|
||||
if media_type == 'animated' or not isinstance(items, list):
|
||||
normalized_node[media_type] = items
|
||||
continue
|
||||
normalized_items = []
|
||||
for item in items:
|
||||
if item is None:
|
||||
continue
|
||||
norm = normalize_output_item(item)
|
||||
normalized_items.append(norm if norm is not None else item)
|
||||
normalized_node[media_type] = normalized_items
|
||||
normalized[node_id] = normalized_node
|
||||
return normalized
|
||||
|
||||
|
||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||
@ -45,9 +95,9 @@ def is_previewable(media_type: str, item: dict) -> bool:
|
||||
Maintains backwards compatibility with existing logic.
|
||||
|
||||
Priority:
|
||||
1. media_type is 'images', 'video', or 'audio'
|
||||
1. media_type is 'images', 'video', 'audio', or '3d'
|
||||
2. format field starts with 'video/' or 'audio/'
|
||||
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
|
||||
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb, .usdz)
|
||||
"""
|
||||
if media_type in PREVIEWABLE_MEDIA_TYPES:
|
||||
return True
|
||||
@ -139,7 +189,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
|
||||
})
|
||||
|
||||
if include_outputs:
|
||||
job['outputs'] = outputs
|
||||
job['outputs'] = normalize_outputs(outputs)
|
||||
job['execution_status'] = status_info
|
||||
job['workflow'] = {
|
||||
'prompt': prompt,
|
||||
@ -171,18 +221,23 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
|
||||
continue
|
||||
|
||||
for item in items:
|
||||
count += 1
|
||||
|
||||
if not isinstance(item, dict):
|
||||
normalized = normalize_output_item(item)
|
||||
if normalized is None:
|
||||
continue
|
||||
|
||||
if preview_output is None and is_previewable(media_type, item):
|
||||
count += 1
|
||||
|
||||
if preview_output is not None:
|
||||
continue
|
||||
|
||||
if isinstance(normalized, dict) and is_previewable(media_type, normalized):
|
||||
enriched = {
|
||||
**item,
|
||||
**normalized,
|
||||
'nodeId': node_id,
|
||||
'mediaType': media_type
|
||||
}
|
||||
if item.get('type') == 'output':
|
||||
if 'mediaType' not in normalized:
|
||||
enriched['mediaType'] = media_type
|
||||
if normalized.get('type') == 'output':
|
||||
preview_output = enriched
|
||||
elif fallback_preview is None:
|
||||
fallback_preview = enriched
|
||||
|
||||
@ -49,13 +49,14 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
|
||||
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
||||
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
||||
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
||||
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import os
|
||||
import numpy as np
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from tqdm.auto import trange
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
||||
"""
|
||||
CFGGuider with modifications for training specific logic
|
||||
"""
|
||||
|
||||
def __init__(self, *args, offloading=False, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.offloading = offloading
|
||||
|
||||
def outer_sample(
|
||||
self,
|
||||
noise,
|
||||
@ -45,9 +51,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
||||
noise.shape,
|
||||
self.conds,
|
||||
self.model_options,
|
||||
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
|
||||
force_full_load=not self.offloading,
|
||||
force_offload=self.offloading,
|
||||
)
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
if denoise_mask is not None:
|
||||
@ -404,16 +412,97 @@ def find_all_highest_child_module_with_forward(
|
||||
return result
|
||||
|
||||
|
||||
def patch(m):
|
||||
def find_modules_at_depth(
|
||||
model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
|
||||
) -> list[nn.Module]:
|
||||
"""
|
||||
Find modules at a specific depth level for gradient checkpointing.
|
||||
|
||||
Args:
|
||||
model: The model to search
|
||||
depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
|
||||
result: Accumulator for results
|
||||
current_depth: Current recursion depth
|
||||
name: Current module name for logging
|
||||
|
||||
Returns:
|
||||
List of modules at the target depth
|
||||
"""
|
||||
if result is None:
|
||||
result = []
|
||||
name = name or "root"
|
||||
|
||||
# Skip container modules (they don't have meaningful forward)
|
||||
is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
|
||||
has_forward = hasattr(model, "forward") and not is_container
|
||||
|
||||
if has_forward:
|
||||
current_depth += 1
|
||||
if current_depth == depth:
|
||||
result.append(model)
|
||||
logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
|
||||
return result
|
||||
|
||||
# Recurse into children
|
||||
for next_name, child in model.named_children():
|
||||
find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class OffloadCheckpointFunction(torch.autograd.Function):
|
||||
"""
|
||||
Gradient checkpointing that works with weight offloading.
|
||||
|
||||
Forward: no_grad -> compute -> weights can be freed
|
||||
Backward: enable_grad -> recompute -> backward -> weights can be freed
|
||||
|
||||
For single input, single output modules (Linear, Conv*).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x: torch.Tensor, forward_fn):
|
||||
ctx.save_for_backward(x)
|
||||
ctx.forward_fn = forward_fn
|
||||
with torch.no_grad():
|
||||
return forward_fn(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out: torch.Tensor):
|
||||
x, = ctx.saved_tensors
|
||||
forward_fn = ctx.forward_fn
|
||||
|
||||
# Clear context early
|
||||
ctx.forward_fn = None
|
||||
|
||||
with torch.enable_grad():
|
||||
x_detached = x.detach().requires_grad_(True)
|
||||
y = forward_fn(x_detached)
|
||||
y.backward(grad_out)
|
||||
grad_x = x_detached.grad
|
||||
|
||||
# Explicit cleanup
|
||||
del y, x_detached, forward_fn
|
||||
|
||||
return grad_x, None
|
||||
|
||||
|
||||
def patch(m, offloading=False):
|
||||
if not hasattr(m, "forward"):
|
||||
return
|
||||
org_forward = m.forward
|
||||
|
||||
def fwd(args, kwargs):
|
||||
return org_forward(*args, **kwargs)
|
||||
# Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
|
||||
if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
def checkpointing_fwd(x):
|
||||
return OffloadCheckpointFunction.apply(x, org_forward)
|
||||
# Branch 2: Others -> standard checkpoint
|
||||
else:
|
||||
def fwd(args, kwargs):
|
||||
return org_forward(*args, **kwargs)
|
||||
|
||||
def checkpointing_fwd(*args, **kwargs):
|
||||
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
||||
def checkpointing_fwd(*args, **kwargs):
|
||||
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
||||
|
||||
m.org_forward = org_forward
|
||||
m.forward = checkpointing_fwd
|
||||
@ -936,6 +1025,18 @@ class TrainLoraNode(io.ComfyNode):
|
||||
default=True,
|
||||
tooltip="Use gradient checkpointing for training.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"checkpoint_depth",
|
||||
default=1,
|
||||
min=1,
|
||||
max=5,
|
||||
tooltip="Depth level for gradient checkpointing.",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"offloading",
|
||||
default=False,
|
||||
tooltip="Depth level for gradient checkpointing.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"existing_lora",
|
||||
options=folder_paths.get_filename_list("loras") + ["[None]"],
|
||||
@ -982,6 +1083,8 @@ class TrainLoraNode(io.ComfyNode):
|
||||
lora_dtype,
|
||||
algorithm,
|
||||
gradient_checkpointing,
|
||||
checkpoint_depth,
|
||||
offloading,
|
||||
existing_lora,
|
||||
bucket_mode,
|
||||
bypass_mode,
|
||||
@ -1000,6 +1103,8 @@ class TrainLoraNode(io.ComfyNode):
|
||||
lora_dtype = lora_dtype[0]
|
||||
algorithm = algorithm[0]
|
||||
gradient_checkpointing = gradient_checkpointing[0]
|
||||
offloading = offloading[0]
|
||||
checkpoint_depth = checkpoint_depth[0]
|
||||
existing_lora = existing_lora[0]
|
||||
bucket_mode = bucket_mode[0]
|
||||
bypass_mode = bypass_mode[0]
|
||||
@ -1054,16 +1159,18 @@ class TrainLoraNode(io.ComfyNode):
|
||||
|
||||
# Setup gradient checkpointing
|
||||
if gradient_checkpointing:
|
||||
for m in find_all_highest_child_module_with_forward(
|
||||
mp.model.diffusion_model
|
||||
):
|
||||
patch(m)
|
||||
modules_to_patch = find_modules_at_depth(
|
||||
mp.model.diffusion_model, depth=checkpoint_depth
|
||||
)
|
||||
logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
|
||||
for m in modules_to_patch:
|
||||
patch(m, offloading=offloading)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
# With force_full_load=False we should be able to have offloading
|
||||
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
||||
comfy.model_management.load_models_gpu(
|
||||
[mp], memory_required=1e20, force_full_load=True
|
||||
[mp], memory_required=1e20, force_full_load=not offloading
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -1100,7 +1207,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
)
|
||||
|
||||
# Setup guider
|
||||
guider = TrainGuider(mp)
|
||||
guider = TrainGuider(mp, offloading=offloading)
|
||||
guider.set_conds(positive)
|
||||
|
||||
# Inject bypass hooks if bypass mode is enabled
|
||||
@ -1113,6 +1220,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
|
||||
# Run training loop
|
||||
try:
|
||||
comfy.model_management.in_training = True
|
||||
_run_training_loop(
|
||||
guider,
|
||||
train_sampler,
|
||||
@ -1123,6 +1231,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
multi_res,
|
||||
)
|
||||
finally:
|
||||
comfy.model_management.in_training = False
|
||||
# Eject bypass hooks if they were injected
|
||||
if bypass_injections is not None:
|
||||
for injection in bypass_injections:
|
||||
@ -1132,19 +1241,20 @@ class TrainLoraNode(io.ComfyNode):
|
||||
unpatch(m)
|
||||
del train_sampler, optimizer
|
||||
|
||||
# Finalize adapters
|
||||
for param in lora_sd:
|
||||
lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
|
||||
|
||||
for adapter in all_weight_adapters:
|
||||
adapter.requires_grad_(False)
|
||||
|
||||
for param in lora_sd:
|
||||
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
||||
del adapter
|
||||
del all_weight_adapters
|
||||
|
||||
# mp in train node is highly specialized for training
|
||||
# use it in inference will result in bad behavior so we don't return it
|
||||
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
|
||||
|
||||
|
||||
class LoraModelLoader(io.ComfyNode):#
|
||||
class LoraModelLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
@ -1166,6 +1276,11 @@ class LoraModelLoader(io.ComfyNode):#
|
||||
max=100.0,
|
||||
tooltip="How strongly to modify the diffusion model. This value can be negative.",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"bypass",
|
||||
default=False,
|
||||
tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(
|
||||
@ -1175,13 +1290,18 @@ class LoraModelLoader(io.ComfyNode):#
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, lora, strength_model):
|
||||
def execute(cls, model, lora, strength_model, bypass=False):
|
||||
if strength_model == 0:
|
||||
return io.NodeOutput(model)
|
||||
|
||||
model_lora, _ = comfy.sd.load_lora_for_models(
|
||||
model, None, lora, strength_model, 0
|
||||
)
|
||||
if bypass:
|
||||
model_lora, _ = comfy.sd.load_bypass_lora_for_models(
|
||||
model, None, lora, strength_model, 0
|
||||
)
|
||||
else:
|
||||
model_lora, _ = comfy.sd.load_lora_for_models(
|
||||
model, None, lora, strength_model, 0
|
||||
)
|
||||
return io.NodeOutput(model_lora)
|
||||
|
||||
|
||||
|
||||
@ -202,6 +202,56 @@ class LoadVideo(io.ComfyNode):
|
||||
|
||||
return True
|
||||
|
||||
class VideoSlice(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Video Slice",
|
||||
display_name="Video Slice",
|
||||
search_aliases=[
|
||||
"trim video duration",
|
||||
"skip first frames",
|
||||
"frame load cap",
|
||||
"start time",
|
||||
],
|
||||
category="image/video",
|
||||
inputs=[
|
||||
io.Video.Input("video"),
|
||||
io.Float.Input(
|
||||
"start_time",
|
||||
default=0.0,
|
||||
max=1e5,
|
||||
min=-1e5,
|
||||
step=0.001,
|
||||
tooltip="Start time in seconds",
|
||||
),
|
||||
io.Float.Input(
|
||||
"duration",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
step=0.001,
|
||||
tooltip="Duration in seconds, or 0 for unlimited duration",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"strict_duration",
|
||||
default=False,
|
||||
tooltip="If True, when the specified duration is not possible, an error will be raised.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: io.Video.Type, start_time: float, duration: float, strict_duration: bool) -> io.NodeOutput:
|
||||
trimmed = video.as_trimmed(start_time, duration, strict_duration=strict_duration)
|
||||
if trimmed is not None:
|
||||
return io.NodeOutput(trimmed)
|
||||
raise ValueError(
|
||||
f"Failed to slice video:\nSource duration: {video.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}"
|
||||
)
|
||||
|
||||
|
||||
class VideoExtension(ComfyExtension):
|
||||
@override
|
||||
@ -212,6 +262,7 @@ class VideoExtension(ComfyExtension):
|
||||
CreateVideo,
|
||||
GetVideoComponents,
|
||||
LoadVideo,
|
||||
VideoSlice,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> VideoExtension:
|
||||
|
||||
@ -623,6 +623,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
||||
logging.error("Got an OOM, unloading all loaded models.")
|
||||
comfy.model_management.unload_all_models()
|
||||
elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type:
|
||||
tips = "\n\nTIPS: If you have any \"Load CLIP\" or \"*CLIP Loader\" nodes in your workflow connected to this sampler node make sure the correct file(s) and type is selected."
|
||||
|
||||
error_details = {
|
||||
"node_id": real_node_id,
|
||||
|
||||
@ -5,8 +5,11 @@ from comfy_execution.jobs import (
|
||||
is_previewable,
|
||||
normalize_queue_item,
|
||||
normalize_history_item,
|
||||
normalize_output_item,
|
||||
normalize_outputs,
|
||||
get_outputs_summary,
|
||||
apply_sorting,
|
||||
has_3d_extension,
|
||||
)
|
||||
|
||||
|
||||
@ -35,8 +38,8 @@ class TestIsPreviewable:
|
||||
"""Unit tests for is_previewable()"""
|
||||
|
||||
def test_previewable_media_types(self):
|
||||
"""Images, video, audio media types should be previewable."""
|
||||
for media_type in ['images', 'video', 'audio']:
|
||||
"""Images, video, audio, 3d media types should be previewable."""
|
||||
for media_type in ['images', 'video', 'audio', '3d']:
|
||||
assert is_previewable(media_type, {}) is True
|
||||
|
||||
def test_non_previewable_media_types(self):
|
||||
@ -46,7 +49,7 @@ class TestIsPreviewable:
|
||||
|
||||
def test_3d_extensions_previewable(self):
|
||||
"""3D file extensions should be previewable regardless of media_type."""
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
|
||||
item = {'filename': f'model{ext}'}
|
||||
assert is_previewable('files', item) is True
|
||||
|
||||
@ -160,7 +163,7 @@ class TestGetOutputsSummary:
|
||||
|
||||
def test_3d_files_previewable(self):
|
||||
"""3D file extensions should be previewable."""
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
|
||||
outputs = {
|
||||
'node1': {
|
||||
'files': [{'filename': f'model{ext}', 'type': 'output'}]
|
||||
@ -192,6 +195,64 @@ class TestGetOutputsSummary:
|
||||
assert preview['mediaType'] == 'images'
|
||||
assert preview['subfolder'] == 'outputs'
|
||||
|
||||
def test_string_3d_filename_creates_preview(self):
|
||||
"""String items with 3D extensions should synthesize a preview (Preview3D node output).
|
||||
Only the .glb counts — nulls and non-file strings are excluded."""
|
||||
outputs = {
|
||||
'node1': {
|
||||
'result': ['preview3d_abc123.glb', None, None]
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert count == 1
|
||||
assert preview is not None
|
||||
assert preview['filename'] == 'preview3d_abc123.glb'
|
||||
assert preview['mediaType'] == '3d'
|
||||
assert preview['nodeId'] == 'node1'
|
||||
assert preview['type'] == 'output'
|
||||
|
||||
def test_string_non_3d_filename_no_preview(self):
|
||||
"""String items without 3D extensions should not create a preview."""
|
||||
outputs = {
|
||||
'node1': {
|
||||
'result': ['data.json', None]
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert count == 0
|
||||
assert preview is None
|
||||
|
||||
def test_string_3d_filename_used_as_fallback(self):
|
||||
"""String 3D preview should be used when no dict items are previewable."""
|
||||
outputs = {
|
||||
'node1': {
|
||||
'latents': [{'filename': 'latent.safetensors'}],
|
||||
},
|
||||
'node2': {
|
||||
'result': ['model.glb', None]
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert preview is not None
|
||||
assert preview['filename'] == 'model.glb'
|
||||
assert preview['mediaType'] == '3d'
|
||||
|
||||
|
||||
class TestHas3DExtension:
|
||||
"""Unit tests for has_3d_extension()"""
|
||||
|
||||
def test_recognized_extensions(self):
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
|
||||
assert has_3d_extension(f'model{ext}') is True
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert has_3d_extension('MODEL.GLB') is True
|
||||
assert has_3d_extension('Scene.GLTF') is True
|
||||
|
||||
def test_non_3d_extensions(self):
|
||||
for name in ['photo.png', 'video.mp4', 'data.json', 'model']:
|
||||
assert has_3d_extension(name) is False
|
||||
|
||||
|
||||
class TestApplySorting:
|
||||
"""Unit tests for apply_sorting()"""
|
||||
@ -395,3 +456,142 @@ class TestNormalizeHistoryItem:
|
||||
'prompt': {'nodes': {'1': {}}},
|
||||
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
|
||||
}
|
||||
|
||||
def test_include_outputs_normalizes_3d_strings(self):
|
||||
"""Detail view should transform string 3D filenames into file output dicts."""
|
||||
history_item = {
|
||||
'prompt': (
|
||||
5,
|
||||
'prompt-3d',
|
||||
{'nodes': {}},
|
||||
{'create_time': 1234567890},
|
||||
['node1'],
|
||||
),
|
||||
'status': {'status_str': 'success', 'completed': True, 'messages': []},
|
||||
'outputs': {
|
||||
'node1': {
|
||||
'result': ['preview3d_abc123.glb', None, None]
|
||||
}
|
||||
},
|
||||
}
|
||||
job = normalize_history_item('prompt-3d', history_item, include_outputs=True)
|
||||
|
||||
assert job['outputs_count'] == 1
|
||||
result_items = job['outputs']['node1']['result']
|
||||
assert len(result_items) == 1
|
||||
assert result_items[0] == {
|
||||
'filename': 'preview3d_abc123.glb',
|
||||
'type': 'output',
|
||||
'subfolder': '',
|
||||
'mediaType': '3d',
|
||||
}
|
||||
|
||||
def test_include_outputs_preserves_dict_items(self):
|
||||
"""Detail view normalization should pass dict items through unchanged."""
|
||||
history_item = {
|
||||
'prompt': (
|
||||
5,
|
||||
'prompt-img',
|
||||
{'nodes': {}},
|
||||
{'create_time': 1234567890},
|
||||
['node1'],
|
||||
),
|
||||
'status': {'status_str': 'success', 'completed': True, 'messages': []},
|
||||
'outputs': {
|
||||
'node1': {
|
||||
'images': [
|
||||
{'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
job = normalize_history_item('prompt-img', history_item, include_outputs=True)
|
||||
|
||||
assert job['outputs_count'] == 1
|
||||
assert job['outputs']['node1']['images'] == [
|
||||
{'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
|
||||
]
|
||||
|
||||
|
||||
class TestNormalizeOutputItem:
|
||||
"""Unit tests for normalize_output_item()"""
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert normalize_output_item(None) is None
|
||||
|
||||
def test_string_3d_extension_synthesizes_dict(self):
|
||||
result = normalize_output_item('model.glb')
|
||||
assert result == {'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
|
||||
|
||||
def test_string_non_3d_extension_returns_none(self):
|
||||
assert normalize_output_item('data.json') is None
|
||||
|
||||
def test_string_no_extension_returns_none(self):
|
||||
assert normalize_output_item('camera_info_string') is None
|
||||
|
||||
def test_dict_passes_through(self):
|
||||
item = {'filename': 'test.png', 'type': 'output'}
|
||||
assert normalize_output_item(item) is item
|
||||
|
||||
def test_other_types_return_none(self):
|
||||
assert normalize_output_item(42) is None
|
||||
assert normalize_output_item(True) is None
|
||||
|
||||
|
||||
class TestNormalizeOutputs:
|
||||
"""Unit tests for normalize_outputs()"""
|
||||
|
||||
def test_empty_outputs(self):
|
||||
assert normalize_outputs({}) == {}
|
||||
|
||||
def test_dict_items_pass_through(self):
|
||||
outputs = {
|
||||
'node1': {
|
||||
'images': [{'filename': 'a.png', 'type': 'output'}],
|
||||
}
|
||||
}
|
||||
result = normalize_outputs(outputs)
|
||||
assert result == outputs
|
||||
|
||||
def test_3d_string_synthesized(self):
|
||||
outputs = {
|
||||
'node1': {
|
||||
'result': ['model.glb', None, None],
|
||||
}
|
||||
}
|
||||
result = normalize_outputs(outputs)
|
||||
assert result == {
|
||||
'node1': {
|
||||
'result': [
|
||||
{'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
def test_animated_key_preserved(self):
|
||||
outputs = {
|
||||
'node1': {
|
||||
'images': [{'filename': 'a.png', 'type': 'output'}],
|
||||
'animated': [True],
|
||||
}
|
||||
}
|
||||
result = normalize_outputs(outputs)
|
||||
assert result['node1']['animated'] == [True]
|
||||
|
||||
def test_non_dict_node_outputs_preserved(self):
|
||||
outputs = {'node1': 'unexpected_value'}
|
||||
result = normalize_outputs(outputs)
|
||||
assert result == {'node1': 'unexpected_value'}
|
||||
|
||||
def test_none_items_filtered_but_other_types_preserved(self):
|
||||
outputs = {
|
||||
'node1': {
|
||||
'result': ['data.json', None, [1, 2, 3]],
|
||||
}
|
||||
}
|
||||
result = normalize_outputs(outputs)
|
||||
assert result == {
|
||||
'node1': {
|
||||
'result': ['data.json', [1, 2, 3]],
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user