mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-19 11:45:17 +08:00
Support Z Image alibaba pai fun controlnets. (#11062)
These are not actual controlnets so put it in the models/model_patches folder and use the ModelPatchLoader + QwenImageDiffsynthControlnet node to use it.
This commit is contained in:
@ -6,6 +6,7 @@ import comfy.ops
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.latent_formats
|
||||
import comfy.ldm.lumina.controlnet
|
||||
|
||||
|
||||
class BlockWiseControlBlock(torch.nn.Module):
|
||||
@ -189,6 +190,35 @@ class SigLIPMultiFeatProjModel(torch.nn.Module):
|
||||
|
||||
return embedding
|
||||
|
||||
def z_image_convert(sd):
|
||||
replace_keys = {".attention.to_out.0.bias": ".attention.out.bias",
|
||||
".attention.norm_k.weight": ".attention.k_norm.weight",
|
||||
".attention.norm_q.weight": ".attention.q_norm.weight",
|
||||
".attention.to_out.0.weight": ".attention.out.weight"
|
||||
}
|
||||
|
||||
out_sd = {}
|
||||
for k in sorted(sd.keys()):
|
||||
w = sd[k]
|
||||
|
||||
k_out = k
|
||||
if k_out.endswith(".attention.to_k.weight"):
|
||||
cc = [w]
|
||||
continue
|
||||
if k_out.endswith(".attention.to_q.weight"):
|
||||
cc = [w] + cc
|
||||
continue
|
||||
if k_out.endswith(".attention.to_v.weight"):
|
||||
cc = cc + [w]
|
||||
w = torch.cat(cc, dim=0)
|
||||
k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight")
|
||||
|
||||
for r, rr in replace_keys.items():
|
||||
k_out = k_out.replace(r, rr)
|
||||
out_sd[k_out] = w
|
||||
|
||||
return out_sd
|
||||
|
||||
class ModelPatchLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -211,6 +241,9 @@ class ModelPatchLoader:
|
||||
elif 'feature_embedder.mid_layer_norm.bias' in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
|
||||
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
|
||||
sd = z_image_convert(sd)
|
||||
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||
|
||||
model.load_state_dict(sd)
|
||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
@ -263,6 +296,69 @@ class DiffSynthCnetPatch:
|
||||
def models(self):
|
||||
return [self.model_patch]
|
||||
|
||||
class ZImageControlPatch:
|
||||
def __init__(self, model_patch, vae, image, strength):
|
||||
self.model_patch = model_patch
|
||||
self.vae = vae
|
||||
self.image = image
|
||||
self.strength = strength
|
||||
self.encoded_image = self.encode_latent_cond(image)
|
||||
self.encoded_image_size = (image.shape[1], image.shape[2])
|
||||
self.temp_data = None
|
||||
|
||||
def encode_latent_cond(self, image):
|
||||
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
|
||||
return latent_image
|
||||
|
||||
def __call__(self, kwargs):
|
||||
x = kwargs.get("x")
|
||||
img = kwargs.get("img")
|
||||
txt = kwargs.get("txt")
|
||||
pe = kwargs.get("pe")
|
||||
vec = kwargs.get("vec")
|
||||
block_index = kwargs.get("block_index")
|
||||
spacial_compression = self.vae.spacial_compression_encode()
|
||||
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
||||
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
|
||||
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
||||
comfy.model_management.load_models_gpu(loaded_models)
|
||||
|
||||
cnet_index = (block_index // 5)
|
||||
cnet_index_float = (block_index / 5)
|
||||
|
||||
kwargs.pop("img") # we do ops in place
|
||||
kwargs.pop("txt")
|
||||
|
||||
cnet_blocks = self.model_patch.model.n_control_layers
|
||||
if cnet_index_float > (cnet_blocks - 1):
|
||||
self.temp_data = None
|
||||
return kwargs
|
||||
|
||||
if self.temp_data is None or self.temp_data[0] > cnet_index:
|
||||
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||
|
||||
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
|
||||
next_layer = self.temp_data[0] + 1
|
||||
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||
|
||||
if cnet_index_float == self.temp_data[0]:
|
||||
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||
if cnet_blocks == self.temp_data[0] + 1:
|
||||
self.temp_data = None
|
||||
|
||||
return kwargs
|
||||
|
||||
def to(self, device_or_dtype):
|
||||
if isinstance(device_or_dtype, torch.device):
|
||||
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
||||
self.temp_data = None
|
||||
return self
|
||||
|
||||
def models(self):
|
||||
return [self.model_patch]
|
||||
|
||||
class QwenImageDiffsynthControlnet:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -289,7 +385,10 @@ class QwenImageDiffsynthControlnet:
|
||||
mask = mask.unsqueeze(2)
|
||||
mask = 1.0 - mask
|
||||
|
||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
||||
model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
|
||||
else:
|
||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||
return (model_patched,)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user