mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 11:53:04 +08:00
feat(cache-provider): add on_set_prompt lifecycle hook for providers
Adds a new on_set_prompt() lifecycle hook on CacheProvider that fires after the cache key set is prepared for a new prompt. Dispatched via asyncio.create_task with errors swallowed (same fail-safe pattern as on_store / on_lookup). Why: BasicCache's lifecycle notifications to external providers were incomplete. set_prompt is a key per-prompt event that providers need visibility into — for example, to reset per-prompt timing/state used for cost-aware caching policies (a provider can set t=0 here, then measure elapsed at each on_store to estimate compute saved by a hit). Backward-compatible: default implementation is a no-op, existing providers compile and run unchanged. Providers that need the per-prompt boundary override on_set_prompt().
This commit is contained in:
@ -21,6 +21,10 @@ class CacheProvider(ABC):
|
||||
Exceptions from provider methods are caught by the caller and never break execution.
|
||||
"""
|
||||
|
||||
async def on_set_prompt(self) -> None:
|
||||
"""Called after prompt cache keys are prepared. Dispatched via asyncio.create_task."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
|
||||
|
||||
@ -164,6 +164,7 @@ class BasicCache:
|
||||
await self.cache_key_set.add_keys(node_ids)
|
||||
self.is_changed_cache = is_changed_cache
|
||||
self.initialized = True
|
||||
await self._notify_providers_set_prompt()
|
||||
|
||||
def all_node_ids(self):
|
||||
assert self.initialized
|
||||
@ -263,6 +264,24 @@ class BasicCache:
|
||||
except Exception as e:
|
||||
_logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
|
||||
|
||||
async def _notify_providers_set_prompt(self):
|
||||
from comfy_execution.cache_provider import (
|
||||
_has_cache_providers, _get_cache_providers, _logger
|
||||
)
|
||||
|
||||
if not self.enable_providers:
|
||||
return
|
||||
if not _has_cache_providers():
|
||||
return
|
||||
|
||||
for provider in _get_cache_providers():
|
||||
try:
|
||||
task = asyncio.create_task(self._safe_provider_set_prompt(provider))
|
||||
self._pending_store_tasks.add(task)
|
||||
task.add_done_callback(self._pending_store_tasks.discard)
|
||||
except Exception as e:
|
||||
_logger.warning(f"Cache provider {provider.__class__.__name__} error on set_prompt: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def _safe_provider_store(provider, context, cache_value):
|
||||
from comfy_execution.cache_provider import _logger
|
||||
@ -271,6 +290,14 @@ class BasicCache:
|
||||
except Exception as e:
|
||||
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def _safe_provider_set_prompt(provider):
|
||||
from comfy_execution.cache_provider import _logger
|
||||
try:
|
||||
await provider.on_set_prompt()
|
||||
except Exception as e:
|
||||
_logger.warning(f"Cache provider {provider.__class__.__name__} async set_prompt error: {e}")
|
||||
|
||||
async def _check_providers_lookup(self, node_id, cache_key):
|
||||
from comfy_execution.cache_provider import (
|
||||
_has_cache_providers, _get_cache_providers,
|
||||
|
||||
Reference in New Issue
Block a user