mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-28 00:46:52 +08:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4badc89490 |
130
comfy_extras/nodes_lora_stack.py
Normal file
130
comfy_extras/nodes_lora_stack.py
Normal file
@ -0,0 +1,130 @@
|
||||
"""LoRA stacking loaders built on io.DynamicGroup.
|
||||
|
||||
Two nodes that let you stack any number of LoRAs in a single node, each row
|
||||
carrying only a LoRA name and a strength:
|
||||
|
||||
LoadLoraModel
|
||||
Applies a stack of LoRAs to a diffusion MODEL.
|
||||
|
||||
LoadLoraTextEncoder
|
||||
Applies a stack of LoRAs to a CLIP text encoder.
|
||||
|
||||
Both are modelled on DynamicGroupLoraStyleTest in nodes_dynamic_group_test.py,
|
||||
but operate on real models and real LoRA files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
# Module-level cache so repeated executions don't re-read the same file from disk.
|
||||
_LORA_CACHE: dict[str, tuple] = {}
|
||||
|
||||
|
||||
def _load_lora_file(lora_name: str):
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
||||
cached = _LORA_CACHE.get(lora_path)
|
||||
if cached is not None:
|
||||
return cached
|
||||
lora, metadata = comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True)
|
||||
_LORA_CACHE[lora_path] = (lora, metadata)
|
||||
return lora, metadata
|
||||
|
||||
|
||||
def _lora_template() -> list[io.Input]:
|
||||
return [
|
||||
io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras"),
|
||||
tooltip="The name of the LoRA file to apply."),
|
||||
io.Float.Input("strength", default=1.0, min=-100.0, max=100.0, step=0.01,
|
||||
tooltip="How strongly to apply this LoRA. 0 = off, negative inverts the effect."),
|
||||
]
|
||||
|
||||
|
||||
class LoadLoraModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadLoraModel",
|
||||
display_name="Load LoRA (Model)",
|
||||
search_aliases=["lora", "load lora", "apply lora", "lora model", "lora stack"],
|
||||
category="model/loaders",
|
||||
description="Apply a stack of LoRAs to a diffusion model. Add one row per LoRA; "
|
||||
"each row picks a LoRA file and its strength.",
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The diffusion model the LoRAs will be applied to."),
|
||||
io.DynamicGroup.Input(
|
||||
"loras",
|
||||
template=_lora_template(),
|
||||
min=1,
|
||||
max=50,
|
||||
tooltip="Each row applies one LoRA to the model.",
|
||||
group_name="LoRA",
|
||||
),
|
||||
],
|
||||
outputs=[io.Model.Output(tooltip="The modified diffusion model.")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, loras: list[dict]) -> io.NodeOutput:
|
||||
for row in loras:
|
||||
lora_name = row.get("lora_name")
|
||||
strength = row.get("strength", 1.0)
|
||||
if not lora_name or lora_name == "none" or strength == 0:
|
||||
continue
|
||||
lora, metadata = _load_lora_file(lora_name)
|
||||
model, _ = comfy.sd.load_lora_for_models(model, None, lora, strength, 0, lora_metadata=metadata)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class LoadLoraTextEncoder(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadLoraTextEncoder",
|
||||
display_name="Load LoRA (Text Encoder)",
|
||||
search_aliases=["lora", "load lora", "apply lora", "clip lora", "lora stack"],
|
||||
category="model/loaders",
|
||||
description="Apply a stack of LoRAs to a CLIP text encoder. Add one row per LoRA; "
|
||||
"each row picks a LoRA file and its strength.",
|
||||
inputs=[
|
||||
io.Clip.Input("clip", tooltip="The CLIP text encoder the LoRAs will be applied to."),
|
||||
io.DynamicGroup.Input(
|
||||
"loras",
|
||||
template=_lora_template(),
|
||||
min=1,
|
||||
max=50,
|
||||
tooltip="Each row applies one LoRA to the text encoder.",
|
||||
group_name="LoRA",
|
||||
),
|
||||
],
|
||||
outputs=[io.Clip.Output(tooltip="The modified CLIP text encoder.")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, loras: list[dict]) -> io.NodeOutput:
|
||||
for row in loras:
|
||||
lora_name = row.get("lora_name")
|
||||
strength = row.get("strength", 1.0)
|
||||
if not lora_name or lora_name == "none" or strength == 0:
|
||||
continue
|
||||
lora, metadata = _load_lora_file(lora_name)
|
||||
_, clip = comfy.sd.load_lora_for_models(None, clip, lora, 0, strength, lora_metadata=metadata)
|
||||
return io.NodeOutput(clip)
|
||||
|
||||
|
||||
class LoraStackExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
LoadLoraModel,
|
||||
LoadLoraTextEncoder,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> LoraStackExtension:
|
||||
return LoraStackExtension()
|
||||
Reference in New Issue
Block a user