From da90bc93e486fe953e6959b8ee0e2d735ae12dde Mon Sep 17 00:00:00 2001 From: Glary-Bot Date: Sat, 16 May 2026 03:04:37 +0000 Subject: [PATCH] Add SaveImagePromotable PoC node MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pass-through SaveImage variant with accumulating previews and a promote/lock feature. The node: - Saves images and passes the input tensor through as the output, so it fits naturally mid-graph (unlike core SaveImage which is a sink). - Exposes an 'accumulate' flag, mirroring upstream PR #12647 — the frontend uses this to append previews to a per-node gallery instead of replacing it. - Accepts an optional 'promoted_asset_ref' STRING widget that the frontend writes when the user clicks a 'lock' UI on a preview. When set, the node skips saving, loads the referenced image from output/input/temp, and outputs that image. Stale refs silently fall back to pass-through. - IS_CHANGED returns a ref-derived key (incl. file mtime) when locked, so re-queues with the same lock are cache hits and upstream ancestors are skipped. Unlocked, it defers to normal input-signature caching. Includes unit tests covering ref parsing (incl. path-traversal and symlink-escape rejection), path resolution, pass-through and locked execution, and IS_CHANGED behavior. 24/24 pass; ruff clean. --- comfy_extras/nodes_save_image_promotable.py | 249 +++++++++++++++ .../save_image_promotable_test.py | 289 ++++++++++++++++++ 2 files changed, 538 insertions(+) create mode 100644 comfy_extras/nodes_save_image_promotable.py create mode 100644 tests-unit/comfy_extras_test/save_image_promotable_test.py diff --git a/comfy_extras/nodes_save_image_promotable.py b/comfy_extras/nodes_save_image_promotable.py new file mode 100644 index 000000000..0e69458a4 --- /dev/null +++ b/comfy_extras/nodes_save_image_promotable.py @@ -0,0 +1,249 @@ +""" +SaveImagePromotable: a pass-through SaveImage variant with accumulating previews +and a "promote/lock" feature. + +Modes: +- Pass-through (default): saves incoming images, emits preview UI, returns the + input tensor as output. With `accumulate=True`, the frontend appends previews + to a gallery instead of replacing it. +- Locked: when `promoted_asset_ref` is a non-empty JSON ref to a saved asset, + the node skips saving, loads the referenced image, and outputs that image. + The frontend is expected to write the ref into the widget when the user + clicks the "lock" UI on a preview. + +Caching: IS_CHANGED returns a stable key derived from the ref (+ file mtime) +when locked, so re-queues with the same lock are cache hits and upstream +ancestors are skipped. Unlocked, IS_CHANGED returns False to defer to normal +input-signature caching. +""" + +from __future__ import annotations + +import json +import os + +import numpy as np +import torch +from PIL import Image, ImageOps, ImageSequence +from PIL.PngImagePlugin import PngInfo + +import folder_paths +import node_helpers +from comfy.cli_args import args + + +def _parse_promoted_ref(promoted_asset_ref: str) -> dict | None: + if not promoted_asset_ref: + return None + try: + ref = json.loads(promoted_asset_ref) + except (json.JSONDecodeError, TypeError): + return None + if not isinstance(ref, dict): + return None + filename = ref.get("filename") + if not isinstance(filename, str) or not filename: + return None + subfolder = ref.get("subfolder", "") or "" + asset_type = ref.get("type", "output") or "output" + if not isinstance(subfolder, str) or not isinstance(asset_type, str): + return None + # Reject anything that could escape the base directory. + if os.path.isabs(subfolder) or ".." in subfolder.split(os.sep): + return None + if os.path.isabs(filename) or ".." in filename.split(os.sep): + return None + return {"filename": filename, "subfolder": subfolder, "type": asset_type} + + +def _resolve_ref_path(ref: dict) -> str | None: + asset_type = ref["type"] + if asset_type == "output": + base = folder_paths.get_output_directory() + elif asset_type == "input": + base = folder_paths.get_input_directory() + elif asset_type == "temp": + base = folder_paths.get_temp_directory() + else: + return None + path = os.path.join(base, ref["subfolder"], ref["filename"]) + # Defense-in-depth: ensure the resolved path stays inside the base dir. + base_real = os.path.realpath(base) + path_real = os.path.realpath(path) + if not path_real.startswith(base_real + os.sep) and path_real != base_real: + return None + if not os.path.isfile(path_real): + return None + return path_real + + +def _load_image_tensor(path: str) -> torch.Tensor: + img = node_helpers.pillow(Image.open, path) + output_images: list[torch.Tensor] = [] + w: int | None = None + h: int | None = None + for frame in ImageSequence.Iterator(img): + frame = node_helpers.pillow(ImageOps.exif_transpose, frame) + image = frame.convert("RGB") + if not output_images: + w, h = image.size + if image.size != (w, h): + continue + arr = np.array(image).astype(np.float32) / 255.0 + output_images.append(torch.from_numpy(arr)[None,]) + if not output_images: + raise RuntimeError(f"Failed to decode any frames from {path}") + return torch.cat(output_images, dim=0) + + +class SaveImagePromotable: + """Pass-through SaveImage with accumulating previews and promote/lock. + + Inputs: + images: IMAGE tensor to save + pass through (ignored when locked). + filename_prefix: STRING prefix for saved files. + accumulate: BOOLEAN — when True, frontend appends previews to gallery. + promoted_asset_ref: STRING — JSON ref written by the frontend on lock. + Empty string means "not locked, normal pass-through". + Output: + IMAGE — input pass-through, or the loaded promoted image when locked. + """ + + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.compress_level = 4 + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ( + "IMAGE", + { + "tooltip": "Images to save and pass through. Ignored when a promoted asset is locked." + }, + ), + "filename_prefix": ( + "STRING", + {"default": "ComfyUI", "tooltip": "Prefix for saved files."}, + ), + "accumulate": ( + "BOOLEAN", + { + "default": False, + "tooltip": "When enabled, previews append to a per-node gallery instead of replacing it.", + }, + ), + }, + "optional": { + "promoted_asset_ref": ( + "STRING", + { + "default": "", + "multiline": False, + "tooltip": "JSON ref to a saved asset. Set by the UI; do not edit manually.", + }, + ), + }, + "hidden": { + "prompt": "PROMPT", + "extra_pnginfo": "EXTRA_PNGINFO", + }, + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) + FUNCTION = "execute" + OUTPUT_NODE = True + CATEGORY = "image" + DESCRIPTION = "Saves images, shows accumulating previews, and passes the input through. A promoted (locked) preview overrides pass-through to output the chosen image." + + def _save_images(self, images, filename_prefix, prompt, extra_pnginfo): + full_output_folder, filename, counter, subfolder, _ = ( + folder_paths.get_save_image_path( + filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0] + ) + ) + results: list[dict] = [] + for batch_number, image in enumerate(images): + arr = 255.0 * image.cpu().numpy() + img = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) + metadata: PngInfo | None = None + if not args.disable_metadata: + metadata = PngInfo() + if prompt is not None: + metadata.add_text("prompt", json.dumps(prompt)) + if extra_pnginfo is not None: + for key in extra_pnginfo: + metadata.add_text(key, json.dumps(extra_pnginfo[key])) + filename_with_batch = filename.replace("%batch_num%", str(batch_number)) + out_name = f"{filename_with_batch}_{counter:05}_.png" + img.save( + os.path.join(full_output_folder, out_name), + pnginfo=metadata, + compress_level=self.compress_level, + ) + results.append( + {"filename": out_name, "subfolder": subfolder, "type": self.type} + ) + counter += 1 + return results + + def execute( + self, + images, + filename_prefix="ComfyUI", + accumulate=False, # noqa: ARG002 + promoted_asset_ref="", + prompt=None, + extra_pnginfo=None, + ): + ref = _parse_promoted_ref(promoted_asset_ref) + if ref is not None: + path = _resolve_ref_path(ref) + if path is not None: + tensor = _load_image_tensor(path) + tensor = tensor.to(device=images.device, dtype=images.dtype) + return { + "ui": {"images": [ref]}, + "result": (tensor,), + } + # Ref is set but stale (file deleted / failed validation): fall + # through to pass-through so the user gets a working graph rather + # than an execution error. + + saved = self._save_images(images, filename_prefix, prompt, extra_pnginfo) + return {"ui": {"images": saved}, "result": (images,)} + + @classmethod + def IS_CHANGED( + cls, + images, # noqa: ARG003 + filename_prefix="ComfyUI", + accumulate=False, # noqa: ARG003 + promoted_asset_ref="", + prompt=None, # noqa: ARG003 + extra_pnginfo=None, # noqa: ARG003 + ): + ref = _parse_promoted_ref(promoted_asset_ref) + if ref is None: + return False + path = _resolve_ref_path(ref) + if path is None: + return f"PROMOTED::MISSING::{promoted_asset_ref}" + try: + stat = os.stat(path) + sig = f"{stat.st_size}:{stat.st_mtime_ns}" + except OSError: + sig = "NOSTAT" + return f"PROMOTED::{promoted_asset_ref}::{sig}" + + +NODE_CLASS_MAPPINGS = { + "SaveImagePromotable": SaveImagePromotable, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "SaveImagePromotable": "Save Image (Promotable, PoC)", +} diff --git a/tests-unit/comfy_extras_test/save_image_promotable_test.py b/tests-unit/comfy_extras_test/save_image_promotable_test.py new file mode 100644 index 000000000..f4df9ceab --- /dev/null +++ b/tests-unit/comfy_extras_test/save_image_promotable_test.py @@ -0,0 +1,289 @@ +import json +import os +from unittest.mock import MagicMock, patch + +import numpy as np +import torch +from PIL import Image + +mock_nodes = MagicMock() +mock_nodes.MAX_RESOLUTION = 16384 +mock_server = MagicMock() + +with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}): + from comfy_extras import nodes_save_image_promotable as mod + + +def _make_image(width=8, height=4): + return torch.rand(1, height, width, 3) + + +def _write_png(path: str, width=8, height=4): + arr = (np.random.rand(height, width, 3) * 255).astype(np.uint8) + Image.fromarray(arr).save(path) + + +class TestParseRef: + def test_empty(self): + assert mod._parse_promoted_ref("") is None + + def test_invalid_json(self): + assert mod._parse_promoted_ref("{not json") is None + + def test_non_object(self): + assert mod._parse_promoted_ref('"a string"') is None + assert mod._parse_promoted_ref("[]") is None + + def test_missing_filename(self): + assert mod._parse_promoted_ref('{"subfolder":"x","type":"output"}') is None + + def test_path_traversal_filename(self): + ref = json.dumps( + {"filename": "../etc/passwd", "subfolder": "", "type": "output"} + ) + assert mod._parse_promoted_ref(ref) is None + + def test_path_traversal_subfolder(self): + ref = json.dumps({"filename": "x.png", "subfolder": "../..", "type": "output"}) + assert mod._parse_promoted_ref(ref) is None + + def test_absolute_filename(self): + ref = json.dumps({"filename": "/etc/passwd", "subfolder": "", "type": "output"}) + assert mod._parse_promoted_ref(ref) is None + + def test_valid(self): + ref = json.dumps({"filename": "x.png", "subfolder": "sub", "type": "output"}) + parsed = mod._parse_promoted_ref(ref) + assert parsed == {"filename": "x.png", "subfolder": "sub", "type": "output"} + + def test_defaults_applied(self): + ref = json.dumps({"filename": "x.png"}) + parsed = mod._parse_promoted_ref(ref) + assert parsed == {"filename": "x.png", "subfolder": "", "type": "output"} + + +class TestResolveRefPath: + def test_unknown_type(self, tmp_path): + with patch.object( + mod.folder_paths, "get_output_directory", return_value=str(tmp_path) + ): + assert ( + mod._resolve_ref_path( + {"filename": "x.png", "subfolder": "", "type": "garbage"} + ) + is None + ) + + def test_missing_file(self, tmp_path): + with patch.object( + mod.folder_paths, "get_output_directory", return_value=str(tmp_path) + ): + assert ( + mod._resolve_ref_path( + {"filename": "missing.png", "subfolder": "", "type": "output"} + ) + is None + ) + + def test_resolves_file(self, tmp_path): + target = tmp_path / "img.png" + _write_png(str(target)) + with patch.object( + mod.folder_paths, "get_output_directory", return_value=str(tmp_path) + ): + resolved = mod._resolve_ref_path( + {"filename": "img.png", "subfolder": "", "type": "output"} + ) + assert resolved is not None + assert os.path.realpath(resolved) == os.path.realpath(str(target)) + + def test_resolves_file_in_subfolder(self, tmp_path): + sub = tmp_path / "nested" + sub.mkdir() + target = sub / "img.png" + _write_png(str(target)) + with patch.object( + mod.folder_paths, "get_output_directory", return_value=str(tmp_path) + ): + resolved = mod._resolve_ref_path( + {"filename": "img.png", "subfolder": "nested", "type": "output"} + ) + assert resolved is not None + assert os.path.realpath(resolved) == os.path.realpath(str(target)) + + def test_symlink_escape_rejected(self, tmp_path): + outside = tmp_path / "outside" + outside.mkdir() + secret = outside / "secret.png" + _write_png(str(secret)) + base = tmp_path / "base" + base.mkdir() + link = base / "link.png" + os.symlink(str(secret), str(link)) + with patch.object( + mod.folder_paths, "get_output_directory", return_value=str(base) + ): + resolved = mod._resolve_ref_path( + {"filename": "link.png", "subfolder": "", "type": "output"} + ) + assert resolved is None + + +class TestNodeContract: + def test_input_types_shape(self): + inp = mod.SaveImagePromotable.INPUT_TYPES() + assert set(inp["required"].keys()) == { + "images", + "filename_prefix", + "accumulate", + } + assert set(inp["optional"].keys()) == {"promoted_asset_ref"} + assert set(inp["hidden"].keys()) == {"prompt", "extra_pnginfo"} + assert inp["required"]["accumulate"][0] == "BOOLEAN" + assert inp["required"]["accumulate"][1]["default"] is False + + def test_class_metadata(self): + cls = mod.SaveImagePromotable + assert cls.RETURN_TYPES == ("IMAGE",) + assert cls.RETURN_NAMES == ("images",) + assert cls.OUTPUT_NODE is True + assert cls.FUNCTION == "execute" + assert "SaveImagePromotable" in mod.NODE_CLASS_MAPPINGS + assert mod.NODE_CLASS_MAPPINGS["SaveImagePromotable"] is cls + + +class TestExecutePassthrough: + def test_passthrough_saves_and_returns_input(self, tmp_path): + node = mod.SaveImagePromotable() + node.output_dir = str(tmp_path) + images = _make_image() + + with ( + patch.object(mod.args, "disable_metadata", True), + patch.object(mod.folder_paths, "get_save_image_path") as get_path, + ): + get_path.return_value = (str(tmp_path), "ComfyUI", 1, "", "ComfyUI") + result = node.execute( + images, + filename_prefix="ComfyUI", + accumulate=False, + promoted_asset_ref="", + ) + + assert "ui" in result + assert "result" in result + assert torch.equal(result["result"][0], images) + assert len(result["ui"]["images"]) == 1 + saved_name = result["ui"]["images"][0]["filename"] + assert os.path.isfile(os.path.join(str(tmp_path), saved_name)) + + def test_stale_ref_falls_through_to_passthrough(self, tmp_path): + node = mod.SaveImagePromotable() + node.output_dir = str(tmp_path) + images = _make_image() + ref = json.dumps( + {"filename": "does_not_exist.png", "subfolder": "", "type": "output"} + ) + + with ( + patch.object(mod.args, "disable_metadata", True), + patch.object(mod.folder_paths, "get_save_image_path") as get_path, + patch.object( + mod.folder_paths, "get_output_directory", return_value=str(tmp_path) + ), + ): + get_path.return_value = (str(tmp_path), "ComfyUI", 1, "", "ComfyUI") + result = node.execute(images, promoted_asset_ref=ref) + + assert torch.equal(result["result"][0], images) + + +class TestExecuteLocked: + def test_locked_outputs_loaded_image(self, tmp_path): + target = tmp_path / "promoted.png" + _write_png(str(target), width=8, height=4) + ref = json.dumps( + {"filename": "promoted.png", "subfolder": "", "type": "output"} + ) + node = mod.SaveImagePromotable() + node.output_dir = str(tmp_path) + upstream = _make_image(width=8, height=4) + + with patch.object( + mod.folder_paths, "get_output_directory", return_value=str(tmp_path) + ): + result = node.execute(upstream, promoted_asset_ref=ref) + + assert result["ui"]["images"] == [ + {"filename": "promoted.png", "subfolder": "", "type": "output"} + ] + out = result["result"][0] + assert out.shape == upstream.shape + assert out.dtype == upstream.dtype + assert not torch.equal(out, upstream) + + def test_locked_does_not_save(self, tmp_path): + target = tmp_path / "promoted.png" + _write_png(str(target)) + ref = json.dumps( + {"filename": "promoted.png", "subfolder": "", "type": "output"} + ) + node = mod.SaveImagePromotable() + node.output_dir = str(tmp_path) + images = _make_image() + + with ( + patch.object( + mod.folder_paths, "get_output_directory", return_value=str(tmp_path) + ), + patch.object(node, "_save_images") as save_mock, + ): + node.execute(images, promoted_asset_ref=ref) + + save_mock.assert_not_called() + + +class TestIsChanged: + def test_unlocked_returns_false(self): + assert ( + mod.SaveImagePromotable.IS_CHANGED(images=None, promoted_asset_ref="") + is False + ) + + def test_locked_missing_file(self, tmp_path): + ref = json.dumps({"filename": "missing.png", "subfolder": "", "type": "output"}) + with patch.object( + mod.folder_paths, "get_output_directory", return_value=str(tmp_path) + ): + key = mod.SaveImagePromotable.IS_CHANGED( + images=None, promoted_asset_ref=ref + ) + assert isinstance(key, str) + assert key.startswith("PROMOTED::MISSING::") + + def test_locked_stable_key(self, tmp_path): + target = tmp_path / "p.png" + _write_png(str(target)) + ref = json.dumps({"filename": "p.png", "subfolder": "", "type": "output"}) + with patch.object( + mod.folder_paths, "get_output_directory", return_value=str(tmp_path) + ): + k1 = mod.SaveImagePromotable.IS_CHANGED(images=None, promoted_asset_ref=ref) + k2 = mod.SaveImagePromotable.IS_CHANGED(images=None, promoted_asset_ref=ref) + assert k1 == k2 + assert k1.startswith("PROMOTED::") + + def test_locked_key_changes_when_file_changes(self, tmp_path): + target = tmp_path / "p.png" + _write_png(str(target), width=8, height=4) + ref = json.dumps({"filename": "p.png", "subfolder": "", "type": "output"}) + with patch.object( + mod.folder_paths, "get_output_directory", return_value=str(tmp_path) + ): + k1 = mod.SaveImagePromotable.IS_CHANGED(images=None, promoted_asset_ref=ref) + os.utime(str(target), (1234567890, 1234567890)) + with patch.object( + mod.folder_paths, "get_output_directory", return_value=str(tmp_path) + ): + k2 = mod.SaveImagePromotable.IS_CHANGED(images=None, promoted_asset_ref=ref) + assert k1 != k2