add split sigmas node

This commit is contained in:
Stéphane du Hamel
2025-08-12 13:19:12 +02:00
parent 57987b58aa
commit e5b2576c73
2 changed files with 46 additions and 3 deletions

View File

@ -1,11 +1,13 @@
from .nodes import WanMoeKSampler,WanMoeKSamplerAdvanced
from .nodes import WanMoeKSampler,WanMoeKSamplerAdvanced,SplitSigmasAtT
NODE_CLASS_MAPPINGS = {
"WanMoeKSampler":WanMoeKSampler,
"WanMoeKSamplerAdvanced":WanMoeKSamplerAdvanced
"WanMoeKSamplerAdvanced":WanMoeKSamplerAdvanced,
"SplitSigmasAtT":SplitSigmasAtT
}
NODE_DISPLAY_NAME_MAPPINGS = {
"WanMoeKSampler": "Wan MoE KSampler",
"WanMoeKSamplerAdvanced": "Wan MoE KSampler (Advanced)"
"WanMoeKSamplerAdvanced": "Wan MoE KSampler (Advanced)",
"SplitSigmasAtT": "Split sigmas at timestep"
}

View File

@ -154,3 +154,44 @@ class WanMoeKSamplerAdvanced:
if add_noise == "disable":
disable_noise = True
return wan_ksampler(model_high_noise, model_low_noise, noise_seed, steps, (cfg_high_noise, cfg_low_noise), sampler_name, scheduler, positive, negative, latent_image, boundary=boundary, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
class SplitSigmasAtT:
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"boundary": ("FLOAT", {"default": 0.875, "min": 0.0, "max": 1.0, "step": 0.001, "round":0.001}),
"sigmas": ("SIGMAS", ),
},
"optional":
{
"model": ("MODEL", {"tooltip": "Used to determine the model type. Assumes FLOW model by default if not provided"}),
}
}
RETURN_TYPES = ("SIGMAS", "SIGMAS", "INT", )
RETURN_NAMES = ("high noise sigmas", "low noise sigmas", "split at", )
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "split"
def split(self, boundary, sigmas:torch.Tensor, model = None):
if model is None:
sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow
sampling_type = comfy.model_sampling.CONST
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
sampling = ModelSamplingAdvanced()
else:
sampling = model.get_model_object("model_sampling")
timesteps = [sampling.timestep(sigma)/1000 for sigma in sigmas.tolist()]
switching_step = sigmas.size(0)
for (i,t) in enumerate(timesteps[1:]):
if t < boundary:
switching_step = i
break
print(f"splitting sigmas at index {switching_step}")
return (sigmas[:switching_step + 1], sigmas[switching_step:], switching_step, )