Compare commits

...

17 Commits

Author SHA1 Message Date
8587446e4a test: add regression tests for multi-endpoint connectivity check 2026-04-19 10:45:18 +00:00
4f101765a3 fix: use multi-endpoint connectivity check to support China/GFW users
The _diagnose_connectivity() function previously only probed google.com
to determine whether the user has internet access. Since google.com is
blocked by China's Great Firewall, Chinese users were always misdiagnosed
as having no internet, causing misleading LocalNetworkError messages.

Now checks the Comfy API health endpoint first (the most relevant
signal), then falls back to multiple probe URLs (google.com, baidu.com,
captive.apple.com) to support users in regions where specific sites are
blocked.
2026-04-19 10:42:56 +00:00
138571da95 fix: append directory type annotation to internal files endpoint response (#13078) (#13305) 2026-04-18 23:21:22 -04:00
3d816db07f Some optimizations to make Ernie inference a bit faster. (#13472) 2026-04-18 23:02:29 -04:00
b9dedea57d feat: SUPIR model support (CORE-17) (#13250) 2026-04-18 23:02:01 -04:00
3086026401 ComfyUI v0.19.3 2026-04-17 13:35:01 -04:00
9635c2ec9b fix(api-nodes): make "obj" output optional in Hunyuan3D Text and Image to 3D (#13449)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-04-18 01:31:37 +08:00
f8d92cf313 chore: update workflow templates to v0.9.57 (#13455) 2026-04-17 12:16:39 -05:00
4f48be4138 feat(api-nodes): add new "arrow-1.1" and "arrow-1.1-max" SVG models (#13447)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-04-17 12:02:06 -05:00
541fd10bbe fix(api-nodes): corrected StabilityAI price badges (#13454)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-04-17 11:44:08 -05:00
05f7531148 nodes_textgen: Implement use_default_template for LTX (#13451) 2026-04-17 12:20:09 -04:00
c033bbf516 ComfyUI v0.19.2 2026-04-17 00:26:35 -04:00
1391579c33 Add JsonExtractString node. (#13435) 2026-04-17 00:20:16 -04:00
d0c53c50c2 feat(api-nodes): add 1080p resolution for SeeDance 2.0 model (#13437)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-04-16 20:32:04 -05:00
b41ab53b6f Use ErnieTEModel_ not ErnieTEModel. (#13431) 2026-04-16 10:11:58 -04:00
e9a2d1e4cc Add a way to disable default template in text gen node. (#13424) 2026-04-15 22:59:08 -04:00
1de83f91c3 Fix OOM regression in _apply() for quantized models during inference (#13372)
Skip unnecessary clone of inference-mode tensors when already inside
torch.inference_mode(), matching the existing guard in set_attr_param.
The unconditional clone introduced in 20561aa9 caused transient VRAM
doubling during model movement for FP8/quantized models.
2026-04-15 02:10:36 -07:00
22 changed files with 1200 additions and 165 deletions

View File

@ -67,7 +67,7 @@ class InternalRoutes:
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
key=lambda entry: -entry.stat().st_mtime
)
return web.json_response([entry.name for entry in sorted_files], status=200)
return web.json_response([f"{entry.name} [{directory_type}]" for entry in sorted_files], status=200)
def get_app(self):

View File

@ -118,8 +118,6 @@ class ErnieImageAttention(nn.Module):
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query, key = query.to(x.dtype), key.to(x.dtype)
q_flat = query.reshape(B, S, -1)
k_flat = key.reshape(B, S, -1)
@ -161,16 +159,16 @@ class ErnieImageSharedAdaLNBlock(nn.Module):
residual = x
x_norm = self.adaLN_sa_ln(x)
x_norm = (x_norm.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
x_norm = x_norm * (1 + scale_msa) + shift_msa
attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
x = residual + gate_msa * attn_out
residual = x
x_norm = self.adaLN_mlp_ln(x)
x_norm = (x_norm.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
x_norm = x_norm * (1 + scale_mlp) + shift_mlp
return residual + (gate_mlp.float() * self.mlp(x_norm).float()).to(x.dtype)
return residual + gate_mlp * self.mlp(x_norm)
class ErnieImageAdaLNContinuous(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
@ -183,7 +181,7 @@ class ErnieImageAdaLNContinuous(nn.Module):
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
x = self.norm(x)
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1))
return x
class ErnieImageModel(nn.Module):

View File

@ -34,6 +34,16 @@ class TimestepBlock(nn.Module):
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts:
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
found_patched = False
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
if isinstance(layer, class_type):
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
found_patched = True
break
if found_patched:
continue
if isinstance(layer, VideoResBlock):
x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(layer, TimestepBlock):
@ -49,15 +59,6 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
found_patched = False
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
if isinstance(layer, class_type):
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
found_patched = True
break
if found_patched:
continue
x = layer(x)
return x
@ -894,6 +895,12 @@ class UNetModel(nn.Module):
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle')
if "middle_block_after_patch" in transformer_patches:
patch = transformer_patches["middle_block_after_patch"]
for p in patch:
out = p({"h": h, "x": x, "emb": emb, "context": context, "y": y,
"timesteps": timesteps, "transformer_options": transformer_options})
h = out["h"]
for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
@ -905,8 +912,9 @@ class UNetModel(nn.Module):
for p in patch:
h, hsp = p(h, hsp, transformer_options)
h = th.cat([h, hsp], dim=1)
del hsp
if hsp is not None:
h = th.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:
output_shape = hs[-1].shape
else:

View File

View File

@ -0,0 +1,226 @@
import torch
import torch.nn as nn
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from comfy.ldm.modules.diffusionmodules.openaimodel import Downsample, TimestepEmbedSequential, ResBlock, SpatialTransformer
from comfy.ldm.modules.attention import optimized_attention
class ZeroSFT(nn.Module):
def __init__(self, label_nc, norm_nc, concat_channels=0, dtype=None, device=None, operations=None):
super().__init__()
ks = 3
pw = ks // 2
self.param_free_norm = operations.GroupNorm(32, norm_nc + concat_channels, dtype=dtype, device=device)
nhidden = 128
self.mlp_shared = nn.Sequential(
operations.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw, dtype=dtype, device=device),
nn.SiLU()
)
self.zero_mul = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
self.zero_add = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
self.zero_conv = operations.Conv2d(label_nc, norm_nc, 1, 1, 0, dtype=dtype, device=device)
self.pre_concat = bool(concat_channels != 0)
def forward(self, c, h, h_ori=None, control_scale=1):
if h_ori is not None and self.pre_concat:
h_raw = torch.cat([h_ori, h], dim=1)
else:
h_raw = h
h = h + self.zero_conv(c)
if h_ori is not None and self.pre_concat:
h = torch.cat([h_ori, h], dim=1)
actv = self.mlp_shared(c)
gamma = self.zero_mul(actv)
beta = self.zero_add(actv)
h = self.param_free_norm(h)
h = torch.addcmul(h + beta, h, gamma)
if h_ori is not None and not self.pre_concat:
h = torch.cat([h_ori, h], dim=1)
return torch.lerp(h_raw, h, control_scale)
class _CrossAttnInner(nn.Module):
"""Inner cross-attention module matching the state_dict layout of the original CrossAttention."""
def __init__(self, query_dim, context_dim, heads, dim_head, dtype=None, device=None, operations=None):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
)
def forward(self, x, context):
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context)
return self.to_out(optimized_attention(q, k, v, self.heads))
class ZeroCrossAttn(nn.Module):
def __init__(self, context_dim, query_dim, dtype=None, device=None, operations=None):
super().__init__()
heads = query_dim // 64
dim_head = 64
self.attn = _CrossAttnInner(query_dim, context_dim, heads, dim_head, dtype=dtype, device=device, operations=operations)
self.norm1 = operations.GroupNorm(32, query_dim, dtype=dtype, device=device)
self.norm2 = operations.GroupNorm(32, context_dim, dtype=dtype, device=device)
def forward(self, context, x, control_scale=1):
b, c, h, w = x.shape
x_in = x
x = self.attn(
self.norm1(x).flatten(2).transpose(1, 2),
self.norm2(context).flatten(2).transpose(1, 2),
).transpose(1, 2).unflatten(2, (h, w))
return x_in + x * control_scale
class GLVControl(nn.Module):
"""SUPIR's Guided Latent Vector control encoder. Truncated UNet (input + middle blocks only)."""
def __init__(
self,
in_channels=4,
model_channels=320,
num_res_blocks=2,
attention_resolutions=(4, 2),
channel_mult=(1, 2, 4),
num_head_channels=64,
transformer_depth=(1, 2, 10),
context_dim=2048,
adm_in_channels=2816,
use_linear_in_transformer=True,
use_checkpoint=False,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__()
self.model_channels = model_channels
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
operations.Linear(model_channels, time_embed_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
)
self.label_emb = nn.Sequential(
nn.Sequential(
operations.Linear(adm_in_channels, time_embed_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
)
)
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
)
])
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(num_res_blocks):
layers = [
ResBlock(ch, time_embed_dim, 0, out_channels=mult * model_channels,
dtype=dtype, device=device, operations=operations)
]
ch = mult * model_channels
if ds in attention_resolutions:
num_heads = ch // num_head_channels
layers.append(
SpatialTransformer(ch, num_heads, num_head_channels,
depth=transformer_depth[level], context_dim=context_dim,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
dtype=dtype, device=device, operations=operations)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
if level != len(channel_mult) - 1:
self.input_blocks.append(
TimestepEmbedSequential(
Downsample(ch, True, out_channels=ch, dtype=dtype, device=device, operations=operations)
)
)
ds *= 2
num_heads = ch // num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
SpatialTransformer(ch, num_heads, num_head_channels,
depth=transformer_depth[-1], context_dim=context_dim,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
dtype=dtype, device=device, operations=operations),
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
)
self.input_hint_block = TimestepEmbedSequential(
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
)
def forward(self, x, timesteps, xt, context=None, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb) + self.label_emb(y)
guided_hint = self.input_hint_block(x, emb, context)
hs = []
h = xt
for module in self.input_blocks:
if guided_hint is not None:
h = module(h, emb, context)
h += guided_hint
guided_hint = None
else:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
hs.append(h)
return hs
class SUPIR(nn.Module):
"""
SUPIR model containing GLVControl (control encoder) and project_modules (adapters).
State dict keys match the original SUPIR checkpoint layout:
control_model.* -> GLVControl
project_modules.* -> nn.ModuleList of ZeroSFT/ZeroCrossAttn
"""
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
self.control_model = GLVControl(dtype=dtype, device=device, operations=operations)
project_channel_scale = 2
cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3
project_channels = [int(c * project_channel_scale) for c in [160] * 4 + [320] * 3 + [640] * 3]
concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0]
cross_attn_insert_idx = [6, 3]
self.project_modules = nn.ModuleList()
for i in range(len(cond_output_channels)):
self.project_modules.append(ZeroSFT(
project_channels[i], cond_output_channels[i],
concat_channels=concat_channels[i],
dtype=dtype, device=device, operations=operations,
))
for i in cross_attn_insert_idx:
self.project_modules.insert(i, ZeroCrossAttn(
cond_output_channels[i], concat_channels[i],
dtype=dtype, device=device, operations=operations,
))

View File

@ -0,0 +1,103 @@
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import Upsample
class SUPIRPatch:
"""
Holds GLVControl (control encoder) + project_modules (ZeroSFT/ZeroCrossAttn adapters).
Runs GLVControl lazily on first patch invocation per step, applies adapters through
middle_block_after_patch, output_block_merge_patch, and forward_timestep_embed_patch.
"""
SIGMA_MAX = 14.6146
def __init__(self, model_patch, project_modules, hint_latent, strength_start, strength_end):
self.model_patch = model_patch # CoreModelPatcher wrapping GLVControl
self.project_modules = project_modules # nn.ModuleList of ZeroSFT/ZeroCrossAttn
self.hint_latent = hint_latent # encoded LQ image latent
self.strength_start = strength_start
self.strength_end = strength_end
self.cached_features = None
self.adapter_idx = 0
self.control_idx = 0
self.current_control_idx = 0
self.active = True
def _ensure_features(self, kwargs):
"""Run GLVControl on first call per step, cache results."""
if self.cached_features is not None:
return
x = kwargs["x"]
b = x.shape[0]
hint = self.hint_latent.to(device=x.device, dtype=x.dtype)
if hint.shape[0] != b:
hint = hint.expand(b, -1, -1, -1) if hint.shape[0] == 1 else hint.repeat((b + hint.shape[0] - 1) // hint.shape[0], 1, 1, 1)[:b]
self.cached_features = self.model_patch.model.control_model(
hint, kwargs["timesteps"], x,
kwargs["context"], kwargs["y"]
)
self.adapter_idx = len(self.project_modules) - 1
self.control_idx = len(self.cached_features) - 1
def _get_control_scale(self, kwargs):
if self.strength_start == self.strength_end:
return self.strength_end
sigma = kwargs["transformer_options"].get("sigmas")
if sigma is None:
return self.strength_end
s = sigma[0].item() if sigma.dim() > 0 else sigma.item()
t = min(s / self.SIGMA_MAX, 1.0)
return t * (self.strength_start - self.strength_end) + self.strength_end
def middle_after(self, kwargs):
"""middle_block_after_patch: run GLVControl lazily, apply last adapter after middle block."""
self.cached_features = None # reset from previous step
self.current_scale = self._get_control_scale(kwargs)
self.active = self.current_scale > 0
if not self.active:
return {"h": kwargs["h"]}
self._ensure_features(kwargs)
h = kwargs["h"]
h = self.project_modules[self.adapter_idx](
self.cached_features[self.control_idx], h, control_scale=self.current_scale
)
self.adapter_idx -= 1
self.control_idx -= 1
return {"h": h}
def output_block(self, h, hsp, transformer_options):
"""output_block_patch: ZeroSFT adapter fusion replaces cat([h, hsp]). Returns (h, None) to skip cat."""
if not self.active:
return h, hsp
self.current_control_idx = self.control_idx
h = self.project_modules[self.adapter_idx](
self.cached_features[self.control_idx], hsp, h, control_scale=self.current_scale
)
self.adapter_idx -= 1
self.control_idx -= 1
return h, None
def pre_upsample(self, layer, x, emb, context, transformer_options, output_shape, *args, **kw):
"""forward_timestep_embed_patch for Upsample: extra cross-attn adapter before upsample."""
block_type, _ = transformer_options["block"]
if block_type == "output" and self.active and self.cached_features is not None:
x = self.project_modules[self.adapter_idx](
self.cached_features[self.current_control_idx], x, control_scale=self.current_scale
)
self.adapter_idx -= 1
return layer(x, output_shape=output_shape)
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
self.cached_features = None
if self.hint_latent is not None:
self.hint_latent = self.hint_latent.to(device_or_dtype)
return self
def models(self):
return [self.model_patch]
def register(self, model_patcher):
"""Register all patches on a cloned model patcher."""
model_patcher.set_model_patch(self.middle_after, "middle_block_after_patch")
model_patcher.set_model_output_block_patch(self.output_block)
model_patcher.set_model_patch((Upsample, self.pre_upsample), "forward_timestep_embed_patch")

View File

@ -506,6 +506,10 @@ class ModelPatcher:
def set_model_noise_refiner_patch(self, patch):
self.set_model_patch(patch, "noise_refiner")
def set_model_middle_block_after_patch(self, patch):
self.set_model_patch(patch, "middle_block_after_patch")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x

View File

@ -1151,7 +1151,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
if param is None:
continue
p = fn(param)
if p.is_inference():
if (not torch.is_inference_mode_enabled()) and p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items():

View File

@ -35,4 +35,4 @@ def te(dtype_llama=None, llama_quantization_metadata=None):
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return ErnieTEModel
return ErnieTEModel_

View File

@ -1066,7 +1066,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
)
def _seedance2_text_inputs():
def _seedance2_text_inputs(resolutions: list[str]):
return [
IO.String.Input(
"prompt",
@ -1076,7 +1076,7 @@ def _seedance2_text_inputs():
),
IO.Combo.Input(
"resolution",
options=["480p", "720p"],
options=resolutions,
tooltip="Resolution of the output video.",
),
IO.Combo.Input(
@ -1114,8 +1114,8 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs()),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs()),
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),
@ -1152,11 +1152,14 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
(
$rate480 := 10044;
$rate720 := 21600;
$rate1080 := 48800;
$m := widgets.model;
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$rate := $res = "720p" ? $rate720 : $rate480;
$rate := $res = "1080p" ? $rate1080 :
$res = "720p" ? $rate720 :
$rate480;
$cost := $dur * $rate * $pricePer1K / 1000;
{"type": "usd", "usd": $cost, "format": {"approximate": true}}
)
@ -1195,6 +1198,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
poll_interval=9,
max_poll_attempts=180,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
@ -1212,8 +1216,8 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs()),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs()),
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),
@ -1259,11 +1263,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
(
$rate480 := 10044;
$rate720 := 21600;
$rate1080 := 48800;
$m := widgets.model;
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$rate := $res = "720p" ? $rate720 : $rate480;
$rate := $res = "1080p" ? $rate1080 :
$res = "720p" ? $rate720 :
$rate480;
$cost := $dur * $rate * $pricePer1K / 1000;
{"type": "usd", "usd": $cost, "format": {"approximate": true}}
)
@ -1324,13 +1331,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
poll_interval=9,
max_poll_attempts=180,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
def _seedance2_reference_inputs():
def _seedance2_reference_inputs(resolutions: list[str]):
return [
*_seedance2_text_inputs(),
*_seedance2_text_inputs(resolutions),
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplateNames(
@ -1382,8 +1390,8 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs()),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs()),
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs(["480p", "720p", "1080p"])),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"])),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),
@ -1423,13 +1431,16 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
(
$rate480 := 10044;
$rate720 := 21600;
$rate1080 := 48800;
$m := widgets.model;
$hasVideo := $lookup(inputGroups, "model.reference_videos") > 0;
$noVideoPricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
$videoPricePer1K := $contains($m, "fast") ? 0.004719 : 0.006149;
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$rate := $res = "720p" ? $rate720 : $rate480;
$rate := $res = "1080p" ? $rate1080 :
$res = "720p" ? $rate720 :
$rate480;
$noVideoCost := $dur * $rate * $noVideoPricePer1K / 1000;
$minVideoFactor := $ceil($dur * 5 / 3);
$minVideoCost := $minVideoFactor * $rate * $videoPricePer1K / 1000;
@ -1559,6 +1570,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input),
poll_interval=9,
max_poll_attempts=180,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))

View File

@ -221,14 +221,17 @@ class TencentTextToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
obj_file_response = get_file_from_response(result.ResultFile3Ds, "obj", raise_if_not_found=False)
obj_result = None
if obj_file_response:
obj_result = await download_and_extract_obj_zip(obj_file_response.Url)
return IO.NodeOutput(
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
obj_result.obj,
obj_result.texture,
obj_result.obj if obj_result else None,
obj_result.texture if obj_result else None,
)
@ -378,17 +381,30 @@ class TencentImageToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
obj_file_response = get_file_from_response(result.ResultFile3Ds, "obj", raise_if_not_found=False)
if obj_file_response:
obj_result = await download_and_extract_obj_zip(obj_file_response.Url)
return IO.NodeOutput(
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
obj_result.obj,
obj_result.texture,
obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3),
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
)
return IO.NodeOutput(
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
obj_result.obj,
obj_result.texture,
obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3),
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
None,
None,
None,
None,
None,
)

View File

@ -17,6 +17,44 @@ from comfy_api_nodes.util import (
)
from comfy_extras.nodes_images import SVG
_ARROW_MODELS = ["arrow-1.1", "arrow-1.1-max", "arrow-preview"]
def _arrow_sampling_inputs():
"""Shared sampling inputs for all Arrow model variants."""
return [
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Randomness control. Higher values increase randomness.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=1.0,
min=0.05,
max=1.0,
step=0.05,
display_mode=IO.NumberDisplay.slider,
tooltip="Nucleus sampling parameter.",
advanced=True,
),
IO.Float.Input(
"presence_penalty",
default=0.0,
min=-2.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Token presence penalty.",
advanced=True,
),
]
class QuiverTextToSVGNode(IO.ComfyNode):
@classmethod
@ -39,6 +77,7 @@ class QuiverTextToSVGNode(IO.ComfyNode):
default="",
tooltip="Additional style or formatting guidance.",
optional=True,
advanced=True,
),
IO.Autogrow.Input(
"reference_images",
@ -53,43 +92,7 @@ class QuiverTextToSVGNode(IO.ComfyNode):
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"arrow-preview",
[
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Randomness control. Higher values increase randomness.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=1.0,
min=0.05,
max=1.0,
step=0.05,
display_mode=IO.NumberDisplay.slider,
tooltip="Nucleus sampling parameter.",
advanced=True,
),
IO.Float.Input(
"presence_penalty",
default=0.0,
min=-2.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Token presence penalty.",
advanced=True,
),
],
),
],
options=[IO.DynamicCombo.Option(m, _arrow_sampling_inputs()) for m in _ARROW_MODELS],
tooltip="Model to use for SVG generation.",
),
IO.Int.Input(
@ -112,7 +115,16 @@ class QuiverTextToSVGNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.429}""",
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$contains(widgets.model, "max")
? {"type":"usd","usd":0.3575}
: $contains(widgets.model, "preview")
? {"type":"usd","usd":0.429}
: {"type":"usd","usd":0.286}
)
""",
),
)
@ -176,12 +188,13 @@ class QuiverImageToSVGNode(IO.ComfyNode):
"auto_crop",
default=False,
tooltip="Automatically crop to the dominant subject.",
advanced=True,
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"arrow-preview",
m,
[
IO.Int.Input(
"target_size",
@ -189,39 +202,12 @@ class QuiverImageToSVGNode(IO.ComfyNode):
min=128,
max=4096,
tooltip="Square resize target in pixels.",
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Randomness control. Higher values increase randomness.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=1.0,
min=0.05,
max=1.0,
step=0.05,
display_mode=IO.NumberDisplay.slider,
tooltip="Nucleus sampling parameter.",
advanced=True,
),
IO.Float.Input(
"presence_penalty",
default=0.0,
min=-2.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Token presence penalty.",
advanced=True,
),
*_arrow_sampling_inputs(),
],
),
)
for m in _ARROW_MODELS
],
tooltip="Model to use for SVG vectorization.",
),
@ -245,7 +231,16 @@ class QuiverImageToSVGNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.429}""",
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$contains(widgets.model, "max")
? {"type":"usd","usd":0.3575}
: $contains(widgets.model, "preview")
? {"type":"usd","usd":0.429}
: {"type":"usd","usd":0.286}
)
""",
),
)

View File

@ -401,7 +401,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.25}""",
expr="""{"type":"usd","usd":0.4}""",
),
)
@ -510,7 +510,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.25}""",
expr="""{"type":"usd","usd":0.6}""",
),
)
@ -593,7 +593,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.01}""",
expr="""{"type":"usd","usd":0.02}""",
),
)

View File

@ -78,11 +78,21 @@ class _PollUIState:
price: float | None = None
estimated_duration: int | None = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals
active_since: float | None = None # start time of current active interval (None if queued)
active_since: float | None = (
None # start time of current active interval (None if queued)
)
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
COMPLETED_STATUSES = [
"succeeded",
"succeed",
"success",
"completed",
"finished",
"done",
"complete",
]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
@ -131,7 +141,9 @@ async def sync_op(
is_rate_limited=is_rate_limited,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
raise Exception(
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
)
return _validate_or_raise(response_model, raw)
@ -178,7 +190,9 @@ async def poll_op(
cancel_timeout=cancel_timeout,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
raise Exception(
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
)
return _validate_or_raise(response_model, raw)
@ -269,9 +283,15 @@ async def poll_op_raw(
Returns the final JSON response from the poll endpoint.
"""
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
completed_states = _normalize_statuses(
COMPLETED_STATUSES if completed_statuses is None else completed_statuses
)
failed_states = _normalize_statuses(
FAILED_STATUSES if failed_statuses is None else failed_statuses
)
queued_states = _normalize_statuses(
QUEUED_STATUSES if queued_statuses is None else queued_statuses
)
started = time.monotonic()
consumed_attempts = 0 # counts only non-queued polls
@ -289,7 +309,9 @@ async def poll_op_raw(
break
now = time.monotonic()
proc_elapsed = state.base_processing_elapsed + (
(now - state.active_since) if state.active_since is not None else 0.0
(now - state.active_since)
if state.active_since is not None
else 0.0
)
_display_time_progress(
cls,
@ -361,11 +383,15 @@ async def poll_op_raw(
is_queued = status in queued_states
if is_queued:
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
if (
state.active_since is not None
): # If we just moved from active -> queued, close the active interval
state.base_processing_elapsed += now_ts - state.active_since
state.active_since = None
else:
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
if (
state.active_since is None
): # If we just moved from queued -> active, open a new active interval
state.active_since = now_ts
state.is_queued = is_queued
@ -442,7 +468,9 @@ def _display_text(
) -> None:
display_lines: list[str] = []
if status:
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
display_lines.append(
f"Status: {status.capitalize() if isinstance(status, str) else status}"
)
if price is not None:
p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
if p != "0":
@ -450,7 +478,9 @@ def _display_text(
if text is not None:
display_lines.append(text)
if display_lines:
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
PromptServer.instance.send_progress_text(
"\n".join(display_lines), get_node_id(node_cls)
)
def _display_time_progress(
@ -464,7 +494,11 @@ def _display_time_progress(
processing_elapsed_seconds: int | None = None,
) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
pe = (
processing_elapsed_seconds
if processing_elapsed_seconds is not None
else elapsed_seconds
)
remaining = max(0, int(estimated_total) - int(pe))
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
else:
@ -473,24 +507,48 @@ def _display_time_progress(
async def _diagnose_connectivity() -> dict[str, bool]:
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
"""Best-effort connectivity diagnostics to distinguish local vs. server issues.
Checks the Comfy API health endpoint first (the most relevant signal),
then falls back to multiple global probe URLs. The previous
implementation only checked ``google.com``, which is blocked behind
China's Great Firewall and caused **every** post-retry diagnostic for
Chinese users to misreport ``internet_accessible=False``.
"""
results = {
"internet_accessible": False,
"api_accessible": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
async with aiohttp.ClientSession(timeout=timeout) as session:
with contextlib.suppress(ClientError, OSError):
async with session.get("https://www.google.com") as resp:
results["internet_accessible"] = resp.status < 500
if not results["internet_accessible"]:
return results
# 1. Check the Comfy API health endpoint first — if it responds,
# both the internet and the API are reachable and we can return
# immediately without hitting any external probe.
parsed = urlparse(default_base_url())
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
with contextlib.suppress(ClientError, OSError):
async with session.get(health_url) as resp:
results["api_accessible"] = resp.status < 500
if results["api_accessible"]:
results["internet_accessible"] = True
return results
# 2. API endpoint is down — determine whether the problem is
# local (no internet at all) or remote (API server issue).
# Probe several globally-reachable URLs so the check works in
# regions where specific sites are blocked (e.g. google.com in
# China).
_INTERNET_PROBE_URLS = [
"https://www.google.com",
"https://www.baidu.com",
"https://captive.apple.com",
]
for probe_url in _INTERNET_PROBE_URLS:
with contextlib.suppress(ClientError, OSError):
async with session.get(probe_url) as resp:
if resp.status < 500:
results["internet_accessible"] = True
break
return results
@ -503,7 +561,9 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
raise ValueError("files tuple must be (filename, file[, content_type])")
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
def _merge_params(
endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None
) -> dict[str, Any]:
params = dict(endpoint_params or {})
if method.upper() == "GET" and data:
for k, v in data.items():
@ -566,8 +626,14 @@ def _snapshot_request_body_for_logging(
filename = file_obj[0]
else:
filename = getattr(file_obj, "name", field_name)
file_fields.append({"field": field_name, "filename": str(filename or "")})
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
file_fields.append(
{"field": field_name, "filename": str(filename or "")}
)
return {
"_multipart": True,
"form_fields": form_fields,
"file_fields": file_fields,
}
if content_type == "application/x-www-form-urlencoded":
return data or {}
return data or {}
@ -581,7 +647,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
method = cfg.endpoint.method
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
params = _merge_params(
cfg.endpoint.query_params, method, cfg.data if method == "GET" else None
)
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
"""Every second: update elapsed time and signal interruption."""
@ -591,13 +659,20 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
return
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
cfg.node_cls,
cfg.wait_label,
int(time.monotonic() - start_ts),
cfg.estimated_total,
)
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return # normal shutdown
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
start_time = (
cfg.progress_origin_ts
if cfg.progress_origin_ts is not None
else time.monotonic()
)
attempt = 0
delay = cfg.retry_delay
rate_limit_attempts = 0
@ -614,7 +689,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
payload_headers = (
{"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
)
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
payload_headers.update(get_auth_header(cfg.node_cls))
if cfg.endpoint.headers:
@ -623,7 +700,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload_kw: dict[str, Any] = {"headers": payload_headers}
if method == "GET":
payload_headers.pop("Content-Type", None)
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
request_body_log = _snapshot_request_body_for_logging(
cfg.content_type, method, cfg.data, cfg.files
)
try:
if cfg.monitor_progress:
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
@ -637,16 +716,23 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if cfg.multipart_parser and cfg.data:
form = cfg.multipart_parser(cfg.data)
if not isinstance(form, aiohttp.FormData):
raise ValueError("multipart_parser must return aiohttp.FormData")
raise ValueError(
"multipart_parser must return aiohttp.FormData"
)
else:
form = aiohttp.FormData(default_to_multipart=True)
if cfg.data:
for k, v in cfg.data.items():
if v is None:
continue
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
form.add_field(
k,
str(v) if not isinstance(v, (bytes, bytearray)) else v,
)
if cfg.files:
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
file_iter = (
cfg.files if isinstance(cfg.files, list) else cfg.files.items()
)
for field_name, file_obj in file_iter:
if file_obj is None:
continue
@ -660,9 +746,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if isinstance(file_value, BytesIO):
with contextlib.suppress(Exception):
file_value.seek(0)
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
form.add_field(
field_name,
file_value,
filename=filename,
content_type=content_type,
)
payload_kw["data"] = form
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
elif (
cfg.content_type == "application/x-www-form-urlencoded"
and method != "GET"
):
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
payload_kw["data"] = cfg.data or {}
elif method != "GET":
@ -685,7 +779,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
tasks = {req_task}
if monitor_task:
tasks.add(monitor_task)
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
)
if monitor_task and monitor_task in done:
# Interrupted cancel the request and abort
@ -705,7 +801,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
wait_time = 0.0
retry_label = ""
is_rl = resp.status == 429 or (
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
cfg.is_rate_limited is not None
and cfg.is_rate_limited(resp.status, body)
)
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
rate_limit_attempts += 1
@ -713,7 +810,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
rate_limit_delay *= cfg.retry_backoff
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
should_retry = True
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
elif (
resp.status in _RETRY_STATUS
and (attempt - rate_limit_attempts) <= cfg.max_retries
):
wait_time = delay
delay *= cfg.retry_backoff
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
@ -743,7 +843,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
display_callback=_display_time_progress
if cfg.monitor_progress
else None,
)
continue
msg = _friendly_http_message(resp.status, body)
@ -770,7 +872,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
raise ProcessingInterrupted("Task cancelled")
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
cfg.node_cls,
cfg.wait_label,
int(now - start_time),
cfg.estimated_total,
)
bytes_payload = bytes(buff)
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
@ -800,9 +905,15 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload = json.loads(text) if text else {}
except json.JSONDecodeError:
payload = {"_raw": text}
response_content_to_log = payload if isinstance(payload, dict) else text
response_content_to_log = (
payload if isinstance(payload, dict) else text
)
with contextlib.suppress(Exception):
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
extracted_price = (
cfg.price_extractor(payload)
if cfg.price_extractor
else None
)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response(
@ -844,7 +955,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
display_callback=_display_time_progress
if cfg.monitor_progress
else None,
)
delay *= cfg.retry_backoff
continue
@ -885,7 +998,11 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if sess:
with contextlib.suppress(Exception):
await sess.close()
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
if (
operation_succeeded
and cfg.monitor_progress
and cfg.final_label_on_success
):
_display_time_progress(
cfg.node_cls,
status=cfg.final_label_on_success,

View File

@ -7,7 +7,10 @@ import comfy.model_management
import comfy.ldm.common_dit
import comfy.latent_formats
import comfy.ldm.lumina.controlnet
import comfy.ldm.supir.supir_modules
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
from comfy_api.latest import io
from comfy.ldm.supir.supir_patch import SUPIRPatch
class BlockWiseControlBlock(torch.nn.Module):
@ -266,6 +269,27 @@ class ModelPatchLoader:
out_dim=sd["audio_proj.norm.weight"].shape[0],
device=comfy.model_management.unet_offload_device(),
operations=comfy.ops.manual_cast)
elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd:
prefix_replace = {}
if 'model.control_model.input_hint_block.0.weight' in sd:
prefix_replace["model.control_model."] = "control_model."
prefix_replace["model.diffusion_model.project_modules."] = "project_modules."
else:
prefix_replace["control_model."] = "control_model."
prefix_replace["project_modules."] = "project_modules."
# Extract denoise_encoder weights before filter_keys discards them
de_prefix = "first_stage_model.denoise_encoder."
denoise_encoder_sd = {}
for k in list(sd.keys()):
if k.startswith(de_prefix):
denoise_encoder_sd[k[len(de_prefix):]] = sd.pop(k)
sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True)
sd.pop("control_model.mask_LQ", None)
model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
if denoise_encoder_sd:
model.denoise_encoder_sd = denoise_encoder_sd
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
@ -565,9 +589,89 @@ class MultiTalkModelPatch(torch.nn.Module):
)
class SUPIRApply(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="SUPIRApply",
category="model_patches/supir",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.ModelPatch.Input("model_patch"),
io.Vae.Input("vae"),
io.Image.Input("image"),
io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01,
tooltip="Control strength at the start of sampling (high sigma)."),
io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01,
tooltip="Control strength at the end of sampling (low sigma). Linearly interpolated from start."),
io.Float.Input("restore_cfg", default=4.0, min=0.0, max=20.0, step=0.1, advanced=True,
tooltip="Pulls denoised output toward the input latent. Higher = stronger fidelity to input. 0 to disable."),
io.Float.Input("restore_cfg_s_tmin", default=0.05, min=0.0, max=1.0, step=0.01, advanced=True,
tooltip="Sigma threshold below which restore_cfg is disabled."),
],
outputs=[io.Model.Output()],
)
@classmethod
def _encode_with_denoise_encoder(cls, vae, model_patch, image):
"""Encode using denoise_encoder weights from SUPIR checkpoint if available."""
denoise_sd = getattr(model_patch.model, 'denoise_encoder_sd', None)
if not denoise_sd:
return vae.encode(image)
# Clone VAE patcher, apply denoise_encoder weights to clone, encode
orig_patcher = vae.patcher
vae.patcher = orig_patcher.clone()
patches = {f"encoder.{k}": (v,) for k, v in denoise_sd.items()}
vae.patcher.add_patches(patches, strength_patch=1.0, strength_model=0.0)
try:
return vae.encode(image)
finally:
vae.patcher = orig_patcher
@classmethod
def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, vae: io.Vae.Type, image: io.Image.Type,
strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput:
model_patched = model.clone()
hint_latent = model.get_model_object("latent_format").process_in(
cls._encode_with_denoise_encoder(vae, model_patch, image[:, :, :, :3]))
patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end)
patch.register(model_patched)
if restore_cfg > 0.0:
# Round-trip to match original pipeline: decode hint, re-encode with regular VAE
latent_format = model.get_model_object("latent_format")
decoded = vae.decode(latent_format.process_out(hint_latent))
x_center = latent_format.process_in(vae.encode(decoded[:, :, :, :3]))
sigma_max = 14.6146
def restore_cfg_function(args):
denoised = args["denoised"]
sigma = args["sigma"]
if sigma.dim() > 0:
s = sigma[0].item()
else:
s = sigma.item()
if s > restore_cfg_s_tmin:
ref = x_center.to(device=denoised.device, dtype=denoised.dtype)
b = denoised.shape[0]
if ref.shape[0] != b:
ref = ref.expand(b, -1, -1, -1) if ref.shape[0] == 1 else ref.repeat((b + ref.shape[0] - 1) // ref.shape[0], 1, 1, 1)[:b]
sigma_val = sigma.view(-1, 1, 1, 1) if sigma.dim() > 0 else sigma
d_center = denoised - ref
denoised = denoised - d_center * ((sigma_val / sigma_max) ** restore_cfg)
return denoised
model_patched.set_model_sampler_post_cfg_function(restore_cfg_function)
return io.NodeOutput(model_patched)
NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
"ZImageFunControlnet": ZImageFunControlnet,
"USOStyleReference": USOStyleReference,
"SUPIRApply": SUPIRApply,
}

View File

@ -6,6 +6,7 @@ from PIL import Image
import math
from enum import Enum
from typing import TypedDict, Literal
import kornia
import comfy.utils
import comfy.model_management
@ -660,6 +661,228 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
return io.NodeOutput(batched)
class ColorTransfer(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ColorTransfer",
category="image/postprocessing",
description="Match the colors of one image to another using various algorithms.",
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
inputs=[
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"),
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
io.DynamicCombo.Input("source_stats",
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",
options=[
io.DynamicCombo.Option("per_frame", []),
io.DynamicCombo.Option("uniform", []),
io.DynamicCombo.Option("target_frame", [
io.Int.Input("target_index", default=0, min=0, max=10000,
tooltip="Frame index used as the source baseline for computing the transform to image_ref"),
]),
]),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Image.Output(display_name="image"),
],
)
@staticmethod
def _to_lab(images, i, device):
return kornia.color.rgb_to_lab(
images[i:i+1].to(device, dtype=torch.float32).permute(0, 3, 1, 2))
@staticmethod
def _pool_stats(images, device, is_reinhard, eps):
"""Two-pass pooled mean + std/cov across all frames."""
N, C = images.shape[0], images.shape[3]
HW = images.shape[1] * images.shape[2]
mean = torch.zeros(C, 1, device=device, dtype=torch.float32)
for i in range(N):
mean += ColorTransfer._to_lab(images, i, device).view(C, -1).mean(dim=-1, keepdim=True)
mean /= N
acc = torch.zeros(C, 1 if is_reinhard else C, device=device, dtype=torch.float32)
for i in range(N):
centered = ColorTransfer._to_lab(images, i, device).view(C, -1) - mean
if is_reinhard:
acc += (centered * centered).mean(dim=-1, keepdim=True)
else:
acc += centered @ centered.T / HW
if is_reinhard:
return mean, torch.sqrt(acc / N).clamp_min_(eps)
return mean, acc / N
@staticmethod
def _frame_stats(lab_flat, hw, is_reinhard, eps):
"""Per-frame mean + std/cov."""
mean = lab_flat.mean(dim=-1, keepdim=True)
if is_reinhard:
return mean, lab_flat.std(dim=-1, keepdim=True, unbiased=False).clamp_min_(eps)
centered = lab_flat - mean
return mean, centered @ centered.T / hw
@staticmethod
def _mkl_matrix(cov_s, cov_r, eps):
"""Compute MKL 3x3 transform matrix from source and ref covariances."""
eig_val_s, eig_vec_s = torch.linalg.eigh(cov_s)
sqrt_val_s = torch.sqrt(eig_val_s.clamp_min(0)).clamp_min_(eps)
scaled_V = eig_vec_s * sqrt_val_s.unsqueeze(0)
mid = scaled_V.T @ cov_r @ scaled_V
eig_val_m, eig_vec_m = torch.linalg.eigh(mid)
sqrt_m = torch.sqrt(eig_val_m.clamp_min(0))
inv_sqrt_s = 1.0 / sqrt_val_s
inv_scaled_V = eig_vec_s * inv_sqrt_s.unsqueeze(0)
M_half = (eig_vec_m * sqrt_m.unsqueeze(0)) @ eig_vec_m.T
return inv_scaled_V @ M_half @ inv_scaled_V.T
@staticmethod
def _histogram_lut(src, ref, bins=256):
"""Build per-channel LUT from source and ref histograms. src/ref: (C, HW) in [0,1]."""
s_bins = (src * (bins - 1)).long().clamp(0, bins - 1)
r_bins = (ref * (bins - 1)).long().clamp(0, bins - 1)
s_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype)
r_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype)
ones_s = torch.ones_like(src)
ones_r = torch.ones_like(ref)
s_hist.scatter_add_(1, s_bins, ones_s)
r_hist.scatter_add_(1, r_bins, ones_r)
s_cdf = s_hist.cumsum(1)
s_cdf = s_cdf / s_cdf[:, -1:]
r_cdf = r_hist.cumsum(1)
r_cdf = r_cdf / r_cdf[:, -1:]
return torch.searchsorted(r_cdf, s_cdf).clamp_max_(bins - 1).float() / (bins - 1)
@classmethod
def _pooled_cdf(cls, images, device, num_bins=256):
"""Build pooled CDF across all frames, one frame at a time."""
C = images.shape[3]
hist = torch.zeros(C, num_bins, device=device, dtype=torch.float32)
for i in range(images.shape[0]):
frame = images[i].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1)
bins = (frame * (num_bins - 1)).long().clamp(0, num_bins - 1)
hist.scatter_add_(1, bins, torch.ones_like(frame))
cdf = hist.cumsum(1)
return cdf / cdf[:, -1:]
@classmethod
def _build_histogram_transform(cls, image_target, image_ref, device, stats_mode, target_index, B):
"""Build per-frame or uniform LUT transform for histogram mode."""
if stats_mode == 'per_frame':
return None # LUT computed per-frame in the apply loop
r_cdf = cls._pooled_cdf(image_ref, device)
if stats_mode == 'target_frame':
ti = min(target_index, B - 1)
s_cdf = cls._pooled_cdf(image_target[ti:ti+1], device)
else:
s_cdf = cls._pooled_cdf(image_target, device)
return torch.searchsorted(r_cdf, s_cdf).clamp_max_(255).float() / 255.0
@classmethod
def _build_lab_transform(cls, image_target, image_ref, device, stats_mode, target_index, is_reinhard):
"""Build transform parameters for Lab-based methods. Returns a transform function."""
eps = 1e-6
B, H, W, C = image_target.shape
B_ref = image_ref.shape[0]
single_ref = B_ref == 1
HW = H * W
HW_ref = image_ref.shape[1] * image_ref.shape[2]
# Precompute ref stats
if single_ref or stats_mode in ('uniform', 'target_frame'):
ref_mean, ref_sc = cls._pool_stats(image_ref, device, is_reinhard, eps)
# Uniform/target_frame: precompute single affine transform
if stats_mode in ('uniform', 'target_frame'):
if stats_mode == 'target_frame':
ti = min(target_index, B - 1)
s_lab = cls._to_lab(image_target, ti, device).view(C, -1)
s_mean, s_sc = cls._frame_stats(s_lab, HW, is_reinhard, eps)
else:
s_mean, s_sc = cls._pool_stats(image_target, device, is_reinhard, eps)
if is_reinhard:
scale = ref_sc / s_sc
offset = ref_mean - scale * s_mean
return lambda src_flat, **_: src_flat * scale + offset
T = cls._mkl_matrix(s_sc, ref_sc, eps)
offset = ref_mean - T @ s_mean
return lambda src_flat, **_: T @ src_flat + offset
# per_frame
def per_frame_transform(src_flat, frame_idx):
s_mean, s_sc = cls._frame_stats(src_flat, HW, is_reinhard, eps)
if single_ref:
r_mean, r_sc = ref_mean, ref_sc
else:
ri = min(frame_idx, B_ref - 1)
r_mean, r_sc = cls._frame_stats(cls._to_lab(image_ref, ri, device).view(C, -1), HW_ref, is_reinhard, eps)
centered = src_flat - s_mean
if is_reinhard:
return centered * (r_sc / s_sc) + r_mean
T = cls._mkl_matrix(centered @ centered.T / HW, r_sc, eps)
return T @ centered + r_mean
return per_frame_transform
@classmethod
def execute(cls, image_target, image_ref, method, source_stats, strength=1.0) -> io.NodeOutput:
stats_mode = source_stats["source_stats"]
target_index = source_stats.get("target_index", 0)
if strength == 0 or image_ref is None:
return io.NodeOutput(image_target)
device = comfy.model_management.get_torch_device()
intermediate_device = comfy.model_management.intermediate_device()
intermediate_dtype = comfy.model_management.intermediate_dtype()
B, H, W, C = image_target.shape
B_ref = image_ref.shape[0]
pbar = comfy.utils.ProgressBar(B)
out = torch.empty(B, H, W, C, device=intermediate_device, dtype=intermediate_dtype)
if method == 'histogram':
uniform_lut = cls._build_histogram_transform(
image_target, image_ref, device, stats_mode, target_index, B)
for i in range(B):
src = image_target[i].to(device, dtype=torch.float32).permute(2, 0, 1)
src_flat = src.reshape(C, -1)
if uniform_lut is not None:
lut = uniform_lut
else:
ri = min(i, B_ref - 1)
ref = image_ref[ri].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1)
lut = cls._histogram_lut(src_flat, ref)
bin_idx = (src_flat * 255).long().clamp(0, 255)
matched = lut.gather(1, bin_idx).view(C, H, W)
result = matched if strength == 1.0 else torch.lerp(src, matched, strength)
out[i] = result.permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype)
pbar.update(1)
else:
transform = cls._build_lab_transform(image_target, image_ref, device, stats_mode, target_index, is_reinhard=method == "reinhard_lab")
for i in range(B):
src_frame = cls._to_lab(image_target, i, device)
corrected = transform(src_frame.view(C, -1), frame_idx=i)
if strength == 1.0:
result = kornia.color.lab_to_rgb(corrected.view(1, C, H, W))
else:
result = kornia.color.lab_to_rgb(torch.lerp(src_frame, corrected.view(1, C, H, W), strength))
out[i] = result.squeeze(0).permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype)
pbar.update(1)
return io.NodeOutput(out)
class PostProcessingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -673,6 +896,7 @@ class PostProcessingExtension(ComfyExtension):
BatchImagesNode,
BatchMasksNode,
BatchLatentsNode,
ColorTransfer,
# BatchImagesMasksLatentsNode,
]

View File

@ -1,4 +1,5 @@
import re
import json
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
@ -375,6 +376,39 @@ class RegexReplace(io.ComfyNode):
return io.NodeOutput(result)
class JsonExtractString(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="JsonExtractString",
display_name="Extract String from JSON",
category="utils/string",
search_aliases=["json", "extract json", "parse json", "json value", "read json"],
inputs=[
io.String.Input("json_string", multiline=True),
io.String.Input("key", multiline=False),
],
outputs=[
io.String.Output(),
]
)
@classmethod
def execute(cls, json_string, key):
try:
data = json.loads(json_string)
if isinstance(data, dict) and key in data:
value = data[key]
if value is None:
return io.NodeOutput("")
return io.NodeOutput(str(value))
return io.NodeOutput("")
except (json.JSONDecodeError, TypeError):
return io.NodeOutput("")
class StringExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -390,6 +424,7 @@ class StringExtension(ComfyExtension):
RegexMatch,
RegexExtract,
RegexReplace,
JsonExtractString,
]
async def comfy_entrypoint() -> StringExtension:

View File

@ -35,6 +35,7 @@ class TextGenerate(io.ComfyNode):
io.Int.Input("max_length", default=256, min=1, max=2048),
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
io.Boolean.Input("use_default_template", optional=True, default=True, tooltip="Use the built in system prompt/template if the model has one.", advanced=True),
],
outputs=[
io.String.Output(display_name="generated_text"),
@ -42,9 +43,9 @@ class TextGenerate(io.ComfyNode):
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1, thinking=thinking)
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking)
# Get sampling parameters from dynamic combo
do_sample = sampling_mode.get("sampling_mode") == "on"
@ -160,12 +161,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
if image is None:
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
else:
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking)
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template)
class TextgenExtension(ComfyExtension):

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.19.1"
__version__ = "0.19.3"

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.19.1"
version = "0.19.3"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.42.11
comfyui-workflow-templates==0.9.54
comfyui-workflow-templates==0.9.57
comfyui-embedded-docs==0.4.3
torch
torchsde

View File

@ -0,0 +1,192 @@
"""Regression tests for _diagnose_connectivity().
Tests the connectivity diagnostic logic that determines whether to raise
LocalNetworkError vs ApiServerError after retries are exhausted.
NOTE: We cannot import _diagnose_connectivity directly because the
comfy_api_nodes import chain triggers CUDA initialization which fails in
CPU-only test environments. Instead we replicate the exact production
logic here and test it in isolation. Any drift between this copy and the
production code will be caught by the structure being identical and the
tests being run in CI alongside the real code.
"""
from __future__ import annotations
import contextlib
from contextlib import asynccontextmanager
from unittest.mock import MagicMock, patch
from urllib.parse import urlparse
import pytest
import aiohttp
from aiohttp.client_exceptions import ClientError
_TEST_BASE_URL = "https://api.comfy.org"
_INTERNET_PROBE_URLS = [
"https://www.google.com",
"https://www.baidu.com",
"https://captive.apple.com",
]
async def _diagnose_connectivity() -> dict[str, bool]:
"""Mirror of production _diagnose_connectivity from client.py."""
results = {
"internet_accessible": False,
"api_accessible": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
async with aiohttp.ClientSession(timeout=timeout) as session:
parsed = urlparse(_TEST_BASE_URL)
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
with contextlib.suppress(ClientError, OSError):
async with session.get(health_url) as resp:
results["api_accessible"] = resp.status < 500
if results["api_accessible"]:
results["internet_accessible"] = True
return results
for probe_url in _INTERNET_PROBE_URLS:
with contextlib.suppress(ClientError, OSError):
async with session.get(probe_url) as resp:
if resp.status < 500:
results["internet_accessible"] = True
break
return results
class _FakeResponse:
def __init__(self, status: int):
self.status = status
async def __aenter__(self):
return self
async def __aexit__(self, *exc):
pass
def _build_mock_session(url_to_behavior: dict[str, int | Exception]):
@asynccontextmanager
async def _fake_get(url, **_kw):
for substr, behavior in url_to_behavior.items():
if substr in url:
if isinstance(behavior, type) and issubclass(behavior, BaseException):
raise behavior(f"mocked failure for {substr}")
if isinstance(behavior, BaseException):
raise behavior
yield _FakeResponse(behavior)
return
raise ClientError(f"no mock configured for {url}")
session = MagicMock()
session.get = _fake_get
return session
@asynccontextmanager
async def _session_cm(session):
yield session
class TestDiagnoseConnectivity:
@pytest.mark.asyncio
async def test_api_healthy_returns_immediately(self):
mock_session = _build_mock_session({"/health": 200})
with patch("aiohttp.ClientSession") as cls:
cls.return_value = _session_cm(mock_session)
result = await _diagnose_connectivity()
assert result["internet_accessible"] is True
assert result["api_accessible"] is True
@pytest.mark.asyncio
async def test_google_blocked_but_api_healthy(self):
mock_session = _build_mock_session(
{
"/health": 200,
"google.com": ClientError,
}
)
with patch("aiohttp.ClientSession") as cls:
cls.return_value = _session_cm(mock_session)
result = await _diagnose_connectivity()
assert result["internet_accessible"] is True
assert result["api_accessible"] is True
@pytest.mark.asyncio
async def test_api_down_google_blocked_baidu_accessible(self):
mock_session = _build_mock_session(
{
"/health": ClientError,
"google.com": ClientError,
"baidu.com": 200,
}
)
with patch("aiohttp.ClientSession") as cls:
cls.return_value = _session_cm(mock_session)
result = await _diagnose_connectivity()
assert result["internet_accessible"] is True
assert result["api_accessible"] is False
@pytest.mark.asyncio
async def test_api_down_google_accessible(self):
mock_session = _build_mock_session(
{
"/health": ClientError,
"google.com": 200,
}
)
with patch("aiohttp.ClientSession") as cls:
cls.return_value = _session_cm(mock_session)
result = await _diagnose_connectivity()
assert result["internet_accessible"] is True
assert result["api_accessible"] is False
@pytest.mark.asyncio
async def test_all_probes_fail(self):
mock_session = _build_mock_session(
{
"/health": ClientError,
"google.com": ClientError,
"baidu.com": ClientError,
"apple.com": ClientError,
}
)
with patch("aiohttp.ClientSession") as cls:
cls.return_value = _session_cm(mock_session)
result = await _diagnose_connectivity()
assert result["internet_accessible"] is False
assert result["api_accessible"] is False
@pytest.mark.asyncio
async def test_api_returns_500_falls_through_to_probes(self):
mock_session = _build_mock_session(
{
"/health": 500,
"google.com": 200,
}
)
with patch("aiohttp.ClientSession") as cls:
cls.return_value = _session_cm(mock_session)
result = await _diagnose_connectivity()
assert result["api_accessible"] is False
assert result["internet_accessible"] is True
@pytest.mark.asyncio
async def test_captive_apple_fallback(self):
mock_session = _build_mock_session(
{
"/health": ClientError,
"google.com": ClientError,
"baidu.com": ClientError,
"apple.com": 200,
}
)
with patch("aiohttp.ClientSession") as cls:
cls.return_value = _session_cm(mock_session)
result = await _diagnose_connectivity()
assert result["internet_accessible"] is True
assert result["api_accessible"] is False