mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
Merge remote-tracking branch 'upstream/main' into feat/rag-2
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
import decimal
|
||||
import hashlib
|
||||
from threading import Lock
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
@ -100,7 +99,7 @@ class AIModel(BaseModel):
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
# get price info from predefined model schema
|
||||
price_config: Optional[PriceConfig] = None
|
||||
price_config: PriceConfig | None = None
|
||||
if model_schema and model_schema.pricing:
|
||||
price_config = model_schema.pricing
|
||||
|
||||
@ -133,7 +132,7 @@ class AIModel(BaseModel):
|
||||
currency=price_config.currency,
|
||||
)
|
||||
|
||||
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
|
||||
def get_model_schema(self, model: str, credentials: dict | None = None) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema by model name and credentials
|
||||
|
||||
@ -174,7 +173,7 @@ class AIModel(BaseModel):
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
Get customizable model schema from credentials
|
||||
|
||||
@ -232,7 +231,7 @@ class AIModel(BaseModel):
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
@ -93,12 +93,12 @@ class LargeLanguageModel(AIModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
model_parameters: dict | None = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -244,11 +244,11 @@ class LargeLanguageModel(AIModel):
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Invoke result generator
|
||||
@ -329,7 +329,7 @@ class LargeLanguageModel(AIModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
@ -357,7 +357,7 @@ class LargeLanguageModel(AIModel):
|
||||
)
|
||||
return 0
|
||||
|
||||
def _calc_response_usage(
|
||||
def calc_response_usage(
|
||||
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
|
||||
) -> LLMUsage:
|
||||
"""
|
||||
@ -406,11 +406,11 @@ class LargeLanguageModel(AIModel):
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
):
|
||||
"""
|
||||
Trigger before invoke callbacks
|
||||
@ -454,11 +454,11 @@ class LargeLanguageModel(AIModel):
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
):
|
||||
"""
|
||||
Trigger new chunk callbacks
|
||||
@ -501,11 +501,11 @@ class LargeLanguageModel(AIModel):
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
):
|
||||
"""
|
||||
Trigger after invoke callbacks
|
||||
@ -551,11 +551,11 @@ class LargeLanguageModel(AIModel):
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
):
|
||||
"""
|
||||
Trigger invoke error callbacks
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
@ -17,7 +16,7 @@ class ModerationModel(AIModel):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool:
|
||||
def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
@ -18,9 +16,9 @@ class RerankModel(AIModel):
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import IO, Optional
|
||||
from typing import IO
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
@ -16,7 +16,7 @@ class Speech2TextModel(AIModel):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str:
|
||||
"""
|
||||
Invoke speech to text model
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
@ -23,7 +21,7 @@ class TextEmbeddingModel(AIModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
user: str | None = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
@ -48,7 +46,7 @@ class TextEmbeddingModel(AIModel):
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=input_type.value,
|
||||
input_type=input_type,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import logging
|
||||
from threading import Lock
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_tokenizer: Optional[Any] = None
|
||||
_tokenizer: Any | None = None
|
||||
_lock = Lock()
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
@ -27,7 +26,7 @@ class TTSModel(AIModel):
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
user: Optional[str] = None,
|
||||
user: str | None = None,
|
||||
) -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -57,7 +56,7 @@ class TTSModel(AIModel):
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None):
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: str | None = None):
|
||||
"""
|
||||
Retrieves the list of voices supported by a given text-to-speech (TTS) model.
|
||||
|
||||
|
||||
@ -1,14 +1,9 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from threading import Lock
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import contexts
|
||||
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
@ -26,15 +21,10 @@ from models.provider_ids import ModelProviderID
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelProviderExtension(BaseModel):
|
||||
plugin_model_provider_entity: PluginModelProviderEntity
|
||||
position: Optional[int] = None
|
||||
|
||||
|
||||
class ModelProviderFactory:
|
||||
provider_position_map: dict[str, int]
|
||||
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
def __init__(self, tenant_id: str):
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
self.provider_position_map = {}
|
||||
@ -42,34 +32,15 @@ class ModelProviderFactory:
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_model_manager = PluginModelClient()
|
||||
|
||||
if not self.provider_position_map:
|
||||
# get the path of current classes
|
||||
current_path = os.path.abspath(__file__)
|
||||
model_providers_path = os.path.dirname(current_path)
|
||||
|
||||
# get _position.yaml file path
|
||||
self.provider_position_map = get_provider_position_map(model_providers_path)
|
||||
|
||||
def get_providers(self) -> Sequence[ProviderEntity]:
|
||||
"""
|
||||
Get all providers
|
||||
:return: list of providers
|
||||
"""
|
||||
# Fetch plugin model providers
|
||||
# FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server
|
||||
# The plugin server should return providers in the desired order
|
||||
plugin_providers = self.get_plugin_model_providers()
|
||||
|
||||
# Convert PluginModelProviderEntity to ModelProviderExtension
|
||||
model_provider_extensions = []
|
||||
for provider in plugin_providers:
|
||||
model_provider_extensions.append(ModelProviderExtension(plugin_model_provider_entity=provider))
|
||||
|
||||
sorted_extensions = sort_to_dict_by_position_map(
|
||||
position_map=self.provider_position_map,
|
||||
data=model_provider_extensions,
|
||||
name_func=lambda x: x.plugin_model_provider_entity.declaration.provider,
|
||||
)
|
||||
|
||||
return [extension.plugin_model_provider_entity.declaration for extension in sorted_extensions.values()]
|
||||
return [provider.declaration for provider in plugin_providers]
|
||||
|
||||
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
|
||||
"""
|
||||
@ -238,9 +209,9 @@ class ModelProviderFactory:
|
||||
def get_models(
|
||||
self,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
provider_configs: Optional[list[ProviderConfig]] = None,
|
||||
provider: str | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
provider_configs: list[ProviderConfig] | None = None,
|
||||
) -> list[SimpleProviderEntity]:
|
||||
"""
|
||||
Get all models for given model type
|
||||
|
||||
Reference in New Issue
Block a user