Merge remote-tracking branch 'upstream/main' into feat/rag-2

This commit is contained in:
QuantumGhost
2025-09-16 14:59:35 +08:00
791 changed files with 24372 additions and 7085 deletions

View File

@ -1,7 +1,6 @@
import decimal
import hashlib
from threading import Lock
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
@ -100,7 +99,7 @@ class AIModel(BaseModel):
model_schema = self.get_model_schema(model, credentials)
# get price info from predefined model schema
price_config: Optional[PriceConfig] = None
price_config: PriceConfig | None = None
if model_schema and model_schema.pricing:
price_config = model_schema.pricing
@ -133,7 +132,7 @@ class AIModel(BaseModel):
currency=price_config.currency,
)
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
def get_model_schema(self, model: str, credentials: dict | None = None) -> AIModelEntity | None:
"""
Get model schema by model name and credentials
@ -174,7 +173,7 @@ class AIModel(BaseModel):
return schema
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get customizable model schema from credentials
@ -232,7 +231,7 @@ class AIModel(BaseModel):
return schema
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get customizable model schema

View File

@ -2,7 +2,7 @@ import logging
import time
import uuid
from collections.abc import Generator, Sequence
from typing import Optional, Union
from typing import Union
from pydantic import ConfigDict
@ -93,12 +93,12 @@ class LargeLanguageModel(AIModel):
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
model_parameters: dict | None = None,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
"""
Invoke large language model
@ -244,11 +244,11 @@ class LargeLanguageModel(AIModel):
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Invoke result generator
@ -329,7 +329,7 @@ class LargeLanguageModel(AIModel):
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
tools: list[PromptMessageTool] | None = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -357,7 +357,7 @@ class LargeLanguageModel(AIModel):
)
return 0
def _calc_response_usage(
def calc_response_usage(
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
) -> LLMUsage:
"""
@ -406,11 +406,11 @@ class LargeLanguageModel(AIModel):
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
):
"""
Trigger before invoke callbacks
@ -454,11 +454,11 @@ class LargeLanguageModel(AIModel):
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
):
"""
Trigger new chunk callbacks
@ -501,11 +501,11 @@ class LargeLanguageModel(AIModel):
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
):
"""
Trigger after invoke callbacks
@ -551,11 +551,11 @@ class LargeLanguageModel(AIModel):
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
):
"""
Trigger invoke error callbacks

View File

@ -1,5 +1,4 @@
import time
from typing import Optional
from pydantic import ConfigDict
@ -17,7 +16,7 @@ class ModerationModel(AIModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool:
def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool:
"""
Invoke moderation model

View File

@ -1,5 +1,3 @@
from typing import Optional
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.model_providers.__base.ai_model import AIModel
@ -18,9 +16,9 @@ class RerankModel(AIModel):
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke rerank model

View File

@ -1,4 +1,4 @@
from typing import IO, Optional
from typing import IO
from pydantic import ConfigDict
@ -16,7 +16,7 @@ class Speech2TextModel(AIModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str:
"""
Invoke speech to text model

View File

@ -1,5 +1,3 @@
from typing import Optional
from pydantic import ConfigDict
from core.entities.embedding_type import EmbeddingInputType
@ -23,7 +21,7 @@ class TextEmbeddingModel(AIModel):
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""
@ -48,7 +46,7 @@ class TextEmbeddingModel(AIModel):
model=model,
credentials=credentials,
texts=texts,
input_type=input_type.value,
input_type=input_type,
)
except Exception as e:
raise self._transform_invoke_error(e)

View File

@ -1,10 +1,10 @@
import logging
from threading import Lock
from typing import Any, Optional
from typing import Any
logger = logging.getLogger(__name__)
_tokenizer: Optional[Any] = None
_tokenizer: Any | None = None
_lock = Lock()

View File

@ -1,6 +1,5 @@
import logging
from collections.abc import Iterable
from typing import Optional
from pydantic import ConfigDict
@ -27,7 +26,7 @@ class TTSModel(AIModel):
credentials: dict,
content_text: str,
voice: str,
user: Optional[str] = None,
user: str | None = None,
) -> Iterable[bytes]:
"""
Invoke large language model
@ -57,7 +56,7 @@ class TTSModel(AIModel):
except Exception as e:
raise self._transform_invoke_error(e)
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None):
def get_tts_model_voices(self, model: str, credentials: dict, language: str | None = None):
"""
Retrieves the list of voices supported by a given text-to-speech (TTS) model.

View File

@ -1,14 +1,9 @@
import hashlib
import logging
import os
from collections.abc import Sequence
from threading import Lock
from typing import Optional
from pydantic import BaseModel
import contexts
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel
@ -26,15 +21,10 @@ from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
class ModelProviderExtension(BaseModel):
plugin_model_provider_entity: PluginModelProviderEntity
position: Optional[int] = None
class ModelProviderFactory:
provider_position_map: dict[str, int]
def __init__(self, tenant_id: str) -> None:
def __init__(self, tenant_id: str):
from core.plugin.impl.model import PluginModelClient
self.provider_position_map = {}
@ -42,34 +32,15 @@ class ModelProviderFactory:
self.tenant_id = tenant_id
self.plugin_model_manager = PluginModelClient()
if not self.provider_position_map:
# get the path of current classes
current_path = os.path.abspath(__file__)
model_providers_path = os.path.dirname(current_path)
# get _position.yaml file path
self.provider_position_map = get_provider_position_map(model_providers_path)
def get_providers(self) -> Sequence[ProviderEntity]:
"""
Get all providers
:return: list of providers
"""
# Fetch plugin model providers
# FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server
# The plugin server should return providers in the desired order
plugin_providers = self.get_plugin_model_providers()
# Convert PluginModelProviderEntity to ModelProviderExtension
model_provider_extensions = []
for provider in plugin_providers:
model_provider_extensions.append(ModelProviderExtension(plugin_model_provider_entity=provider))
sorted_extensions = sort_to_dict_by_position_map(
position_map=self.provider_position_map,
data=model_provider_extensions,
name_func=lambda x: x.plugin_model_provider_entity.declaration.provider,
)
return [extension.plugin_model_provider_entity.declaration for extension in sorted_extensions.values()]
return [provider.declaration for provider in plugin_providers]
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
"""
@ -238,9 +209,9 @@ class ModelProviderFactory:
def get_models(
self,
*,
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
provider_configs: Optional[list[ProviderConfig]] = None,
provider: str | None = None,
model_type: ModelType | None = None,
provider_configs: list[ProviderConfig] | None = None,
) -> list[SimpleProviderEntity]:
"""
Get all models for given model type