feat(cache-provider): add on_set_prompt lifecycle hook for providers

Adds a new on_set_prompt() lifecycle hook on CacheProvider that fires
after the cache key set is prepared for a new prompt. Dispatched via
asyncio.create_task with errors swallowed (same fail-safe pattern as
on_store / on_lookup).

Why: BasicCache's lifecycle notifications to external providers were
incomplete. set_prompt is a key per-prompt event that providers need
visibility into — for example, to reset per-prompt timing/state used
for cost-aware caching policies (a provider can set t=0 here, then
measure elapsed at each on_store to estimate compute saved by a hit).

Backward-compatible: default implementation is a no-op, existing
providers compile and run unchanged. Providers that need the per-prompt
boundary override on_set_prompt().
This commit is contained in:
Deep Mehta
2026-05-20 21:15:35 -07:00
parent 95fdc6cf91
commit fcbe7db46f
2 changed files with 31 additions and 0 deletions

View File

@ -21,6 +21,10 @@ class CacheProvider(ABC):
Exceptions from provider methods are caught by the caller and never break execution.
"""
async def on_set_prompt(self) -> None:
"""Called after prompt cache keys are prepared. Dispatched via asyncio.create_task."""
pass
@abstractmethod
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
"""Called on local cache miss. Return CacheValue if found, None otherwise."""

View File

@ -164,6 +164,7 @@ class BasicCache:
await self.cache_key_set.add_keys(node_ids)
self.is_changed_cache = is_changed_cache
self.initialized = True
await self._notify_providers_set_prompt()
def all_node_ids(self):
assert self.initialized
@ -263,6 +264,24 @@ class BasicCache:
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
async def _notify_providers_set_prompt(self):
from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers, _logger
)
if not self.enable_providers:
return
if not _has_cache_providers():
return
for provider in _get_cache_providers():
try:
task = asyncio.create_task(self._safe_provider_set_prompt(provider))
self._pending_store_tasks.add(task)
task.add_done_callback(self._pending_store_tasks.discard)
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} error on set_prompt: {e}")
@staticmethod
async def _safe_provider_store(provider, context, cache_value):
from comfy_execution.cache_provider import _logger
@ -271,6 +290,14 @@ class BasicCache:
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
@staticmethod
async def _safe_provider_set_prompt(provider):
from comfy_execution.cache_provider import _logger
try:
await provider.on_set_prompt()
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} async set_prompt error: {e}")
async def _check_providers_lookup(self, node_id, cache_key):
from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers,