From e5b2576c73d02f5991711275d1c202828133c035 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Tue, 12 Aug 2025 13:19:12 +0200 Subject: [PATCH] add split sigmas node --- __init__.py | 8 +++++--- nodes.py | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/__init__.py b/__init__.py index 1b761cf..3e66e6d 100644 --- a/__init__.py +++ b/__init__.py @@ -1,11 +1,13 @@ -from .nodes import WanMoeKSampler,WanMoeKSamplerAdvanced +from .nodes import WanMoeKSampler,WanMoeKSamplerAdvanced,SplitSigmasAtT NODE_CLASS_MAPPINGS = { "WanMoeKSampler":WanMoeKSampler, - "WanMoeKSamplerAdvanced":WanMoeKSamplerAdvanced + "WanMoeKSamplerAdvanced":WanMoeKSamplerAdvanced, + "SplitSigmasAtT":SplitSigmasAtT } NODE_DISPLAY_NAME_MAPPINGS = { "WanMoeKSampler": "Wan MoE KSampler", - "WanMoeKSamplerAdvanced": "Wan MoE KSampler (Advanced)" + "WanMoeKSamplerAdvanced": "Wan MoE KSampler (Advanced)", + "SplitSigmasAtT": "Split sigmas at timestep" } \ No newline at end of file diff --git a/nodes.py b/nodes.py index ad96a29..17676b4 100644 --- a/nodes.py +++ b/nodes.py @@ -154,3 +154,44 @@ class WanMoeKSamplerAdvanced: if add_noise == "disable": disable_noise = True return wan_ksampler(model_high_noise, model_low_noise, noise_seed, steps, (cfg_high_noise, cfg_low_noise), sampler_name, scheduler, positive, negative, latent_image, boundary=boundary, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise) + +class SplitSigmasAtT: + @classmethod + def INPUT_TYPES(s): + return {"required": + { + "boundary": ("FLOAT", {"default": 0.875, "min": 0.0, "max": 1.0, "step": 0.001, "round":0.001}), + "sigmas": ("SIGMAS", ), + }, + "optional": + { + "model": ("MODEL", {"tooltip": "Used to determine the model type. Assumes FLOW model by default if not provided"}), + } + } + + RETURN_TYPES = ("SIGMAS", "SIGMAS", "INT", ) + RETURN_NAMES = ("high noise sigmas", "low noise sigmas", "split at", ) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "split" + + def split(self, boundary, sigmas:torch.Tensor, model = None): + if model is None: + sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow + sampling_type = comfy.model_sampling.CONST + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + sampling = ModelSamplingAdvanced() + else: + sampling = model.get_model_object("model_sampling") + timesteps = [sampling.timestep(sigma)/1000 for sigma in sigmas.tolist()] + switching_step = sigmas.size(0) + for (i,t) in enumerate(timesteps[1:]): + if t < boundary: + switching_step = i + break + print(f"splitting sigmas at index {switching_step}") + return (sigmas[:switching_step + 1], sigmas[switching_step:], switching_step, ) + + \ No newline at end of file