mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-15 22:27:49 +08:00
Compare commits
16 Commits
v0.20.0
...
deepme987/
| Author | SHA1 | Date | |
|---|---|---|---|
| 5225f109a6 | |||
| 24de8dc01b | |||
| c0d77a5d53 | |||
| ed201fff08 | |||
| b47f15f25a | |||
| 3cbf015578 | |||
| 64b8457f55 | |||
| e35fe5bc09 | |||
| 77054cd49e | |||
| 1cd2730b25 | |||
| d4351f77f8 | |||
| 9837dd368a | |||
| 62ec9a3238 | |||
| b20cb7892e | |||
| b9b24d425b | |||
| d731cb6ae1 |
45
.github/workflows/tag-dispatch-cloud.yml
vendored
Normal file
45
.github/workflows/tag-dispatch-cloud.yml
vendored
Normal file
@ -0,0 +1,45 @@
|
||||
name: Tag Dispatch to Cloud
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
jobs:
|
||||
dispatch-cloud:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Send repository dispatch to cloud
|
||||
env:
|
||||
DISPATCH_TOKEN: ${{ secrets.CLOUD_REPO_DISPATCH_TOKEN }}
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [ -z "${DISPATCH_TOKEN:-}" ]; then
|
||||
echo "::error::CLOUD_REPO_DISPATCH_TOKEN is required but not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
RELEASE_URL="https://github.com/${{ github.repository }}/releases/tag/${RELEASE_TAG}"
|
||||
|
||||
PAYLOAD="$(jq -n \
|
||||
--arg release_tag "$RELEASE_TAG" \
|
||||
--arg release_url "$RELEASE_URL" \
|
||||
'{
|
||||
event_type: "comfyui_tag_pushed",
|
||||
client_payload: {
|
||||
release_tag: $release_tag,
|
||||
release_url: $release_url
|
||||
}
|
||||
}')"
|
||||
|
||||
curl -fsSL \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
|
||||
https://api.github.com/repos/Comfy-Org/cloud/dispatches \
|
||||
-d "$PAYLOAD"
|
||||
|
||||
echo "✅ Dispatched ComfyUI tag ${RELEASE_TAG} to Comfy-Org/cloud"
|
||||
@ -1,5 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
@ -7,7 +11,6 @@ if TYPE_CHECKING:
|
||||
from comfy_api.latest._io_public import NodeReplace
|
||||
|
||||
from comfy_execution.graph_utils import is_link
|
||||
import nodes
|
||||
|
||||
class NodeStruct(TypedDict):
|
||||
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
||||
@ -43,6 +46,7 @@ class NodeReplaceManager:
|
||||
return old_node_id in self._replacements
|
||||
|
||||
def apply_replacements(self, prompt: dict[str, NodeStruct]):
|
||||
import nodes
|
||||
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||
need_replacement: set[str] = set()
|
||||
for node_number, node_struct in prompt.items():
|
||||
@ -94,6 +98,60 @@ class NodeReplaceManager:
|
||||
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
|
||||
previous_input[1] = new_output_idx
|
||||
|
||||
def load_from_json(self, module_dir: str, module_name: str, _node_replace_class=None):
|
||||
"""Load node_replacements.json from a custom node directory and register replacements.
|
||||
|
||||
Custom node authors can ship a node_replacements.json file in their repo root
|
||||
to define node replacements declaratively. The file format matches the output
|
||||
of NodeReplace.as_dict(), keyed by old_node_id.
|
||||
|
||||
Fail-open: all errors are logged and skipped so a malformed file never
|
||||
prevents the custom node from loading.
|
||||
"""
|
||||
replacements_path = os.path.join(module_dir, "node_replacements.json")
|
||||
if not os.path.isfile(replacements_path):
|
||||
return
|
||||
|
||||
try:
|
||||
with open(replacements_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if not isinstance(data, dict):
|
||||
logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.")
|
||||
return
|
||||
|
||||
if _node_replace_class is None:
|
||||
from comfy_api.latest._io import NodeReplace
|
||||
_node_replace_class = NodeReplace
|
||||
|
||||
count = 0
|
||||
for old_node_id, replacements in data.items():
|
||||
if not isinstance(replacements, list):
|
||||
logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.")
|
||||
continue
|
||||
for entry in replacements:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
new_node_id = entry.get("new_node_id", "")
|
||||
if not new_node_id:
|
||||
logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.")
|
||||
continue
|
||||
self.register(_node_replace_class(
|
||||
new_node_id=new_node_id,
|
||||
old_node_id=entry.get("old_node_id", old_node_id),
|
||||
old_widget_ids=entry.get("old_widget_ids"),
|
||||
input_mapping=entry.get("input_mapping"),
|
||||
output_mapping=entry.get("output_mapping"),
|
||||
))
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json")
|
||||
except json.JSONDecodeError as e:
|
||||
logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}")
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}")
|
||||
|
||||
def as_dict(self):
|
||||
"""Serialize all replacements to dict."""
|
||||
return {
|
||||
|
||||
@ -31,6 +31,7 @@ import comfy.float
|
||||
import comfy.hooks
|
||||
import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
@ -856,7 +857,9 @@ class ModelPatcher:
|
||||
if m.comfy_patched_weights == True:
|
||||
continue
|
||||
|
||||
for param in params:
|
||||
for param, param_value in params.items():
|
||||
if hasattr(m, "comfy_cast_weights") and getattr(param_value, "is_meta", False):
|
||||
comfy.ops.disable_weight_init._zero_init_parameter(m, param)
|
||||
key = key_param_name_to_key(n, param)
|
||||
self.unpin_weight(key)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
|
||||
16
comfy/ops.py
16
comfy/ops.py
@ -79,14 +79,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||
def materialize_meta_param(s, param_keys):
|
||||
for param_key in param_keys:
|
||||
param = getattr(s, param_key, None)
|
||||
if param is not None and getattr(param, "is_meta", False):
|
||||
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
|
||||
|
||||
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||
#If you are a custom node author reading this, please move your layer to the GPU
|
||||
#or declare your ModelPatcher as CPU in the first place.
|
||||
if comfy.model_management.is_device_cpu(device):
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
@ -108,6 +115,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||
|
||||
if not resident:
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
cast_dest = None
|
||||
|
||||
@ -306,6 +314,12 @@ class CastWeightBiasOp:
|
||||
bias_function = []
|
||||
|
||||
class disable_weight_init:
|
||||
@staticmethod
|
||||
def _zero_init_parameter(module, name):
|
||||
param = getattr(module, name)
|
||||
device = None if getattr(param, "is_meta", False) else param.device
|
||||
setattr(module, name, torch.nn.Parameter(torch.zeros(param.shape, device=device, dtype=param.dtype), requires_grad=False))
|
||||
|
||||
@staticmethod
|
||||
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
|
||||
missing_keys, unexpected_keys, weight_shape,
|
||||
|
||||
@ -12,6 +12,7 @@ import numpy as np
|
||||
import math
|
||||
import torch
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
import logging
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
@ -238,32 +239,86 @@ class VideoFromFile(VideoInput):
|
||||
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||
else:
|
||||
start_time = self.__start_time
|
||||
|
||||
# Get video frames
|
||||
frames = []
|
||||
audio_frames = []
|
||||
alphas = None
|
||||
start_pts = int(start_time / video_stream.time_base)
|
||||
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
image_format = 'gbrpf32le'
|
||||
for frame in container.decode(video_stream):
|
||||
if alphas is None:
|
||||
for comp in frame.format.components:
|
||||
if comp.is_alpha:
|
||||
alphas = []
|
||||
image_format = 'gbrapf32le'
|
||||
break
|
||||
|
||||
if frame.pts < start_pts:
|
||||
continue
|
||||
if self.__duration and frame.pts >= end_pts:
|
||||
if start_pts != 0:
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
|
||||
image_format = 'gbrpf32le'
|
||||
audio = None
|
||||
|
||||
streams = [video_stream]
|
||||
has_first_audio_frame = False
|
||||
checked_alpha = False
|
||||
|
||||
# Default to False so we decode until EOF if duration is 0
|
||||
video_done = False
|
||||
audio_done = True
|
||||
|
||||
if len(container.streams.audio):
|
||||
audio_stream = container.streams.audio[-1]
|
||||
streams += [audio_stream]
|
||||
resampler = av.audio.resampler.AudioResampler(format='fltp')
|
||||
audio_done = False
|
||||
|
||||
for packet in container.demux(*streams):
|
||||
if video_done and audio_done:
|
||||
break
|
||||
|
||||
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
|
||||
if alphas is None:
|
||||
frames.append(torch.from_numpy(img))
|
||||
else:
|
||||
frames.append(torch.from_numpy(img[..., :-1]))
|
||||
alphas.append(torch.from_numpy(img[..., -1:]))
|
||||
if packet.stream.type == "video":
|
||||
if video_done:
|
||||
continue
|
||||
try:
|
||||
for frame in packet.decode():
|
||||
if frame.pts < start_pts:
|
||||
continue
|
||||
if self.__duration and frame.pts >= end_pts:
|
||||
video_done = True
|
||||
break
|
||||
|
||||
if not checked_alpha:
|
||||
for comp in frame.format.components:
|
||||
if comp.is_alpha:
|
||||
alphas = []
|
||||
image_format = 'gbrapf32le'
|
||||
break
|
||||
checked_alpha = True
|
||||
|
||||
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
|
||||
if alphas is None:
|
||||
frames.append(torch.from_numpy(img))
|
||||
else:
|
||||
frames.append(torch.from_numpy(img[..., :-1]))
|
||||
alphas.append(torch.from_numpy(img[..., -1:]))
|
||||
except av.error.InvalidDataError:
|
||||
logging.info("pyav decode error")
|
||||
|
||||
elif packet.stream.type == "audio":
|
||||
if audio_done:
|
||||
continue
|
||||
|
||||
aframes = itertools.chain.from_iterable(
|
||||
map(resampler.resample, packet.decode())
|
||||
)
|
||||
for frame in aframes:
|
||||
if self.__duration and frame.time > start_time + self.__duration:
|
||||
audio_done = True
|
||||
break
|
||||
|
||||
if not has_first_audio_frame:
|
||||
offset_seconds = start_time - frame.pts * audio_stream.time_base
|
||||
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
|
||||
if to_skip < frame.samples:
|
||||
has_first_audio_frame = True
|
||||
audio_frames.append(frame.to_ndarray()[..., to_skip:])
|
||||
else:
|
||||
audio_frames.append(frame.to_ndarray())
|
||||
|
||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 0, 0, 3)
|
||||
if alphas is not None:
|
||||
@ -272,42 +327,16 @@ class VideoFromFile(VideoInput):
|
||||
# Get frame rate
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
|
||||
|
||||
# Get audio if available
|
||||
audio = None
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
# Use last stream for consistency
|
||||
if len(container.streams.audio):
|
||||
audio_stream = container.streams.audio[-1]
|
||||
audio_frames = []
|
||||
resample = av.audio.resampler.AudioResampler(format='fltp').resample
|
||||
frames = itertools.chain.from_iterable(
|
||||
map(resample, container.decode(audio_stream))
|
||||
)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
if self.__duration:
|
||||
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
|
||||
|
||||
has_first_frame = False
|
||||
for frame in frames:
|
||||
offset_seconds = start_time - frame.pts * audio_stream.time_base
|
||||
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
|
||||
if to_skip < frame.samples:
|
||||
has_first_frame = True
|
||||
break
|
||||
if has_first_frame:
|
||||
audio_frames.append(frame.to_ndarray()[..., to_skip:])
|
||||
|
||||
for frame in frames:
|
||||
if self.__duration and frame.time > start_time + self.__duration:
|
||||
break
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
if self.__duration:
|
||||
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
|
||||
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
|
||||
})
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
|
||||
})
|
||||
|
||||
metadata = container.metadata
|
||||
return VideoComponents(images=images, alpha=alphas, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
|
||||
@ -637,7 +637,7 @@ class SaveGLB(IO.ComfyNode):
|
||||
],
|
||||
tooltip="Mesh or 3D file to save",
|
||||
),
|
||||
IO.String.Input("filename_prefix", default="mesh/ComfyUI"),
|
||||
IO.String.Input("filename_prefix", default="3d/ComfyUI"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
|
||||
)
|
||||
|
||||
@ -2,6 +2,7 @@ import numpy as np
|
||||
import scipy.ndimage
|
||||
import torch
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import node_helpers
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, UI
|
||||
@ -188,7 +189,7 @@ class SolidMask(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, value, width, height) -> IO.NodeOutput:
|
||||
out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
|
||||
out = torch.full((1, height, width), value, dtype=torch.float32, device=comfy.model_management.intermediate_device())
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
solid = execute # TODO: remove
|
||||
@ -262,6 +263,7 @@ class MaskComposite(IO.ComfyNode):
|
||||
def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput:
|
||||
output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
|
||||
source = source.reshape((-1, source.shape[-2], source.shape[-1]))
|
||||
source = source.to(output.device)
|
||||
|
||||
left, top = (x, y,)
|
||||
right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]))
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.20.0"
|
||||
__version__ = "0.20.1"
|
||||
|
||||
6
nodes.py
6
nodes.py
@ -2228,6 +2228,12 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
|
||||
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
|
||||
|
||||
# Only load node_replacements.json from directory-based custom nodes (proper packs).
|
||||
# Single-file .py nodes share a parent dir, so checking there would be incorrect.
|
||||
if os.path.isdir(module_path):
|
||||
from server import PromptServer
|
||||
PromptServer.instance.node_replace_manager.load_from_json(module_dir, module_name)
|
||||
|
||||
try:
|
||||
from comfy_config import config_parser
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.20.0"
|
||||
version = "0.20.1"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
217
tests/test_node_replacements_json.py
Normal file
217
tests/test_node_replacements_json.py
Normal file
@ -0,0 +1,217 @@
|
||||
"""Tests for NodeReplaceManager.load_from_json — auto-registration of
|
||||
node_replacements.json from custom node directories."""
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from app.node_replace_manager import NodeReplaceManager
|
||||
|
||||
|
||||
class SimpleNodeReplace:
|
||||
"""Lightweight stand-in for comfy_api.latest._io.NodeReplace (avoids torch import)."""
|
||||
def __init__(self, new_node_id, old_node_id, old_widget_ids=None,
|
||||
input_mapping=None, output_mapping=None):
|
||||
self.new_node_id = new_node_id
|
||||
self.old_node_id = old_node_id
|
||||
self.old_widget_ids = old_widget_ids
|
||||
self.input_mapping = input_mapping
|
||||
self.output_mapping = output_mapping
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"new_node_id": self.new_node_id,
|
||||
"old_node_id": self.old_node_id,
|
||||
"old_widget_ids": self.old_widget_ids,
|
||||
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
|
||||
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
|
||||
}
|
||||
|
||||
|
||||
class TestLoadFromJson(unittest.TestCase):
|
||||
"""Test auto-registration of node_replacements.json from custom node directories."""
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
self.manager = NodeReplaceManager()
|
||||
|
||||
def _write_json(self, data):
|
||||
path = os.path.join(self.tmpdir, "node_replacements.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
def _load(self):
|
||||
self.manager.load_from_json(self.tmpdir, "test-node-pack", _node_replace_class=SimpleNodeReplace)
|
||||
|
||||
def test_no_file_does_nothing(self):
|
||||
"""No node_replacements.json — should silently do nothing."""
|
||||
self._load()
|
||||
self.assertEqual(self.manager.as_dict(), {})
|
||||
|
||||
def test_empty_object(self):
|
||||
"""Empty {} — should do nothing."""
|
||||
self._write_json({})
|
||||
self._load()
|
||||
self.assertEqual(self.manager.as_dict(), {})
|
||||
|
||||
def test_single_replacement(self):
|
||||
"""Single replacement entry registers correctly."""
|
||||
self._write_json({
|
||||
"OldNode": [{
|
||||
"new_node_id": "NewNode",
|
||||
"old_node_id": "OldNode",
|
||||
"input_mapping": [{"new_id": "model", "old_id": "ckpt_name"}],
|
||||
"output_mapping": [{"new_idx": 0, "old_idx": 0}],
|
||||
}]
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertIn("OldNode", result)
|
||||
self.assertEqual(len(result["OldNode"]), 1)
|
||||
entry = result["OldNode"][0]
|
||||
self.assertEqual(entry["new_node_id"], "NewNode")
|
||||
self.assertEqual(entry["old_node_id"], "OldNode")
|
||||
self.assertEqual(entry["input_mapping"], [{"new_id": "model", "old_id": "ckpt_name"}])
|
||||
self.assertEqual(entry["output_mapping"], [{"new_idx": 0, "old_idx": 0}])
|
||||
|
||||
def test_multiple_replacements(self):
|
||||
"""Multiple old_node_ids each with entries."""
|
||||
self._write_json({
|
||||
"NodeA": [{"new_node_id": "NodeB", "old_node_id": "NodeA"}],
|
||||
"NodeC": [{"new_node_id": "NodeD", "old_node_id": "NodeC"}],
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertIn("NodeA", result)
|
||||
self.assertIn("NodeC", result)
|
||||
|
||||
def test_multiple_alternatives_for_same_node(self):
|
||||
"""Multiple replacement options for the same old node."""
|
||||
self._write_json({
|
||||
"OldNode": [
|
||||
{"new_node_id": "AltA", "old_node_id": "OldNode"},
|
||||
{"new_node_id": "AltB", "old_node_id": "OldNode"},
|
||||
]
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertEqual(len(result["OldNode"]), 2)
|
||||
|
||||
def test_null_mappings(self):
|
||||
"""Null input/output mappings (trivial replacement)."""
|
||||
self._write_json({
|
||||
"OldNode": [{
|
||||
"new_node_id": "NewNode",
|
||||
"old_node_id": "OldNode",
|
||||
"input_mapping": None,
|
||||
"output_mapping": None,
|
||||
}]
|
||||
})
|
||||
self._load()
|
||||
entry = self.manager.as_dict()["OldNode"][0]
|
||||
self.assertIsNone(entry["input_mapping"])
|
||||
self.assertIsNone(entry["output_mapping"])
|
||||
|
||||
def test_old_node_id_defaults_to_key(self):
|
||||
"""If old_node_id is missing from entry, uses the dict key."""
|
||||
self._write_json({
|
||||
"OldNode": [{"new_node_id": "NewNode"}]
|
||||
})
|
||||
self._load()
|
||||
entry = self.manager.as_dict()["OldNode"][0]
|
||||
self.assertEqual(entry["old_node_id"], "OldNode")
|
||||
|
||||
def test_invalid_json_skips(self):
|
||||
"""Invalid JSON file — should warn and skip, not crash."""
|
||||
path = os.path.join(self.tmpdir, "node_replacements.json")
|
||||
with open(path, "w") as f:
|
||||
f.write("{invalid json")
|
||||
self._load()
|
||||
self.assertEqual(self.manager.as_dict(), {})
|
||||
|
||||
def test_non_object_json_skips(self):
|
||||
"""JSON array instead of object — should warn and skip."""
|
||||
self._write_json([1, 2, 3])
|
||||
self._load()
|
||||
self.assertEqual(self.manager.as_dict(), {})
|
||||
|
||||
def test_non_list_value_skips(self):
|
||||
"""Value is not a list — should warn and skip that key."""
|
||||
self._write_json({
|
||||
"OldNode": "not a list",
|
||||
"GoodNode": [{"new_node_id": "NewNode", "old_node_id": "GoodNode"}],
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertNotIn("OldNode", result)
|
||||
self.assertIn("GoodNode", result)
|
||||
|
||||
def test_with_old_widget_ids(self):
|
||||
"""old_widget_ids are passed through."""
|
||||
self._write_json({
|
||||
"OldNode": [{
|
||||
"new_node_id": "NewNode",
|
||||
"old_node_id": "OldNode",
|
||||
"old_widget_ids": ["width", "height"],
|
||||
}]
|
||||
})
|
||||
self._load()
|
||||
entry = self.manager.as_dict()["OldNode"][0]
|
||||
self.assertEqual(entry["old_widget_ids"], ["width", "height"])
|
||||
|
||||
def test_set_value_in_input_mapping(self):
|
||||
"""input_mapping with set_value entries."""
|
||||
self._write_json({
|
||||
"OldNode": [{
|
||||
"new_node_id": "NewNode",
|
||||
"old_node_id": "OldNode",
|
||||
"input_mapping": [
|
||||
{"new_id": "method", "set_value": "lanczos"},
|
||||
{"new_id": "size", "old_id": "dimension"},
|
||||
],
|
||||
}]
|
||||
})
|
||||
self._load()
|
||||
entry = self.manager.as_dict()["OldNode"][0]
|
||||
self.assertEqual(len(entry["input_mapping"]), 2)
|
||||
|
||||
def test_missing_new_node_id_skipped(self):
|
||||
"""Entry without new_node_id is skipped."""
|
||||
self._write_json({
|
||||
"OldNode": [
|
||||
{"old_node_id": "OldNode"},
|
||||
{"new_node_id": "", "old_node_id": "OldNode"},
|
||||
{"new_node_id": "ValidNew", "old_node_id": "OldNode"},
|
||||
]
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertEqual(len(result["OldNode"]), 1)
|
||||
self.assertEqual(result["OldNode"][0]["new_node_id"], "ValidNew")
|
||||
|
||||
def test_non_dict_entry_skipped(self):
|
||||
"""Non-dict entries in the list are silently skipped."""
|
||||
self._write_json({
|
||||
"OldNode": [
|
||||
"not a dict",
|
||||
{"new_node_id": "NewNode", "old_node_id": "OldNode"},
|
||||
]
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertEqual(len(result["OldNode"]), 1)
|
||||
|
||||
def test_has_replacement_after_load(self):
|
||||
"""Manager reports has_replacement correctly after JSON load."""
|
||||
self._write_json({
|
||||
"OldNode": [{"new_node_id": "NewNode", "old_node_id": "OldNode"}],
|
||||
})
|
||||
self.assertFalse(self.manager.has_replacement("OldNode"))
|
||||
self._load()
|
||||
self.assertTrue(self.manager.has_replacement("OldNode"))
|
||||
self.assertFalse(self.manager.has_replacement("UnknownNode"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user