from __future__ import annotations import json import logging import os from aiohttp import web from typing import TYPE_CHECKING, TypedDict if TYPE_CHECKING: from comfy_api.latest._io_public import NodeReplace from comfy_execution.graph_utils import is_link class NodeStruct(TypedDict): inputs: dict[str, str | int | float | bool | tuple[str, int]] class_type: str _meta: dict[str, str] def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct: new_node_struct = node_struct.copy() if empty_inputs: new_node_struct["inputs"] = {} else: new_node_struct["inputs"] = node_struct["inputs"].copy() new_node_struct["_meta"] = node_struct["_meta"].copy() return new_node_struct class NodeReplaceManager: """Manages node replacement registrations.""" def __init__(self): self._replacements: dict[str, list[NodeReplace]] = {} def register(self, node_replace: NodeReplace): """Register a node replacement mapping.""" self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace) def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None: """Get replacements for an old node ID.""" return self._replacements.get(old_node_id) def has_replacement(self, old_node_id: str) -> bool: """Check if a replacement exists for an old node ID.""" return old_node_id in self._replacements def apply_replacements(self, prompt: dict[str, NodeStruct]): import nodes connections: dict[str, list[tuple[str, str, int]]] = {} need_replacement: set[str] = set() for node_number, node_struct in prompt.items(): if "class_type" not in node_struct or "inputs" not in node_struct: continue class_type = node_struct["class_type"] # need replacement if not in NODE_CLASS_MAPPINGS and has replacement if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type): need_replacement.add(node_number) # keep track of connections for input_id, input_value in node_struct["inputs"].items(): if is_link(input_value): conn_number = input_value[0] connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1])) for node_number in need_replacement: node_struct = prompt[node_number] class_type = node_struct["class_type"] replacements = self.get_replacement(class_type) if replacements is None: continue # just use the first replacement replacement = replacements[0] new_node_id = replacement.new_node_id # if replacement is not a valid node, skip trying to replace it as will only cause confusion if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys(): continue # first, replace node id (class_type) new_node_struct = copy_node_struct(node_struct, empty_inputs=True) new_node_struct["class_type"] = new_node_id # TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema # second, replace inputs if replacement.input_mapping is not None: for input_map in replacement.input_mapping: if "set_value" in input_map: new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"] elif "old_id" in input_map: new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]] # finalize input replacement prompt[node_number] = new_node_struct # third, replace outputs if replacement.output_mapping is not None: # re-mapping outputs requires changing the input values of nodes that receive connections from this one if node_number in connections: for conns in connections[node_number]: conn_node_number, conn_input_id, old_output_idx = conns for output_map in replacement.output_mapping: if output_map["old_idx"] == old_output_idx: new_output_idx = output_map["new_idx"] previous_input = prompt[conn_node_number]["inputs"][conn_input_id] previous_input[1] = new_output_idx def load_from_json(self, module_dir: str, module_name: str, _node_replace_class=None): """Load node_replacements.json from a custom node directory and register replacements. Custom node authors can ship a node_replacements.json file in their repo root to define node replacements declaratively. The file format matches the output of NodeReplace.as_dict(), keyed by old_node_id. Fail-open: all errors are logged and skipped so a malformed file never prevents the custom node from loading. """ replacements_path = os.path.join(module_dir, "node_replacements.json") if not os.path.isfile(replacements_path): return try: with open(replacements_path, "r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, dict): logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.") return if _node_replace_class is None: from comfy_api.latest._io import NodeReplace _node_replace_class = NodeReplace count = 0 for old_node_id, replacements in data.items(): if not isinstance(replacements, list): logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.") continue for entry in replacements: if not isinstance(entry, dict): continue new_node_id = entry.get("new_node_id", "") if not new_node_id: logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.") continue self.register(_node_replace_class( new_node_id=new_node_id, old_node_id=entry.get("old_node_id", old_node_id), old_widget_ids=entry.get("old_widget_ids"), input_mapping=entry.get("input_mapping"), output_mapping=entry.get("output_mapping"), )) count += 1 if count > 0: logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json") except json.JSONDecodeError as e: logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}") except Exception as e: logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}") def as_dict(self): """Serialize all replacements to dict.""" return { k: [v.as_dict() for v in v_list] for k, v_list in self._replacements.items() } def add_routes(self, routes): @routes.get("/node_replacements") async def get_node_replacements(request): return web.json_response(self.as_dict())