mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 14:08:18 +08:00
refactor(api): continue decoupling dify_graph from API concerns (#33580)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
499
api/core/plugin/impl/model_runtime.py
Normal file
499
api/core/plugin/impl/model_runtime.py
Normal file
@ -0,0 +1,499 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from threading import Lock
|
||||
from typing import IO, Any, Union
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
from configs import dify_config
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
|
||||
from dify_graph.model_runtime.runtime import ModelRuntime
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# `TS` means tenant scope
|
||||
TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__"
|
||||
|
||||
|
||||
class PluginModelRuntime(ModelRuntime):
|
||||
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str | None
|
||||
client: PluginModelClient
|
||||
_provider_entities: tuple[ProviderEntity, ...] | None
|
||||
_provider_entities_lock: Lock
|
||||
|
||||
def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None:
|
||||
if client is None:
|
||||
raise ValueError("client is required.")
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self.client = client
|
||||
self._provider_entities = None
|
||||
self._provider_entities_lock = Lock()
|
||||
|
||||
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
|
||||
if self._provider_entities is not None:
|
||||
return self._provider_entities
|
||||
|
||||
with self._provider_entities_lock:
|
||||
if self._provider_entities is None:
|
||||
self._provider_entities = tuple(
|
||||
self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id)
|
||||
)
|
||||
|
||||
return self._provider_entities
|
||||
|
||||
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
||||
provider_schema = self._get_provider_schema(provider)
|
||||
|
||||
if icon_type.lower() == "icon_small":
|
||||
if not provider_schema.icon_small:
|
||||
raise ValueError(f"Provider {provider} does not have small icon.")
|
||||
file_name = (
|
||||
provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
|
||||
)
|
||||
elif icon_type.lower() == "icon_small_dark":
|
||||
if not provider_schema.icon_small_dark:
|
||||
raise ValueError(f"Provider {provider} does not have small dark icon.")
|
||||
file_name = (
|
||||
provider_schema.icon_small_dark.zh_Hans
|
||||
if lang.lower() == "zh_hans"
|
||||
else provider_schema.icon_small_dark.en_US
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported icon type: {icon_type}.")
|
||||
|
||||
if not file_name:
|
||||
raise ValueError(f"Provider {provider} does not have icon.")
|
||||
|
||||
image_mime_types = {
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"bmp": "image/bmp",
|
||||
"tiff": "image/tiff",
|
||||
"tif": "image/tiff",
|
||||
"webp": "image/webp",
|
||||
"svg": "image/svg+xml",
|
||||
"ico": "image/vnd.microsoft.icon",
|
||||
"heif": "image/heif",
|
||||
"heic": "image/heic",
|
||||
}
|
||||
|
||||
extension = file_name.split(".")[-1]
|
||||
mime_type = image_mime_types.get(extension, "image/png")
|
||||
return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
|
||||
|
||||
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
self.client.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
def validate_model_credentials(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> None:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
self.client.validate_model_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
def get_model_schema(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> AIModelEntity | None:
|
||||
cache_key = self._get_schema_cache_key(
|
||||
provider=provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
cached_schema_json = None
|
||||
try:
|
||||
cached_schema_json = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to read plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if cached_schema_json:
|
||||
try:
|
||||
return AIModelEntity.model_validate_json(cached_schema_json)
|
||||
except ValidationError:
|
||||
logger.warning("Failed to validate cached plugin model schema for model %s", model, exc_info=True)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to delete invalid plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
schema = self.client.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
if schema:
|
||||
try:
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to write plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
model_parameters=model_parameters,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=tools,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
) -> int:
|
||||
if not dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
|
||||
return 0
|
||||
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.get_llm_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=list(tools) if tools else None,
|
||||
)
|
||||
|
||||
def invoke_text_embedding(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
input_type: EmbeddingInputType,
|
||||
) -> EmbeddingResult:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_text_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
documents: list[dict[str, Any]],
|
||||
input_type: EmbeddingInputType,
|
||||
) -> EmbeddingResult:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_multimodal_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
documents=documents,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
) -> list[int]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.get_text_embedding_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
def invoke_rerank(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: float | None,
|
||||
top_n: int | None,
|
||||
) -> RerankResult:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None,
|
||||
top_n: int | None,
|
||||
) -> RerankResult:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_multimodal_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
def invoke_tts(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
content_text: str,
|
||||
voice: str,
|
||||
) -> Iterable[bytes]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_tts(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def get_tts_model_voices(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
language: str | None,
|
||||
) -> Any:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.get_tts_model_voices(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
language=language,
|
||||
)
|
||||
|
||||
def invoke_speech_to_text(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
file: IO[bytes],
|
||||
) -> str:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_speech_to_text(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
file=file,
|
||||
)
|
||||
|
||||
def invoke_moderation(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
text: str,
|
||||
) -> bool:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
text=text,
|
||||
)
|
||||
|
||||
def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str:
|
||||
"""
|
||||
Expose a bare provider alias only for the canonical provider mapping.
|
||||
|
||||
Multiple plugins can publish the same short provider slug. If every
|
||||
provider entity keeps that slug in ``provider_name``, callers that still
|
||||
resolve by short name become order-dependent. Restrict the alias to the
|
||||
provider selected by ``ModelProviderID`` so legacy short-name lookups
|
||||
remain deterministic while the runtime surface stays canonical.
|
||||
"""
|
||||
try:
|
||||
canonical_provider_id = ModelProviderID(provider.provider)
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
if canonical_provider_id.plugin_id != provider.plugin_id:
|
||||
return ""
|
||||
if canonical_provider_id.provider_name != provider.provider:
|
||||
return ""
|
||||
|
||||
return provider.provider
|
||||
|
||||
def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity:
|
||||
declaration = provider.declaration.model_copy(deep=True)
|
||||
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
|
||||
declaration.provider_name = self._get_provider_short_name_alias(provider)
|
||||
return declaration
|
||||
|
||||
def _get_provider_schema(self, provider: str) -> ProviderEntity:
|
||||
providers = self.fetch_model_providers()
|
||||
provider_entity = next((item for item in providers if item.provider == provider), None)
|
||||
if provider_entity is None:
|
||||
provider_entity = next((item for item in providers if provider == item.provider_name), None)
|
||||
if provider_entity is None:
|
||||
raise ValueError(f"Invalid provider: {provider}")
|
||||
return provider_entity
|
||||
|
||||
def _get_schema_cache_key(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> str:
|
||||
# The plugin daemon distinguishes ``None`` from an explicit empty-string
|
||||
# caller id, so the cache must only collapse ``None`` into tenant scope.
|
||||
cache_user_id = TENANT_SCOPE_SCHEMA_CACHE_USER_ID if self.user_id is None else self.user_id
|
||||
cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}:{cache_user_id}"
|
||||
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||
if not sorted_credentials:
|
||||
return cache_key
|
||||
hashed_credentials = ":".join(
|
||||
[hashlib.md5(f"{key}:{value}".encode()).hexdigest() for key, value in sorted_credentials]
|
||||
)
|
||||
return f"{cache_key}:{hashed_credentials}"
|
||||
|
||||
def _split_provider(self, provider: str) -> tuple[str, str]:
|
||||
provider_id = ModelProviderID(provider)
|
||||
return provider_id.plugin_id, provider_id.provider_name
|
||||
Reference in New Issue
Block a user