mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-30 16:56:32 +08:00
Compare commits
14 Commits
sortblock
...
js/progres
| Author | SHA1 | Date | |
|---|---|---|---|
| e9db9554bd | |||
| 98be8e1969 | |||
| 766ff74207 | |||
| b1b5f87534 | |||
| cf45fd1742 | |||
| e7314f49e6 | |||
| e2d1e5dad9 | |||
| 27e067ce50 | |||
| 9b15155972 | |||
| 32a627bf1f | |||
| fe442fac2e | |||
| d2c502e629 | |||
| fea9ea8268 | |||
| f949094b3c |
30
.github/workflows/test-execution.yml
vendored
Normal file
30
.github/workflows/test-execution.yml
vendored
Normal file
@ -0,0 +1,30 @@
|
||||
name: Execution Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install -r tests-unit/requirements.txt
|
||||
- name: Run Execution Tests
|
||||
run: |
|
||||
python -m pytest tests/execution -v --skip-timing-checks
|
||||
@ -143,6 +143,7 @@ class PerformanceFeature(enum.Enum):
|
||||
Fp16Accumulation = "fp16_accumulation"
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||
|
||||
|
||||
@ -171,6 +171,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
|
||||
return sigmas
|
||||
|
||||
|
||||
def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the result of h*phi_1(h) in exponential integrator methods."""
|
||||
return torch.expm1(h)
|
||||
|
||||
|
||||
def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the result of h*phi_2(h) in exponential integrator methods."""
|
||||
return (torch.expm1(h) - h) / h
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
@ -1550,13 +1560,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
@ -1564,55 +1573,53 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
fac = 1 / (2 * r)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = lambda_s + r * h
|
||||
fac = 1 / (2 * r)
|
||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||
continue
|
||||
|
||||
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r)
|
||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||
|
||||
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
# 0 < r < 1
|
||||
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
|
||||
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised
|
||||
if inject_noise:
|
||||
sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
|
||||
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
# Step 2
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
if inject_noise:
|
||||
segment_factor = (r - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1])
|
||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
|
||||
arXiv: https://arxiv.org/abs/2305.14267
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
@ -1624,45 +1631,49 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = lambda_s + r_1 * h
|
||||
lambda_s_2 = lambda_s + r_2 * h
|
||||
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
|
||||
continue
|
||||
|
||||
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1)
|
||||
lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2)
|
||||
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
|
||||
|
||||
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
# 0 < r_1 < r_2 < 1
|
||||
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
|
||||
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised
|
||||
if inject_noise:
|
||||
sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
|
||||
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
||||
if inject_noise:
|
||||
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
# Step 2
|
||||
a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta)
|
||||
a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2
|
||||
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2)
|
||||
if inject_noise:
|
||||
segment_factor = (r_1 - r_2) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2)
|
||||
x_3 = x_3 + sde_noise * sigma_s_2 * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
|
||||
# Step 3
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
||||
# Step 3
|
||||
b3 = ei_h_phi_2(-h_eta) / r_2
|
||||
b1 = ei_h_phi_1(-h_eta) - b3
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3)
|
||||
if inject_noise:
|
||||
segment_factor = (r_2 - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@ -133,7 +133,6 @@ class Attention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options={},
|
||||
**cross_attention_kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
@ -141,7 +140,6 @@ class Attention(nn.Module):
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
@ -368,7 +366,6 @@ class CustomerAttnProcessor2_0:
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
transformer_options={},
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@ -436,7 +433,7 @@ class CustomerAttnProcessor2_0:
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = optimized_attention(
|
||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
|
||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
|
||||
).to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
@ -700,7 +697,6 @@ class LinearTransformerBlock(nn.Module):
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
temb: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
):
|
||||
|
||||
N = hidden_states.shape[0]
|
||||
@ -724,7 +720,6 @@ class LinearTransformerBlock(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
else:
|
||||
attn_output, _ = self.attn(
|
||||
@ -734,7 +729,6 @@ class LinearTransformerBlock(nn.Module):
|
||||
encoder_attention_mask=None,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=None,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if self.use_adaln_single:
|
||||
@ -749,7 +743,6 @@ class LinearTransformerBlock(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
|
||||
@ -314,7 +314,6 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
output_length: int = 0,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
transformer_options={},
|
||||
):
|
||||
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
||||
temb = self.t_block(embedded_timestep)
|
||||
@ -340,7 +339,6 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
||||
temb=temb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||
@ -395,7 +393,6 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
|
||||
output_length = hidden_states.shape[-1]
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
output = self.decode(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
@ -405,7 +402,6 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
output_length=output_length,
|
||||
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
||||
controlnet_scale=controlnet_scale,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@ -298,8 +298,7 @@ class Attention(nn.Module):
|
||||
mask = None,
|
||||
context_mask = None,
|
||||
rotary_pos_emb = None,
|
||||
causal = None,
|
||||
transformer_options={},
|
||||
causal = None
|
||||
):
|
||||
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
||||
|
||||
@ -364,7 +363,7 @@ class Attention(nn.Module):
|
||||
heads_per_kv_head = h // kv_h
|
||||
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True)
|
||||
out = self.to_out(out)
|
||||
|
||||
if mask is not None:
|
||||
@ -489,8 +488,7 @@ class TransformerBlock(nn.Module):
|
||||
global_cond=None,
|
||||
mask = None,
|
||||
context_mask = None,
|
||||
rotary_pos_emb = None,
|
||||
transformer_options={}
|
||||
rotary_pos_emb = None
|
||||
):
|
||||
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
||||
|
||||
@ -500,12 +498,12 @@ class TransformerBlock(nn.Module):
|
||||
residual = x
|
||||
x = self.pre_norm(x)
|
||||
x = x * (1 + scale_self) + shift_self
|
||||
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
|
||||
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||
x = x * torch.sigmoid(1 - gate_self)
|
||||
x = x + residual
|
||||
|
||||
if context is not None:
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer(x)
|
||||
@ -519,10 +517,10 @@ class TransformerBlock(nn.Module):
|
||||
x = x + residual
|
||||
|
||||
else:
|
||||
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
|
||||
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||
|
||||
if context is not None:
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer(x)
|
||||
@ -608,8 +606,7 @@ class ContinuousTransformer(nn.Module):
|
||||
return_info = False,
|
||||
**kwargs
|
||||
):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
context = kwargs["context"]
|
||||
|
||||
@ -648,13 +645,13 @@ class ContinuousTransformer(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
|
||||
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
|
||||
if return_info:
|
||||
|
||||
@ -85,7 +85,7 @@ class SingleAttention(nn.Module):
|
||||
)
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c, transformer_options={}):
|
||||
def forward(self, c):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
|
||||
@ -95,7 +95,7 @@ class SingleAttention(nn.Module):
|
||||
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
q, k = self.q_norm1(q), self.k_norm1(k)
|
||||
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||
c = self.w1o(output)
|
||||
return c
|
||||
|
||||
@ -144,7 +144,7 @@ class DoubleAttention(nn.Module):
|
||||
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c, x, transformer_options={}):
|
||||
def forward(self, c, x):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
bsz, seqlen2, _ = x.shape
|
||||
@ -168,7 +168,7 @@ class DoubleAttention(nn.Module):
|
||||
torch.cat([cv, xv], dim=1),
|
||||
)
|
||||
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||
|
||||
c, x = output.split([seqlen1, seqlen2], dim=1)
|
||||
c = self.w1o(c)
|
||||
@ -207,7 +207,7 @@ class MMDiTBlock(nn.Module):
|
||||
self.is_last = is_last
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c, x, global_cond, transformer_options={}, **kwargs):
|
||||
def forward(self, c, x, global_cond, **kwargs):
|
||||
|
||||
cres, xres = c, x
|
||||
|
||||
@ -225,7 +225,7 @@ class MMDiTBlock(nn.Module):
|
||||
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
|
||||
|
||||
# attention
|
||||
c, x = self.attn(c, x, transformer_options=transformer_options)
|
||||
c, x = self.attn(c, x)
|
||||
|
||||
|
||||
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
|
||||
@ -255,13 +255,13 @@ class DiTBlock(nn.Module):
|
||||
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, cx, global_cond, transformer_options={}, **kwargs):
|
||||
def forward(self, cx, global_cond, **kwargs):
|
||||
cxres = cx
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
|
||||
global_cond
|
||||
).chunk(6, dim=1)
|
||||
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
|
||||
cx = self.attn(cx, transformer_options=transformer_options)
|
||||
cx = self.attn(cx)
|
||||
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
|
||||
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
|
||||
cx = gate_mlp.unsqueeze(1) * mlpout
|
||||
@ -473,14 +473,13 @@ class MMDiT(nn.Module):
|
||||
out = {}
|
||||
out["txt"], out["img"] = layer(args["txt"],
|
||||
args["img"],
|
||||
args["vec"],
|
||||
transformer_options=args["transformer_options"])
|
||||
args["vec"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
|
||||
c = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
|
||||
if len(self.single_layers) > 0:
|
||||
c_len = c.size(1)
|
||||
@ -489,13 +488,13 @@ class MMDiT(nn.Module):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
|
||||
out["img"] = layer(args["img"], args["vec"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
|
||||
cx = out["img"]
|
||||
else:
|
||||
cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
|
||||
x = cx[:, c_len:]
|
||||
|
||||
|
||||
@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module):
|
||||
|
||||
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, q, k, v, transformer_options={}):
|
||||
def forward(self, q, k, v):
|
||||
q = self.to_q(q)
|
||||
k = self.to_k(k)
|
||||
v = self.to_v(v)
|
||||
|
||||
out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
|
||||
out = optimized_attention(q, k, v, self.heads)
|
||||
|
||||
return self.out_proj(out)
|
||||
|
||||
@ -47,13 +47,13 @@ class Attention2D(nn.Module):
|
||||
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
|
||||
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, kv, self_attn=False, transformer_options={}):
|
||||
def forward(self, x, kv, self_attn=False):
|
||||
orig_shape = x.shape
|
||||
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
||||
if self_attn:
|
||||
kv = torch.cat([x, kv], dim=1)
|
||||
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
||||
x = self.attn(x, kv, kv, transformer_options=transformer_options)
|
||||
x = self.attn(x, kv, kv)
|
||||
x = x.permute(0, 2, 1).view(*orig_shape)
|
||||
return x
|
||||
|
||||
@ -114,9 +114,9 @@ class AttnBlock(nn.Module):
|
||||
operations.Linear(c_cond, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, kv, transformer_options={}):
|
||||
def forward(self, x, kv):
|
||||
kv = self.kv_mapper(kv)
|
||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
|
||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@ -173,7 +173,7 @@ class StageB(nn.Module):
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip, transformer_options={}):
|
||||
def _down_encode(self, x, r_embed, clip):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
@ -187,7 +187,7 @@ class StageB(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@ -199,7 +199,7 @@ class StageB(nn.Module):
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
|
||||
def _up_decode(self, level_outputs, r_embed, clip):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
@ -216,7 +216,7 @@ class StageB(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@ -228,7 +228,7 @@ class StageB(nn.Module):
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
|
||||
if pixels is None:
|
||||
pixels = x.new_zeros(x.size(0), 3, 8, 8)
|
||||
|
||||
@ -245,8 +245,8 @@ class StageB(nn.Module):
|
||||
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
|
||||
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
|
||||
level_outputs = self._down_encode(x, r_embed, clip)
|
||||
x = self._up_decode(level_outputs, r_embed, clip)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
|
||||
@ -182,7 +182,7 @@ class StageC(nn.Module):
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}):
|
||||
def _down_encode(self, x, r_embed, clip, cnet=None):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
@ -201,7 +201,7 @@ class StageC(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@ -213,7 +213,7 @@ class StageC(nn.Module):
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}):
|
||||
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
@ -235,7 +235,7 @@ class StageC(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@ -247,7 +247,7 @@ class StageC(nn.Module):
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
|
||||
# Process the conditioning embeddings
|
||||
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
||||
for c in self.t_conds:
|
||||
@ -262,8 +262,8 @@ class StageC(nn.Module):
|
||||
|
||||
# Model Blocks
|
||||
x = self.embedding(x)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, cnet)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, cnet)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
|
||||
@ -76,7 +76,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
|
||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
@ -95,7 +95,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
@ -148,7 +148,7 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
||||
mod = vec
|
||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
@ -157,7 +157,7 @@ class SingleStreamBlock(nn.Module):
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x.addcmul_(mod.gate, output)
|
||||
|
||||
@ -193,16 +193,14 @@ class Chroma(nn.Module):
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": double_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
@ -211,8 +209,7 @@ class Chroma(nn.Module):
|
||||
txt=txt,
|
||||
vec=double_mod,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@ -232,19 +229,17 @@ class Chroma(nn.Module):
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": single_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
|
||||
@ -176,7 +176,6 @@ class Attention(nn.Module):
|
||||
context=None,
|
||||
mask=None,
|
||||
rope_emb=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -185,7 +184,7 @@ class Attention(nn.Module):
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options)
|
||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||
del q, k, v
|
||||
out = rearrange(out, " b n s c -> s b (n c)")
|
||||
return self.to_out(out)
|
||||
@ -547,7 +546,6 @@ class VideoAttn(nn.Module):
|
||||
context: Optional[torch.Tensor] = None,
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for video attention.
|
||||
@ -573,7 +571,6 @@ class VideoAttn(nn.Module):
|
||||
context_M_B_D,
|
||||
crossattn_mask,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
|
||||
return x_T_H_W_B_D
|
||||
@ -668,7 +665,6 @@ class DITBuildingBlock(nn.Module):
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for dynamically configured blocks with adaptive normalization.
|
||||
@ -706,7 +702,6 @@ class DITBuildingBlock(nn.Module):
|
||||
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||
context=None,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
elif self.block_type in ["cross_attn", "ca"]:
|
||||
x = x + gate_1_1_1_B_D * self.block(
|
||||
@ -714,7 +709,6 @@ class DITBuildingBlock(nn.Module):
|
||||
context=crossattn_emb,
|
||||
crossattn_mask=crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown block type: {self.block_type}")
|
||||
@ -790,7 +784,6 @@ class GeneralDITTransformerBlock(nn.Module):
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
for block in self.blocks:
|
||||
x = block(
|
||||
@ -800,6 +793,5 @@ class GeneralDITTransformerBlock(nn.Module):
|
||||
crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
return x
|
||||
|
||||
@ -520,7 +520,6 @@ class GeneralDIT(nn.Module):
|
||||
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
for _, block in self.blocks.items():
|
||||
assert (
|
||||
self.blocks["block0"].x_format == block.x_format
|
||||
@ -535,7 +534,6 @@ class GeneralDIT(nn.Module):
|
||||
crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||
|
||||
@ -44,7 +44,7 @@ class GPT2FeedForward(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
||||
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes multi-head attention using PyTorch's native implementation.
|
||||
|
||||
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
||||
@ -71,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
|
||||
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
||||
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options)
|
||||
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@ -180,8 +180,8 @@ class Attention(nn.Module):
|
||||
|
||||
return q, k, v
|
||||
|
||||
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
||||
result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
|
||||
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
result = self.attn_op(q, k, v) # [B, S, H, D]
|
||||
return self.output_dropout(self.output_proj(result))
|
||||
|
||||
def forward(
|
||||
@ -189,7 +189,6 @@ class Attention(nn.Module):
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
rope_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -197,7 +196,7 @@ class Attention(nn.Module):
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
||||
return self.compute_attention(q, k, v, transformer_options=transformer_options)
|
||||
return self.compute_attention(q, k, v)
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
@ -460,7 +459,6 @@ class Block(nn.Module):
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
@ -514,7 +512,6 @@ class Block(nn.Module):
|
||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T,
|
||||
@ -528,7 +525,6 @@ class Block(nn.Module):
|
||||
layer_norm_cross_attn: Callable,
|
||||
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
_normalized_x_B_T_H_W_D = _fn(
|
||||
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
||||
@ -538,7 +534,6 @@ class Block(nn.Module):
|
||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T,
|
||||
@ -552,7 +547,6 @@ class Block(nn.Module):
|
||||
self.layer_norm_cross_attn,
|
||||
scale_cross_attn_B_T_1_1_D,
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
|
||||
@ -871,7 +865,6 @@ class MiniTrainDIT(nn.Module):
|
||||
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
||||
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
"transformer_options": kwargs.get("transformer_options", {}),
|
||||
}
|
||||
for block in self.blocks:
|
||||
x_B_T_H_W_D = block(
|
||||
|
||||
@ -159,7 +159,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
@ -182,7 +182,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
||||
torch.cat((img_k, txt_k), dim=2),
|
||||
torch.cat((img_v, txt_v), dim=2),
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
@ -190,7 +190,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
@ -244,7 +244,7 @@ class SingleStreamBlock(nn.Module):
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
@ -252,7 +252,7 @@ class SingleStreamBlock(nn.Module):
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||
|
||||
@ -6,7 +6,7 @@ from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
@ -17,7 +17,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@ -128,7 +128,6 @@ class Flux(nn.Module):
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block"] = ("double_block", i, 2)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -136,16 +135,14 @@ class Flux(nn.Module):
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
@ -154,8 +151,7 @@ class Flux(nn.Module):
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@ -170,26 +166,23 @@ class Flux(nn.Module):
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block"] = ("single_block", i, 1)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@ -240,12 +233,18 @@ class Flux(nn.Module):
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
|
||||
ref_latents_method = kwargs.get("ref_latents_method", "offset")
|
||||
for ref in ref_latents:
|
||||
if index_ref_method:
|
||||
if ref_latents_method == "index":
|
||||
index += 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
elif ref_latents_method == "uso":
|
||||
index = 0
|
||||
h_offset = h_len * patch_size + h
|
||||
w_offset = w_len * patch_size + w
|
||||
h += ref.shape[-2]
|
||||
w += ref.shape[-1]
|
||||
else:
|
||||
index = 1
|
||||
h_offset = 0
|
||||
|
||||
@ -109,7 +109,6 @@ class AsymmetricAttention(nn.Module):
|
||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||
crop_y,
|
||||
transformer_options={},
|
||||
**rope_rotation,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
rope_cos = rope_rotation.get("rope_cos")
|
||||
@ -144,7 +143,7 @@ class AsymmetricAttention(nn.Module):
|
||||
|
||||
xy = optimized_attention(q,
|
||||
k,
|
||||
v, self.num_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||
v, self.num_heads, skip_reshape=True)
|
||||
|
||||
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||
x = self.proj_x(x)
|
||||
@ -225,7 +224,6 @@ class AsymmetricJointBlock(nn.Module):
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
transformer_options={},
|
||||
**attn_kwargs,
|
||||
):
|
||||
"""Forward pass of a block.
|
||||
@ -258,7 +256,6 @@ class AsymmetricJointBlock(nn.Module):
|
||||
y,
|
||||
scale_x=scale_msa_x,
|
||||
scale_y=scale_msa_y,
|
||||
transformer_options=transformer_options,
|
||||
**attn_kwargs,
|
||||
)
|
||||
|
||||
@ -527,11 +524,10 @@ class AsymmDiTJoint(nn.Module):
|
||||
args["txt"],
|
||||
rope_cos=args["rope_cos"],
|
||||
rope_sin=args["rope_sin"],
|
||||
crop_y=args["num_tokens"],
|
||||
transformer_options=args["transformer_options"]
|
||||
crop_y=args["num_tokens"]
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
||||
y_feat = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
@ -542,7 +538,6 @@ class AsymmDiTJoint(nn.Module):
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
crop_y=num_tokens,
|
||||
transformer_options=transformer_options,
|
||||
) # (B, M, D), (B, L, D)
|
||||
del y_feat # Final layers don't use dense text features.
|
||||
|
||||
|
||||
@ -72,8 +72,8 @@ class TimestepEmbed(nn.Module):
|
||||
return t_emb
|
||||
|
||||
|
||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}):
|
||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options)
|
||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
||||
|
||||
|
||||
class HiDreamAttnProcessor_flashattn:
|
||||
@ -86,7 +86,6 @@ class HiDreamAttnProcessor_flashattn:
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
@ -134,7 +133,7 @@ class HiDreamAttnProcessor_flashattn:
|
||||
query = torch.cat([query_1, query_2], dim=-1)
|
||||
key = torch.cat([key_1, key_2], dim=-1)
|
||||
|
||||
hidden_states = attention(query, key, value, transformer_options=transformer_options)
|
||||
hidden_states = attention(query, key, value)
|
||||
|
||||
if not attn.single:
|
||||
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
|
||||
@ -200,7 +199,6 @@ class HiDreamAttention(nn.Module):
|
||||
image_tokens_masks: torch.FloatTensor = None,
|
||||
norm_text_tokens: torch.FloatTensor = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
@ -208,7 +206,6 @@ class HiDreamAttention(nn.Module):
|
||||
image_tokens_masks = image_tokens_masks,
|
||||
text_tokens = norm_text_tokens,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
|
||||
@ -409,7 +406,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
|
||||
) -> torch.FloatTensor:
|
||||
wtype = image_tokens.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
|
||||
@ -422,7 +419,6 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
norm_image_tokens,
|
||||
image_tokens_masks,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||
|
||||
@ -487,7 +483,6 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
) -> torch.FloatTensor:
|
||||
wtype = image_tokens.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
|
||||
@ -505,7 +500,6 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
image_tokens_masks,
|
||||
norm_text_tokens,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||
@ -556,7 +550,6 @@ class HiDreamImageBlock(nn.Module):
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: torch.FloatTensor = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
) -> torch.FloatTensor:
|
||||
return self.block(
|
||||
image_tokens,
|
||||
@ -564,7 +557,6 @@ class HiDreamImageBlock(nn.Module):
|
||||
text_tokens,
|
||||
adaln_input,
|
||||
rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
|
||||
@ -794,7 +786,6 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
text_tokens = cur_encoder_hidden_states,
|
||||
adaln_input = adaln_input,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
||||
block_id += 1
|
||||
@ -818,7 +809,6 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
text_tokens=None,
|
||||
adaln_input=adaln_input,
|
||||
rope=rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
||||
block_id += 1
|
||||
|
||||
@ -99,16 +99,14 @@ class Hunyuan3Dv2(nn.Module):
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args["transformer_options"])
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
@ -117,8 +115,7 @@ class Hunyuan3Dv2(nn.Module):
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
attn_mask=attn_mask)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
@ -129,19 +126,17 @@ class Hunyuan3Dv2(nn.Module):
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args["transformer_options"])
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
img = img[:, txt.shape[1]:, ...]
|
||||
img = self.final_layer(img, vec)
|
||||
|
||||
@ -78,13 +78,13 @@ class TokenRefinerBlock(nn.Module):
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x, c, mask, transformer_options={}):
|
||||
def forward(self, x, c, mask):
|
||||
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
|
||||
norm_x = self.norm1(x)
|
||||
qkv = self.self_attn.qkv(norm_x)
|
||||
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
|
||||
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
|
||||
|
||||
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
|
||||
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
||||
@ -115,14 +115,14 @@ class IndividualTokenRefiner(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, c, mask, transformer_options={}):
|
||||
def forward(self, x, c, mask):
|
||||
m = None
|
||||
if mask is not None:
|
||||
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
|
||||
m = m + m.transpose(2, 3)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, c, m, transformer_options=transformer_options)
|
||||
x = block(x, c, m)
|
||||
return x
|
||||
|
||||
|
||||
@ -150,7 +150,6 @@ class TokenRefiner(nn.Module):
|
||||
x,
|
||||
timesteps,
|
||||
mask,
|
||||
transformer_options={},
|
||||
):
|
||||
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||
# m = mask.float().unsqueeze(-1)
|
||||
@ -159,7 +158,7 @@ class TokenRefiner(nn.Module):
|
||||
|
||||
c = t + self.c_embedder(c.to(x.dtype))
|
||||
x = self.input_embedder(x)
|
||||
x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options)
|
||||
x = self.individual_token_refiner(x, c, mask)
|
||||
return x
|
||||
|
||||
class HunyuanVideo(nn.Module):
|
||||
@ -268,7 +267,7 @@ class HunyuanVideo(nn.Module):
|
||||
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
||||
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
||||
|
||||
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
|
||||
txt = self.txt_in(txt, timesteps, txt_mask)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
@ -286,14 +285,14 @@ class HunyuanVideo(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"])
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options)
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@ -308,13 +307,13 @@ class HunyuanVideo(nn.Module):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"])
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
|
||||
@ -271,7 +271,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
|
||||
def forward(self, x, context=None, mask=None, pe=None):
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
@ -285,9 +285,9 @@ class CrossAttention(nn.Module):
|
||||
k = apply_rotary_emb(k, pe)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
@ -303,12 +303,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
x += self.attn2(x, context=context, mask=attention_mask)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
@ -479,10 +479,10 @@ class LTXVModel(torch.nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
@ -490,8 +490,7 @@ class LTXVModel(torch.nn.Module):
|
||||
context=context,
|
||||
attention_mask=attention_mask,
|
||||
timestep=timestep,
|
||||
pe=pe,
|
||||
transformer_options=transformer_options,
|
||||
pe=pe
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
|
||||
@ -104,7 +104,6 @@ class JointAttention(nn.Module):
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
transformer_options={},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
@ -141,7 +140,7 @@ class JointAttention(nn.Module):
|
||||
if n_rep >= 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
|
||||
|
||||
return self.out(output)
|
||||
|
||||
@ -269,7 +268,6 @@ class JointTransformerBlock(nn.Module):
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor]=None,
|
||||
transformer_options={},
|
||||
):
|
||||
"""
|
||||
Perform a forward pass through the TransformerBlock.
|
||||
@ -292,7 +290,6 @@ class JointTransformerBlock(nn.Module):
|
||||
modulate(self.attention_norm1(x), scale_msa),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||
@ -307,7 +304,6 @@ class JointTransformerBlock(nn.Module):
|
||||
self.attention_norm1(x),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
)
|
||||
x = x + self.ffn_norm2(
|
||||
@ -498,7 +494,7 @@ class NextDiT(nn.Module):
|
||||
return imgs
|
||||
|
||||
def patchify_and_embed(
|
||||
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
|
||||
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||
bsz = len(x)
|
||||
pH = pW = self.patch_size
|
||||
@ -558,7 +554,7 @@ class NextDiT(nn.Module):
|
||||
|
||||
# refine context
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
|
||||
|
||||
# refine image
|
||||
flat_x = []
|
||||
@ -577,7 +573,7 @@ class NextDiT(nn.Module):
|
||||
padded_img_embed = self.x_embedder(padded_img_embed)
|
||||
padded_img_mask = padded_img_mask.unsqueeze(1)
|
||||
for layer in self.noise_refiner:
|
||||
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
|
||||
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
|
||||
|
||||
if cap_mask is not None:
|
||||
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
||||
@ -620,13 +616,12 @@ class NextDiT(nn.Module):
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
|
||||
freqs_cis = freqs_cis.to(x.device)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
x = layer(x, mask, freqs_cis, adaln_input)
|
||||
|
||||
x = self.final_layer(x, adaln_input)
|
||||
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
||||
|
||||
@ -5,9 +5,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from typing import Optional, Any, Callable, Union
|
||||
from typing import Optional
|
||||
import logging
|
||||
import functools
|
||||
|
||||
from .diffusionmodules.util import AlphaBlender, timestep_embedding
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
@ -18,45 +17,23 @@ if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
SAGE_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGE_ATTENTION_IS_AVAILABLE = True
|
||||
except ModuleNotFoundError as e:
|
||||
if model_management.sage_attention_enabled():
|
||||
if model_management.sage_attention_enabled():
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except ModuleNotFoundError as e:
|
||||
if e.name == "sageattention":
|
||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||
else:
|
||||
raise e
|
||||
exit(-1)
|
||||
|
||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
FLASH_ATTENTION_IS_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
if model_management.flash_attention_enabled():
|
||||
if model_management.flash_attention_enabled():
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
except ModuleNotFoundError:
|
||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||
exit(-1)
|
||||
|
||||
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||
def register_attention_function(name: str, func: Callable):
|
||||
# avoid replacing existing functions
|
||||
if name not in REGISTERED_ATTENTION_FUNCTIONS:
|
||||
REGISTERED_ATTENTION_FUNCTIONS[name] = func
|
||||
else:
|
||||
logging.warning(f"Attention function {name} already registered, skipping registration.")
|
||||
|
||||
def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]:
|
||||
if name == "optimized":
|
||||
return optimized_attention
|
||||
elif name not in REGISTERED_ATTENTION_FUNCTIONS:
|
||||
if default is ...:
|
||||
raise KeyError(f"Attention function {name} not found.")
|
||||
else:
|
||||
return default
|
||||
return REGISTERED_ATTENTION_FUNCTIONS[name]
|
||||
|
||||
from comfy.cli_args import args
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@ -114,27 +91,7 @@ class FeedForward(nn.Module):
|
||||
def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def wrap_attn(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
remove_attn_wrapper_key = False
|
||||
try:
|
||||
if "_inside_attn_wrapper" not in kwargs:
|
||||
transformer_options = kwargs.get("transformer_options", None)
|
||||
remove_attn_wrapper_key = True
|
||||
kwargs["_inside_attn_wrapper"] = True
|
||||
if transformer_options is not None:
|
||||
if "optimized_attention_override" in transformer_options:
|
||||
return transformer_options["optimized_attention_override"](func, *args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
if remove_attn_wrapper_key:
|
||||
del kwargs["_inside_attn_wrapper"]
|
||||
return wrapper
|
||||
|
||||
@wrap_attn
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
@ -202,8 +159,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
)
|
||||
return out
|
||||
|
||||
@wrap_attn
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision, query.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
@ -273,8 +230,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||
return hidden_states
|
||||
|
||||
@wrap_attn
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
@ -403,8 +359,7 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
@wrap_attn
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
b = q.shape[0]
|
||||
dim_head = q.shape[-1]
|
||||
# check to make sure xformers isn't broken
|
||||
@ -419,7 +374,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
disabled_xformers = True
|
||||
|
||||
if disabled_xformers:
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
|
||||
|
||||
if skip_reshape:
|
||||
# b h k d -> b k h d
|
||||
@ -472,8 +427,8 @@ else:
|
||||
#TODO: other GPUs ?
|
||||
SDP_BATCH_LIMIT = 2**31
|
||||
|
||||
@wrap_attn
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
@ -515,8 +470,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||
return out
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout = "HND"
|
||||
@ -546,7 +501,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
lambda t: t.transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
|
||||
|
||||
if tensor_layout == "HND":
|
||||
if not skip_output_reshape:
|
||||
@ -579,8 +534,8 @@ except AttributeError as error:
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||
|
||||
@wrap_attn
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
@ -642,19 +597,6 @@ else:
|
||||
|
||||
optimized_attention_masked = optimized_attention
|
||||
|
||||
|
||||
# register core-supported attention functions
|
||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("sage", attention_sage)
|
||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("flash", attention_flash)
|
||||
if model_management.xformers_enabled():
|
||||
register_attention_function("xformers", attention_xformers)
|
||||
register_attention_function("pytorch", attention_pytorch)
|
||||
register_attention_function("sub_quad", attention_sub_quad)
|
||||
register_attention_function("split", attention_split)
|
||||
|
||||
|
||||
def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||
if small_input:
|
||||
if model_management.pytorch_attention_enabled():
|
||||
@ -687,7 +629,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None, transformer_options={}):
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
@ -698,9 +640,9 @@ class CrossAttention(nn.Module):
|
||||
v = self.to_v(context)
|
||||
|
||||
if mask is None:
|
||||
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
else:
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
@ -804,7 +746,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
|
||||
n = self.attn1.to_out(n)
|
||||
else:
|
||||
n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options)
|
||||
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
@ -844,7 +786,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
||||
n = self.attn2.to_out(n)
|
||||
else:
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options)
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||
|
||||
if "attn2_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn2_output_patch"]
|
||||
@ -1075,7 +1017,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
|
||||
B, S, C = x_mix.shape
|
||||
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
|
||||
x_mix = mix_block(x_mix, context=time_context, transformer_options=transformer_options)
|
||||
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
|
||||
x_mix = rearrange(
|
||||
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
|
||||
)
|
||||
|
||||
@ -606,7 +606,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
return _block_mixing(*args, **kwargs)
|
||||
|
||||
|
||||
def _block_mixing(context, x, context_block, x_block, c, transformer_options={}):
|
||||
def _block_mixing(context, x, context_block, x_block, c):
|
||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||
|
||||
if x_block.x_block_self_attn:
|
||||
@ -622,7 +622,6 @@ def _block_mixing(context, x, context_block, x_block, c, transformer_options={})
|
||||
attn = optimized_attention(
|
||||
qkv[0], qkv[1], qkv[2],
|
||||
heads=x_block.attn.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
context_attn, x_attn = (
|
||||
attn[:, : context_qkv[0].shape[1]],
|
||||
@ -638,7 +637,6 @@ def _block_mixing(context, x, context_block, x_block, c, transformer_options={})
|
||||
attn2 = optimized_attention(
|
||||
x_qkv2[0], x_qkv2[1], x_qkv2[2],
|
||||
heads=x_block.attn2.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
||||
else:
|
||||
@ -960,10 +958,10 @@ class MMDiT(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"])
|
||||
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
|
||||
context = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
@ -972,7 +970,6 @@ class MMDiT(nn.Module):
|
||||
x,
|
||||
c=c_mod,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
if control is not None:
|
||||
control_o = control.get("output")
|
||||
|
||||
@ -120,7 +120,7 @@ class Attention(nn.Module):
|
||||
nn.Dropout(0.0)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
@ -146,7 +146,7 @@ class Attention(nn.Module):
|
||||
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||
|
||||
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@ -182,16 +182,16 @@ class OmniGen2TransformerBlock(nn.Module):
|
||||
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.modulation:
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
|
||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
||||
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
|
||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
||||
hidden_states = hidden_states + self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
||||
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
||||
@ -390,7 +390,7 @@ class OmniGen2Transformer2DModel(nn.Module):
|
||||
ref_img_sizes, img_sizes,
|
||||
)
|
||||
|
||||
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, transformer_options={}):
|
||||
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
|
||||
batch_size = len(hidden_states)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
@ -405,17 +405,17 @@ class OmniGen2Transformer2DModel(nn.Module):
|
||||
shift += ref_img_len
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb, transformer_options=transformer_options)
|
||||
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
|
||||
|
||||
if ref_image_hidden_states is not None:
|
||||
for layer in self.ref_image_refiner:
|
||||
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb, transformer_options=transformer_options)
|
||||
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
|
||||
|
||||
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
_, _, H_padded, W_padded = hidden_states.shape
|
||||
@ -444,7 +444,7 @@ class OmniGen2Transformer2DModel(nn.Module):
|
||||
)
|
||||
|
||||
for layer in self.context_refiner:
|
||||
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options)
|
||||
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
|
||||
|
||||
img_len = hidden_states.shape[1]
|
||||
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
||||
@ -453,14 +453,13 @@ class OmniGen2Transformer2DModel(nn.Module):
|
||||
noise_rotary_emb, ref_img_rotary_emb,
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
temb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
|
||||
attention_mask = None
|
||||
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb, transformer_options=transformer_options)
|
||||
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
|
||||
@ -132,7 +132,6 @@ class Attention(nn.Module):
|
||||
encoder_hidden_states_mask: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
@ -160,7 +159,7 @@ class Attention(nn.Module):
|
||||
joint_key = joint_key.flatten(start_dim=2)
|
||||
joint_value = joint_value.flatten(start_dim=2)
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask)
|
||||
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||
@ -227,7 +226,6 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
encoder_hidden_states_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
img_mod_params = self.img_mod(temb)
|
||||
txt_mod_params = self.txt_mod(temb)
|
||||
@ -244,7 +242,6 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
encoder_hidden_states=txt_modulated,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||
@ -437,9 +434,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
else:
|
||||
@ -449,12 +446,11 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ class WanSelfAttention(nn.Module):
|
||||
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, freqs, transformer_options={}):
|
||||
def forward(self, x, freqs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
@ -75,7 +75,6 @@ class WanSelfAttention(nn.Module):
|
||||
k.view(b, s, n * d),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
x = self.o(x)
|
||||
@ -84,7 +83,7 @@ class WanSelfAttention(nn.Module):
|
||||
|
||||
class WanT2VCrossAttention(WanSelfAttention):
|
||||
|
||||
def forward(self, x, context, transformer_options={}, **kwargs):
|
||||
def forward(self, x, context, **kwargs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
@ -96,7 +95,7 @@ class WanT2VCrossAttention(WanSelfAttention):
|
||||
v = self.v(context)
|
||||
|
||||
# compute attention
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads)
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
@ -117,7 +116,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, context, context_img_len, transformer_options={}):
|
||||
def forward(self, x, context, context_img_len):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
@ -132,9 +131,9 @@ class WanI2VCrossAttention(WanSelfAttention):
|
||||
v = self.v(context)
|
||||
k_img = self.norm_k_img(self.k_img(context_img))
|
||||
v_img = self.v_img(context_img)
|
||||
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
|
||||
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
|
||||
# compute attention
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads)
|
||||
|
||||
# output
|
||||
x = x + img_x
|
||||
@ -207,7 +206,6 @@ class WanAttentionBlock(nn.Module):
|
||||
freqs,
|
||||
context,
|
||||
context_img_len=257,
|
||||
transformer_options={},
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
@ -226,12 +224,12 @@ class WanAttentionBlock(nn.Module):
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs, transformer_options=transformer_options)
|
||||
freqs)
|
||||
|
||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
@ -561,12 +559,12 @@ class WanModel(torch.nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
@ -744,17 +742,17 @@ class VaceWanModel(WanModel):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
|
||||
ii = self.vace_layers_mapping.get(i, None)
|
||||
if ii is not None:
|
||||
for iii in range(len(c)):
|
||||
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
x += c_skip * vace_strength[iii]
|
||||
del c_skip
|
||||
# head
|
||||
@ -843,12 +841,12 @@ class CameraWanModel(WanModel):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
@ -260,6 +260,10 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
for k in sdk:
|
||||
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
|
||||
if k.endswith(".weight") and ".linear1." in k:
|
||||
key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3))
|
||||
|
||||
if isinstance(model, comfy.model_base.GenmoMochi):
|
||||
for k in sdk:
|
||||
|
||||
@ -15,10 +15,29 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
def convert_lora_wan_fun(sd): #Wan Fun loras
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
|
||||
|
||||
def convert_uso_lora(sd):
|
||||
sd_out = {}
|
||||
for k in sd:
|
||||
tensor = sd[k]
|
||||
k_to = "diffusion_model.{}".format(k.replace(".down.weight", ".lora_down.weight")
|
||||
.replace(".up.weight", ".lora_up.weight")
|
||||
.replace(".qkv_lora2.", ".txt_attn.qkv.")
|
||||
.replace(".qkv_lora1.", ".img_attn.qkv.")
|
||||
.replace(".proj_lora1.", ".img_attn.proj.")
|
||||
.replace(".proj_lora2.", ".txt_attn.proj.")
|
||||
.replace(".qkv_lora.", ".linear1_qkv.")
|
||||
.replace(".proj_lora.", ".linear2.")
|
||||
.replace(".processor.", ".")
|
||||
)
|
||||
sd_out[k_to] = tensor
|
||||
return sd_out
|
||||
|
||||
|
||||
def convert_lora(sd):
|
||||
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
||||
return convert_lora_bfl_control(sd)
|
||||
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
|
||||
return convert_lora_wan_fun(sd)
|
||||
if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd:
|
||||
return convert_uso_lora(sd)
|
||||
return sd
|
||||
|
||||
@ -52,6 +52,9 @@ except (ModuleNotFoundError, TypeError):
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
from inspect import cleandoc
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api_nodes.apis.stability_api import (
|
||||
StabilityUpscaleConservativeRequest,
|
||||
StabilityUpscaleCreativeRequest,
|
||||
@ -46,87 +49,94 @@ def get_async_dummy_status(x: StabilityResultsGetResponse):
|
||||
return StabilityPollStatus.in_progress
|
||||
|
||||
|
||||
class StabilityStableImageUltraNode:
|
||||
class StabilityStableImageUltraNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Stability AI"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
|
||||
"What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="StabilityStableImageUltraNode",
|
||||
display_name="Stability AI Stable Image Ultra",
|
||||
category="api node/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
|
||||
"elements, colors, and subjects will lead to better results. " +
|
||||
"To control the weight of a given word use the format `(word:weight)`," +
|
||||
"where `word` is the word you'd like to control the weight of and `weight`" +
|
||||
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
|
||||
"would convey a sky that was blue and green, but more green than blue."
|
||||
},
|
||||
"would convey a sky that was blue and green, but more green than blue.",
|
||||
),
|
||||
"aspect_ratio": ([x.value for x in StabilityAspectRatio],
|
||||
{
|
||||
"default": StabilityAspectRatio.ratio_1_1,
|
||||
"tooltip": "Aspect ratio of generated image.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[x.value for x in StabilityAspectRatio],
|
||||
default=StabilityAspectRatio.ratio_1_1.value,
|
||||
tooltip="Aspect ratio of generated image.",
|
||||
),
|
||||
"style_preset": (get_stability_style_presets(),
|
||||
{
|
||||
"tooltip": "Optional desired style of generated image.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"style_preset",
|
||||
options=get_stability_style_presets(),
|
||||
tooltip="Optional desired style of generated image.",
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 4294967294,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image": (IO.IMAGE,),
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "A blurb of text describing what you do not wish to see in the output image. This is an advanced feature."
|
||||
},
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
),
|
||||
"image_denoise": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.5,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
"tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
|
||||
},
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
comfy_io.Float.Input(
|
||||
"image_denoise",
|
||||
default=0.5,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
async def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
|
||||
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||
**kwargs):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
style_preset: str,
|
||||
seed: int,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
negative_prompt: str = "",
|
||||
image_denoise: Optional[float] = 0.5,
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
@ -144,6 +154,11 @@ class StabilityStableImageUltraNode:
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/generate/ultra",
|
||||
@ -161,7 +176,7 @@ class StabilityStableImageUltraNode:
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
@ -171,95 +186,106 @@ class StabilityStableImageUltraNode:
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return (returned_image,)
|
||||
return comfy_io.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityStableImageSD_3_5Node:
|
||||
class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Stability AI"
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="StabilityStableImageSD_3_5Node",
|
||||
display_name="Stability AI Stable Diffusion 3.5 Image",
|
||||
category="api node/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=[x.value for x in Stability_SD3_5_Model],
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[x.value for x in StabilityAspectRatio],
|
||||
default=StabilityAspectRatio.ratio_1_1.value,
|
||||
tooltip="Aspect ratio of generated image.",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"style_preset",
|
||||
options=get_stability_style_presets(),
|
||||
tooltip="Optional desired style of generated image.",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
"cfg_scale",
|
||||
default=4.0,
|
||||
min=1.0,
|
||||
max=10.0,
|
||||
step=0.1,
|
||||
tooltip="How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
"image_denoise",
|
||||
default=0.5,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
|
||||
},
|
||||
),
|
||||
"model": ([x.value for x in Stability_SD3_5_Model],),
|
||||
"aspect_ratio": ([x.value for x in StabilityAspectRatio],
|
||||
{
|
||||
"default": StabilityAspectRatio.ratio_1_1,
|
||||
"tooltip": "Aspect ratio of generated image.",
|
||||
},
|
||||
),
|
||||
"style_preset": (get_stability_style_presets(),
|
||||
{
|
||||
"tooltip": "Optional desired style of generated image.",
|
||||
},
|
||||
),
|
||||
"cfg_scale": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 4.0,
|
||||
"min": 1.0,
|
||||
"max": 10.0,
|
||||
"step": 0.1,
|
||||
"tooltip": "How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 4294967294,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image": (IO.IMAGE,),
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
|
||||
},
|
||||
),
|
||||
"image_denoise": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.5,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
"tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
async def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
|
||||
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||
**kwargs):
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
style_preset: str,
|
||||
seed: int,
|
||||
cfg_scale: float,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
negative_prompt: str = "",
|
||||
image_denoise: Optional[float] = 0.5,
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
@ -280,6 +306,11 @@ class StabilityStableImageSD_3_5Node:
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/generate/sd3",
|
||||
@ -300,7 +331,7 @@ class StabilityStableImageSD_3_5Node:
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
@ -310,72 +341,75 @@ class StabilityStableImageSD_3_5Node:
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return (returned_image,)
|
||||
return comfy_io.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityUpscaleConservativeNode:
|
||||
class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Upscale image with minimal alterations to 4K resolution.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Stability AI"
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="StabilityUpscaleConservativeNode",
|
||||
display_name="Stability AI Upscale Conservative",
|
||||
category="api node/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Image.Input("image"),
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
"creativity",
|
||||
default=0.35,
|
||||
min=0.2,
|
||||
max=0.5,
|
||||
step=0.01,
|
||||
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
|
||||
},
|
||||
),
|
||||
"creativity": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.35,
|
||||
"min": 0.2,
|
||||
"max": 0.5,
|
||||
"step": 0.01,
|
||||
"tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 4294967294,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
|
||||
**kwargs):
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
creativity: float,
|
||||
seed: int,
|
||||
negative_prompt: str = "",
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
@ -386,6 +420,11 @@ class StabilityUpscaleConservativeNode:
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
|
||||
@ -401,7 +440,7 @@ class StabilityUpscaleConservativeNode:
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
@ -411,77 +450,81 @@ class StabilityUpscaleConservativeNode:
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return (returned_image,)
|
||||
return comfy_io.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityUpscaleCreativeNode:
|
||||
class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Upscale image with minimal alterations to 4K resolution.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Stability AI"
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="StabilityUpscaleCreativeNode",
|
||||
display_name="Stability AI Upscale Creative",
|
||||
category="api node/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Image.Input("image"),
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
"creativity",
|
||||
default=0.3,
|
||||
min=0.1,
|
||||
max=0.5,
|
||||
step=0.01,
|
||||
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"style_preset",
|
||||
options=get_stability_style_presets(),
|
||||
tooltip="Optional desired style of generated image.",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
|
||||
},
|
||||
),
|
||||
"creativity": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.3,
|
||||
"min": 0.1,
|
||||
"max": 0.5,
|
||||
"step": 0.01,
|
||||
"tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||
},
|
||||
),
|
||||
"style_preset": (get_stability_style_presets(),
|
||||
{
|
||||
"tooltip": "Optional desired style of generated image.",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 4294967294,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
|
||||
**kwargs):
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
creativity: float,
|
||||
style_preset: str,
|
||||
seed: int,
|
||||
negative_prompt: str = "",
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
@ -494,6 +537,11 @@ class StabilityUpscaleCreativeNode:
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/upscale/creative",
|
||||
@ -510,7 +558,7 @@ class StabilityUpscaleCreativeNode:
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
@ -525,7 +573,8 @@ class StabilityUpscaleCreativeNode:
|
||||
completed_statuses=[StabilityPollStatus.finished],
|
||||
failed_statuses=[StabilityPollStatus.failed],
|
||||
status_extractor=lambda x: get_async_dummy_status(x),
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
node_id=cls.hidden.unique_id,
|
||||
)
|
||||
response_poll: StabilityResultsGetResponse = await operation.execute()
|
||||
|
||||
@ -535,41 +584,48 @@ class StabilityUpscaleCreativeNode:
|
||||
image_data = base64.b64decode(response_poll.result)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return (returned_image,)
|
||||
return comfy_io.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityUpscaleFastNode:
|
||||
class StabilityUpscaleFastNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Stability AI"
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="StabilityUpscaleFastNode",
|
||||
display_name="Stability AI Upscale Fast",
|
||||
category="api node/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Image.Input("image"),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
async def api_call(self, image: torch.Tensor, **kwargs):
|
||||
async def execute(cls, image: torch.Tensor) -> comfy_io.NodeOutput:
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/upscale/fast",
|
||||
@ -580,7 +636,7 @@ class StabilityUpscaleFastNode:
|
||||
request=EmptyRequest(),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
@ -590,24 +646,20 @@ class StabilityUpscaleFastNode:
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return (returned_image,)
|
||||
return comfy_io.NodeOutput(returned_image)
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StabilityStableImageUltraNode": StabilityStableImageUltraNode,
|
||||
"StabilityStableImageSD_3_5Node": StabilityStableImageSD_3_5Node,
|
||||
"StabilityUpscaleConservativeNode": StabilityUpscaleConservativeNode,
|
||||
"StabilityUpscaleCreativeNode": StabilityUpscaleCreativeNode,
|
||||
"StabilityUpscaleFastNode": StabilityUpscaleFastNode,
|
||||
}
|
||||
class StabilityExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
return [
|
||||
StabilityStableImageUltraNode,
|
||||
StabilityStableImageSD_3_5Node,
|
||||
StabilityUpscaleConservativeNode,
|
||||
StabilityUpscaleCreativeNode,
|
||||
StabilityUpscaleFastNode,
|
||||
]
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StabilityStableImageUltraNode": "Stability AI Stable Image Ultra",
|
||||
"StabilityStableImageSD_3_5Node": "Stability AI Stable Diffusion 3.5 Image",
|
||||
"StabilityUpscaleConservativeNode": "Stability AI Upscale Conservative",
|
||||
"StabilityUpscaleCreativeNode": "Stability AI Upscale Creative",
|
||||
"StabilityUpscaleFastNode": "Stability AI Upscale Fast",
|
||||
}
|
||||
|
||||
async def comfy_entrypoint() -> StabilityExtension:
|
||||
return StabilityExtension()
|
||||
|
||||
@ -181,8 +181,9 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
}
|
||||
|
||||
# Send a combined progress_state message with all node states
|
||||
# Include client_id to ensure message is only sent to the initiating client
|
||||
self.server_instance.send_sync(
|
||||
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}
|
||||
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
|
||||
)
|
||||
|
||||
@override
|
||||
|
||||
@ -1,503 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
import comfy.patcher_extension
|
||||
import logging
|
||||
import torch
|
||||
import math
|
||||
import comfy.model_patcher
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
|
||||
def easysortblock_predict_noise_wrapper(executor, *args, **kwargs):
|
||||
# get values from args
|
||||
x: torch.Tensor = args[0]
|
||||
timestep: float = args[1]
|
||||
model_options: dict[str] = args[2]
|
||||
easycache: EasySortblockHolder = model_options["transformer_options"]["easycache"]
|
||||
|
||||
# initialize predict_ratios
|
||||
if easycache.initial_step:
|
||||
sample_sigmas = model_options["transformer_options"]["sample_sigmas"]
|
||||
relevant_sigmas = []
|
||||
for i,sigma in enumerate(sample_sigmas):
|
||||
if easycache.check_if_within_timesteps(sigma):
|
||||
relevant_sigmas.append((i, sigma))
|
||||
start_index = relevant_sigmas[0][0]
|
||||
end_index = relevant_sigmas[-1][0]
|
||||
easycache.predict_ratios = torch.linspace(easycache.start_predict_ratio, easycache.end_predict_ratio, end_index - start_index + 1)
|
||||
easycache.predict_start_index = start_index
|
||||
|
||||
easycache.skip_current_step = False
|
||||
if easycache.is_past_end_timestep(timestep):
|
||||
return executor(*args, **kwargs)
|
||||
# prepare next x_prev
|
||||
next_x_prev = x
|
||||
input_change = None
|
||||
do_easycache = easycache.should_do_easycache(timestep)
|
||||
if do_easycache:
|
||||
easycache.check_metadata(x)
|
||||
if easycache.has_x_prev_subsampled():
|
||||
if easycache.has_x_prev_subsampled():
|
||||
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
|
||||
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||
easycache.cumulative_change_rate += approx_output_change_rate
|
||||
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasySortblock [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
# other conds should also skip this step
|
||||
easycache.skip_current_step = True
|
||||
easycache.steps_skipped.append(easycache.step_count)
|
||||
else:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasySortblock [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
easycache.cumulative_change_rate = 0.0
|
||||
output: torch.Tensor = executor(*args, **kwargs)
|
||||
if easycache.has_output_prev_norm():
|
||||
output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
|
||||
if easycache.verbose:
|
||||
output_change_rate = output_change / easycache.output_prev_norm
|
||||
easycache.output_change_rates.append(output_change_rate.item())
|
||||
if easycache.has_relative_transformation_rate():
|
||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasySortblock [verbose] - approx_output_change_rate: {approx_output_change_rate}")
|
||||
if input_change is not None:
|
||||
easycache.relative_transformation_rate = output_change / input_change
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasySortblock [verbose] - output_change_rate: {output_change_rate}")
|
||||
easycache.x_prev_subsampled = easycache.subsample(next_x_prev)
|
||||
easycache.output_prev_subsampled = easycache.subsample(output)
|
||||
easycache.output_prev_norm = output.flatten().abs().mean()
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasySortblock [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
||||
|
||||
# increment step count
|
||||
easycache.step_count += 1
|
||||
easycache.initial_step = False
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def easysortblock_outer_sample_wrapper(executor, *args, **kwargs):
|
||||
"""
|
||||
This OUTER_SAMPLE wrapper makes sure EasySortblock is prepped for current run, and all memory usage is cleared at the end.
|
||||
"""
|
||||
try:
|
||||
guider = executor.class_obj
|
||||
orig_model_options = guider.model_options
|
||||
guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options)
|
||||
# clone and prepare timesteps
|
||||
guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
|
||||
easycache: EasySortblockHolder = guider.model_options['transformer_options']['easycache']
|
||||
logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}")
|
||||
return executor(*args, **kwargs)
|
||||
finally:
|
||||
easycache = guider.model_options['transformer_options']['easycache']
|
||||
output_change_rates = easycache.output_change_rates
|
||||
approx_output_change_rates = easycache.approx_output_change_rates
|
||||
if easycache.verbose:
|
||||
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
|
||||
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
|
||||
total_steps = len(args[3])-1
|
||||
logging.info(f"{easycache.name} - skipped {len(easycache.steps_skipped)}/{total_steps} steps")# ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).")
|
||||
logging.info(f"{easycache.name} - skipped steps: {easycache.steps_skipped}")
|
||||
easycache.reset()
|
||||
guider.model_options = orig_model_options
|
||||
|
||||
|
||||
def model_forward_wrapper(executor, *args, **kwargs):
|
||||
# TODO: make work with batches of conds
|
||||
transformer_options: dict[str] = args[-1]
|
||||
if not isinstance(transformer_options, dict):
|
||||
transformer_options = kwargs.get("transformer_options")
|
||||
if not transformer_options:
|
||||
transformer_options = args[-2]
|
||||
sigmas = transformer_options["sigmas"]
|
||||
sb_holder: EasySortblockHolder = transformer_options["easycache"]
|
||||
|
||||
# if initial step, prepare everything for Sortblock
|
||||
if sb_holder.initial_step:
|
||||
logging.info(f"EasySortblock: inside model {executor.class_obj.__class__.__name__}")
|
||||
# TODO: generalize for other models
|
||||
# these won't stick around past this step; should store on sb_holder instead
|
||||
logging.info(f"EasySortblock: preparing {len(executor.class_obj.double_blocks)} double blocks and {len(executor.class_obj.single_blocks)} single blocks")
|
||||
if hasattr(executor.class_obj, "double_blocks"):
|
||||
for block in executor.class_obj.double_blocks:
|
||||
prepare_block(block, sb_holder)
|
||||
if hasattr(executor.class_obj, "single_blocks"):
|
||||
for block in executor.class_obj.single_blocks:
|
||||
prepare_block(block, sb_holder)
|
||||
if hasattr(executor.class_obj, "blocks"):
|
||||
for block in executor.class_obj.block:
|
||||
prepare_block(block, sb_holder)
|
||||
|
||||
if sb_holder.skip_current_step:
|
||||
predict_index = max(0, sb_holder.step_count - sb_holder.predict_start_index)
|
||||
predict_ratio = sb_holder.predict_ratios[predict_index]
|
||||
logging.info(f"EasySortblock: skipping step {sb_holder.step_count}, predict_ratio: {predict_ratio}")
|
||||
# reuse_ratio = 1.0 - predict_ratio
|
||||
for block_type, blocks in sb_holder.blocks_per_type.items():
|
||||
for block in blocks:
|
||||
cache: BlockCache = block.__block_cache
|
||||
cache.allowed_to_skip = False
|
||||
sorted_blocks = sorted(blocks, key=lambda x: (x.__block_cache.consecutive_skipped_steps, x.__block_cache.prev_change_rate))
|
||||
# for block in sorted_blocks:
|
||||
# pass
|
||||
threshold_index = int(len(sorted_blocks) * predict_ratio)
|
||||
# blocks with lower similarity are marked for recomputation
|
||||
for block in sorted_blocks[:threshold_index]:
|
||||
cache: BlockCache = block.__block_cache
|
||||
cache.allowed_to_skip = True
|
||||
logging.info(f"EasySortblock: skip block {block.__class__.__name__} - consecutive_skipped_steps: {block.__block_cache.consecutive_skipped_steps}, prev_change_rate: {block.__block_cache.prev_change_rate}, index: {block.__block_cache.block_index}")
|
||||
not_skipped = [block for block in blocks if not block.__block_cache.allowed_to_skip]
|
||||
for block in not_skipped:
|
||||
logging.info(f"EasySortblock: reco block {block.__class__.__name__} - consecutive_skipped_steps: {block.__block_cache.consecutive_skipped_steps}, prev_change_rate: {block.__block_cache.prev_change_rate}, index: {block.__block_cache.block_index}")
|
||||
logging.info(f"EasySortblock: for {block_type}, selected {len(sorted_blocks[:threshold_index])} blocks for prediction and {len(sorted_blocks[threshold_index:])} blocks for recomputation")
|
||||
# return executor(*args, **kwargs)
|
||||
to_return = executor(*args, **kwargs)
|
||||
|
||||
return to_return
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def block_forward_factory(func, block):
|
||||
def block_forward_wrapper(*args, **kwargs):
|
||||
transformer_options: dict[str] = kwargs.get("transformer_options")
|
||||
sigmas = transformer_options["sigmas"]
|
||||
sb_holder: EasySortblockHolder = transformer_options["easycache"]
|
||||
cache: BlockCache = block.__block_cache
|
||||
# make sure stream count is properly set for this block
|
||||
if sb_holder.initial_step:
|
||||
sb_holder.add_to_blocks_per_type(block, transformer_options['block'][0])
|
||||
cache.block_index = transformer_options['block'][1]
|
||||
cache.stream_count = transformer_options['block'][2]
|
||||
|
||||
if sb_holder.is_past_end_timestep(sigmas):
|
||||
return func(*args, **kwargs)
|
||||
# do sortblock stuff
|
||||
x = cache.get_next_x_prev(args, kwargs)
|
||||
# prepare next_x_prev
|
||||
next_x_prev = cache.get_next_x_prev(args, kwargs, clone=True)
|
||||
input_change = None
|
||||
do_sortblock = sb_holder.should_do_easycache(sigmas)
|
||||
if do_sortblock:
|
||||
# TODO: checkmetadata
|
||||
if cache.has_x_prev_subsampled():
|
||||
input_change = (cache.subsample(x, clone=False) - cache.x_prev_subsampled).flatten().abs().mean()
|
||||
if cache.has_output_prev_norm() and cache.has_relative_transformation_rate():
|
||||
approx_output_change_rate = (cache.relative_transformation_rate * input_change) / cache.output_prev_norm
|
||||
cache.cumulative_change_rate += approx_output_change_rate
|
||||
if cache.allowed_to_skip:
|
||||
# if cache.cumulative_change_rate < sb_holder.reuse_threshold:
|
||||
# accumulate error + skip block
|
||||
# cache.want_to_skip = True
|
||||
# if cache.allowed_to_skip:
|
||||
cache.consecutive_skipped_steps += 1
|
||||
cache.prev_change_rate = approx_output_change_rate
|
||||
return cache.apply_cache_diff(x, sb_holder)
|
||||
else:
|
||||
# reset error; NOT skipping block and recalculating
|
||||
cache.cumulative_change_rate = 0.0
|
||||
cache.prev_change_rate = approx_output_change_rate
|
||||
cache.want_to_skip = False
|
||||
cache.consecutive_skipped_steps = 0
|
||||
# output_raw is expected to have cache.stream_count elements if count is greaater than 1 (double block, etc.)
|
||||
output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]] = func(*args, **kwargs)
|
||||
# if more than one stream from block, only use first one
|
||||
if isinstance(output_raw, tuple):
|
||||
output = output_raw[0]
|
||||
else:
|
||||
output = output_raw
|
||||
if cache.has_output_prev_norm():
|
||||
output_change = (cache.subsample(output, clone=False) - cache.output_prev_subsampled).flatten().abs().mean()
|
||||
# if verbose in future
|
||||
output_change_rate = output_change / cache.output_prev_norm
|
||||
cache.output_change_rates.append(output_change_rate.item())
|
||||
if cache.has_relative_transformation_rate():
|
||||
approx_output_change_rate = (cache.relative_transformation_rate * input_change) / cache.output_prev_norm
|
||||
cache.approx_output_change_rates.append(approx_output_change_rate.item())
|
||||
if input_change is not None:
|
||||
cache.relative_transformation_rate = output_change / input_change
|
||||
# TODO: allow cache_diff to be offloaded
|
||||
cache.update_cache_diff(output_raw, next_x_prev)
|
||||
cache.x_prev_subsampled = cache.subsample(next_x_prev)
|
||||
cache.output_prev_subsampled = cache.subsample(output)
|
||||
cache.output_prev_norm = output.flatten().abs().mean()
|
||||
return output_raw
|
||||
return block_forward_wrapper
|
||||
|
||||
def prepare_block(block, sb_holder: EasySortblockHolder, stream_count: int=1):
|
||||
sb_holder.add_to_all_blocks(block)
|
||||
block.__original_forward = block.forward
|
||||
block.forward = block_forward_factory(block.__original_forward, block)
|
||||
block.__block_cache = BlockCache(subsample_factor=sb_holder.subsample_factor, verbose=sb_holder.verbose)
|
||||
|
||||
def clean_block(block):
|
||||
block.forward = block.__original_forward
|
||||
del block.__original_forward
|
||||
del block.__block_cache
|
||||
|
||||
class BlockCache:
|
||||
def __init__(self, subsample_factor: int=8, verbose: bool=False):
|
||||
self.subsample_factor = subsample_factor
|
||||
self.verbose = verbose
|
||||
self.stream_count = 1
|
||||
self.block_index = 0
|
||||
# control values
|
||||
self.relative_transformation_rate: float = None
|
||||
self.cumulative_change_rate = 0.0
|
||||
self.prev_change_rate = 0.0
|
||||
# cached values
|
||||
self.x_prev_subsampled: torch.Tensor = None
|
||||
self.output_prev_subsampled: torch.Tensor = None
|
||||
self.output_prev_norm: torch.Tensor = None
|
||||
self.cache_diff: list[torch.Tensor] = []
|
||||
self.output_change_rates = []
|
||||
self.approx_output_change_rates = []
|
||||
self.steps_skipped: list[int] = []
|
||||
self.consecutive_skipped_steps = 0
|
||||
# self.state_metadata = None
|
||||
self.want_to_skip = False
|
||||
self.allowed_to_skip = False
|
||||
|
||||
def has_cache_diff(self) -> bool:
|
||||
return self.cache_diff[0] is not None
|
||||
|
||||
def has_x_prev_subsampled(self) -> bool:
|
||||
return self.x_prev_subsampled is not None
|
||||
|
||||
def has_output_prev_subsampled(self) -> bool:
|
||||
return self.output_prev_subsampled is not None
|
||||
|
||||
def has_output_prev_norm(self) -> bool:
|
||||
return self.output_prev_norm is not None
|
||||
|
||||
def has_relative_transformation_rate(self) -> bool:
|
||||
return self.relative_transformation_rate is not None
|
||||
|
||||
def get_next_x_prev(self, d_args: tuple[torch.Tensor, ...], d_kwargs: dict[str, torch.Tensor], clone: bool=False) -> tuple[torch.Tensor, ...]:
|
||||
if self.stream_count == 1:
|
||||
if clone:
|
||||
return d_args[0].clone()
|
||||
return d_args[0]
|
||||
keys = list(d_kwargs.keys())[:self.stream_count]
|
||||
orig_inputs = []
|
||||
for key in keys:
|
||||
if clone:
|
||||
orig_inputs.append(d_kwargs[key].clone())
|
||||
else:
|
||||
orig_inputs.append(d_kwargs[key])
|
||||
return tuple(orig_inputs)
|
||||
|
||||
def subsample(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]], clone: bool = True) -> torch.Tensor:
|
||||
# subsample only the first compoenent
|
||||
if isinstance(x, tuple):
|
||||
return self.subsample(x[0], clone)
|
||||
if self.subsample_factor > 1:
|
||||
to_return = x[..., ::self.subsample_factor, ::self.subsample_factor]
|
||||
if clone:
|
||||
return to_return.clone()
|
||||
return to_return
|
||||
if clone:
|
||||
return x.clone()
|
||||
return x
|
||||
|
||||
def apply_cache_diff(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]], sb_holder: EasySortblockHolder):
|
||||
self.steps_skipped.append(sb_holder.step_count)
|
||||
if not isinstance(x, tuple):
|
||||
x = (x, )
|
||||
to_return = tuple([x[i] + self.cache_diff[i] for i in range(self.stream_count)])
|
||||
if len(to_return) == 1:
|
||||
return to_return[0]
|
||||
return to_return
|
||||
|
||||
def update_cache_diff(self, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], x: Union[torch.Tensor, tuple[torch.Tensor, ...]]):
|
||||
if not isinstance(output_raw, tuple):
|
||||
output_raw = (output_raw, )
|
||||
if not isinstance(x, tuple):
|
||||
x = (x, )
|
||||
self.cache_diff = tuple([output_raw[i] - x[i] for i in range(self.stream_count)])
|
||||
|
||||
def reset(self):
|
||||
self.relative_transformation_rate = 0.0
|
||||
self.cumulative_change_rate = 0.0
|
||||
self.prev_change_rate = 0.0
|
||||
self.x_prev_subsampled = None
|
||||
self.output_prev_subsampled = None
|
||||
self.output_prev_norm = None
|
||||
self.cache_diff = []
|
||||
self.output_change_rates = []
|
||||
self.approx_output_change_rates = []
|
||||
self.steps_skipped = []
|
||||
self.consecutive_skipped_steps = 0
|
||||
self.want_to_skip = False
|
||||
self.allowed_to_skip = False
|
||||
return self
|
||||
|
||||
|
||||
class EasySortblockHolder:
|
||||
def __init__(self, reuse_threshold: float, start_predict_ratio: float, end_predict_ratio: float, max_skipped_steps: int,
|
||||
start_percent: float, end_percent: float, subsample_factor: int, verbose: bool=False):
|
||||
self.name = "EasySortblock"
|
||||
self.reuse_threshold = reuse_threshold
|
||||
self.start_predict_ratio = start_predict_ratio
|
||||
self.end_predict_ratio = end_predict_ratio
|
||||
self.max_skipped_steps = max_skipped_steps
|
||||
self.start_percent = start_percent
|
||||
self.end_percent = end_percent
|
||||
self.subsample_factor = subsample_factor
|
||||
self.verbose = verbose
|
||||
# timestep values
|
||||
self.start_t = 0.0
|
||||
self.end_t = 0.0
|
||||
# control values
|
||||
self.relative_transformation_rate: float = None
|
||||
self.cumulative_change_rate = 0.0
|
||||
self.initial_step = True
|
||||
self.step_count = 0
|
||||
self.predict_ratios = []
|
||||
self.skip_current_step = False
|
||||
self.predict_start_index = 0
|
||||
# cache values
|
||||
self.x_prev_subsampled: torch.Tensor = None
|
||||
self.output_prev_subsampled: torch.Tensor = None
|
||||
self.output_prev_norm: torch.Tensor = None
|
||||
self.steps_skipped: list[int] = []
|
||||
self.output_change_rates = []
|
||||
self.approx_output_change_rates = []
|
||||
self.state_metadata = None
|
||||
self.all_blocks = []
|
||||
self.blocks_per_type = {}
|
||||
|
||||
def add_to_all_blocks(self, block):
|
||||
self.all_blocks.append(block)
|
||||
|
||||
def add_to_blocks_per_type(self, block, block_type: str):
|
||||
self.blocks_per_type.setdefault(block_type, []).append(block)
|
||||
|
||||
def is_past_end_timestep(self, timestep: float) -> bool:
|
||||
return not (timestep[0] > self.end_t).item()
|
||||
|
||||
def should_do_easycache(self, timestep: float) -> bool:
|
||||
return (timestep[0] <= self.start_t).item()
|
||||
|
||||
def check_if_within_timesteps(self, timestep: Union[float, torch.Tensor]) -> bool:
|
||||
return (timestep <= self.start_t).item() and (timestep > self.end_t).item()
|
||||
|
||||
def has_x_prev_subsampled(self) -> bool:
|
||||
return self.x_prev_subsampled is not None
|
||||
|
||||
def has_output_prev_subsampled(self) -> bool:
|
||||
return self.output_prev_subsampled is not None
|
||||
|
||||
def has_output_prev_norm(self) -> bool:
|
||||
return self.output_prev_norm is not None
|
||||
|
||||
def has_relative_transformation_rate(self) -> bool:
|
||||
return self.relative_transformation_rate is not None
|
||||
|
||||
def prepare_timesteps(self, model_sampling):
|
||||
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
|
||||
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
|
||||
return self
|
||||
|
||||
def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor:
|
||||
if self.subsample_factor > 1:
|
||||
to_return = x[..., ::self.subsample_factor, ::self.subsample_factor]
|
||||
if clone:
|
||||
return to_return.clone()
|
||||
return to_return
|
||||
if clone:
|
||||
return x.clone()
|
||||
return x
|
||||
|
||||
def check_metadata(self, x: torch.Tensor) -> bool:
|
||||
metadata = (x.device, x.dtype, x.shape)
|
||||
if self.state_metadata is None:
|
||||
self.state_metadata = metadata
|
||||
return True
|
||||
if metadata == self.state_metadata:
|
||||
return True
|
||||
logging.warning(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
|
||||
self.reset()
|
||||
return False
|
||||
|
||||
def reset(self):
|
||||
logging.info(f"EasySortblock: resetting {len(self.all_blocks)} blocks")
|
||||
for block in self.all_blocks:
|
||||
clean_block(block)
|
||||
self.relative_transformation_rate = 0.0
|
||||
self.cumulative_change_rate = 0.0
|
||||
self.initial_step = True
|
||||
self.step_count = 0
|
||||
self.predict_ratios = []
|
||||
self.skip_current_step = False
|
||||
self.predict_start_index = 0
|
||||
self.x_prev_subsampled = None
|
||||
self.output_prev_subsampled = None
|
||||
self.output_prev_norm = None
|
||||
self.steps_skipped = []
|
||||
self.output_change_rates = []
|
||||
self.approx_output_change_rates = []
|
||||
self.state_metadata = None
|
||||
self.all_blocks = []
|
||||
self.blocks_per_type = {}
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
return EasySortblockHolder(self.reuse_threshold, self.start_predict_ratio, self.end_predict_ratio, self.max_skipped_steps,
|
||||
self.start_percent, self.end_percent, self.subsample_factor, self.verbose)
|
||||
|
||||
class EasySortblockScaledNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="EasySortblockScaled",
|
||||
display_name="EasySortblockScaled",
|
||||
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
|
||||
category="advanced/debug/model",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to add Sortblock to."),
|
||||
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
|
||||
io.Float.Input("start_predict_ratio", min=0.0, default=0.2, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
|
||||
io.Float.Input("end_predict_ratio", min=0.0, default=0.9, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
|
||||
io.Int.Input("policy_refresh_interval", min=3, default=5, max=100, step=1, tooltip="The interval at which to refresh the policy."),
|
||||
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of Sortblock."),
|
||||
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of Sortblock."),
|
||||
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with Sortblock."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_predict_ratio: float, end_predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||
# TODO: check for specific flavors of supported models
|
||||
model = model.clone()
|
||||
model.model_options["transformer_options"]["easycache"] = EasySortblockHolder(reuse_threshold, start_predict_ratio, end_predict_ratio, policy_refresh_interval, start_percent, end_percent, subsample_factor=8, verbose=verbose)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", easysortblock_predict_noise_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", easysortblock_outer_sample_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class EasySortblockExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
# EasySortblockNode,
|
||||
EasySortblockScaledNode,
|
||||
]
|
||||
|
||||
def comfy_entrypoint():
|
||||
return EasySortblockExtension()
|
||||
|
||||
@ -105,7 +105,7 @@ class FluxKontextMultiReferenceLatentMethod:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"conditioning": ("CONDITIONING", ),
|
||||
"reference_latents_method": (("offset", "index"), ),
|
||||
"reference_latents_method": (("offset", "index", "uso"), ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
|
||||
@ -1,98 +1,109 @@
|
||||
# Primitive nodes that are evaluated at backend.
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class String(ComfyNodeABC):
|
||||
class String(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.STRING, {})},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PrimitiveString",
|
||||
display_name="String",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.String.Input("value"),
|
||||
],
|
||||
outputs=[io.String.Output()],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/primitive"
|
||||
|
||||
def execute(self, value: str) -> tuple[str]:
|
||||
return (value,)
|
||||
|
||||
|
||||
class StringMultiline(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.STRING, {"multiline": True,},)},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/primitive"
|
||||
|
||||
def execute(self, value: str) -> tuple[str]:
|
||||
return (value,)
|
||||
def execute(cls, value: str) -> io.NodeOutput:
|
||||
return io.NodeOutput(value)
|
||||
|
||||
|
||||
class Int(ComfyNodeABC):
|
||||
class StringMultiline(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PrimitiveStringMultiline",
|
||||
display_name="String (Multiline)",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.String.Input("value", multiline=True),
|
||||
],
|
||||
outputs=[io.String.Output()],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.INT,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/primitive"
|
||||
|
||||
def execute(self, value: int) -> tuple[int]:
|
||||
return (value,)
|
||||
|
||||
|
||||
class Float(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.FLOAT,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/primitive"
|
||||
|
||||
def execute(self, value: float) -> tuple[float]:
|
||||
return (value,)
|
||||
def execute(cls, value: str) -> io.NodeOutput:
|
||||
return io.NodeOutput(value)
|
||||
|
||||
|
||||
class Boolean(ComfyNodeABC):
|
||||
class Int(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.BOOLEAN, {})},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PrimitiveInt",
|
||||
display_name="Int",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True),
|
||||
],
|
||||
outputs=[io.Int.Output()],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/primitive"
|
||||
|
||||
def execute(self, value: bool) -> tuple[bool]:
|
||||
return (value,)
|
||||
@classmethod
|
||||
def execute(cls, value: int) -> io.NodeOutput:
|
||||
return io.NodeOutput(value)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PrimitiveString": String,
|
||||
"PrimitiveStringMultiline": StringMultiline,
|
||||
"PrimitiveInt": Int,
|
||||
"PrimitiveFloat": Float,
|
||||
"PrimitiveBoolean": Boolean,
|
||||
}
|
||||
class Float(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PrimitiveFloat",
|
||||
display_name="Float",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize),
|
||||
],
|
||||
outputs=[io.Float.Output()],
|
||||
)
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PrimitiveString": "String",
|
||||
"PrimitiveStringMultiline": "String (Multiline)",
|
||||
"PrimitiveInt": "Int",
|
||||
"PrimitiveFloat": "Float",
|
||||
"PrimitiveBoolean": "Boolean",
|
||||
}
|
||||
@classmethod
|
||||
def execute(cls, value: float) -> io.NodeOutput:
|
||||
return io.NodeOutput(value)
|
||||
|
||||
|
||||
class Boolean(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PrimitiveBoolean",
|
||||
display_name="Boolean",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.Boolean.Input("value"),
|
||||
],
|
||||
outputs=[io.Boolean.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, value: bool) -> io.NodeOutput:
|
||||
return io.NodeOutput(value)
|
||||
|
||||
|
||||
class PrimitivesExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
String,
|
||||
StringMultiline,
|
||||
Int,
|
||||
Float,
|
||||
Boolean,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> PrimitivesExtension:
|
||||
return PrimitivesExtension()
|
||||
|
||||
@ -1,462 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
import comfy.patcher_extension
|
||||
import logging
|
||||
import torch
|
||||
import math
|
||||
import comfy.model_patcher
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
def prepare_noise_wrapper(executor, *args, **kwargs):
|
||||
try:
|
||||
transformer_options: dict[str] = args[2]["transformer_options"]
|
||||
sb_holder: SortblockHolder = transformer_options["sortblock"]
|
||||
if sb_holder.initial_step:
|
||||
sample_sigmas = transformer_options["sample_sigmas"]
|
||||
relevant_sigmas = []
|
||||
# find start and end steps, then use to interpolate between start and end predict ratios
|
||||
for i,sigma in enumerate(sample_sigmas):
|
||||
if sb_holder.check_if_within_timesteps(sigma):
|
||||
relevant_sigmas.append((i, sigma))
|
||||
start_index = relevant_sigmas[0][0]
|
||||
end_index = relevant_sigmas[-1][0]
|
||||
sb_holder.predict_ratios = torch.linspace(sb_holder.start_predict_ratio, sb_holder.end_predict_ratio, end_index - start_index + 1)
|
||||
sb_holder.predict_start_index = start_index
|
||||
|
||||
return executor(*args, **kwargs)
|
||||
finally:
|
||||
transformer_options: dict[str] = args[2]["transformer_options"]
|
||||
sb_holder: SortblockHolder = transformer_options["sortblock"]
|
||||
sb_holder.step_count += 1
|
||||
if sb_holder.should_do_sortblock():
|
||||
sb_holder.active_steps += 1
|
||||
|
||||
|
||||
def outer_sample_wrapper(executor, *args, **kwargs):
|
||||
try:
|
||||
logging.info("Sortblock: inside outer_sample!")
|
||||
guider = executor.class_obj
|
||||
orig_model_options = guider.model_options
|
||||
guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options)
|
||||
# clone and prepare timesteps
|
||||
sb_holder = guider.model_options["transformer_options"]["sortblock"]
|
||||
guider.model_options["transformer_options"]["sortblock"] = sb_holder.clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
|
||||
sb_holder: SortblockHolder = guider.model_options["transformer_options"]["sortblock"]
|
||||
logging.info(f"Sortblock: enabled - threshold: {sb_holder.start_predict_ratio}, start_percent: {sb_holder.start_percent}, end_percent: {sb_holder.end_percent}")
|
||||
return executor(*args, **kwargs)
|
||||
finally:
|
||||
sb_holder = guider.model_options["transformer_options"]["sortblock"]
|
||||
logging.info(f"Sortblock: final step count: {sb_holder.step_count}")
|
||||
sb_holder.reset()
|
||||
guider.model_options = orig_model_options
|
||||
|
||||
|
||||
def model_forward_wrapper(executor, *args, **kwargs):
|
||||
# TODO: make work with batches of conds
|
||||
transformer_options: dict[str] = args[-1]
|
||||
if not isinstance(transformer_options, dict):
|
||||
transformer_options = kwargs.get("transformer_options")
|
||||
if not transformer_options:
|
||||
transformer_options = args[-2]
|
||||
sigmas = transformer_options["sigmas"]
|
||||
sb_holder: SortblockHolder = transformer_options["sortblock"]
|
||||
sb_holder.update_should_do_sortblock(sigmas)
|
||||
|
||||
# if initial step, prepare everything for Sortblock
|
||||
if sb_holder.initial_step:
|
||||
logging.info(f"Sortblock: inside model {executor.class_obj.__class__.__name__}")
|
||||
# TODO: generalize for other models
|
||||
# these won't stick around past this step; should store on sb_holder instead
|
||||
logging.info(f"Sortblock: preparing {len(executor.class_obj.double_blocks)} double blocks and {len(executor.class_obj.single_blocks)} single blocks")
|
||||
if hasattr(executor.class_obj, "double_blocks"):
|
||||
for block in executor.class_obj.double_blocks:
|
||||
prepare_block(block, sb_holder)
|
||||
if hasattr(executor.class_obj, "single_blocks"):
|
||||
for block in executor.class_obj.single_blocks:
|
||||
prepare_block(block, sb_holder)
|
||||
if hasattr(executor.class_obj, "blocks"):
|
||||
for block in executor.class_obj.block:
|
||||
prepare_block(block, sb_holder)
|
||||
|
||||
# when 0: Initialization(1)
|
||||
if sb_holder.step_modulus == 0:
|
||||
logging.info(f"Sortblock: for step {sb_holder.step_count}, all blocks are marked for recomputation")
|
||||
# all features are computed, input-outputs changes for all DiT blocks are stored for relative step 'k'
|
||||
sb_holder.activated_steps.append(sb_holder.step_count)
|
||||
for block in sb_holder.all_blocks:
|
||||
cache: BlockCache = block.__block_cache
|
||||
cache.mark_recompute()
|
||||
|
||||
# all block operations are performed in forward pass of model
|
||||
to_return = executor(*args, **kwargs)
|
||||
|
||||
# when 1: Select DiT blocks(4)
|
||||
if sb_holder.step_modulus == 1:
|
||||
predict_index = max(0, sb_holder.step_count - sb_holder.predict_start_index)
|
||||
predict_ratio = sb_holder.predict_ratios[predict_index]
|
||||
logging.info(f"Sortblock: for step {sb_holder.step_count}, selecting blocks for recomputation and prediction, predict_ratio: {predict_ratio}")
|
||||
reuse_ratio = 1.0 - predict_ratio
|
||||
for block_type, blocks in sb_holder.blocks_per_type.items():
|
||||
sorted_blocks = sorted(blocks, key=lambda x: x.__block_cache.cosine_similarity)
|
||||
threshold_index = int(len(sorted_blocks) * reuse_ratio)
|
||||
# blocks with lower similarity are marked for recomputation
|
||||
for block in sorted_blocks[:threshold_index]:
|
||||
cache: BlockCache = block.__block_cache
|
||||
cache.mark_recompute()
|
||||
# blocks with higher similarity are marked for prediction
|
||||
for block in sorted_blocks[threshold_index:]:
|
||||
cache: BlockCache = block.__block_cache
|
||||
cache.mark_predict()
|
||||
logging.info(f"Sortblock: for {block_type}, selected {len(sorted_blocks[:threshold_index])} blocks for recomputation and {len(sorted_blocks[threshold_index:])} blocks for prediction")
|
||||
|
||||
if sb_holder.initial_step:
|
||||
sb_holder.initial_step = False
|
||||
return to_return
|
||||
|
||||
def block_forward_factory(func, block):
|
||||
def block_forward_wrapper(*args, **kwargs):
|
||||
transformer_options: dict[str] = kwargs.get("transformer_options")
|
||||
sb_holder: SortblockHolder = transformer_options["sortblock"]
|
||||
cache: BlockCache = block.__block_cache
|
||||
# make sure stream count is properly set for this block
|
||||
if sb_holder.initial_step:
|
||||
sb_holder.add_to_blocks_per_type(block, transformer_options['block'][0])
|
||||
cache.block_index = transformer_options['block'][1]
|
||||
cache.stream_count = transformer_options['block'][2]
|
||||
# do sortblock stuff
|
||||
if cache.recompute and sb_holder.step_modulus != 1:
|
||||
# clone relevant inputs
|
||||
orig_inputs = cache.get_orig_inputs(args, kwargs, clone=True)
|
||||
# get block outputs
|
||||
# NOTE: output_raw is expected to have cache.stream_count elements if count is greaater than 1 (double block, etc.)
|
||||
if cache.stream_count == 1:
|
||||
zzz = 10
|
||||
output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]] = func(*args, **kwargs)
|
||||
# perform derivative approximation;
|
||||
cache.derivative_approximation(sb_holder, output_raw, orig_inputs)
|
||||
# if step_modulus is 0, input-output changes for DiT block are stored
|
||||
if sb_holder.step_modulus == 0:
|
||||
cache.cache_previous_residual(output_raw, orig_inputs)
|
||||
else:
|
||||
# if not to recompute, predict features for current timestep
|
||||
orig_inputs = cache.get_orig_inputs(args, kwargs, clone=False)
|
||||
# when 1: Linear Prediction(2)
|
||||
# if step_modulus is 1, store block residuals as 'current' after applying taylor_formula
|
||||
if sb_holder.step_modulus == 1:
|
||||
cache.cache_current_residual(sb_holder)
|
||||
# based on features computed in last timestep, all features for current timestep are predicted using Eq. 4,
|
||||
# input-output changes for all DiT blocks are stored for relative step 'k+1'
|
||||
output_raw = cache.apply_linear_prediction(sb_holder, orig_inputs)
|
||||
|
||||
# when 1: Identify Changes(3)
|
||||
if sb_holder.step_modulus == 1:
|
||||
# based on features computed in last timestep, all features for current timestep are predicted using Eq. 4,
|
||||
# input-output changes for all DiT blocks are stored for relative step 'k+1'
|
||||
cache.calculate_cosine_similarity()
|
||||
|
||||
# return output_raw
|
||||
return output_raw
|
||||
return block_forward_wrapper
|
||||
|
||||
|
||||
def perform_sortblock(blocks: list):
|
||||
...
|
||||
|
||||
def prepare_block(block, sb_holder: SortblockHolder, stream_count: int=1):
|
||||
sb_holder.add_to_all_blocks(block)
|
||||
block.__original_forward = block.forward
|
||||
block.forward = block_forward_factory(block.__original_forward, block)
|
||||
block.__block_cache = BlockCache(subsample_factor=sb_holder.subsample_factor, verbose=sb_holder.verbose)
|
||||
|
||||
def clean_block(block):
|
||||
block.forward = block.__original_forward
|
||||
del block.__original_forward
|
||||
del block.__block_cache
|
||||
|
||||
def subsample(x: torch.Tensor, factor: int, clone: bool=True) -> torch.Tensor:
|
||||
if factor > 1:
|
||||
to_return = x[..., ::factor, ::factor]
|
||||
if clone:
|
||||
return to_return.clone()
|
||||
return to_return
|
||||
if clone:
|
||||
return x.clone()
|
||||
return x
|
||||
|
||||
class BlockCache:
|
||||
def __init__(self, subsample_factor: int=8, verbose: bool=False):
|
||||
self.subsample_factor = subsample_factor
|
||||
self.verbose = verbose
|
||||
self.stream_count = 1
|
||||
self.recompute = False
|
||||
self.block_index = 0
|
||||
# cached values
|
||||
self.previous_residual_subsampled: torch.Tensor = None
|
||||
self.current_residual_subsampled: torch.Tensor = None
|
||||
self.cosine_similarity: float = None
|
||||
self.previous_taylor_factors: dict[int, torch.Tensor] = {}
|
||||
self.current_taylor_factors: dict[int, torch.Tensor] = {}
|
||||
|
||||
def mark_recompute(self):
|
||||
self.recompute = True
|
||||
|
||||
def mark_predict(self):
|
||||
self.recompute = False
|
||||
|
||||
def cache_previous_residual(self, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], orig_inputs: Union[torch.Tensor, tuple[torch.Tensor, ...]]):
|
||||
if isinstance(output_raw, tuple):
|
||||
output_raw = output_raw[0]
|
||||
if isinstance(orig_inputs, tuple):
|
||||
orig_inputs = orig_inputs[0]
|
||||
del self.previous_residual_subsampled
|
||||
self.previous_residual_subsampled = subsample(output_raw - orig_inputs, self.subsample_factor, clone=True)
|
||||
|
||||
def cache_current_residual(self, sb_holder: SortblockHolder):
|
||||
del self.current_residual_subsampled
|
||||
self.current_residual_subsampled = subsample(self.use_taylor_formula(sb_holder)[0], self.subsample_factor, clone=True)
|
||||
|
||||
def get_orig_inputs(self, d_args: tuple, d_kwargs: dict, clone: bool=True) -> tuple[torch.Tensor, ...]:
|
||||
if self.stream_count == 1:
|
||||
if clone:
|
||||
return d_args[0].clone()
|
||||
return d_args[0]
|
||||
keys = list(d_kwargs.keys())[:self.stream_count]
|
||||
orig_inputs = []
|
||||
for key in keys:
|
||||
if clone:
|
||||
orig_inputs.append(d_kwargs[key].clone())
|
||||
else:
|
||||
orig_inputs.append(d_kwargs[key])
|
||||
return tuple(orig_inputs)
|
||||
|
||||
def apply_linear_prediction(self, sb_holder: SortblockHolder, orig_inputs: Union[torch.Tensor, tuple[torch.Tensor, ...]]) -> None:
|
||||
drop_tuple = False
|
||||
if not isinstance(orig_inputs, tuple):
|
||||
orig_inputs = (orig_inputs,)
|
||||
drop_tuple = True
|
||||
taylor_results = self.use_taylor_formula(sb_holder)
|
||||
for output, taylor_result in zip(orig_inputs, taylor_results):
|
||||
if output.shape != taylor_result.shape:
|
||||
zzz = 10
|
||||
output += taylor_result
|
||||
if drop_tuple:
|
||||
orig_inputs = orig_inputs[0]
|
||||
return orig_inputs
|
||||
|
||||
def calculate_cosine_similarity(self) -> None:
|
||||
self.cosine_similarity = torch.nn.functional.cosine_similarity(self.previous_residual_subsampled, self.current_residual_subsampled, dim=-1).mean().item()
|
||||
|
||||
def derivative_approximation(self, sb_holder: SortblockHolder, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], orig_inputs: Union[torch.Tensor, tuple[torch.Tensor, ...]]):
|
||||
activation_distance = sb_holder.activated_steps[-1] - sb_holder.activated_steps[-2]
|
||||
# make tuple if not already tuple, so that works with both single and double blocks
|
||||
if not isinstance(output_raw, tuple):
|
||||
output_raw = (output_raw,)
|
||||
if not isinstance(orig_inputs, tuple):
|
||||
orig_inputs = (orig_inputs,)
|
||||
|
||||
for i, (output, x) in enumerate(zip(output_raw, orig_inputs)):
|
||||
feature = output.clone() - x
|
||||
has_previous_taylor_factor = self.previous_taylor_factors.get(i, None) is not None
|
||||
# NOTE: not sure why - 2, but that's what's in the original implementation. Maybe consider changing values?
|
||||
if has_previous_taylor_factor and sb_holder.step_count > (sb_holder.first_enhance - 2):
|
||||
self.current_taylor_factors[i] = (
|
||||
feature - self.previous_taylor_factors[i]
|
||||
) / activation_distance
|
||||
|
||||
self.previous_taylor_factors[i] = feature
|
||||
|
||||
def use_taylor_formula(self, sb_holder: SortblockHolder) -> tuple[torch.Tensor, ...]:
|
||||
step_distance = sb_holder.step_count - sb_holder.activated_steps[-1]
|
||||
|
||||
output_predicted = []
|
||||
|
||||
for key in self.previous_taylor_factors.keys():
|
||||
previous_tf = self.previous_taylor_factors[key]
|
||||
current_tf = self.current_taylor_factors[key]
|
||||
predicted = taylor_formula(previous_tf, 0, step_distance)
|
||||
predicted += taylor_formula(current_tf, 1, step_distance)
|
||||
output_predicted.append(predicted)
|
||||
|
||||
return tuple(output_predicted)
|
||||
|
||||
def reset(self):
|
||||
self.recompute = False
|
||||
self.current_residual_subsampled = None
|
||||
self.previous_residual_subsampled = None
|
||||
self.cosine_similarity = None
|
||||
self.previous_taylor_factors = {}
|
||||
self.current_taylor_factors = {}
|
||||
|
||||
def taylor_formula(taylor_factor: torch.Tensor, i: int, step_distance: int):
|
||||
return (
|
||||
(1 / math.factorial(i))
|
||||
* taylor_factor
|
||||
* (step_distance ** i)
|
||||
)
|
||||
|
||||
class SortblockHolder:
|
||||
def __init__(self, start_predict_ratio: float, end_predict_ratio: float, policy_refresh_interval: int,
|
||||
start_percent: float, end_percent: float, subsample_factor: int=8, verbose: bool=False):
|
||||
self.start_predict_ratio = start_predict_ratio
|
||||
self.end_predict_ratio = end_predict_ratio
|
||||
self.start_percent = start_percent
|
||||
self.end_percent = end_percent
|
||||
self.subsample_factor = subsample_factor
|
||||
self.verbose = verbose
|
||||
|
||||
# NOTE: number represents steps
|
||||
self.policy_refresh_interval = policy_refresh_interval
|
||||
self.active_policy_refresh_interval = 1
|
||||
self.first_enhance = 3 # NOTE: this value is 2 higher than the one actually used in code (subtracted by 2 in derivative_approximation)
|
||||
# timestep values
|
||||
self.start_t = 0.0
|
||||
self.end_t = 0.0
|
||||
self.curr_t = 0.0
|
||||
# control values
|
||||
self.initial_step = True
|
||||
self.step_count = 0
|
||||
self.activated_steps: list[int] = [0]
|
||||
self.step_modulus = 0
|
||||
self.do_sortblock = False
|
||||
self.active_steps = 0
|
||||
self.predict_ratios = []
|
||||
self.predict_start_index = 0
|
||||
|
||||
# cache values
|
||||
self.all_blocks = []
|
||||
self.blocks_per_type = {}
|
||||
|
||||
def add_to_all_blocks(self, block):
|
||||
self.all_blocks.append(block)
|
||||
|
||||
def add_to_blocks_per_type(self, block, block_type: str):
|
||||
self.blocks_per_type.setdefault(block_type, []).append(block)
|
||||
|
||||
def prepare_timesteps(self, model_sampling):
|
||||
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
|
||||
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
|
||||
return self
|
||||
|
||||
def check_if_within_timesteps(self, timestep: Union[float, torch.Tensor]) -> bool:
|
||||
return (timestep <= self.start_t).item() and (timestep > self.end_t).item()
|
||||
|
||||
def update_should_do_sortblock(self, timestep: float) -> bool:
|
||||
self.do_sortblock = (timestep[0] <= self.start_t).item() and (timestep[0] > self.end_t).item()
|
||||
self.curr_t = timestep
|
||||
if self.do_sortblock:
|
||||
self.active_policy_refresh_interval = self.policy_refresh_interval
|
||||
else:
|
||||
self.active_policy_refresh_interval = 1
|
||||
self.update_step_modulus()
|
||||
return self.do_sortblock
|
||||
|
||||
def update_step_modulus(self):
|
||||
self.step_modulus = int(self.step_count % self.active_policy_refresh_interval)
|
||||
|
||||
def should_do_sortblock(self) -> bool:
|
||||
return self.do_sortblock
|
||||
|
||||
def reset(self):
|
||||
self.initial_step = True
|
||||
self.curr_t = 0.0
|
||||
logging.info(f"Sortblock: resetting {len(self.all_blocks)} blocks")
|
||||
for block in self.all_blocks:
|
||||
clean_block(block)
|
||||
self.all_blocks = []
|
||||
self.blocks_per_type = {}
|
||||
self.step_count = 0
|
||||
self.activated_steps = [0]
|
||||
self.step_modulus = 0
|
||||
self.active_steps = 0
|
||||
self.predict_ratios = []
|
||||
self.do_sortblock = False
|
||||
self.predict_start_index = 0
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
return SortblockHolder(start_predict_ratio=self.start_predict_ratio, end_predict_ratio=self.end_predict_ratio, policy_refresh_interval=self.policy_refresh_interval,
|
||||
start_percent=self.start_percent, end_percent=self.end_percent, subsample_factor=self.subsample_factor,
|
||||
verbose=self.verbose)
|
||||
|
||||
|
||||
class SortblockNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="Sortblock",
|
||||
display_name="Sortblock",
|
||||
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
|
||||
category="advanced/debug/model",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to add Sortblock to."),
|
||||
io.Float.Input("predict_ratio", min=0.0, default=0.8, max=3.0, step=0.01, tooltip="The ratio of blocks to predict."),
|
||||
io.Int.Input("policy_refresh_interval", min=3, default=5, max=100, step=1, tooltip="The interval at which to refresh the policy."),
|
||||
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of Sortblock."),
|
||||
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of Sortblock."),
|
||||
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with Sortblock."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||
# TODO: check for specific flavors of supported models
|
||||
model = model.clone()
|
||||
model.model_options["transformer_options"]["sortblock"] = SortblockHolder(start_predict_ratio=predict_ratio, end_predict_ratio=predict_ratio, policy_refresh_interval=policy_refresh_interval,
|
||||
start_percent=start_percent, end_percent=end_percent, subsample_factor=8, verbose=verbose)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", prepare_noise_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", outer_sample_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class SortblockScaledNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SortblockScaled",
|
||||
display_name="SortblockScaled",
|
||||
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
|
||||
category="advanced/debug/model",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to add Sortblock to."),
|
||||
io.Float.Input("start_predict_ratio", min=0.0, default=0.2, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
|
||||
io.Float.Input("end_predict_ratio", min=0.0, default=0.9, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
|
||||
io.Int.Input("policy_refresh_interval", min=3, default=5, max=100, step=1, tooltip="The interval at which to refresh the policy."),
|
||||
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of Sortblock."),
|
||||
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of Sortblock."),
|
||||
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with Sortblock."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, start_predict_ratio: float, end_predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||
# TODO: check for specific flavors of supported models
|
||||
model = model.clone()
|
||||
model.model_options["transformer_options"]["sortblock"] = SortblockHolder(start_predict_ratio, end_predict_ratio, policy_refresh_interval, start_percent, end_percent, subsample_factor=8, verbose=verbose)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", prepare_noise_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", outer_sample_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class SortblockExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SortblockNode,
|
||||
SortblockScaledNode,
|
||||
]
|
||||
|
||||
def comfy_entrypoint():
|
||||
return SortblockExtension()
|
||||
@ -17,55 +17,61 @@
|
||||
"""
|
||||
|
||||
import torch
|
||||
import nodes
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.utils
|
||||
import nodes
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class StableCascade_EmptyLatentImage:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
class StableCascade_EmptyLatentImage(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StableCascade_EmptyLatentImage",
|
||||
category="latent/stable_cascade",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("compression", default=42, min=4, max=128, step=1),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="stage_c"),
|
||||
io.Latent.Output(display_name="stage_b"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT", "LATENT")
|
||||
RETURN_NAMES = ("stage_c", "stage_b")
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/stable_cascade"
|
||||
|
||||
def generate(self, width, height, compression, batch_size=1):
|
||||
def execute(cls, width, height, compression, batch_size=1):
|
||||
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
|
||||
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
|
||||
return ({
|
||||
return io.NodeOutput({
|
||||
"samples": c_latent,
|
||||
}, {
|
||||
"samples": b_latent,
|
||||
})
|
||||
|
||||
class StableCascade_StageC_VAEEncode:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
class StableCascade_StageC_VAEEncode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StableCascade_StageC_VAEEncode",
|
||||
category="latent/stable_cascade",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("compression", default=42, min=4, max=128, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="stage_c"),
|
||||
io.Latent.Output(display_name="stage_b"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"image": ("IMAGE",),
|
||||
"vae": ("VAE", ),
|
||||
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT", "LATENT")
|
||||
RETURN_NAMES = ("stage_c", "stage_b")
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/stable_cascade"
|
||||
|
||||
def generate(self, image, vae, compression):
|
||||
def execute(cls, image, vae, compression):
|
||||
width = image.shape[-2]
|
||||
height = image.shape[-3]
|
||||
out_width = (width // compression) * vae.downscale_ratio
|
||||
@ -75,51 +81,59 @@ class StableCascade_StageC_VAEEncode:
|
||||
|
||||
c_latent = vae.encode(s[:,:,:,:3])
|
||||
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
|
||||
return ({
|
||||
return io.NodeOutput({
|
||||
"samples": c_latent,
|
||||
}, {
|
||||
"samples": b_latent,
|
||||
})
|
||||
|
||||
class StableCascade_StageB_Conditioning:
|
||||
|
||||
class StableCascade_StageB_Conditioning(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "conditioning": ("CONDITIONING",),
|
||||
"stage_c": ("LATENT",),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StableCascade_StageB_Conditioning",
|
||||
category="conditioning/stable_cascade",
|
||||
inputs=[
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Latent.Input("stage_c"),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
FUNCTION = "set_prior"
|
||||
|
||||
CATEGORY = "conditioning/stable_cascade"
|
||||
|
||||
def set_prior(self, conditioning, stage_c):
|
||||
@classmethod
|
||||
def execute(cls, conditioning, stage_c):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
d = t[1].copy()
|
||||
d['stable_cascade_prior'] = stage_c['samples']
|
||||
d["stable_cascade_prior"] = stage_c["samples"]
|
||||
n = [t[0], d]
|
||||
c.append(n)
|
||||
return (c, )
|
||||
return io.NodeOutput(c)
|
||||
|
||||
class StableCascade_SuperResolutionControlnet:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
class StableCascade_SuperResolutionControlnet(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StableCascade_SuperResolutionControlnet",
|
||||
category="_for_testing/stable_cascade",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Vae.Input("vae"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="controlnet_input"),
|
||||
io.Latent.Output(display_name="stage_c"),
|
||||
io.Latent.Output(display_name="stage_b"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"image": ("IMAGE",),
|
||||
"vae": ("VAE", ),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE", "LATENT", "LATENT")
|
||||
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
|
||||
FUNCTION = "generate"
|
||||
|
||||
EXPERIMENTAL = True
|
||||
CATEGORY = "_for_testing/stable_cascade"
|
||||
|
||||
def generate(self, image, vae):
|
||||
def execute(cls, image, vae):
|
||||
width = image.shape[-2]
|
||||
height = image.shape[-3]
|
||||
batch_size = image.shape[0]
|
||||
@ -127,15 +141,22 @@ class StableCascade_SuperResolutionControlnet:
|
||||
|
||||
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
|
||||
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])
|
||||
return (controlnet_input, {
|
||||
return io.NodeOutput(controlnet_input, {
|
||||
"samples": c_latent,
|
||||
}, {
|
||||
"samples": b_latent,
|
||||
})
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage,
|
||||
"StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning,
|
||||
"StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode,
|
||||
"StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet,
|
||||
}
|
||||
|
||||
class StableCascadeExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
StableCascade_EmptyLatentImage,
|
||||
StableCascade_StageB_Conditioning,
|
||||
StableCascade_StageC_VAEEncode,
|
||||
StableCascade_SuperResolutionControlnet,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> StableCascadeExtension:
|
||||
return StableCascadeExtension()
|
||||
|
||||
@ -5,52 +5,49 @@ import av
|
||||
import torch
|
||||
import folder_paths
|
||||
import json
|
||||
from typing import Optional, Literal
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
from fractions import Fraction
|
||||
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
|
||||
from comfy_api.latest import Input, InputImpl, Types
|
||||
from comfy_api.input import AudioInput, ImageInput, VideoInput
|
||||
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
|
||||
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from comfy.cli_args import args
|
||||
|
||||
class SaveWEBM:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
class SaveWEBM(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveWEBM",
|
||||
category="image/video",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Image.Input("images"),
|
||||
io.String.Input("filename_prefix", default="ComfyUI"),
|
||||
io.Combo.Input("codec", options=["vp9", "av1"]),
|
||||
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
|
||||
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
|
||||
"codec": (["vp9", "av1"],),
|
||||
"fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
|
||||
"crf": ("FLOAT", {"default": 32.0, "min": 0, "max": 63.0, "step": 1, "tooltip": "Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_images"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "image/video"
|
||||
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def save_images(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
def execute(cls, images, codec, fps, filename_prefix, crf) -> io.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
|
||||
file = f"{filename}_{counter:05}_.webm"
|
||||
container = av.open(os.path.join(full_output_folder, file), mode="w")
|
||||
|
||||
if prompt is not None:
|
||||
container.metadata["prompt"] = json.dumps(prompt)
|
||||
if cls.hidden.prompt is not None:
|
||||
container.metadata["prompt"] = json.dumps(cls.hidden.prompt)
|
||||
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
container.metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||
|
||||
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
|
||||
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
|
||||
@ -69,63 +66,46 @@ class SaveWEBM:
|
||||
container.mux(stream.encode())
|
||||
container.close()
|
||||
|
||||
results: list[FileLocator] = [{
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
}]
|
||||
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
|
||||
|
||||
return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side
|
||||
|
||||
class SaveVideo(ComfyNodeABC):
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type: Literal["output"] = "output"
|
||||
self.prefix_append = ""
|
||||
class SaveVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveVideo",
|
||||
display_name="Save Video",
|
||||
category="image/video",
|
||||
description="Saves the input images to your ComfyUI output directory.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to save."),
|
||||
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
|
||||
io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
||||
io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"video": (IO.VIDEO, {"tooltip": "The video to save."}),
|
||||
"filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
|
||||
"format": (Types.VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
|
||||
"codec": (Types.VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO"
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_video"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
|
||||
|
||||
def save_video(self, video: Input.Video, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput:
|
||||
width, height = video.get_dimensions()
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
filename_prefix,
|
||||
self.output_dir,
|
||||
folder_paths.get_output_directory(),
|
||||
width,
|
||||
height
|
||||
)
|
||||
results: list[FileLocator] = list()
|
||||
saved_metadata = None
|
||||
if not args.disable_metadata:
|
||||
metadata = {}
|
||||
if extra_pnginfo is not None:
|
||||
metadata.update(extra_pnginfo)
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = prompt
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
metadata.update(cls.hidden.extra_pnginfo)
|
||||
if cls.hidden.prompt is not None:
|
||||
metadata["prompt"] = cls.hidden.prompt
|
||||
if len(metadata) > 0:
|
||||
saved_metadata = metadata
|
||||
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
||||
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
|
||||
video.save_to(
|
||||
os.path.join(full_output_folder, file),
|
||||
format=format,
|
||||
@ -133,83 +113,82 @@ class SaveVideo(ComfyNodeABC):
|
||||
metadata=saved_metadata
|
||||
)
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
|
||||
|
||||
return { "ui": { "images": results, "animated": (True,) } }
|
||||
|
||||
class CreateVideo(ComfyNodeABC):
|
||||
class CreateVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": (IO.IMAGE, {"tooltip": "The images to create a video from."}),
|
||||
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 1.0}),
|
||||
},
|
||||
"optional": {
|
||||
"audio": (IO.AUDIO, {"tooltip": "The audio to add to the video."}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CreateVideo",
|
||||
display_name="Create Video",
|
||||
category="image/video",
|
||||
description="Create a video from images.",
|
||||
inputs=[
|
||||
io.Image.Input("images", tooltip="The images to create a video from."),
|
||||
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
||||
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
FUNCTION = "create_video"
|
||||
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Create a video from images."
|
||||
|
||||
def create_video(self, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None):
|
||||
return (InputImpl.VideoFromComponents(
|
||||
Types.VideoComponents(
|
||||
images=images,
|
||||
audio=audio,
|
||||
frame_rate=Fraction(fps),
|
||||
)
|
||||
),)
|
||||
|
||||
class GetVideoComponents(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"video": (IO.VIDEO, {"tooltip": "The video to extract components from."}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = (IO.IMAGE, IO.AUDIO, IO.FLOAT)
|
||||
RETURN_NAMES = ("images", "audio", "fps")
|
||||
FUNCTION = "get_components"
|
||||
def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
|
||||
return io.NodeOutput(
|
||||
VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
||||
)
|
||||
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate."
|
||||
class GetVideoComponents(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="GetVideoComponents",
|
||||
display_name="Get Video Components",
|
||||
category="image/video",
|
||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to extract components from."),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="images"),
|
||||
io.Audio.Output(display_name="audio"),
|
||||
io.Float.Output(display_name="fps"),
|
||||
],
|
||||
)
|
||||
|
||||
def get_components(self, video: Input.Video):
|
||||
@classmethod
|
||||
def execute(cls, video: VideoInput) -> io.NodeOutput:
|
||||
components = video.get_components()
|
||||
|
||||
return (components.images, components.audio, float(components.frame_rate))
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
||||
|
||||
class LoadVideo(ComfyNodeABC):
|
||||
class LoadVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
def define_schema(cls):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
files = folder_paths.filter_files_content_types(files, ["video"])
|
||||
return {"required":
|
||||
{"file": (sorted(files), {"video_upload": True})},
|
||||
}
|
||||
|
||||
CATEGORY = "image/video"
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
FUNCTION = "load_video"
|
||||
def load_video(self, file):
|
||||
video_path = folder_paths.get_annotated_filepath(file)
|
||||
return (InputImpl.VideoFromFile(video_path),)
|
||||
return io.Schema(
|
||||
node_id="LoadVideo",
|
||||
display_name="Load Video",
|
||||
category="image/video",
|
||||
inputs=[
|
||||
io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, file):
|
||||
def execute(cls, file) -> io.NodeOutput:
|
||||
video_path = folder_paths.get_annotated_filepath(file)
|
||||
return io.NodeOutput(VideoFromFile(video_path))
|
||||
|
||||
@classmethod
|
||||
def fingerprint_inputs(s, file):
|
||||
video_path = folder_paths.get_annotated_filepath(file)
|
||||
mod_time = os.path.getmtime(video_path)
|
||||
# Instead of hashing the file, we can just use the modification time to avoid
|
||||
@ -217,24 +196,23 @@ class LoadVideo(ComfyNodeABC):
|
||||
return mod_time
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, file):
|
||||
def validate_inputs(s, file):
|
||||
if not folder_paths.exists_annotated_filepath(file):
|
||||
return "Invalid video file: {}".format(file)
|
||||
|
||||
return True
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SaveWEBM": SaveWEBM,
|
||||
"SaveVideo": SaveVideo,
|
||||
"CreateVideo": CreateVideo,
|
||||
"GetVideoComponents": GetVideoComponents,
|
||||
"LoadVideo": LoadVideo,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SaveVideo": "Save Video",
|
||||
"CreateVideo": "Create Video",
|
||||
"GetVideoComponents": "Get Video Components",
|
||||
"LoadVideo": "Load Video",
|
||||
}
|
||||
class VideoExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SaveWEBM,
|
||||
SaveVideo,
|
||||
CreateVideo,
|
||||
GetVideoComponents,
|
||||
LoadVideo,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> VideoExtension:
|
||||
return VideoExtension()
|
||||
|
||||
1
main.py
1
main.py
@ -113,7 +113,6 @@ import gc
|
||||
|
||||
if os.name == "nt":
|
||||
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.default_device is not None:
|
||||
|
||||
2
nodes.py
2
nodes.py
@ -2325,8 +2325,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_model_patch.py",
|
||||
"nodes_easycache.py",
|
||||
"nodes_audio_encoder.py",
|
||||
"nodes_sortblock.py",
|
||||
"nodes_easysortblock.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@ -6,6 +6,7 @@ def pytest_addoption(parser):
|
||||
parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images')
|
||||
parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||
parser.addoption("--port", type=int, default=8188, help="Set the listen port.")
|
||||
parser.addoption("--skip-timing-checks", action="store_true", default=False, help="Skip timing-related assertions in tests (useful for CI environments with variable performance)")
|
||||
|
||||
# This initializes args at the beginning of the test session
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@ -19,6 +20,11 @@ def args_pytest(pytestconfig):
|
||||
|
||||
return args
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def skip_timing_checks(pytestconfig):
|
||||
"""Fixture that returns whether timing checks should be skipped."""
|
||||
return pytestconfig.getoption("--skip-timing-checks")
|
||||
|
||||
def pytest_collection_modifyitems(items):
|
||||
# Modifies items so tests run in the correct order
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ import subprocess
|
||||
|
||||
from pytest import fixture
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from tests.inference.test_execution import ComfyClient, run_warmup
|
||||
from tests.execution.test_execution import ComfyClient, run_warmup
|
||||
|
||||
|
||||
@pytest.mark.execution
|
||||
@ -23,7 +23,7 @@ class TestAsyncNodes:
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
||||
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
|
||||
'--cpu',
|
||||
]
|
||||
use_lru, lru_size = request.param
|
||||
@ -81,7 +81,7 @@ class TestAsyncNodes:
|
||||
assert len(result_images) == 1, "Should have 1 image"
|
||||
assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black"
|
||||
|
||||
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||
"""Test that multiple async nodes execute in parallel."""
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client)
|
||||
@ -104,7 +104,8 @@ class TestAsyncNodes:
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should take ~0.5s (max duration) not 1.2s (sum of durations)
|
||||
assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s"
|
||||
if not skip_timing_checks:
|
||||
assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s"
|
||||
|
||||
# Verify all nodes executed
|
||||
assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3)
|
||||
@ -150,7 +151,7 @@ class TestAsyncNodes:
|
||||
with pytest.raises(urllib.error.HTTPError):
|
||||
client.run(g)
|
||||
|
||||
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
|
||||
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||
"""Test async nodes with lazy evaluation."""
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client, prefix="warmup_lazy")
|
||||
@ -173,7 +174,8 @@ class TestAsyncNodes:
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should only execute sleep1, not sleep2
|
||||
assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s"
|
||||
if not skip_timing_checks:
|
||||
assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s"
|
||||
assert result.did_run(sleep1), "Sleep1 should have executed"
|
||||
assert not result.did_run(sleep2), "Sleep2 should have been skipped"
|
||||
|
||||
@ -310,7 +312,7 @@ class TestAsyncNodes:
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have blocked second image"
|
||||
|
||||
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
|
||||
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||
"""Test that async nodes are properly cached."""
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client, prefix="warmup_cache")
|
||||
@ -330,9 +332,10 @@ class TestAsyncNodes:
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
assert not result2.did_run(sleep_node), "Should be cached"
|
||||
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
|
||||
if not skip_timing_checks:
|
||||
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
|
||||
|
||||
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
|
||||
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||
"""Test async nodes within dynamically generated prompts."""
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client, prefix="warmup_dynamic")
|
||||
@ -345,8 +348,8 @@ class TestAsyncNodes:
|
||||
dynamic_async = g.node("TestDynamicAsyncGeneration",
|
||||
image1=image1.out(0),
|
||||
image2=image2.out(0),
|
||||
num_async_nodes=3,
|
||||
sleep_duration=0.2)
|
||||
num_async_nodes=5,
|
||||
sleep_duration=0.4)
|
||||
g.node("SaveImage", images=dynamic_async.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
@ -354,7 +357,8 @@ class TestAsyncNodes:
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should execute async nodes in parallel within dynamic prompt
|
||||
assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s"
|
||||
if not skip_timing_checks:
|
||||
assert elapsed_time < 1.0, f"Dynamic async execution took {elapsed_time}s"
|
||||
assert result.did_run(dynamic_async)
|
||||
|
||||
def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder):
|
||||
@ -149,7 +149,7 @@ class TestExecution:
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
||||
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
|
||||
'--cpu',
|
||||
]
|
||||
use_lru, lru_size = request.param
|
||||
@ -518,7 +518,7 @@ class TestExecution:
|
||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||
assert not result.did_run(test_node), "The execution should have been cached"
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client)
|
||||
|
||||
@ -541,14 +541,15 @@ class TestExecution:
|
||||
|
||||
# The test should take around 3.0 seconds (the longest sleep duration)
|
||||
# plus some overhead, but definitely less than the sum of all sleeps (9.0s)
|
||||
assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s"
|
||||
if not skip_timing_checks:
|
||||
assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s"
|
||||
|
||||
# Verify that all nodes executed
|
||||
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
|
||||
assert result.did_run(sleep_node2), "Sleep node 2 should have run"
|
||||
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
|
||||
|
||||
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
|
||||
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client)
|
||||
|
||||
@ -574,7 +575,9 @@ class TestExecution:
|
||||
|
||||
# Similar to the previous test, expect parallel execution of the sleep nodes
|
||||
# which should complete in less than the sum of all sleeps
|
||||
assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s"
|
||||
# Lots of leeway here since Windows CI is slow
|
||||
if not skip_timing_checks:
|
||||
assert elapsed_time < 13.0, f"Expansion execution took {elapsed_time}s"
|
||||
|
||||
# Verify the parallel sleep node executed
|
||||
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
|
||||
233
tests/execution/test_progress_isolation.py
Normal file
233
tests/execution/test_progress_isolation.py
Normal file
@ -0,0 +1,233 @@
|
||||
"""Test that progress updates are properly isolated between WebSocket clients."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
import time
|
||||
import threading
|
||||
import uuid
|
||||
import websocket
|
||||
from typing import List, Dict, Any
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from tests.execution.test_execution import ComfyClient
|
||||
|
||||
|
||||
class ProgressTracker:
|
||||
"""Tracks progress messages received by a WebSocket client."""
|
||||
|
||||
def __init__(self, client_id: str):
|
||||
self.client_id = client_id
|
||||
self.progress_messages: List[Dict[str, Any]] = []
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def add_message(self, message: Dict[str, Any]):
|
||||
"""Thread-safe addition of progress messages."""
|
||||
with self.lock:
|
||||
self.progress_messages.append(message)
|
||||
|
||||
def get_messages_for_prompt(self, prompt_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all progress messages for a specific prompt_id."""
|
||||
with self.lock:
|
||||
return [
|
||||
msg for msg in self.progress_messages
|
||||
if msg.get('data', {}).get('prompt_id') == prompt_id
|
||||
]
|
||||
|
||||
def has_cross_contamination(self, own_prompt_id: str) -> bool:
|
||||
"""Check if this client received progress for other prompts."""
|
||||
with self.lock:
|
||||
for msg in self.progress_messages:
|
||||
msg_prompt_id = msg.get('data', {}).get('prompt_id')
|
||||
if msg_prompt_id and msg_prompt_id != own_prompt_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class IsolatedClient(ComfyClient):
|
||||
"""Extended ComfyClient that tracks all WebSocket messages."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.progress_tracker = None
|
||||
self.all_messages: List[Dict[str, Any]] = []
|
||||
|
||||
def connect(self, listen='127.0.0.1', port=8188, client_id=None):
|
||||
"""Connect with a specific client_id and set up message tracking."""
|
||||
if client_id is None:
|
||||
client_id = str(uuid.uuid4())
|
||||
super().connect(listen, port, client_id)
|
||||
self.progress_tracker = ProgressTracker(client_id)
|
||||
|
||||
def listen_for_messages(self, duration: float = 5.0):
|
||||
"""Listen for WebSocket messages for a specified duration."""
|
||||
end_time = time.time() + duration
|
||||
self.ws.settimeout(0.5) # Non-blocking with timeout
|
||||
|
||||
while time.time() < end_time:
|
||||
try:
|
||||
out = self.ws.recv()
|
||||
if isinstance(out, str):
|
||||
message = json.loads(out)
|
||||
self.all_messages.append(message)
|
||||
|
||||
# Track progress_state messages
|
||||
if message.get('type') == 'progress_state':
|
||||
self.progress_tracker.add_message(message)
|
||||
except websocket.WebSocketTimeoutException:
|
||||
continue
|
||||
except Exception:
|
||||
# Log error silently in test context
|
||||
break
|
||||
|
||||
|
||||
@pytest.mark.execution
|
||||
class TestProgressIsolation:
|
||||
"""Test suite for verifying progress update isolation between clients."""
|
||||
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def _server(self, args_pytest):
|
||||
"""Start the ComfyUI server for testing."""
|
||||
import subprocess
|
||||
pargs = [
|
||||
'python', 'main.py',
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
|
||||
'--cpu',
|
||||
]
|
||||
p = subprocess.Popen(pargs)
|
||||
yield
|
||||
p.kill()
|
||||
|
||||
def start_client_with_retry(self, listen: str, port: int, client_id: str = None):
|
||||
"""Start client with connection retries."""
|
||||
client = IsolatedClient()
|
||||
# Connect to server (with retries)
|
||||
n_tries = 5
|
||||
for i in range(n_tries):
|
||||
time.sleep(4)
|
||||
try:
|
||||
client.connect(listen, port, client_id)
|
||||
return client
|
||||
except ConnectionRefusedError as e:
|
||||
print(e) # noqa: T201
|
||||
print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201
|
||||
raise ConnectionRefusedError(f"Failed to connect after {n_tries} attempts")
|
||||
|
||||
def test_progress_isolation_between_clients(self, args_pytest):
|
||||
"""Test that progress updates are isolated between different clients."""
|
||||
listen = args_pytest["listen"]
|
||||
port = args_pytest["port"]
|
||||
|
||||
# Create two separate clients with unique IDs
|
||||
client_a_id = "client_a_" + str(uuid.uuid4())
|
||||
client_b_id = "client_b_" + str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Connect both clients with retries
|
||||
client_a = self.start_client_with_retry(listen, port, client_a_id)
|
||||
client_b = self.start_client_with_retry(listen, port, client_b_id)
|
||||
|
||||
# Create simple workflows for both clients
|
||||
graph_a = GraphBuilder(prefix="client_a")
|
||||
image_a = graph_a.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
|
||||
graph_a.node("PreviewImage", images=image_a.out(0))
|
||||
|
||||
graph_b = GraphBuilder(prefix="client_b")
|
||||
image_b = graph_b.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
|
||||
graph_b.node("PreviewImage", images=image_b.out(0))
|
||||
|
||||
# Submit workflows from both clients
|
||||
prompt_a = graph_a.finalize()
|
||||
prompt_b = graph_b.finalize()
|
||||
|
||||
response_a = client_a.queue_prompt(prompt_a)
|
||||
prompt_id_a = response_a['prompt_id']
|
||||
|
||||
response_b = client_b.queue_prompt(prompt_b)
|
||||
prompt_id_b = response_b['prompt_id']
|
||||
|
||||
# Start threads to listen for messages on both clients
|
||||
def listen_client_a():
|
||||
client_a.listen_for_messages(duration=10.0)
|
||||
|
||||
def listen_client_b():
|
||||
client_b.listen_for_messages(duration=10.0)
|
||||
|
||||
thread_a = threading.Thread(target=listen_client_a)
|
||||
thread_b = threading.Thread(target=listen_client_b)
|
||||
|
||||
thread_a.start()
|
||||
thread_b.start()
|
||||
|
||||
# Wait for threads to complete
|
||||
thread_a.join()
|
||||
thread_b.join()
|
||||
|
||||
# Verify isolation
|
||||
# Client A should only receive progress for prompt_id_a
|
||||
assert not client_a.progress_tracker.has_cross_contamination(prompt_id_a), \
|
||||
f"Client A received progress updates for other clients' workflows. " \
|
||||
f"Expected only {prompt_id_a}, but got messages for multiple prompts."
|
||||
|
||||
# Client B should only receive progress for prompt_id_b
|
||||
assert not client_b.progress_tracker.has_cross_contamination(prompt_id_b), \
|
||||
f"Client B received progress updates for other clients' workflows. " \
|
||||
f"Expected only {prompt_id_b}, but got messages for multiple prompts."
|
||||
|
||||
# Verify each client received their own progress updates
|
||||
client_a_messages = client_a.progress_tracker.get_messages_for_prompt(prompt_id_a)
|
||||
client_b_messages = client_b.progress_tracker.get_messages_for_prompt(prompt_id_b)
|
||||
|
||||
assert len(client_a_messages) > 0, \
|
||||
"Client A did not receive any progress updates for its own workflow"
|
||||
assert len(client_b_messages) > 0, \
|
||||
"Client B did not receive any progress updates for its own workflow"
|
||||
|
||||
# Ensure no cross-contamination
|
||||
client_a_other = client_a.progress_tracker.get_messages_for_prompt(prompt_id_b)
|
||||
client_b_other = client_b.progress_tracker.get_messages_for_prompt(prompt_id_a)
|
||||
|
||||
assert len(client_a_other) == 0, \
|
||||
f"Client A incorrectly received {len(client_a_other)} progress updates for Client B's workflow"
|
||||
assert len(client_b_other) == 0, \
|
||||
f"Client B incorrectly received {len(client_b_other)} progress updates for Client A's workflow"
|
||||
|
||||
finally:
|
||||
# Clean up connections
|
||||
if hasattr(client_a, 'ws'):
|
||||
client_a.ws.close()
|
||||
if hasattr(client_b, 'ws'):
|
||||
client_b.ws.close()
|
||||
|
||||
def test_progress_with_missing_client_id(self, args_pytest):
|
||||
"""Test that progress updates handle missing client_id gracefully."""
|
||||
listen = args_pytest["listen"]
|
||||
port = args_pytest["port"]
|
||||
|
||||
try:
|
||||
# Connect client with retries
|
||||
client = self.start_client_with_retry(listen, port)
|
||||
|
||||
# Create a simple workflow
|
||||
graph = GraphBuilder(prefix="test_missing_id")
|
||||
image = graph.node("StubImage", content="BLACK", height=128, width=128, batch_size=1)
|
||||
graph.node("PreviewImage", images=image.out(0))
|
||||
|
||||
# Submit workflow
|
||||
prompt = graph.finalize()
|
||||
response = client.queue_prompt(prompt)
|
||||
prompt_id = response['prompt_id']
|
||||
|
||||
# Listen for messages
|
||||
client.listen_for_messages(duration=5.0)
|
||||
|
||||
# Should still receive progress updates for own workflow
|
||||
messages = client.progress_tracker.get_messages_for_prompt(prompt_id)
|
||||
assert len(messages) > 0, \
|
||||
"Client did not receive progress updates even though it initiated the workflow"
|
||||
|
||||
finally:
|
||||
if hasattr(client, 'ws'):
|
||||
client.ws.close()
|
||||
|
||||
Reference in New Issue
Block a user