diff --git a/comfy_api/latest/_caching.py b/comfy_api/latest/_caching.py index 30c8848cd..44dc67b7a 100644 --- a/comfy_api/latest/_caching.py +++ b/comfy_api/latest/_caching.py @@ -21,6 +21,10 @@ class CacheProvider(ABC): Exceptions from provider methods are caught by the caller and never break execution. """ + async def on_set_prompt(self) -> None: + """Called after prompt cache keys are prepared. Dispatched via asyncio.create_task.""" + pass + @abstractmethod async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: """Called on local cache miss. Return CacheValue if found, None otherwise.""" diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index ba1e8bc84..b823985ac 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -164,6 +164,7 @@ class BasicCache: await self.cache_key_set.add_keys(node_ids) self.is_changed_cache = is_changed_cache self.initialized = True + await self._notify_providers_set_prompt() def all_node_ids(self): assert self.initialized @@ -263,6 +264,24 @@ class BasicCache: except Exception as e: _logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}") + async def _notify_providers_set_prompt(self): + from comfy_execution.cache_provider import ( + _has_cache_providers, _get_cache_providers, _logger + ) + + if not self.enable_providers: + return + if not _has_cache_providers(): + return + + for provider in _get_cache_providers(): + try: + task = asyncio.create_task(self._safe_provider_set_prompt(provider)) + self._pending_store_tasks.add(task) + task.add_done_callback(self._pending_store_tasks.discard) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} error on set_prompt: {e}") + @staticmethod async def _safe_provider_store(provider, context, cache_value): from comfy_execution.cache_provider import _logger @@ -271,6 +290,14 @@ class BasicCache: except Exception as e: _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}") + @staticmethod + async def _safe_provider_set_prompt(provider): + from comfy_execution.cache_provider import _logger + try: + await provider.on_set_prompt() + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} async set_prompt error: {e}") + async def _check_providers_lookup(self, node_id, cache_key): from comfy_execution.cache_provider import ( _has_cache_providers, _get_cache_providers,