from __future__ import annotations import hashlib import logging from collections.abc import Generator, Iterable, Sequence from threading import Lock from typing import IO, Any, Union from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType from graphon.model_runtime.entities.provider_entities import ProviderEntity from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult from graphon.model_runtime.runtime import ModelRuntime from pydantic import ValidationError from redis import RedisError from configs import dify_config from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.model import PluginModelClient from extensions.ext_redis import redis_client from models.provider_ids import ModelProviderID logger = logging.getLogger(__name__) # `TS` means tenant scope TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__" class PluginModelRuntime(ModelRuntime): """Plugin-backed runtime adapter bound to tenant context and optional caller scope.""" tenant_id: str user_id: str | None client: PluginModelClient _provider_entities: tuple[ProviderEntity, ...] | None _provider_entities_lock: Lock def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None: if client is None: raise ValueError("client is required.") self.tenant_id = tenant_id self.user_id = user_id self.client = client self._provider_entities = None self._provider_entities_lock = Lock() def fetch_model_providers(self) -> Sequence[ProviderEntity]: if self._provider_entities is not None: return self._provider_entities with self._provider_entities_lock: if self._provider_entities is None: self._provider_entities = tuple( self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id) ) return self._provider_entities def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: provider_schema = self._get_provider_schema(provider) if icon_type.lower() == "icon_small": if not provider_schema.icon_small: raise ValueError(f"Provider {provider} does not have small icon.") file_name = ( provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US ) elif icon_type.lower() == "icon_small_dark": if not provider_schema.icon_small_dark: raise ValueError(f"Provider {provider} does not have small dark icon.") file_name = ( provider_schema.icon_small_dark.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small_dark.en_US ) else: raise ValueError(f"Unsupported icon type: {icon_type}.") if not file_name: raise ValueError(f"Provider {provider} does not have icon.") image_mime_types = { "jpg": "image/jpeg", "jpeg": "image/jpeg", "png": "image/png", "gif": "image/gif", "bmp": "image/bmp", "tiff": "image/tiff", "tif": "image/tiff", "webp": "image/webp", "svg": "image/svg+xml", "ico": "image/vnd.microsoft.icon", "heif": "image/heif", "heic": "image/heic", } extension = file_name.split(".")[-1] mime_type = image_mime_types.get(extension, "image/png") return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: plugin_id, provider_name = self._split_provider(provider) self.client.validate_provider_credentials( tenant_id=self.tenant_id, user_id=self.user_id, plugin_id=plugin_id, provider=provider_name, credentials=credentials, ) def validate_model_credentials( self, *, provider: str, model_type: ModelType, model: str, credentials: dict[str, Any], ) -> None: plugin_id, provider_name = self._split_provider(provider) self.client.validate_model_credentials( tenant_id=self.tenant_id, user_id=self.user_id, plugin_id=plugin_id, provider=provider_name, model_type=model_type.value, model=model, credentials=credentials, ) def get_model_schema( self, *, provider: str, model_type: ModelType, model: str, credentials: dict[str, Any], ) -> AIModelEntity | None: cache_key = self._get_schema_cache_key( provider=provider, model_type=model_type, model=model, credentials=credentials, ) cached_schema_json = None try: cached_schema_json = redis_client.get(cache_key) except (RedisError, RuntimeError) as exc: logger.warning( "Failed to read plugin model schema cache for model %s: %s", model, str(exc), exc_info=True, ) if cached_schema_json: try: return AIModelEntity.model_validate_json(cached_schema_json) except ValidationError: logger.warning("Failed to validate cached plugin model schema for model %s", model, exc_info=True) try: redis_client.delete(cache_key) except (RedisError, RuntimeError) as exc: logger.warning( "Failed to delete invalid plugin model schema cache for model %s: %s", model, str(exc), exc_info=True, ) plugin_id, provider_name = self._split_provider(provider) schema = self.client.get_model_schema( tenant_id=self.tenant_id, user_id=self.user_id, plugin_id=plugin_id, provider=provider_name, model_type=model_type.value, model=model, credentials=credentials, ) if schema: try: redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) except (RedisError, RuntimeError) as exc: logger.warning( "Failed to write plugin model schema cache for model %s: %s", model, str(exc), exc_info=True, ) return schema def invoke_llm( self, *, provider: str, model: str, credentials: dict[str, Any], model_parameters: dict[str, Any], prompt_messages: Sequence[PromptMessage], tools: list[PromptMessageTool] | None, stop: Sequence[str] | None, stream: bool, ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: plugin_id, provider_name = self._split_provider(provider) return self.client.invoke_llm( tenant_id=self.tenant_id, user_id=self.user_id, plugin_id=plugin_id, provider=provider_name, model=model, credentials=credentials, model_parameters=model_parameters, prompt_messages=list(prompt_messages), tools=tools, stop=list(stop) if stop else None, stream=stream, ) def get_llm_num_tokens( self, *, provider: str, model_type: ModelType, model: str, credentials: dict[str, Any], prompt_messages: Sequence[PromptMessage], tools: Sequence[PromptMessageTool] | None, ) -> int: if not dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED: return 0 plugin_id, provider_name = self._split_provider(provider) return self.client.get_llm_num_tokens( tenant_id=self.tenant_id, user_id=self.user_id, plugin_id=plugin_id, provider=provider_name, model_type=model_type.value, model=model, credentials=credentials, prompt_messages=list(prompt_messages), tools=list(tools) if tools else None, ) def invoke_text_embedding( self, *, provider: str, model: str, credentials: dict[str, Any], texts: list[str], input_type: EmbeddingInputType, ) -> EmbeddingResult: plugin_id, provider_name = self._split_provider(provider) return self.client.invoke_text_embedding( tenant_id=self.tenant_id, user_id=self.user_id, plugin_id=plugin_id, provider=provider_name, model=model, credentials=credentials, texts=texts, input_type=input_type, ) def invoke_multimodal_embedding( self, *, provider: str, model: str, credentials: dict[str, Any], documents: list[dict[str, Any]], input_type: EmbeddingInputType, ) -> EmbeddingResult: plugin_id, provider_name = self._split_provider(provider) return self.client.invoke_multimodal_embedding( tenant_id=self.tenant_id, user_id=self.user_id, plugin_id=plugin_id, provider=provider_name, model=model, credentials=credentials, documents=documents, input_type=input_type, ) def get_text_embedding_num_tokens( self, *, provider: str, model: str, credentials: dict[str, Any], texts: list[str], ) -> list[int]: plugin_id, provider_name = self._split_provider(provider) return self.client.get_text_embedding_num_tokens( tenant_id=self.tenant_id, user_id=self.user_id, plugin_id=plugin_id, provider=provider_name, model=model, credentials=credentials, texts=texts, ) def invoke_rerank( self, *, provider: str, model: str, credentials: dict[str, Any], query: str, docs: list[str], score_threshold: float | None, top_n: int | None, ) -> RerankResult: plugin_id, provider_name = self._split_provider(provider) return self.client.invoke_rerank( tenant_id=self.tenant_id, user_id=self.user_id, plugin_id=plugin_id, provider=provider_name, model=model, credentials=credentials, query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, ) def invoke_multimodal_rerank( self, *, provider: str, model: str, credentials: dict[str, Any], query: MultimodalRerankInput, docs: list[MultimodalRerankInput], score_threshold: float | None, top_n: int | None, ) -> RerankResult: plugin_id, provider_name = self._split_provider(provider) return self.client.invoke_multimodal_rerank( tenant_id=self.tenant_id, user_id=self.user_id, plugin_id=plugin_id, provider=provider_name, model=model, credentials=credentials, query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, ) def invoke_tts( self, *, provider: str, model: str, credentials: dict[str, Any], content_text: str, voice: str, ) -> Iterable[bytes]: plugin_id, provider_name = self._split_provider(provider) return self.client.invoke_tts( tenant_id=self.tenant_id, user_id=self.user_id, plugin_id=plugin_id, provider=provider_name, model=model, credentials=credentials, content_text=content_text, voice=voice, ) def get_tts_model_voices( self, *, provider: str, model: str, credentials: dict[str, Any], language: str | None, ) -> Any: plugin_id, provider_name = self._split_provider(provider) return self.client.get_tts_model_voices( tenant_id=self.tenant_id, user_id=self.user_id, plugin_id=plugin_id, provider=provider_name, model=model, credentials=credentials, language=language, ) def invoke_speech_to_text( self, *, provider: str, model: str, credentials: dict[str, Any], file: IO[bytes], ) -> str: plugin_id, provider_name = self._split_provider(provider) return self.client.invoke_speech_to_text( tenant_id=self.tenant_id, user_id=self.user_id, plugin_id=plugin_id, provider=provider_name, model=model, credentials=credentials, file=file, ) def invoke_moderation( self, *, provider: str, model: str, credentials: dict[str, Any], text: str, ) -> bool: plugin_id, provider_name = self._split_provider(provider) return self.client.invoke_moderation( tenant_id=self.tenant_id, user_id=self.user_id, plugin_id=plugin_id, provider=provider_name, model=model, credentials=credentials, text=text, ) def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str: """ Expose a bare provider alias only for the canonical provider mapping. Multiple plugins can publish the same short provider slug. If every provider entity keeps that slug in ``provider_name``, callers that still resolve by short name become order-dependent. Restrict the alias to the provider selected by ``ModelProviderID`` so legacy short-name lookups remain deterministic while the runtime surface stays canonical. """ try: canonical_provider_id = ModelProviderID(provider.provider) except ValueError: return "" if canonical_provider_id.plugin_id != provider.plugin_id: return "" if canonical_provider_id.provider_name != provider.provider: return "" return provider.provider def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity: declaration = provider.declaration.model_copy(deep=True) declaration.provider = f"{provider.plugin_id}/{provider.provider}" declaration.provider_name = self._get_provider_short_name_alias(provider) return declaration def _get_provider_schema(self, provider: str) -> ProviderEntity: providers = self.fetch_model_providers() provider_entity = next((item for item in providers if item.provider == provider), None) if provider_entity is None: provider_entity = next((item for item in providers if provider == item.provider_name), None) if provider_entity is None: raise ValueError(f"Invalid provider: {provider}") return provider_entity def _get_schema_cache_key( self, *, provider: str, model_type: ModelType, model: str, credentials: dict[str, Any], ) -> str: # The plugin daemon distinguishes ``None`` from an explicit empty-string # caller id, so the cache must only collapse ``None`` into tenant scope. cache_user_id = TENANT_SCOPE_SCHEMA_CACHE_USER_ID if self.user_id is None else self.user_id cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}:{cache_user_id}" sorted_credentials = sorted(credentials.items()) if credentials else [] if not sorted_credentials: return cache_key hashed_credentials = ":".join( [hashlib.md5(f"{key}:{value}".encode()).hexdigest() for key, value in sorted_credentials] ) return f"{cache_key}:{hashed_credentials}" def _split_provider(self, provider: str) -> tuple[str, str]: provider_id = ModelProviderID(provider) return provider_id.plugin_id, provider_id.provider_name