mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
refactor(api): continue decoupling dify_graph from API concerns (#33580)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
@ -30,10 +30,27 @@ from dify_graph.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
@staticmethod
|
||||
def _get_bound_model_instance(
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str | None,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
):
|
||||
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def invoke_llm(
|
||||
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
|
||||
@ -41,8 +58,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke llm
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@ -55,7 +73,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
@ -94,8 +111,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke llm with structured output
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@ -115,7 +133,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
model_parameters=payload.completion_params,
|
||||
)
|
||||
|
||||
@ -156,18 +173,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke text embedding
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_text_embedding(
|
||||
texts=payload.texts,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_text_embedding(texts=payload.texts)
|
||||
|
||||
return response
|
||||
|
||||
@ -176,8 +191,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke rerank
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@ -189,7 +205,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
docs=payload.docs,
|
||||
score_threshold=payload.score_threshold,
|
||||
top_n=payload.top_n,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
return response
|
||||
@ -199,20 +214,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke tts
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_tts(
|
||||
content_text=payload.content_text,
|
||||
tenant_id=tenant.id,
|
||||
voice=payload.voice,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice)
|
||||
|
||||
def handle() -> Generator[dict, None, None]:
|
||||
for chunk in response:
|
||||
@ -225,8 +236,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke speech2text
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@ -238,10 +250,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
temp.flush()
|
||||
temp.seek(0)
|
||||
|
||||
response = model_instance.invoke_speech2text(
|
||||
file=temp,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_speech2text(file=temp)
|
||||
|
||||
return {
|
||||
"result": response,
|
||||
@ -252,36 +261,38 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke moderation
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_moderation(
|
||||
text=payload.text,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_moderation(text=payload.text)
|
||||
|
||||
return {
|
||||
"result": response,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_system_model_max_tokens(cls, tenant_id: str) -> int:
|
||||
def get_system_model_max_tokens(cls, tenant_id: str, user_id: str | None = None) -> int:
|
||||
"""
|
||||
get system model max tokens
|
||||
"""
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id)
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id, user_id=user_id)
|
||||
|
||||
@classmethod
|
||||
def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
|
||||
def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int:
|
||||
"""
|
||||
get prompt tokens
|
||||
"""
|
||||
return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages)
|
||||
return ModelInvocationUtils.calculate_tokens(
|
||||
tenant_id=tenant_id,
|
||||
prompt_messages=prompt_messages,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def invoke_system_model(
|
||||
@ -299,6 +310,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
tool_type=ToolProviderType.PLUGIN,
|
||||
tool_name="plugin",
|
||||
prompt_messages=prompt_messages,
|
||||
caller_user_id=user_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -306,7 +318,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke summary
|
||||
"""
|
||||
max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id)
|
||||
max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id, user_id=user_id)
|
||||
content = payload.text
|
||||
|
||||
SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
|
||||
@ -325,6 +337,7 @@ Here is the extra instruction you need to follow:
|
||||
cls.get_prompt_tokens(
|
||||
tenant_id=tenant.id,
|
||||
prompt_messages=[UserPromptMessage(content=content)],
|
||||
user_id=user_id,
|
||||
)
|
||||
< max_tokens * 0.6
|
||||
):
|
||||
@ -337,6 +350,7 @@ Here is the extra instruction you need to follow:
|
||||
SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
|
||||
UserPromptMessage(content=content),
|
||||
],
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
def summarize(content: str) -> str:
|
||||
@ -394,6 +408,7 @@ Here is the extra instruction you need to follow:
|
||||
cls.get_prompt_tokens(
|
||||
tenant_id=tenant.id,
|
||||
prompt_messages=[UserPromptMessage(content=result)],
|
||||
user_id=user_id,
|
||||
)
|
||||
> max_tokens * 0.7
|
||||
):
|
||||
|
||||
@ -31,7 +31,13 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
|
||||
# get tool runtime
|
||||
try:
|
||||
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
|
||||
tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id
|
||||
tool_type,
|
||||
tenant_id,
|
||||
provider,
|
||||
tool_name,
|
||||
tool_parameters,
|
||||
user_id=user_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
response = ToolEngine.generic_invoke(
|
||||
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1
|
||||
|
||||
Reference in New Issue
Block a user