Compare commits

..

12 Commits

Author SHA1 Message Date
0141af0786 refactor: rename _is_cacheable_value to _is_external_cacheable_value
Clearer name since objects are also cached locally - this specifically
checks for external caching eligibility.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-30 00:46:10 +05:30
0440ebcf6e feat: add optional ui field to CacheValue
- Add ui field to CacheValue dataclass (default None)
- Pass ui when creating CacheValue for external providers
- Use result.ui (or default {}) when returning from external cache lookup

This allows external cache implementations to store/retrieve UI data
if desired, while remaining optional for implementations that skip it.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-29 20:22:04 +05:30
4afa80dc07 docs: make should_cache docstring implementation-agnostic
Remove prescriptive filtering suggestions - let implementations
decide their own caching logic based on their use case.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-29 19:52:06 +05:30
d755f7ca19 docs: clarify should_cache filtering criteria
Change docstring from "Skip large values" to "Skip if download time > compute time"
which better captures the cost/benefit tradeoff for external caching.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-29 19:50:12 +05:30
2049066cff refactor: expose CacheProvider API via comfy_api.latest.Caching
- Add Caching class to comfy_api/latest/__init__.py that re-exports
  from comfy_execution.cache_provider (source of truth)
- Fix docstring: "Skip large values" instead of "Skip small values"
  (small compute-heavy values are good cache targets)
- Maintain backward compatibility: comfy_execution.cache_provider
  imports still work

Usage:
    from comfy_api.latest import Caching

    class MyProvider(Caching.CacheProvider):
        def on_lookup(self, context): ...
        def on_store(self, context, value): ...

    Caching.register_provider(MyProvider())

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-29 19:42:34 +05:30
9b0ca8b95c Merge remote-tracking branch 'origin/master' into feat/cache-provider-api 2026-01-28 13:30:03 +05:30
dcf686857c fix: use hashable types in frozenset test and add dict test
Frozensets can only contain hashable types, so use nested frozensets
instead of dicts. Added separate test for dict handling via serialize_cache_key.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-24 15:47:53 +05:30
17eed38750 fix: move _torch_available before usage and use importlib.util.find_spec
Fixes ruff F821 (undefined name) and F401 (unused import) errors.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-24 15:34:29 +05:30
f4623c0e1b style: remove unused imports in test_cache_provider.py
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-24 14:25:00 +05:30
5e4bbca1ad test: add unit tests for CacheProvider API
- Add comprehensive tests for _canonicalize deterministic ordering
- Add tests for serialize_cache_key hash consistency
- Add tests for contains_nan utility
- Add tests for estimate_value_size
- Add tests for provider registry (register, unregister, clear)
- Move json import to top-level (fix inline import)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-24 14:20:49 +05:30
e17571d9be fix: use deterministic hash for cache keys instead of pickle
Pickle serialization is NOT deterministic across Python sessions due
to hash randomization affecting frozenset iteration order. This causes
distributed caching to fail because different pods compute different
hashes for identical cache keys.

Fix: Use _canonicalize() + JSON serialization which ensures deterministic
ordering regardless of Python's hash randomization.

This is critical for cross-pod cache key consistency in Kubernetes
deployments.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-24 14:07:46 +05:30
6540aa0400 feat: Add CacheProvider API for external distributed caching
Introduces a public API for external cache providers, enabling distributed
caching across multiple ComfyUI instances (e.g., Kubernetes pods).

New files:
- comfy_execution/cache_provider.py: CacheProvider ABC, CacheContext/CacheValue
  dataclasses, thread-safe provider registry, serialization utilities

Modified files:
- comfy_execution/caching.py: Add provider hooks to BasicCache (_notify_providers_store,
  _check_providers_lookup), subcache exclusion, prompt ID propagation
- execution.py: Add prompt lifecycle hooks (on_prompt_start/on_prompt_end) to
  PromptExecutor, set _current_prompt_id on caches

Key features:
- Local-first caching (check local before external for performance)
- NaN detection to prevent incorrect external cache hits
- Subcache exclusion (ephemeral subgraph results not cached externally)
- Thread-safe provider snapshot caching
- Graceful error handling (provider errors logged, never break execution)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-19 16:43:13 +05:30
14 changed files with 923 additions and 552 deletions

View File

@ -81,7 +81,6 @@ class SD_X4(LatentFormat):
class SC_Prior(LatentFormat):
latent_channels = 16
spacial_downscale_ratio = 42
def __init__(self):
self.scale_factor = 1.0
self.latent_rgb_factors = [
@ -104,7 +103,6 @@ class SC_Prior(LatentFormat):
]
class SC_B(LatentFormat):
spacial_downscale_ratio = 4
def __init__(self):
self.scale_factor = 1.0 / 0.43
self.latent_rgb_factors = [
@ -276,7 +274,6 @@ class Mochi(LatentFormat):
class LTXV(LatentFormat):
latent_channels = 128
latent_dimensions = 3
spacial_downscale_ratio = 32
def __init__(self):
self.latent_rgb_factors = [
@ -520,7 +517,6 @@ class Wan21(LatentFormat):
class Wan22(Wan21):
latent_channels = 48
latent_dimensions = 3
spacial_downscale_ratio = 16
latent_rgb_factors = [
[ 0.0119, 0.0103, 0.0046],

View File

@ -106,6 +106,42 @@ class Types:
MESH = MESH
VOXEL = VOXEL
class Caching:
"""
External cache provider API for distributed caching.
Enables sharing cached results across multiple ComfyUI instances
(e.g., Kubernetes pods) without monkey-patching internal methods.
Example usage:
from comfy_api.latest import Caching
class MyRedisProvider(Caching.CacheProvider):
def on_lookup(self, context):
# Check Redis for cached result
...
def on_store(self, context, value):
# Store to Redis (can be async internally)
...
Caching.register_provider(MyRedisProvider())
"""
# Import from comfy_execution.cache_provider (source of truth)
from comfy_execution.cache_provider import (
CacheProvider,
CacheContext,
CacheValue,
register_cache_provider as register_provider,
unregister_cache_provider as unregister_provider,
get_cache_providers as get_providers,
has_cache_providers as has_providers,
clear_cache_providers as clear_providers,
estimate_value_size,
)
ComfyAPI = ComfyAPI_latest
# Create a synchronous version of the API
@ -125,6 +161,7 @@ __all__ = [
"Input",
"InputImpl",
"Types",
"Caching",
"ComfyExtension",
"io",
"IO",

View File

@ -1,67 +0,0 @@
from pydantic import BaseModel, Field
class ImageGenerationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
aspect_ratio: str = Field(...)
n: int = Field(...)
seed: int = Field(...)
response_for: str = Field("url")
class InputUrlObject(BaseModel):
url: str = Field(...)
class ImageEditRequest(BaseModel):
model: str = Field(...)
image: InputUrlObject = Field(...)
prompt: str = Field(...)
resolution: str = Field(...)
n: int = Field(...)
seed: int = Field(...)
response_for: str = Field("url")
class VideoGenerationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
image: InputUrlObject | None = Field(...)
duration: int = Field(...)
aspect_ratio: str | None = Field(...)
resolution: str = Field(...)
seed: int = Field(...)
class VideoEditRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
video: InputUrlObject = Field(...)
seed: int = Field(...)
class ImageResponseObject(BaseModel):
url: str | None = Field(None)
b64_json: str | None = Field(None)
revised_prompt: str | None = Field(None)
class ImageGenerationResponse(BaseModel):
data: list[ImageResponseObject] = Field(...)
class VideoGenerationResponse(BaseModel):
request_id: str = Field(...)
class VideoResponseObject(BaseModel):
url: str = Field(...)
upsampled_prompt: str | None = Field(None)
duration: int = Field(...)
class VideoStatusResponse(BaseModel):
status: str | None = Field(None)
video: VideoResponseObject | None = Field(None)
model: str | None = Field(None)

View File

@ -1,417 +0,0 @@
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.grok import (
ImageEditRequest,
ImageGenerationRequest,
ImageGenerationResponse,
InputUrlObject,
VideoEditRequest,
VideoGenerationRequest,
VideoGenerationResponse,
VideoStatusResponse,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
download_url_to_video_output,
get_fs_object_size,
get_number_of_images,
poll_op,
sync_op,
tensor_to_base64_string,
upload_video_to_comfyapi,
validate_string,
validate_video_duration,
)
class GrokImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GrokImageNode",
display_name="Grok Image",
category="api node/image/Grok",
description="Generate images using Grok based on a text prompt",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="The text prompt used to generate the image",
),
IO.Combo.Input(
"aspect_ratio",
options=[
"1:1",
"2:3",
"3:2",
"3:4",
"4:3",
"9:16",
"16:9",
"9:19.5",
"19.5:9",
"9:20",
"20:9",
"1:2",
"2:1",
],
),
IO.Int.Input(
"number_of_images",
default=1,
min=1,
max=10,
step=1,
tooltip="Number of images to generate",
display_mode=IO.NumberDisplay.number,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""",
),
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
aspect_ratio: str,
number_of_images: int,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/images/generations", method="POST"),
data=ImageGenerationRequest(
model=model,
prompt=prompt,
aspect_ratio=aspect_ratio,
n=number_of_images,
seed=seed,
),
response_model=ImageGenerationResponse,
)
if len(response.data) == 1:
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
return IO.NodeOutput(
torch.cat(
[await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]],
)
)
class GrokImageEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GrokImageEditNode",
display_name="Grok Image Edit",
category="api node/image/Grok",
description="Modify an existing image based on a text prompt",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
IO.Image.Input("image"),
IO.String.Input(
"prompt",
multiline=True,
tooltip="The text prompt used to generate the image",
),
IO.Combo.Input("resolution", options=["1K"]),
IO.Int.Input(
"number_of_images",
default=1,
min=1,
max=10,
step=1,
tooltip="Number of edited images to generate",
display_mode=IO.NumberDisplay.number,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""",
),
)
@classmethod
async def execute(
cls,
model: str,
image: Input.Image,
prompt: str,
resolution: str,
number_of_images: int,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
if get_number_of_images(image) != 1:
raise ValueError("Only one input image is supported.")
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"),
data=ImageEditRequest(
model=model,
image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"),
prompt=prompt,
resolution=resolution.lower(),
n=number_of_images,
seed=seed,
),
response_model=ImageGenerationResponse,
)
if len(response.data) == 1:
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
return IO.NodeOutput(
torch.cat(
[await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]],
)
)
class GrokVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GrokVideoNode",
display_name="Grok Video",
category="api node/video/Grok",
description="Generate video from a prompt or an image",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text description of the desired video.",
),
IO.Combo.Input(
"resolution",
options=["480p", "720p"],
tooltip="The resolution of the output video.",
),
IO.Combo.Input(
"aspect_ratio",
options=["auto", "16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"],
tooltip="The aspect ratio of the output video.",
),
IO.Int.Input(
"duration",
default=6,
min=1,
max=15,
step=1,
tooltip="The duration of the output video in seconds.",
display_mode=IO.NumberDisplay.slider,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
IO.Image.Input("image", optional=True),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]),
expr="""
(
$base := 0.181 * widgets.duration;
{"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base}
)
""",
),
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
resolution: str,
aspect_ratio: str,
duration: int,
seed: int,
image: Input.Image | None = None,
) -> IO.NodeOutput:
image_url = None
if image is not None:
if get_number_of_images(image) != 1:
raise ValueError("Only one input image is supported.")
image_url = InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}")
validate_string(prompt, strip_whitespace=True, min_length=1)
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
data=VideoGenerationRequest(
model=model,
image=image_url,
prompt=prompt,
resolution=resolution,
duration=duration,
aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio,
seed=seed,
),
response_model=VideoGenerationResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
class GrokVideoEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GrokVideoEditNode",
display_name="Grok Video Edit",
category="api node/video/Grok",
description="Edit an existing video based on a text prompt.",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text description of the desired video.",
),
IO.Video.Input("video", tooltip="Maximum supported duration is 8.7 seconds and 50MB file size."),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""",
),
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
video: Input.Video,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
validate_video_duration(video, min_duration=1, max_duration=8.7)
video_stream = video.get_stream_source()
video_size = get_fs_object_size(video_stream)
if video_size > 50 * 1024 * 1024:
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/videos/edits", method="POST"),
data=VideoEditRequest(
model=model,
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
prompt=prompt,
seed=seed,
),
response_model=VideoGenerationResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
class GrokExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
GrokImageNode,
GrokImageEditNode,
GrokVideoNode,
GrokVideoEditNode,
]
async def comfy_entrypoint() -> GrokExtension:
return GrokExtension()

View File

@ -0,0 +1,319 @@
"""
External Cache Provider API for distributed caching.
This module provides a public API for external cache providers, enabling
distributed caching across multiple ComfyUI instances (e.g., Kubernetes pods).
Public API is also available via:
from comfy_api.latest import Caching
Example usage:
from comfy_execution.cache_provider import (
CacheProvider, CacheContext, CacheValue, register_cache_provider
)
class MyRedisProvider(CacheProvider):
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
# Check Redis/GCS for cached result
...
def on_store(self, context: CacheContext, value: CacheValue) -> None:
# Store to Redis/GCS (can be async internally)
...
register_cache_provider(MyRedisProvider())
"""
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple, List
from dataclasses import dataclass
import hashlib
import json
import logging
import math
import pickle
import threading
logger = logging.getLogger(__name__)
# ============================================================
# Data Classes
# ============================================================
@dataclass
class CacheContext:
"""Context passed to provider methods."""
prompt_id: str # Current prompt execution ID
node_id: str # Node being cached
class_type: str # Node class type (e.g., "KSampler")
cache_key: Any # Raw cache key (frozenset structure)
cache_key_bytes: bytes # SHA256 hash for external storage key
@dataclass
class CacheValue:
"""
Value stored/retrieved from external cache.
The ui field is optional - implementations may choose to skip it
(e.g., if it contains non-portable data like local file paths).
"""
outputs: list # The tensor/value outputs
ui: dict = None # Optional UI data (may be skipped by implementations)
# ============================================================
# Provider Interface
# ============================================================
class CacheProvider(ABC):
"""
Abstract base class for external cache providers.
Thread Safety:
Providers may be called from multiple threads. Implementations
must be thread-safe.
Error Handling:
All methods are wrapped in try/except by the caller. Exceptions
are logged but never propagate to break execution.
Performance Guidelines:
- on_lookup: Should complete in <500ms (including network)
- on_store: Can be async internally (fire-and-forget)
- should_cache: Should be fast (<1ms), called frequently
"""
@abstractmethod
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
"""
Check external storage for cached result.
Called AFTER local cache miss (local-first for performance).
Returns:
CacheValue if found externally, None otherwise.
Important:
- Return None on any error (don't raise)
- Validate data integrity before returning
"""
pass
@abstractmethod
def on_store(self, context: CacheContext, value: CacheValue) -> None:
"""
Store value to external cache.
Called AFTER value is stored in local cache.
Important:
- Can be fire-and-forget (async internally)
- Should never block execution
- Handle serialization failures gracefully
"""
pass
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
"""
Filter which nodes should be externally cached.
Called before on_lookup (value=None) and on_store (value provided).
Return False to skip external caching for this node.
Implementations can filter based on context.class_type, value size,
or any custom logic. Use estimate_value_size() to get value size.
Default: Returns True (cache everything).
"""
return True
def on_prompt_start(self, prompt_id: str) -> None:
"""Called when prompt execution begins. Optional."""
pass
def on_prompt_end(self, prompt_id: str) -> None:
"""Called when prompt execution ends. Optional."""
pass
# ============================================================
# Provider Registry
# ============================================================
_providers: List[CacheProvider] = []
_providers_lock = threading.Lock()
_providers_snapshot: Optional[Tuple[CacheProvider, ...]] = None
def register_cache_provider(provider: CacheProvider) -> None:
"""
Register an external cache provider.
Providers are called in registration order. First provider to return
a result from on_lookup wins.
"""
global _providers_snapshot
with _providers_lock:
if provider in _providers:
logger.warning(f"Provider {provider.__class__.__name__} already registered")
return
_providers.append(provider)
_providers_snapshot = None # Invalidate cache
logger.info(f"Registered cache provider: {provider.__class__.__name__}")
def unregister_cache_provider(provider: CacheProvider) -> None:
"""Remove a previously registered provider."""
global _providers_snapshot
with _providers_lock:
try:
_providers.remove(provider)
_providers_snapshot = None
logger.info(f"Unregistered cache provider: {provider.__class__.__name__}")
except ValueError:
logger.warning(f"Provider {provider.__class__.__name__} was not registered")
def get_cache_providers() -> Tuple[CacheProvider, ...]:
"""Get registered providers (cached for performance)."""
global _providers_snapshot
snapshot = _providers_snapshot
if snapshot is not None:
return snapshot
with _providers_lock:
if _providers_snapshot is not None:
return _providers_snapshot
_providers_snapshot = tuple(_providers)
return _providers_snapshot
def has_cache_providers() -> bool:
"""Fast check if any providers registered (no lock)."""
return bool(_providers)
def clear_cache_providers() -> None:
"""Remove all providers. Useful for testing."""
global _providers_snapshot
with _providers_lock:
_providers.clear()
_providers_snapshot = None
# ============================================================
# Utilities
# ============================================================
def _canonicalize(obj: Any) -> Any:
"""
Convert an object to a canonical, JSON-serializable form.
This ensures deterministic ordering regardless of Python's hash randomization,
which is critical for cross-pod cache key consistency. Frozensets in particular
have non-deterministic iteration order between Python sessions.
"""
if isinstance(obj, frozenset):
# Sort frozenset items for deterministic ordering
return ("__frozenset__", sorted(
[_canonicalize(item) for item in obj],
key=lambda x: json.dumps(x, sort_keys=True)
))
elif isinstance(obj, set):
return ("__set__", sorted(
[_canonicalize(item) for item in obj],
key=lambda x: json.dumps(x, sort_keys=True)
))
elif isinstance(obj, tuple):
return ("__tuple__", [_canonicalize(item) for item in obj])
elif isinstance(obj, list):
return [_canonicalize(item) for item in obj]
elif isinstance(obj, dict):
return {str(k): _canonicalize(v) for k, v in sorted(obj.items())}
elif isinstance(obj, (int, float, str, bool, type(None))):
return obj
elif isinstance(obj, bytes):
return ("__bytes__", obj.hex())
elif hasattr(obj, 'value'):
# Handle Unhashable class from ComfyUI
return ("__unhashable__", _canonicalize(getattr(obj, 'value', None)))
else:
# For other types, use repr as fallback
return ("__repr__", repr(obj))
def serialize_cache_key(cache_key: Any) -> bytes:
"""
Serialize cache key to bytes for external storage.
Returns SHA256 hash suitable for Redis/database keys.
Note: Uses canonicalize + JSON serialization instead of pickle because
pickle is NOT deterministic across Python sessions due to hash randomization
affecting frozenset iteration order. This is critical for distributed caching
where different pods need to compute the same hash for identical inputs.
"""
try:
canonical = _canonicalize(cache_key)
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
return hashlib.sha256(json_str.encode('utf-8')).digest()
except Exception as e:
logger.warning(f"Failed to serialize cache key: {e}")
# Fallback to pickle (non-deterministic but better than nothing)
try:
serialized = pickle.dumps(cache_key, protocol=4)
return hashlib.sha256(serialized).digest()
except Exception:
return hashlib.sha256(str(id(cache_key)).encode()).digest()
def contains_nan(obj: Any) -> bool:
"""
Check if cache key contains NaN (indicates uncacheable node).
NaN != NaN in Python, so local cache never hits. But serialized
NaN would match, causing incorrect external hits. Must skip these.
"""
if isinstance(obj, float):
try:
return math.isnan(obj)
except (TypeError, ValueError):
return False
if hasattr(obj, 'value'): # Unhashable class
val = getattr(obj, 'value', None)
if isinstance(val, float):
try:
return math.isnan(val)
except (TypeError, ValueError):
return False
if isinstance(obj, (frozenset, tuple, list, set)):
return any(contains_nan(item) for item in obj)
if isinstance(obj, dict):
return any(contains_nan(k) or contains_nan(v) for k, v in obj.items())
return False
def estimate_value_size(value: CacheValue) -> int:
"""Estimate serialized size in bytes. Useful for size-based filtering."""
try:
import torch
except ImportError:
return 0
total = 0
def estimate(obj):
nonlocal total
if isinstance(obj, torch.Tensor):
total += obj.numel() * obj.element_size()
elif isinstance(obj, dict):
for v in obj.values():
estimate(v)
elif isinstance(obj, (list, tuple)):
for item in obj:
estimate(item)
for output in value.outputs:
estimate(output)
return total

View File

@ -155,6 +155,10 @@ class BasicCache:
self.cache = {}
self.subcaches = {}
# External cache provider support
self._is_subcache = False
self._current_prompt_id = ''
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
@ -201,20 +205,123 @@ class BasicCache:
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
# Notify external providers
self._notify_providers_store(node_id, cache_key, value)
def _get_immediate(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
# Check local cache first (fast path)
if cache_key in self.cache:
return self.cache[cache_key]
else:
# Check external providers on local miss
external_result = self._check_providers_lookup(node_id, cache_key)
if external_result is not None:
self.cache[cache_key] = external_result # Warm local cache
return external_result
return None
def _notify_providers_store(self, node_id, cache_key, value):
"""Notify external providers of cache store."""
from comfy_execution.cache_provider import (
has_cache_providers, get_cache_providers,
CacheContext, CacheValue,
serialize_cache_key, contains_nan, logger
)
# Fast exit conditions
if self._is_subcache:
return
if not has_cache_providers():
return
if not self._is_external_cacheable_value(value):
return
if contains_nan(cache_key):
return
context = CacheContext(
prompt_id=self._current_prompt_id,
node_id=node_id,
class_type=self._get_class_type(node_id),
cache_key=cache_key,
cache_key_bytes=serialize_cache_key(cache_key)
)
cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
for provider in get_cache_providers():
try:
if provider.should_cache(context, cache_value):
provider.on_store(context, cache_value)
except Exception as e:
logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
def _check_providers_lookup(self, node_id, cache_key):
"""Check external providers for cached result."""
from comfy_execution.cache_provider import (
has_cache_providers, get_cache_providers,
CacheContext, CacheValue,
serialize_cache_key, contains_nan, logger
)
if self._is_subcache:
return None
if not has_cache_providers():
return None
if contains_nan(cache_key):
return None
context = CacheContext(
prompt_id=self._current_prompt_id,
node_id=node_id,
class_type=self._get_class_type(node_id),
cache_key=cache_key,
cache_key_bytes=serialize_cache_key(cache_key)
)
for provider in get_cache_providers():
try:
if not provider.should_cache(context):
continue
result = provider.on_lookup(context)
if result is not None:
if not isinstance(result, CacheValue):
logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
continue
if not isinstance(result.outputs, (list, tuple)):
logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
continue
# Import CacheEntry here to avoid circular import at module level
from execution import CacheEntry
return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs))
except Exception as e:
logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
return None
def _is_external_cacheable_value(self, value):
"""Check if value is a CacheEntry suitable for external caching (not objects cache)."""
return hasattr(value, 'outputs') and hasattr(value, 'ui')
def _get_class_type(self, node_id):
"""Get class_type for a node."""
if not self.initialized or not self.dynprompt:
return ''
try:
return self.dynprompt.get_node(node_id).get('class_type', '')
except Exception:
return ''
async def _ensure_subcache(self, node_id, children_ids):
subcache_key = self.cache_key_set.get_subcache_key(node_id)
subcache = self.subcaches.get(subcache_key, None)
if subcache is None:
subcache = BasicCache(self.key_class)
subcache._is_subcache = True # Mark as subcache - excludes from external caching
subcache._current_prompt_id = self._current_prompt_id # Propagate prompt ID
self.subcaches[subcache_key] = subcache
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
return subcache

View File

@ -171,10 +171,9 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
continue
for item in items:
count += 1
if not isinstance(item, dict):
continue
count += 1
if preview_output is None and is_previewable(media_type, item):
enriched = {

View File

@ -56,7 +56,7 @@ class EmptyHunyuanLatentVideo(io.ComfyNode):
@classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})
return io.NodeOutput({"samples":latent})
generate = execute # TODO: remove
@ -73,7 +73,7 @@ class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo):
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
# Using scale factor of 16 instead of 8
latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 16})
return io.NodeOutput({"samples": latent})
class HunyuanVideo15ImageToVideo(io.ComfyNode):

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.11.1"
__version__ = "0.11.0"

View File

@ -669,6 +669,22 @@ class PromptExecutor:
}
self.add_message("execution_error", mes, broadcast=False)
def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
"""Notify external cache providers of prompt lifecycle events."""
from comfy_execution.cache_provider import has_cache_providers, get_cache_providers, logger
if not has_cache_providers():
return
for provider in get_cache_providers():
try:
if event == "start":
provider.on_prompt_start(prompt_id)
elif event == "end":
provider.on_prompt_end(prompt_id)
except Exception as e:
logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}")
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
@ -685,66 +701,77 @@ class PromptExecutor:
self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused()
# Set prompt ID on caches for external provider integration
for cache in self.caches.all:
cache._current_prompt_id = prompt_id
cached_nodes = []
for node_id in prompt:
if self.caches.outputs.get(node_id) is not None:
cached_nodes.append(node_id)
# Notify external cache providers of prompt start
self._notify_prompt_lifecycle("start", prompt_id)
comfy.model_management.cleanup_models_gc()
self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
for node_id in list(execute_outputs):
execution_list.add_node(node_id)
try:
with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused()
while not execution_list.is_empty():
node_id, error, ex = await execution_list.stage_node_execution()
if error is not None:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
cached_nodes = []
for node_id in prompt:
if self.caches.outputs.get(node_id) is not None:
cached_nodes.append(node_id)
assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
elif result == ExecutionResult.PENDING:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
comfy.model_management.cleanup_models_gc()
self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
for node_id in list(execute_outputs):
execution_list.add_node(node_id)
ui_outputs = {}
meta_outputs = {}
for node_id, ui_info in ui_node_outputs.items():
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
}
self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
while not execution_list.is_empty():
node_id, error, ex = await execution_list.stage_node_execution()
if error is not None:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
elif result == ExecutionResult.PENDING:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
ui_outputs = {}
meta_outputs = {}
for node_id, ui_info in ui_node_outputs.items():
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
}
self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
finally:
# Notify external cache providers of prompt end
self._notify_prompt_lifecycle("end", prompt_id)
async def validate_inputs(prompt_id, prompt, item, validated):

View File

@ -1 +1 @@
comfyui_manager==4.1b1
comfyui_manager==4.0.5

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.11.1"
version = "0.11.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.37.11
comfyui-workflow-templates==0.8.27
comfyui-workflow-templates==0.8.24
comfyui-embedded-docs==0.4.0
torch
torchsde

View File

@ -0,0 +1,370 @@
"""Tests for external cache provider API."""
import importlib.util
import pytest
from typing import Optional
def _torch_available() -> bool:
"""Check if PyTorch is available."""
return importlib.util.find_spec("torch") is not None
from comfy_execution.cache_provider import (
CacheProvider,
CacheContext,
CacheValue,
register_cache_provider,
unregister_cache_provider,
get_cache_providers,
has_cache_providers,
clear_cache_providers,
serialize_cache_key,
contains_nan,
estimate_value_size,
_canonicalize,
)
class TestCanonicalize:
"""Test _canonicalize function for deterministic ordering."""
def test_frozenset_ordering_is_deterministic(self):
"""Frozensets should produce consistent canonical form regardless of iteration order."""
# Create two frozensets with same content
fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)])
fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)])
result1 = _canonicalize(fs1)
result2 = _canonicalize(fs2)
assert result1 == result2
def test_nested_frozenset_ordering(self):
"""Nested frozensets should also be deterministically ordered."""
inner1 = frozenset([1, 2, 3])
inner2 = frozenset([3, 2, 1])
fs1 = frozenset([("key", inner1)])
fs2 = frozenset([("key", inner2)])
result1 = _canonicalize(fs1)
result2 = _canonicalize(fs2)
assert result1 == result2
def test_dict_ordering(self):
"""Dicts should be sorted by key."""
d1 = {"z": 1, "a": 2, "m": 3}
d2 = {"a": 2, "m": 3, "z": 1}
result1 = _canonicalize(d1)
result2 = _canonicalize(d2)
assert result1 == result2
def test_tuple_preserved(self):
"""Tuples should be marked and preserved."""
t = (1, 2, 3)
result = _canonicalize(t)
assert result[0] == "__tuple__"
assert result[1] == [1, 2, 3]
def test_list_preserved(self):
"""Lists should be recursively canonicalized."""
lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])]
result = _canonicalize(lst)
# First element should be dict with sorted keys
assert result[0] == {"a": 1, "b": 2}
# Second element should be canonicalized frozenset
assert result[1][0] == "__frozenset__"
def test_primitives_unchanged(self):
"""Primitive types should pass through unchanged."""
assert _canonicalize(42) == 42
assert _canonicalize(3.14) == 3.14
assert _canonicalize("hello") == "hello"
assert _canonicalize(True) is True
assert _canonicalize(None) is None
def test_bytes_converted(self):
"""Bytes should be converted to hex string."""
b = b"\x00\xff"
result = _canonicalize(b)
assert result[0] == "__bytes__"
assert result[1] == "00ff"
def test_set_ordering(self):
"""Sets should be sorted like frozensets."""
s1 = {3, 1, 2}
s2 = {1, 2, 3}
result1 = _canonicalize(s1)
result2 = _canonicalize(s2)
assert result1 == result2
assert result1[0] == "__set__"
class TestSerializeCacheKey:
"""Test serialize_cache_key for deterministic hashing."""
def test_same_content_same_hash(self):
"""Same content should produce same hash."""
key1 = frozenset([("node_1", frozenset([("input", "value")]))])
key2 = frozenset([("node_1", frozenset([("input", "value")]))])
hash1 = serialize_cache_key(key1)
hash2 = serialize_cache_key(key2)
assert hash1 == hash2
def test_different_content_different_hash(self):
"""Different content should produce different hash."""
key1 = frozenset([("node_1", "value_a")])
key2 = frozenset([("node_1", "value_b")])
hash1 = serialize_cache_key(key1)
hash2 = serialize_cache_key(key2)
assert hash1 != hash2
def test_returns_bytes(self):
"""Should return bytes (SHA256 digest)."""
key = frozenset([("test", 123)])
result = serialize_cache_key(key)
assert isinstance(result, bytes)
assert len(result) == 32 # SHA256 produces 32 bytes
def test_complex_nested_structure(self):
"""Complex nested structures should hash deterministically."""
# Note: frozensets can only contain hashable types, so we use
# nested frozensets of tuples to represent dict-like structures
key = frozenset([
("node_1", frozenset([
("input_a", ("tuple", "value")),
("input_b", frozenset([("nested", "dict")])),
])),
("node_2", frozenset([
("param", 42),
])),
])
# Hash twice to verify determinism
hash1 = serialize_cache_key(key)
hash2 = serialize_cache_key(key)
assert hash1 == hash2
def test_dict_in_cache_key(self):
"""Dicts passed directly to serialize_cache_key should work."""
# This tests the _canonicalize function's ability to handle dicts
key = {"node_1": {"input": "value"}, "node_2": 42}
hash1 = serialize_cache_key(key)
hash2 = serialize_cache_key(key)
assert hash1 == hash2
assert isinstance(hash1, bytes)
assert len(hash1) == 32
class TestContainsNan:
"""Test contains_nan utility function."""
def test_nan_float_detected(self):
"""NaN floats should be detected."""
assert contains_nan(float('nan')) is True
def test_regular_float_not_nan(self):
"""Regular floats should not be detected as NaN."""
assert contains_nan(3.14) is False
assert contains_nan(0.0) is False
assert contains_nan(-1.5) is False
def test_infinity_not_nan(self):
"""Infinity is not NaN."""
assert contains_nan(float('inf')) is False
assert contains_nan(float('-inf')) is False
def test_nan_in_list(self):
"""NaN in list should be detected."""
assert contains_nan([1, 2, float('nan'), 4]) is True
assert contains_nan([1, 2, 3, 4]) is False
def test_nan_in_tuple(self):
"""NaN in tuple should be detected."""
assert contains_nan((1, float('nan'))) is True
assert contains_nan((1, 2, 3)) is False
def test_nan_in_frozenset(self):
"""NaN in frozenset should be detected."""
assert contains_nan(frozenset([1, float('nan')])) is True
assert contains_nan(frozenset([1, 2, 3])) is False
def test_nan_in_dict_value(self):
"""NaN in dict value should be detected."""
assert contains_nan({"key": float('nan')}) is True
assert contains_nan({"key": 42}) is False
def test_nan_in_nested_structure(self):
"""NaN in deeply nested structure should be detected."""
nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
assert contains_nan(nested) is True
def test_non_numeric_types(self):
"""Non-numeric types should not be NaN."""
assert contains_nan("string") is False
assert contains_nan(None) is False
assert contains_nan(True) is False
class TestEstimateValueSize:
"""Test estimate_value_size utility function."""
def test_empty_outputs(self):
"""Empty outputs should have zero size."""
value = CacheValue(outputs=[])
assert estimate_value_size(value) == 0
@pytest.mark.skipif(
not _torch_available(),
reason="PyTorch not available"
)
def test_tensor_size_estimation(self):
"""Tensor size should be estimated correctly."""
import torch
# 1000 float32 elements = 4000 bytes
tensor = torch.zeros(1000, dtype=torch.float32)
value = CacheValue(outputs=[[tensor]])
size = estimate_value_size(value)
assert size == 4000
@pytest.mark.skipif(
not _torch_available(),
reason="PyTorch not available"
)
def test_nested_tensor_in_dict(self):
"""Tensors nested in dicts should be counted."""
import torch
tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes
value = CacheValue(outputs=[[{"samples": tensor}]])
size = estimate_value_size(value)
assert size == 400
class TestProviderRegistry:
"""Test cache provider registration and retrieval."""
def setup_method(self):
"""Clear providers before each test."""
clear_cache_providers()
def teardown_method(self):
"""Clear providers after each test."""
clear_cache_providers()
def test_register_provider(self):
"""Provider should be registered successfully."""
provider = MockCacheProvider()
register_cache_provider(provider)
assert has_cache_providers() is True
providers = get_cache_providers()
assert len(providers) == 1
assert providers[0] is provider
def test_unregister_provider(self):
"""Provider should be unregistered successfully."""
provider = MockCacheProvider()
register_cache_provider(provider)
unregister_cache_provider(provider)
assert has_cache_providers() is False
def test_multiple_providers(self):
"""Multiple providers can be registered."""
provider1 = MockCacheProvider()
provider2 = MockCacheProvider()
register_cache_provider(provider1)
register_cache_provider(provider2)
providers = get_cache_providers()
assert len(providers) == 2
def test_duplicate_registration_ignored(self):
"""Registering same provider twice should be ignored."""
provider = MockCacheProvider()
register_cache_provider(provider)
register_cache_provider(provider) # Should be ignored
providers = get_cache_providers()
assert len(providers) == 1
def test_clear_providers(self):
"""clear_cache_providers should remove all providers."""
provider1 = MockCacheProvider()
provider2 = MockCacheProvider()
register_cache_provider(provider1)
register_cache_provider(provider2)
clear_cache_providers()
assert has_cache_providers() is False
assert len(get_cache_providers()) == 0
class TestCacheContext:
"""Test CacheContext dataclass."""
def test_context_creation(self):
"""CacheContext should be created with all fields."""
context = CacheContext(
prompt_id="prompt-123",
node_id="node-456",
class_type="KSampler",
cache_key=frozenset([("test", "value")]),
cache_key_bytes=b"hash_bytes",
)
assert context.prompt_id == "prompt-123"
assert context.node_id == "node-456"
assert context.class_type == "KSampler"
assert context.cache_key == frozenset([("test", "value")])
assert context.cache_key_bytes == b"hash_bytes"
class TestCacheValue:
"""Test CacheValue dataclass."""
def test_value_creation(self):
"""CacheValue should be created with outputs."""
outputs = [[{"samples": "tensor_data"}]]
value = CacheValue(outputs=outputs)
assert value.outputs == outputs
class MockCacheProvider(CacheProvider):
"""Mock cache provider for testing."""
def __init__(self):
self.lookups = []
self.stores = []
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
self.lookups.append(context)
return None
def on_store(self, context: CacheContext, value: CacheValue) -> None:
self.stores.append((context, value))