mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 07:16:09 +08:00
Compare commits
22 Commits
hack-hunte
...
v0.9.1
| Author | SHA1 | Date | |
|---|---|---|---|
| 5ac1372533 | |||
| 1dcbd9efaf | |||
| db9e6edfa1 | |||
| 8af13b439b | |||
| acd0e53653 | |||
| 117e7a5853 | |||
| b3c0e4de57 | |||
| ecaeeb990d | |||
| c2b65e2fce | |||
| fd5c0755af | |||
| c881a1d689 | |||
| a3b5d4996a | |||
| c6238047ee | |||
| 5cd1113236 | |||
| 2f642d5d9b | |||
| cd912963f1 | |||
| 6e4b1f9d00 | |||
| dc202a2e51 | |||
| 153bc524bf | |||
| 393d2880dd | |||
| 4484b93d61 | |||
| bd0e6825e8 |
@ -183,7 +183,7 @@ Simply download, extract with [7-Zip](https://7-zip.org) or with the windows exp
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
Update your Nvidia drivers if it doesn't start.
|
||||
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
|
||||
|
||||
#### Alternative Downloads:
|
||||
|
||||
@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old.
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
|
||||
### Instructions:
|
||||
|
||||
|
||||
@ -92,14 +92,23 @@ def seed_from_paths_batch(
|
||||
session.execute(ins_asset, chunk)
|
||||
|
||||
# try to claim AssetCacheState (file_path)
|
||||
winners_by_path: set[str] = set()
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
|
||||
ins_state = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
.returning(AssetCacheState.file_path)
|
||||
)
|
||||
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
|
||||
winners_by_path.update((session.execute(ins_state, chunk)).scalars().all())
|
||||
session.execute(ins_state, chunk)
|
||||
|
||||
# Query to find which of our paths won (were actually inserted)
|
||||
winners_by_path: set[str] = set()
|
||||
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetCacheState.file_path)
|
||||
.where(AssetCacheState.file_path.in_(chunk))
|
||||
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
|
||||
)
|
||||
winners_by_path.update(result.scalars().all())
|
||||
|
||||
all_paths_set = set(path_list)
|
||||
losers_by_path = all_paths_set - winners_by_path
|
||||
@ -112,16 +121,23 @@ def seed_from_paths_batch(
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
|
||||
|
||||
# insert AssetInfo only for winners
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
|
||||
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
|
||||
ins_info = (
|
||||
sqlite.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
.returning(AssetInfo.id)
|
||||
)
|
||||
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
|
||||
inserted_info_ids.update((session.execute(ins_info, chunk)).scalars().all())
|
||||
session.execute(ins_info, chunk)
|
||||
|
||||
# Query to find which info rows were actually inserted (by matching our generated IDs)
|
||||
all_info_ids = [row["id"] for row in winner_info_rows]
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
|
||||
)
|
||||
inserted_info_ids.update(result.scalars().all())
|
||||
|
||||
# build and insert tag + meta rows for the AssetInfo
|
||||
tag_rows: list[dict] = []
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.ops
|
||||
import math
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
@ -21,6 +22,39 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5):
|
||||
def scale_dim(size, scale):
|
||||
scaled = math.ceil(size * scale / patch_size) * patch_size
|
||||
return max(patch_size, int(scaled))
|
||||
|
||||
# Binary search for optimal scale
|
||||
lo, hi = eps / 10, 100.0
|
||||
while hi - lo >= eps:
|
||||
mid = (lo + hi) / 2
|
||||
h, w = scale_dim(oh, mid), scale_dim(ow, mid)
|
||||
if (h // patch_size) * (w // patch_size) <= max_num_patches:
|
||||
lo = mid
|
||||
else:
|
||||
hi = mid
|
||||
|
||||
return scale_dim(oh, lo), scale_dim(ow, lo)
|
||||
|
||||
def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True):
|
||||
if size > 0:
|
||||
return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop)
|
||||
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
|
||||
b, c, h, w = image.shape
|
||||
h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True)
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1])
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||
super().__init__()
|
||||
@ -175,6 +209,27 @@ class CLIPTextModel(torch.nn.Module):
|
||||
out = self.text_projection(x[2])
|
||||
return (x[0], x[1], out, x[2])
|
||||
|
||||
def siglip2_pos_embed(embed_weight, embeds, orig_shape):
|
||||
embed_weight_len = round(embed_weight.shape[0] ** 0.5)
|
||||
embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len)
|
||||
embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True)
|
||||
embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1)
|
||||
return embeds + embed_weight
|
||||
|
||||
class Siglip2Embeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
|
||||
self.patch_size = patch_size
|
||||
|
||||
def forward(self, pixel_values):
|
||||
b, c, h, w = pixel_values.shape
|
||||
img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c)
|
||||
img = img.permute(0, 1, 3, 2, 4, 5)
|
||||
img = img.reshape(b, img.shape[1] * img.shape[2], -1)
|
||||
img = self.patch_embedding(img)
|
||||
return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size))
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
||||
@ -218,8 +273,11 @@ class CLIPVision(torch.nn.Module):
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type == "siglip_vision_model":
|
||||
if model_type in ["siglip2_vision_model"]:
|
||||
self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type in ["siglip_vision_model", "siglip2_vision_model"]:
|
||||
self.pre_layrnorm = lambda a: a
|
||||
self.output_layernorm = True
|
||||
else:
|
||||
|
||||
@ -21,6 +21,7 @@ clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from br
|
||||
IMAGE_ENCODERS = {
|
||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
}
|
||||
|
||||
@ -32,9 +33,10 @@ class ClipVisionModel():
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
model_type = config.get("model_type", "clip_vision_model")
|
||||
model_class = IMAGE_ENCODERS.get(model_type)
|
||||
if model_type == "siglip_vision_model":
|
||||
self.model_type = config.get("model_type", "clip_vision_model")
|
||||
self.config = config.copy()
|
||||
model_class = IMAGE_ENCODERS.get(self.model_type)
|
||||
if self.model_type == "siglip_vision_model":
|
||||
self.return_all_hidden_states = True
|
||||
else:
|
||||
self.return_all_hidden_states = False
|
||||
@ -55,7 +57,10 @@ class ClipVisionModel():
|
||||
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
if self.model_type == "siglip2_vision_model":
|
||||
pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
else:
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||
|
||||
outputs = Output()
|
||||
@ -107,10 +112,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
if embed_shape == 729:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif embed_shape == 1024:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
||||
patch_embedding_shape = sd["vision_model.embeddings.patch_embedding.weight"].shape
|
||||
if len(patch_embedding_shape) == 2:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip2_base_naflex.json")
|
||||
else:
|
||||
if embed_shape == 729:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif embed_shape == 1024:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
||||
elif embed_shape == 577:
|
||||
if "multi_modal_projector.linear_1.bias" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
|
||||
|
||||
14
comfy/clip_vision_siglip2_base_naflex.json
Normal file
14
comfy/clip_vision_siglip2_base_naflex.json
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": -1,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip2_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 16,
|
||||
"num_patches": 256,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}
|
||||
118
comfy/float.py
118
comfy/float.py
@ -65,3 +65,121 @@ def stochastic_rounding(value, dtype, seed=0):
|
||||
return output
|
||||
|
||||
return value.to(dtype=dtype)
|
||||
|
||||
|
||||
# TODO: improve this?
|
||||
def stochastic_float_to_fp4_e2m1(x, generator):
|
||||
orig_shape = x.shape
|
||||
sign = torch.signbit(x).to(torch.uint8)
|
||||
|
||||
exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3)
|
||||
x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25
|
||||
|
||||
x = x.abs()
|
||||
exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3)
|
||||
|
||||
mantissa = torch.where(
|
||||
exp > 0,
|
||||
(x / (2.0 ** (exp - 1)) - 1.0) * 2.0,
|
||||
(x * 2.0),
|
||||
out=x
|
||||
).round().to(torch.uint8)
|
||||
del x
|
||||
|
||||
exp = exp.to(torch.uint8)
|
||||
|
||||
fp4 = (sign << 3) | (exp << 1) | mantissa
|
||||
del sign, exp, mantissa
|
||||
|
||||
fp4_flat = fp4.view(-1)
|
||||
packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2]
|
||||
return packed.reshape(list(orig_shape)[:-1] + [-1])
|
||||
|
||||
|
||||
def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
|
||||
See:
|
||||
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
||||
|
||||
Args:
|
||||
input_matrix: Input tensor of shape (H, W)
|
||||
Returns:
|
||||
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
|
||||
"""
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
rows, cols = input_matrix.shape
|
||||
n_row_blocks = ceil_div(rows, 128)
|
||||
n_col_blocks = ceil_div(cols, 4)
|
||||
|
||||
# Calculate the padded shape
|
||||
padded_rows = n_row_blocks * 128
|
||||
padded_cols = n_col_blocks * 4
|
||||
|
||||
padded = input_matrix
|
||||
if (rows, cols) != (padded_rows, padded_cols):
|
||||
padded = torch.zeros(
|
||||
(padded_rows, padded_cols),
|
||||
device=input_matrix.device,
|
||||
dtype=input_matrix.dtype,
|
||||
)
|
||||
padded[:rows, :cols] = input_matrix
|
||||
|
||||
# Rearrange the blocks
|
||||
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
|
||||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
||||
if flatten:
|
||||
return rearranged.flatten()
|
||||
|
||||
return rearranged.reshape(padded_rows, padded_cols)
|
||||
|
||||
|
||||
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
|
||||
F4_E2M1_MAX = 6.0
|
||||
F8_E4M3_MAX = 448.0
|
||||
|
||||
def roundup(x: int, multiple: int) -> int:
|
||||
"""Round up x to the nearest multiple."""
|
||||
return ((x + multiple - 1) // multiple) * multiple
|
||||
|
||||
orig_shape = x.shape
|
||||
|
||||
# Handle padding
|
||||
if pad_16x:
|
||||
rows, cols = x.shape
|
||||
padded_rows = roundup(rows, 16)
|
||||
padded_cols = roundup(cols, 16)
|
||||
if padded_rows != rows or padded_cols != cols:
|
||||
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||
# Note: We update orig_shape because the output tensor logic below assumes x.shape matches
|
||||
# what we want to produce. If we pad here, we want the padded output.
|
||||
orig_shape = x.shape
|
||||
|
||||
block_size = 16
|
||||
|
||||
x = x.reshape(orig_shape[0], -1, block_size)
|
||||
max_abs = torch.amax(torch.abs(x), dim=-1)
|
||||
block_scale = max_abs / F4_E2M1_MAX
|
||||
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
|
||||
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
|
||||
total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)
|
||||
|
||||
# Handle zero blocks (from padding): avoid 0/0 NaN
|
||||
zero_scale_mask = (total_scale == 0)
|
||||
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)
|
||||
|
||||
x = x / total_scale_safe.unsqueeze(-1)
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x)
|
||||
|
||||
x = x.view(orig_shape)
|
||||
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
|
||||
|
||||
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
|
||||
return data_lp, blocked_scales
|
||||
|
||||
@ -11,6 +11,69 @@ from comfy.ldm.lightricks.model import (
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class CompressedTimestep:
|
||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
"""
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
||||
"""
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
|
||||
# All patches in a frame are identical, so we only keep the first one
|
||||
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
|
||||
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
|
||||
else:
|
||||
# Not divisible or too small - store directly without compression
|
||||
self.patches_per_frame = 1
|
||||
self.num_frames = num_tokens
|
||||
self.data = tensor
|
||||
|
||||
def expand(self):
|
||||
"""Expand back to original tensor."""
|
||||
if self.patches_per_frame == 1:
|
||||
return self.data
|
||||
|
||||
# [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim]
|
||||
expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim)
|
||||
return expanded.reshape(self.batch_size, -1, self.feature_dim)
|
||||
|
||||
def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)):
|
||||
"""Compute ada values on compressed per-frame data, then expand spatially."""
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
|
||||
# No compression - compute directly
|
||||
if self.patches_per_frame == 1:
|
||||
num_tokens = self.data.shape[1]
|
||||
dim_per_param = self.feature_dim // num_ada_params
|
||||
reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :]
|
||||
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype)
|
||||
ada_values = (table_values + reshaped).unbind(dim=2)
|
||||
return ada_values
|
||||
|
||||
# Compressed: compute on per-frame data then expand spatially
|
||||
# Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param]
|
||||
frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :]
|
||||
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(
|
||||
device=self.data.device, dtype=self.data.dtype
|
||||
)
|
||||
frame_ada = (table_values + frame_reshaped).unbind(dim=2)
|
||||
|
||||
# Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim]
|
||||
return tuple(
|
||||
frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1)
|
||||
.reshape(batch_size, -1, frame_val.shape[-1])
|
||||
for frame_val in frame_ada
|
||||
)
|
||||
|
||||
class BasicAVTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -119,6 +182,9 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
def get_ada_values(
|
||||
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
||||
):
|
||||
if isinstance(timestep, CompressedTimestep):
|
||||
return timestep.expand_for_computation(scale_shift_table, batch_size, indices)
|
||||
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
|
||||
ada_values = (
|
||||
@ -146,10 +212,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
gate_timestep,
|
||||
)
|
||||
|
||||
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
|
||||
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
|
||||
|
||||
return (*scale_shift_chunks, *gate_ada_values)
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -543,72 +606,80 @@ class LTXAVModel(LTXVModel):
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
|
||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
timestep_scaled.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1])
|
||||
v_embedded_timestep = v_embedded_timestep.view(
|
||||
batch_size, -1, v_embedded_timestep.shape[-1]
|
||||
)
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
v_patches_per_frame = None
|
||||
if orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
# Reshape to [batch_size, num_tokens, dim] and compress for storage
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
if a_timestep is not None:
|
||||
a_timestep = a_timestep * self.timestep_scale_multiplier
|
||||
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
||||
a_timestep_flat = a_timestep_scaled.flatten()
|
||||
timestep_flat = timestep_scaled.flatten()
|
||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
||||
|
||||
# Cross-attention timesteps - compress these too
|
||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||
a_timestep.flatten(),
|
||||
a_timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||
timestep.flatten(),
|
||||
timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||
timestep.flatten() * av_ca_factor,
|
||||
timestep_flat * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||
a_timestep.flatten() * av_ca_factor,
|
||||
a_timestep_flat * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||
]
|
||||
|
||||
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
||||
a_timestep.flatten(),
|
||||
a_timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
# Audio timesteps
|
||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||
a_embedded_timestep = a_embedded_timestep.view(
|
||||
batch_size, -1, a_embedded_timestep.shape[-1]
|
||||
)
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep,
|
||||
av_ca_video_scale_shift_timestep,
|
||||
av_ca_a2v_gate_noise_timestep,
|
||||
av_ca_v2a_gate_noise_timestep,
|
||||
]
|
||||
cross_av_timestep_ss = list(
|
||||
[t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss]
|
||||
)
|
||||
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
||||
else:
|
||||
a_timestep = timestep
|
||||
a_timestep = timestep_scaled
|
||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||
cross_av_timestep_ss = []
|
||||
|
||||
@ -767,6 +838,11 @@ class LTXAVModel(LTXVModel):
|
||||
ax = x[1]
|
||||
v_embedded_timestep = embedded_timestep[0]
|
||||
a_embedded_timestep = embedded_timestep[1]
|
||||
|
||||
# Expand compressed video timestep if needed
|
||||
if isinstance(v_embedded_timestep, CompressedTimestep):
|
||||
v_embedded_timestep = v_embedded_timestep.expand()
|
||||
|
||||
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
||||
|
||||
# Process audio output
|
||||
|
||||
@ -322,6 +322,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||
key_map["transformer.{}".format(key_lora)] = to
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||
key_map[key_lora] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.Kandinsky5):
|
||||
for k in sdk:
|
||||
|
||||
@ -237,6 +237,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
else:
|
||||
dit_config["vec_in_dim"] = None
|
||||
|
||||
dit_config["num_heads"] = dit_config["hidden_size"] // sum(dit_config["axes_dim"])
|
||||
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
||||
|
||||
@ -368,7 +368,7 @@ try:
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if rocm_version >= (7, 0):
|
||||
if any((a in arch) for a in ["gfx1201"]):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
|
||||
|
||||
31
comfy/ops.py
31
comfy/ops.py
@ -546,7 +546,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
weight_key = f"{prefix}weight"
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||
logging.warning(f"Missing weight for layer {layer_name}")
|
||||
return
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
@ -624,21 +625,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
missing_keys.remove(key)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
layout_cls = self.weight._layout_cls
|
||||
if destination is not None:
|
||||
sd = destination
|
||||
else:
|
||||
sd = {}
|
||||
|
||||
# Check if it's any FP8 variant (E4M3 or E5M2)
|
||||
if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"):
|
||||
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
|
||||
elif layout_cls == "TensorCoreNVFP4Layout":
|
||||
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale
|
||||
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
|
||||
if self.bias is not None:
|
||||
sd["{}bias".format(prefix)] = self.bias
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
sd[k] = sd_out[k]
|
||||
|
||||
quant_conf = {"format": self.quant_format}
|
||||
if self._full_precision_mm_config:
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
input_scale = getattr(self, 'input_scale', None)
|
||||
if input_scale is not None:
|
||||
sd["{}input_scale".format(prefix)] = input_scale
|
||||
else:
|
||||
sd["{}weight".format(prefix)] = self.weight
|
||||
return sd
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
@ -690,7 +699,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
# dtype is now implicit in the layout class
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
|
||||
@ -7,7 +7,7 @@ try:
|
||||
QuantizedTensor,
|
||||
QuantizedLayout,
|
||||
TensorCoreFP8Layout as _CKFp8Layout,
|
||||
TensorCoreNVFP4Layout, # Direct import, no wrapper needed
|
||||
TensorCoreNVFP4Layout as _CKNvfp4Layout,
|
||||
register_layout_op,
|
||||
register_layout_class,
|
||||
get_layout_class,
|
||||
@ -34,7 +34,7 @@ except ImportError as e:
|
||||
class _CKFp8Layout:
|
||||
pass
|
||||
|
||||
class TensorCoreNVFP4Layout:
|
||||
class _CKNvfp4Layout:
|
||||
pass
|
||||
|
||||
def register_layout_class(name, cls):
|
||||
@ -84,6 +84,39 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
if tensor.dim() != 2:
|
||||
raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D")
|
||||
|
||||
orig_dtype = tensor.dtype
|
||||
orig_shape = tuple(tensor.shape)
|
||||
|
||||
if scale is None or (isinstance(scale, str) and scale == "recalculate"):
|
||||
scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX)
|
||||
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
|
||||
padded_shape = cls.get_padded_shape(orig_shape)
|
||||
needs_padding = padded_shape != orig_shape
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
|
||||
else:
|
||||
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)
|
||||
|
||||
params = cls.Params(
|
||||
scale=scale,
|
||||
orig_dtype=orig_dtype,
|
||||
orig_shape=orig_shape,
|
||||
block_scale=block_scale,
|
||||
)
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
|
||||
|
||||
@ -1059,9 +1059,9 @@ def detect_te_model(sd):
|
||||
return TEModel.JINA_CLIP_2
|
||||
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
if weight.shape[-1] == 4096:
|
||||
if weight.shape[0] == 10240:
|
||||
return TEModel.T5_XXL
|
||||
elif weight.shape[-1] == 2048:
|
||||
elif weight.shape[0] == 5120:
|
||||
return TEModel.T5_XL
|
||||
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
|
||||
return TEModel.T5_XXL_OLD
|
||||
|
||||
@ -845,7 +845,7 @@ class LTXAV(LTXV):
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = 0.061 # TODO
|
||||
self.memory_usage_factor = 0.077 # TODO
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LTXAV(self, device=device)
|
||||
|
||||
@ -36,7 +36,7 @@ def te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype is None:
|
||||
if dtype_t5 is not None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return CosmosTEModel_
|
||||
|
||||
@ -32,7 +32,7 @@ def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype is None:
|
||||
if dtype_t5 is not None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return MochiTEModel_
|
||||
|
||||
@ -36,7 +36,7 @@ def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype is None:
|
||||
if dtype_t5 is not None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return PixArtTEModel_
|
||||
|
||||
41
comfy_api_nodes/apis/vidu.py
Normal file
41
comfy_api_nodes/apis/vidu.py
Normal file
@ -0,0 +1,41 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SubjectReference(BaseModel):
|
||||
id: str = Field(...)
|
||||
images: list[str] = Field(...)
|
||||
|
||||
|
||||
class TaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(..., max_length=2000)
|
||||
duration: int = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
aspect_ratio: str | None = Field(None)
|
||||
resolution: str | None = Field(None)
|
||||
movement_amplitude: str | None = Field(None)
|
||||
images: list[str] | None = Field(None, description="Base64 encoded string or image URL")
|
||||
subjects: list[SubjectReference] | None = Field(None)
|
||||
bgm: bool | None = Field(None)
|
||||
audio: bool | None = Field(None)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
task_id: str = Field(...)
|
||||
state: str = Field(...)
|
||||
created_at: str = Field(...)
|
||||
code: int | None = Field(None, description="Error code")
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
id: str = Field(..., description="Creation id")
|
||||
url: str = Field(..., description="The URL of the generated results, valid for one hour")
|
||||
cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour")
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
state: str = Field(...)
|
||||
err_code: str | None = Field(None)
|
||||
progress: float | None = Field(None)
|
||||
credits: int | None = Field(None)
|
||||
creations: list[TaskResult] = Field(..., description="Generated results")
|
||||
@ -567,7 +567,7 @@ async def execute_lipsync(
|
||||
# Upload the audio file to Comfy API and get download URL
|
||||
if audio:
|
||||
audio_url = await upload_audio_to_comfyapi(
|
||||
cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3"
|
||||
cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg"
|
||||
)
|
||||
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
|
||||
else:
|
||||
|
||||
@ -2,7 +2,6 @@ import builtins
|
||||
from io import BytesIO
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
@ -138,7 +137,7 @@ class TopazImageEnhance(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
prompt: str = "",
|
||||
subject_detection: str = "All",
|
||||
face_enhancement: bool = True,
|
||||
@ -153,7 +152,9 @@ class TopazImageEnhance(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Only one input image is supported.")
|
||||
download_url = await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png")
|
||||
download_url = await upload_images_to_comfyapi(
|
||||
cls, image, max_images=1, mime_type="image/png", total_pixels=4096*4096
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"),
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.vidu import (
|
||||
SubjectReference,
|
||||
TaskCreationRequest,
|
||||
TaskCreationResponse,
|
||||
TaskResult,
|
||||
TaskStatusResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_video_output,
|
||||
@ -17,6 +18,7 @@ from comfy_api_nodes.util import (
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_images_aspect_ratio_closeness,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
|
||||
@ -25,98 +27,33 @@ VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video"
|
||||
VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video"
|
||||
VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class VideoModelName(str, Enum):
|
||||
vidu_q1 = "viduq1"
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
r_16_9 = "16:9"
|
||||
r_9_16 = "9:16"
|
||||
r_1_1 = "1:1"
|
||||
|
||||
|
||||
class Resolution(str, Enum):
|
||||
r_1080p = "1080p"
|
||||
|
||||
|
||||
class MovementAmplitude(str, Enum):
|
||||
auto = "auto"
|
||||
small = "small"
|
||||
medium = "medium"
|
||||
large = "large"
|
||||
|
||||
|
||||
class TaskCreationRequest(BaseModel):
|
||||
model: VideoModelName = VideoModelName.vidu_q1
|
||||
prompt: Optional[str] = Field(None, max_length=1500)
|
||||
duration: Optional[Literal[5]] = 5
|
||||
seed: Optional[int] = Field(0, ge=0, le=2147483647)
|
||||
aspect_ratio: Optional[AspectRatio] = AspectRatio.r_16_9
|
||||
resolution: Optional[Resolution] = Resolution.r_1080p
|
||||
movement_amplitude: Optional[MovementAmplitude] = MovementAmplitude.auto
|
||||
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
task_id: str = Field(...)
|
||||
state: str = Field(...)
|
||||
created_at: str = Field(...)
|
||||
code: Optional[int] = Field(None, description="Error code")
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
id: str = Field(..., description="Creation id")
|
||||
url: str = Field(..., description="The URL of the generated results, valid for one hour")
|
||||
cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour")
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
state: str = Field(...)
|
||||
err_code: Optional[str] = Field(None)
|
||||
creations: list[TaskResult] = Field(..., description="Generated results")
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
if response.creations:
|
||||
return response.creations[0].url
|
||||
return None
|
||||
|
||||
|
||||
def get_video_from_response(response) -> TaskResult:
|
||||
if not response.creations:
|
||||
error_msg = f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}"
|
||||
logging.info(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
logging.info("Vidu task %s succeeded. Video URL: %s", response.creations[0].id, response.creations[0].url)
|
||||
return response.creations[0]
|
||||
|
||||
|
||||
async def execute_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
vidu_endpoint: str,
|
||||
payload: TaskCreationRequest,
|
||||
estimated_duration: int,
|
||||
) -> R:
|
||||
response = await sync_op(
|
||||
) -> list[TaskResult]:
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=payload,
|
||||
)
|
||||
if response.state == "failed":
|
||||
error_msg = f"Vidu request failed. Code: {response.code}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
return await poll_op(
|
||||
if task_creation_response.state == "failed":
|
||||
raise RuntimeError(f"Vidu request failed. Code: {task_creation_response.code}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
|
||||
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % task_creation_response.task_id),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.state,
|
||||
estimated_duration=estimated_duration,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
max_poll_attempts=320,
|
||||
)
|
||||
if not response.creations:
|
||||
raise RuntimeError(
|
||||
f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}"
|
||||
)
|
||||
return response.creations
|
||||
|
||||
|
||||
class ViduTextToVideoNode(IO.ComfyNode):
|
||||
@ -127,14 +64,9 @@ class ViduTextToVideoNode(IO.ComfyNode):
|
||||
node_id="ViduTextToVideoNode",
|
||||
display_name="Vidu Text To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from text prompt",
|
||||
description="Generate video from a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=VideoModelName,
|
||||
default=VideoModelName.vidu_q1,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
@ -163,22 +95,19 @@ class ViduTextToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=AspectRatio,
|
||||
default=AspectRatio.r_16_9,
|
||||
options=["16:9", "9:16", "1:1"],
|
||||
tooltip="The aspect ratio of the output video",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=Resolution,
|
||||
default=Resolution.r_1080p,
|
||||
options=["1080p"],
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=MovementAmplitude,
|
||||
default=MovementAmplitude.auto,
|
||||
options=["auto", "small", "medium", "large"],
|
||||
tooltip="The movement amplitude of objects in the frame",
|
||||
optional=True,
|
||||
),
|
||||
@ -208,7 +137,7 @@ class ViduTextToVideoNode(IO.ComfyNode):
|
||||
if not prompt:
|
||||
raise ValueError("The prompt field is required and cannot be empty.")
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
@ -216,8 +145,8 @@ class ViduTextToVideoNode(IO.ComfyNode):
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class ViduImageToVideoNode(IO.ComfyNode):
|
||||
@ -230,12 +159,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from image and optional prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=VideoModelName,
|
||||
default=VideoModelName.vidu_q1,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="An image to be used as the start frame of the generated video",
|
||||
@ -270,15 +194,13 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=Resolution,
|
||||
default=Resolution.r_1080p,
|
||||
options=["1080p"],
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=MovementAmplitude,
|
||||
default=MovementAmplitude.auto.value,
|
||||
options=["auto", "small", "medium", "large"],
|
||||
tooltip="The movement amplitude of objects in the frame",
|
||||
optional=True,
|
||||
),
|
||||
@ -298,7 +220,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
@ -309,7 +231,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
raise ValueError("Only one input image is allowed.")
|
||||
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
@ -322,8 +244,8 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
)
|
||||
results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
@ -334,14 +256,9 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
node_id="ViduReferenceVideoNode",
|
||||
display_name="Vidu Reference To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from multiple images and prompt",
|
||||
description="Generate video from multiple images and a prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=VideoModelName,
|
||||
default=VideoModelName.vidu_q1,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
IO.Image.Input(
|
||||
"images",
|
||||
tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).",
|
||||
@ -374,22 +291,19 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=AspectRatio,
|
||||
default=AspectRatio.r_16_9,
|
||||
options=["16:9", "9:16", "1:1"],
|
||||
tooltip="The aspect ratio of the output video",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[model.value for model in Resolution],
|
||||
default=Resolution.r_1080p.value,
|
||||
options=["1080p"],
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=[model.value for model in MovementAmplitude],
|
||||
default=MovementAmplitude.auto.value,
|
||||
options=["auto", "small", "medium", "large"],
|
||||
tooltip="The movement amplitude of objects in the frame",
|
||||
optional=True,
|
||||
),
|
||||
@ -409,7 +323,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
images: torch.Tensor,
|
||||
images: Input.Image,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
@ -426,7 +340,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||
validate_image_dimensions(image, min_width=128, min_height=128)
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
@ -440,8 +354,8 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
max_images=7,
|
||||
mime_type="image/png",
|
||||
)
|
||||
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
@ -454,12 +368,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
category="api node/video/Vidu",
|
||||
description="Generate a video from start and end frames and a prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in VideoModelName],
|
||||
default=VideoModelName.vidu_q1.value,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
IO.Image.Input(
|
||||
"first_frame",
|
||||
tooltip="Start frame",
|
||||
@ -497,15 +406,13 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[model.value for model in Resolution],
|
||||
default=Resolution.r_1080p.value,
|
||||
options=["1080p"],
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=[model.value for model in MovementAmplitude],
|
||||
default=MovementAmplitude.auto.value,
|
||||
options=["auto", "small", "medium", "large"],
|
||||
tooltip="The movement amplitude of objects in the frame",
|
||||
optional=True,
|
||||
),
|
||||
@ -525,8 +432,8 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
first_frame: torch.Tensor,
|
||||
end_frame: torch.Tensor,
|
||||
first_frame: Input.Image,
|
||||
end_frame: Input.Image,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
@ -535,7 +442,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
@ -546,8 +453,391 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
|
||||
for frame in (first_frame, end_frame)
|
||||
]
|
||||
results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
results = await execute_task(cls, VIDU_START_END_VIDEO, payload)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class Vidu2TextToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2TextToVideoNode",
|
||||
display_name="Vidu2 Text-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A textual description for video generation, with a maximum length of 2000 characters.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=10,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "3:4", "4:3", "1:1"]),
|
||||
IO.Combo.Input("resolution", options=["720p", "1080p"]),
|
||||
IO.Boolean.Input(
|
||||
"background_music",
|
||||
default=False,
|
||||
tooltip="Whether to add background music to the generated video.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
background_music: bool,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2000)
|
||||
results = await execute_task(
|
||||
cls,
|
||||
VIDU_TEXT_TO_VIDEO,
|
||||
TaskCreationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio,
|
||||
resolution=resolution,
|
||||
bgm=background_music,
|
||||
),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class Vidu2ImageToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2ImageToVideoNode",
|
||||
display_name="Vidu2 Image-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate a video from an image and an optional prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="An image to be used as the start frame of the generated video.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="An optional text prompt for video generation (max 2000 characters).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=10,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=["auto", "small", "medium", "large"],
|
||||
tooltip="The movement amplitude of objects in the frame.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> IO.NodeOutput:
|
||||
if get_number_of_images(image) > 1:
|
||||
raise ValueError("Only one input image is allowed.")
|
||||
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||
validate_string(prompt, max_length=2000)
|
||||
results = await execute_task(
|
||||
cls,
|
||||
VIDU_IMAGE_TO_VIDEO,
|
||||
TaskCreationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
images=await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
),
|
||||
),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class Vidu2ReferenceVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2ReferenceVideoNode",
|
||||
display_name="Vidu2 Reference-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate a video from multiple reference images and a prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2"]),
|
||||
IO.Autogrow.Input(
|
||||
"subjects",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("reference_images"),
|
||||
names=["subject1", "subject2", "subject3"],
|
||||
min=1,
|
||||
),
|
||||
tooltip="For each subject, provide up to 3 reference images (7 images total across all subjects). "
|
||||
"Reference them in prompts via @subject{subject_id}.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="When enabled, the video will include generated speech and background music "
|
||||
"based on the prompt.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"audio",
|
||||
default=False,
|
||||
tooltip="When enabled video will contain generated speech and background music based on the prompt.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=10,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "4:3", "3:4", "1:1"]),
|
||||
IO.Combo.Input("resolution", options=["720p"]),
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=["auto", "small", "medium", "large"],
|
||||
tooltip="The movement amplitude of objects in the frame.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
subjects: IO.Autogrow.Type,
|
||||
prompt: str,
|
||||
audio: bool,
|
||||
duration: int,
|
||||
seed: int,
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2000)
|
||||
total_images = 0
|
||||
for i in subjects:
|
||||
if get_number_of_images(subjects[i]) > 3:
|
||||
raise ValueError("Maximum number of images per subject is 3.")
|
||||
for im in subjects[i]:
|
||||
total_images += 1
|
||||
validate_image_aspect_ratio(im, (1, 4), (4, 1))
|
||||
validate_image_dimensions(im, min_width=128, min_height=128)
|
||||
if total_images > 7:
|
||||
raise ValueError("Too many reference images; the maximum allowed is 7.")
|
||||
subjects_param: list[SubjectReference] = []
|
||||
for i in subjects:
|
||||
subjects_param.append(
|
||||
SubjectReference(
|
||||
id=i,
|
||||
images=await upload_images_to_comfyapi(
|
||||
cls,
|
||||
subjects[i],
|
||||
max_images=3,
|
||||
mime_type="image/png",
|
||||
wait_label=f"Uploading reference images for {i}",
|
||||
),
|
||||
),
|
||||
)
|
||||
payload = TaskCreationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
audio=audio,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio,
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
subjects=subjects_param,
|
||||
)
|
||||
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class Vidu2StartEndToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2StartEndToVideoNode",
|
||||
display_name="Vidu2 Start/End Frame-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate a video from a start frame, an end frame, and a prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
|
||||
IO.Image.Input("first_frame"),
|
||||
IO.Image.Input("end_frame"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Prompt description (max 2000 characters).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=2,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["720p", "1080p"]),
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=["auto", "small", "medium", "large"],
|
||||
tooltip="The movement amplitude of objects in the frame.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
first_frame: Input.Image,
|
||||
end_frame: Input.Image,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=2000)
|
||||
if get_number_of_images(first_frame) > 1:
|
||||
raise ValueError("Only one input image is allowed for `first_frame`.")
|
||||
if get_number_of_images(end_frame) > 1:
|
||||
raise ValueError("Only one input image is allowed for `end_frame`.")
|
||||
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||
payload = TaskCreationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
images=[
|
||||
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
|
||||
for frame in (first_frame, end_frame)
|
||||
],
|
||||
)
|
||||
results = await execute_task(cls, VIDU_START_END_VIDEO, payload)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class ViduExtension(ComfyExtension):
|
||||
@ -558,6 +848,10 @@ class ViduExtension(ComfyExtension):
|
||||
ViduImageToVideoNode,
|
||||
ViduReferenceVideoNode,
|
||||
ViduStartEndToVideoNode,
|
||||
Vidu2TextToVideoNode,
|
||||
Vidu2ImageToVideoNode,
|
||||
Vidu2ReferenceVideoNode,
|
||||
Vidu2StartEndToVideoNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
|
||||
|
||||
def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
name: str | None = None,
|
||||
*,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> BytesIO:
|
||||
@ -75,7 +75,7 @@ def tensor_to_bytesio(
|
||||
|
||||
pil_image = tensor_to_pil(image, total_pixels=total_pixels)
|
||||
img_binary = pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||
img_binary.name = f"{uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||
return img_binary
|
||||
|
||||
|
||||
|
||||
@ -49,6 +49,7 @@ async def upload_images_to_comfyapi(
|
||||
mime_type: str | None = None,
|
||||
wait_label: str | None = "Uploading",
|
||||
show_batch_index: bool = True,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
@ -63,7 +64,7 @@ async def upload_images_to_comfyapi(
|
||||
|
||||
for idx in range(num_to_upload):
|
||||
tensor = image[idx] if is_batch else image
|
||||
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
||||
img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type)
|
||||
|
||||
effective_label = wait_label
|
||||
if wait_label and show_batch_index and num_to_upload > 1:
|
||||
@ -81,7 +82,6 @@ async def upload_audio_to_comfyapi(
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
mime_type: str = "audio/mp4",
|
||||
filename: str = "uploaded_audio.mp4",
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single audio input to ComfyUI API and returns its download URL.
|
||||
@ -91,7 +91,7 @@ async def upload_audio_to_comfyapi(
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
|
||||
return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type)
|
||||
return await upload_file_to_comfyapi(cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type)
|
||||
|
||||
|
||||
async def upload_video_to_comfyapi(
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.8.2"
|
||||
__version__ = "0.9.1"
|
||||
|
||||
13
nodes.py
13
nodes.py
@ -2132,12 +2132,6 @@ def get_module_name(module_path: str) -> str:
|
||||
base_path = os.path.splitext(base_path)[0]
|
||||
return base_path
|
||||
|
||||
def get_hacky_folder_paths():
|
||||
hacky_folder_paths = []
|
||||
for name, values in folder_paths.folder_names_and_paths.items():
|
||||
if len(values) > 2:
|
||||
hacky_folder_paths.append([name, values])
|
||||
return hacky_folder_paths
|
||||
|
||||
async def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
|
||||
module_name = get_module_name(module_path)
|
||||
@ -2247,7 +2241,6 @@ async def init_external_custom_nodes():
|
||||
base_node_names = set(NODE_CLASS_MAPPINGS.keys())
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
node_import_times = []
|
||||
found_first_hacky_folder_path = False
|
||||
for custom_node_path in node_paths:
|
||||
possible_modules = os.listdir(os.path.realpath(custom_node_path))
|
||||
if "__pycache__" in possible_modules:
|
||||
@ -2271,12 +2264,6 @@ async def init_external_custom_nodes():
|
||||
time_before = time.perf_counter()
|
||||
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
if not found_first_hacky_folder_path:
|
||||
hacky_folder_paths = get_hacky_folder_paths()
|
||||
if len(hacky_folder_paths) > 0:
|
||||
logging.warning(f"Found first custom node to have hacky folder paths: {module_path}")
|
||||
logging.warning(f"Hacky folder paths: {hacky_folder_paths}")
|
||||
found_first_hacky_folder_path = True
|
||||
|
||||
if len(node_import_times) > 0:
|
||||
logging.info("\nImport times for custom nodes:")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.8.2"
|
||||
version = "0.9.1"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.36.13
|
||||
comfyui-workflow-templates==0.7.69
|
||||
comfyui-embedded-docs==0.3.1
|
||||
comfyui-frontend-package==1.36.14
|
||||
comfyui-workflow-templates==0.8.4
|
||||
comfyui-embedded-docs==0.4.0
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
@ -21,7 +21,7 @@ psutil
|
||||
alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.5
|
||||
comfy-kitchen>=0.2.6
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
|
||||
@ -153,9 +153,9 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
state_dict2 = model.state_dict()
|
||||
|
||||
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
||||
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
||||
self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0)
|
||||
self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout")
|
||||
self.assertTrue(torch.equal(state_dict2["layer1.weight"].view(torch.uint8), fp8_weight.view(torch.uint8)))
|
||||
self.assertEqual(state_dict2["layer1.weight_scale"].item(), 3.0)
|
||||
self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout")
|
||||
|
||||
# Verify non-quantized layers are standard tensors
|
||||
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
||||
|
||||
Reference in New Issue
Block a user