Merge branch 'main' into fix/chore-fix

This commit is contained in:
Yeuoly
2024-12-24 21:28:56 +08:00
734 changed files with 7911 additions and 5007 deletions

View File

@ -5,7 +5,7 @@ from core.tools.entities.api_entities import ToolProviderApiEntity
class BuiltinToolProviderSort:
_position = {}
_position: dict[str, int] = {}
@classmethod
def sort(cls, providers: list[ToolProviderApiEntity]) -> list[ToolProviderApiEntity]:

View File

@ -23,8 +23,10 @@ class TTSTool(BuiltinTool):
provider, model = tool_parameters.get("model").split("#") # type: ignore
voice = tool_parameters.get(f"voice#{provider}#{model}")
model_manager = ModelManager()
if not self.runtime:
raise ValueError("Runtime is required")
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
provider=provider,
model_type=ModelType.TTS,
model=model,
@ -47,8 +49,11 @@ class TTSTool(BuiltinTool):
)
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
if not self.runtime:
raise ValueError("Runtime is required")
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
tid: str = self.runtime.tenant_id or ""
models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts")
items = []
for provider_model in models:
provider = provider_model.provider
@ -68,6 +73,8 @@ class TTSTool(BuiltinTool):
ToolParameter(
name=f"voice#{provider}#{model}",
label=I18nObject(en_US=f"Voice of {model}({provider})"),
human_description=I18nObject(en_US=f"Select a voice for {model} model"),
placeholder=I18nObject(en_US="Select a voice"),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
options=[
@ -89,6 +96,7 @@ class TTSTool(BuiltinTool):
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=True,
placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"),
options=options,
),
)

View File

@ -49,9 +49,12 @@ class BuiltinTool(Tool):
:return: the model result
"""
# invoke model
if self.runtime is None or self.identity is None:
raise ValueError("runtime and identity are required")
return ModelInvocationUtils.invoke(
user_id=user_id,
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
tool_type="builtin",
tool_name=self.entity.identity.name,
prompt_messages=prompt_messages,
@ -67,8 +70,11 @@ class BuiltinTool(Tool):
:param model_config: the model config
:return: the max tokens
"""
if self.runtime is None:
raise ValueError("runtime is required")
return ModelInvocationUtils.get_max_llm_context_tokens(
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
)
def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
@ -78,7 +84,12 @@ class BuiltinTool(Tool):
:param prompt_messages: the prompt messages
:return: the tokens
"""
return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages)
if self.runtime is None:
raise ValueError("runtime is required")
return ModelInvocationUtils.calculate_tokens(
tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages
)
def summary(self, user_id: str, content: str) -> str:
max_tokens = self.get_max_tokens()
@ -120,16 +131,16 @@ class BuiltinTool(Tool):
# merge lines into messages with max tokens
messages: list[str] = []
for i in new_lines:
for j in new_lines:
if len(messages) == 0:
messages.append(i)
messages.append(j)
else:
if len(messages[-1]) + len(i) < max_tokens * 0.5:
messages[-1] += i
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
messages.append(i)
if len(messages[-1]) + len(j) < max_tokens * 0.5:
messages[-1] += j
if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7:
messages.append(j)
else:
messages[-1] += i
messages[-1] += j
summaries = []
for i in range(len(messages)):