add split sigmas node
This commit is contained in:
@ -1,11 +1,13 @@
|
||||
from .nodes import WanMoeKSampler,WanMoeKSamplerAdvanced
|
||||
from .nodes import WanMoeKSampler,WanMoeKSamplerAdvanced,SplitSigmasAtT
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanMoeKSampler":WanMoeKSampler,
|
||||
"WanMoeKSamplerAdvanced":WanMoeKSamplerAdvanced
|
||||
"WanMoeKSamplerAdvanced":WanMoeKSamplerAdvanced,
|
||||
"SplitSigmasAtT":SplitSigmasAtT
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WanMoeKSampler": "Wan MoE KSampler",
|
||||
"WanMoeKSamplerAdvanced": "Wan MoE KSampler (Advanced)"
|
||||
"WanMoeKSamplerAdvanced": "Wan MoE KSampler (Advanced)",
|
||||
"SplitSigmasAtT": "Split sigmas at timestep"
|
||||
}
|
||||
41
nodes.py
41
nodes.py
@ -154,3 +154,44 @@ class WanMoeKSamplerAdvanced:
|
||||
if add_noise == "disable":
|
||||
disable_noise = True
|
||||
return wan_ksampler(model_high_noise, model_low_noise, noise_seed, steps, (cfg_high_noise, cfg_low_noise), sampler_name, scheduler, positive, negative, latent_image, boundary=boundary, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
|
||||
|
||||
class SplitSigmasAtT:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{
|
||||
"boundary": ("FLOAT", {"default": 0.875, "min": 0.0, "max": 1.0, "step": 0.001, "round":0.001}),
|
||||
"sigmas": ("SIGMAS", ),
|
||||
},
|
||||
"optional":
|
||||
{
|
||||
"model": ("MODEL", {"tooltip": "Used to determine the model type. Assumes FLOW model by default if not provided"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("SIGMAS", "SIGMAS", "INT", )
|
||||
RETURN_NAMES = ("high noise sigmas", "low noise sigmas", "split at", )
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "split"
|
||||
|
||||
def split(self, boundary, sigmas:torch.Tensor, model = None):
|
||||
if model is None:
|
||||
sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow
|
||||
sampling_type = comfy.model_sampling.CONST
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
sampling = ModelSamplingAdvanced()
|
||||
else:
|
||||
sampling = model.get_model_object("model_sampling")
|
||||
timesteps = [sampling.timestep(sigma)/1000 for sigma in sigmas.tolist()]
|
||||
switching_step = sigmas.size(0)
|
||||
for (i,t) in enumerate(timesteps[1:]):
|
||||
if t < boundary:
|
||||
switching_step = i
|
||||
break
|
||||
print(f"splitting sigmas at index {switching_step}")
|
||||
return (sigmas[:switching_step + 1], sigmas[switching_step:], switching_step, )
|
||||
|
||||
|
||||
Reference in New Issue
Block a user