mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 16:38:04 +08:00
Merge branch 'main' into fix/chore-fix
This commit is contained in:
@ -5,7 +5,7 @@ from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
|
||||
|
||||
class BuiltinToolProviderSort:
|
||||
_position = {}
|
||||
_position: dict[str, int] = {}
|
||||
|
||||
@classmethod
|
||||
def sort(cls, providers: list[ToolProviderApiEntity]) -> list[ToolProviderApiEntity]:
|
||||
|
||||
@ -23,8 +23,10 @@ class TTSTool(BuiltinTool):
|
||||
provider, model = tool_parameters.get("model").split("#") # type: ignore
|
||||
voice = tool_parameters.get(f"voice#{provider}#{model}")
|
||||
model_manager = ModelManager()
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
provider=provider,
|
||||
model_type=ModelType.TTS,
|
||||
model=model,
|
||||
@ -47,8 +49,11 @@ class TTSTool(BuiltinTool):
|
||||
)
|
||||
|
||||
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
|
||||
tid: str = self.runtime.tenant_id or ""
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts")
|
||||
items = []
|
||||
for provider_model in models:
|
||||
provider = provider_model.provider
|
||||
@ -68,6 +73,8 @@ class TTSTool(BuiltinTool):
|
||||
ToolParameter(
|
||||
name=f"voice#{provider}#{model}",
|
||||
label=I18nObject(en_US=f"Voice of {model}({provider})"),
|
||||
human_description=I18nObject(en_US=f"Select a voice for {model} model"),
|
||||
placeholder=I18nObject(en_US="Select a voice"),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
options=[
|
||||
@ -89,6 +96,7 @@ class TTSTool(BuiltinTool):
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=True,
|
||||
placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"),
|
||||
options=options,
|
||||
),
|
||||
)
|
||||
|
||||
@ -49,9 +49,12 @@ class BuiltinTool(Tool):
|
||||
:return: the model result
|
||||
"""
|
||||
# invoke model
|
||||
if self.runtime is None or self.identity is None:
|
||||
raise ValueError("runtime and identity are required")
|
||||
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tool_type="builtin",
|
||||
tool_name=self.entity.identity.name,
|
||||
prompt_messages=prompt_messages,
|
||||
@ -67,8 +70,11 @@ class BuiltinTool(Tool):
|
||||
:param model_config: the model config
|
||||
:return: the max tokens
|
||||
"""
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
)
|
||||
|
||||
def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
|
||||
@ -78,7 +84,12 @@ class BuiltinTool(Tool):
|
||||
:param prompt_messages: the prompt messages
|
||||
:return: the tokens
|
||||
"""
|
||||
return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages)
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
|
||||
return ModelInvocationUtils.calculate_tokens(
|
||||
tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
def summary(self, user_id: str, content: str) -> str:
|
||||
max_tokens = self.get_max_tokens()
|
||||
@ -120,16 +131,16 @@ class BuiltinTool(Tool):
|
||||
|
||||
# merge lines into messages with max tokens
|
||||
messages: list[str] = []
|
||||
for i in new_lines:
|
||||
for j in new_lines:
|
||||
if len(messages) == 0:
|
||||
messages.append(i)
|
||||
messages.append(j)
|
||||
else:
|
||||
if len(messages[-1]) + len(i) < max_tokens * 0.5:
|
||||
messages[-1] += i
|
||||
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
|
||||
messages.append(i)
|
||||
if len(messages[-1]) + len(j) < max_tokens * 0.5:
|
||||
messages[-1] += j
|
||||
if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7:
|
||||
messages.append(j)
|
||||
else:
|
||||
messages[-1] += i
|
||||
messages[-1] += j
|
||||
|
||||
summaries = []
|
||||
for i in range(len(messages)):
|
||||
|
||||
Reference in New Issue
Block a user