From fcbe7db46fb10d2b900cb57bfb34f1450e776475 Mon Sep 17 00:00:00 2001 From: Deep Mehta Date: Wed, 20 May 2026 21:15:35 -0700 Subject: [PATCH] feat(cache-provider): add on_set_prompt lifecycle hook for providers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a new on_set_prompt() lifecycle hook on CacheProvider that fires after the cache key set is prepared for a new prompt. Dispatched via asyncio.create_task with errors swallowed (same fail-safe pattern as on_store / on_lookup). Why: BasicCache's lifecycle notifications to external providers were incomplete. set_prompt is a key per-prompt event that providers need visibility into — for example, to reset per-prompt timing/state used for cost-aware caching policies (a provider can set t=0 here, then measure elapsed at each on_store to estimate compute saved by a hit). Backward-compatible: default implementation is a no-op, existing providers compile and run unchanged. Providers that need the per-prompt boundary override on_set_prompt(). --- comfy_api/latest/_caching.py | 4 ++++ comfy_execution/caching.py | 27 +++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/comfy_api/latest/_caching.py b/comfy_api/latest/_caching.py index 30c8848cd..44dc67b7a 100644 --- a/comfy_api/latest/_caching.py +++ b/comfy_api/latest/_caching.py @@ -21,6 +21,10 @@ class CacheProvider(ABC): Exceptions from provider methods are caught by the caller and never break execution. """ + async def on_set_prompt(self) -> None: + """Called after prompt cache keys are prepared. Dispatched via asyncio.create_task.""" + pass + @abstractmethod async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: """Called on local cache miss. Return CacheValue if found, None otherwise.""" diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index ba1e8bc84..b823985ac 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -164,6 +164,7 @@ class BasicCache: await self.cache_key_set.add_keys(node_ids) self.is_changed_cache = is_changed_cache self.initialized = True + await self._notify_providers_set_prompt() def all_node_ids(self): assert self.initialized @@ -263,6 +264,24 @@ class BasicCache: except Exception as e: _logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}") + async def _notify_providers_set_prompt(self): + from comfy_execution.cache_provider import ( + _has_cache_providers, _get_cache_providers, _logger + ) + + if not self.enable_providers: + return + if not _has_cache_providers(): + return + + for provider in _get_cache_providers(): + try: + task = asyncio.create_task(self._safe_provider_set_prompt(provider)) + self._pending_store_tasks.add(task) + task.add_done_callback(self._pending_store_tasks.discard) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} error on set_prompt: {e}") + @staticmethod async def _safe_provider_store(provider, context, cache_value): from comfy_execution.cache_provider import _logger @@ -271,6 +290,14 @@ class BasicCache: except Exception as e: _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}") + @staticmethod + async def _safe_provider_set_prompt(provider): + from comfy_execution.cache_provider import _logger + try: + await provider.on_set_prompt() + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} async set_prompt error: {e}") + async def _check_providers_lookup(self, node_id, cache_key): from comfy_execution.cache_provider import ( _has_cache_providers, _get_cache_providers,