From 69a76ea18120c75b16c2ae327c87bdbe7cc3365f Mon Sep 17 00:00:00 2001 From: logtd Date: Sun, 11 Aug 2024 12:09:48 -0600 Subject: [PATCH] init --- .gitignore | 1 + README | 0 __init__.py | 5 +++ seg_node.py | 101 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 107 insertions(+) create mode 100644 .gitignore create mode 100644 README create mode 100644 __init__.py create mode 100644 seg_node.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0d20b64 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.pyc diff --git a/README b/README new file mode 100644 index 0000000..e69de29 diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..1fe1175 --- /dev/null +++ b/__init__.py @@ -0,0 +1,5 @@ +from .seg_node import SEGAttention + +NODE_CLASS_MAPPINGS = { + "SEGAttention": SEGAttention, +} diff --git a/seg_node.py b/seg_node.py new file mode 100644 index 0000000..cb58475 --- /dev/null +++ b/seg_node.py @@ -0,0 +1,101 @@ + +import math + +from einops import rearrange +import torch +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_patcher +import comfy.samplers + + +def gaussian_blur_2d(img, kernel_size, sigma): + height = img.shape[-1] + kernel_size = min(kernel_size, height - (height % 2 - 1)) + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = F.pad(img, padding, mode="reflect") + img = F.conv2d(img, kernel2d, groups=img.shape[-3]) + + return img + + +class SEGAttention: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}), + "blur": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 999.0, "step": 0.01, "round": 0.01}), + "inf_blur": ("BOOLEAN", {"default": False} ) + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "model_patches/unet" + + def patch(self, model, scale, blur, inf_blur): + m = model.clone() + + def seg_attention(q, k, v, extra_options, mask=None): + _, sequence_length, _ = q.shape + shape = extra_options['original_shape'] + oh, ow = shape[-2:] + ratio = oh/ow + d = sequence_length + w = int((d/ratio)**(0.5)) + h = int(d/w) + q = rearrange(q, 'b (h w) d -> b d w h', h=h) + if not inf_blur: + kernel_size = math.ceil(6 * blur) + 1 - math.ceil(6 * blur) % 2 + q = gaussian_blur_2d(q, kernel_size, blur) + else: + q = q.mean(dim=(-2, -1), keepdim=True) + q = rearrange(q, 'b d w h -> b (h w) d') + return optimized_attention(q, k, v, extra_options['n_heads']) + + def post_cfg_function(args): + model = args["model"] + + cond_pred = args["cond_denoised"] + uncond_pred = args["uncond_denoised"] + + if scale == 0 or blur == 0: + return uncond_pred + (cond_pred - uncond_pred) + + cond = args["cond"] + sigma = args["sigma"] + model_options = args["model_options"].copy() + x = args["input"] + # Hack since comfy doesn't pass in conditionals and unconditionals to cfg_function + # and doesn't pass in cond_scale to post_cfg_function + len_conds = 1 if args.get('uncond', None) is None else 2 + + model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, seg_attention, "attn1", "middle", 0) + (seg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) + + if len_conds == 1: + return cond_pred + scale * (cond_pred - seg) + + return cond_pred + (scale-1.0) * (cond_pred - uncond_pred) + scale * (cond_pred - seg) + + m.set_model_sampler_post_cfg_function(post_cfg_function) + + return (m,) +