diff --git a/nodes.py b/nodes.py index d827637..ad96a29 100644 --- a/nodes.py +++ b/nodes.py @@ -68,17 +68,16 @@ def wan_ksampler(model_high_noise, model_low_noise, seed, steps, cfgs, sampler_n return (out, ) def set_shift(model,sigma_shift): - - sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow - sampling_type = comfy.model_sampling.CONST + model_sampling = model.get_model_object("model_sampling") + if not model_sampling: + sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow + sampling_type = comfy.model_sampling.CONST + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass - class ModelSamplingAdvanced(sampling_base, sampling_type): - pass - - model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling = ModelSamplingAdvanced(model.model.model_config) model_sampling.set_parameters(shift=sigma_shift, multiplier=1000) model.add_object_patch("model_sampling", model_sampling) - model.add_object_patch("model_sampling", model_sampling) return model class WanMoeKSampler: