mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
feat: backwards invoke llm
This commit is contained in:
@ -1,10 +1,13 @@
|
|||||||
import time
|
import time
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.inner_api import api
|
from controllers.inner_api import api
|
||||||
from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
|
from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
|
||||||
from controllers.inner_api.wraps import plugin_inner_api_only
|
from controllers.inner_api.wraps import plugin_inner_api_only
|
||||||
|
from core.plugin.backwards_invocation.model import PluginBackwardsInvocation
|
||||||
from core.plugin.entities.request import (
|
from core.plugin.entities.request import (
|
||||||
RequestInvokeLLM,
|
RequestInvokeLLM,
|
||||||
RequestInvokeModeration,
|
RequestInvokeModeration,
|
||||||
@ -17,7 +20,6 @@ from core.plugin.entities.request import (
|
|||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
from libs.helper import compact_generate_response
|
from libs.helper import compact_generate_response
|
||||||
from models.account import Tenant
|
from models.account import Tenant
|
||||||
from services.plugin.plugin_invoke_service import PluginInvokeService
|
|
||||||
|
|
||||||
|
|
||||||
class PluginInvokeLLMApi(Resource):
|
class PluginInvokeLLMApi(Resource):
|
||||||
@ -26,7 +28,15 @@ class PluginInvokeLLMApi(Resource):
|
|||||||
@get_tenant
|
@get_tenant
|
||||||
@plugin_data(payload_type=RequestInvokeLLM)
|
@plugin_data(payload_type=RequestInvokeLLM)
|
||||||
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM):
|
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM):
|
||||||
pass
|
def generator():
|
||||||
|
response = PluginBackwardsInvocation.invoke_llm(user_id, tenant_model, payload)
|
||||||
|
if isinstance(response, Generator):
|
||||||
|
for chunk in response:
|
||||||
|
yield chunk.model_dump_json().encode() + b'\n\n'
|
||||||
|
else:
|
||||||
|
yield response.model_dump_json().encode() + b'\n\n'
|
||||||
|
|
||||||
|
return compact_generate_response(generator())
|
||||||
|
|
||||||
|
|
||||||
class PluginInvokeTextEmbeddingApi(Resource):
|
class PluginInvokeTextEmbeddingApi(Resource):
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from core.entities.provider_configuration import ProviderConfiguration, Provider
|
|||||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||||
from core.errors.error import ProviderTokenNotInitError
|
from core.errors.error import ProviderTokenNotInitError
|
||||||
from core.model_runtime.callbacks.base_callback import Callback
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
@ -103,7 +103,7 @@ class ModelInstance:
|
|||||||
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
|
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
|
||||||
-> Union[LLMResult, Generator]:
|
-> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
|
|||||||
49
api/core/plugin/backwards_invocation/model.py
Normal file
49
api/core/plugin/backwards_invocation/model.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from core.model_manager import ModelManager
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
|
from core.plugin.entities.request import RequestInvokeLLM
|
||||||
|
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||||
|
from models.account import Tenant
|
||||||
|
|
||||||
|
|
||||||
|
class PluginBackwardsInvocation:
|
||||||
|
@classmethod
|
||||||
|
def invoke_llm(
|
||||||
|
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
|
||||||
|
) -> Generator[LLMResultChunk, None, None] | LLMResult:
|
||||||
|
"""
|
||||||
|
invoke llm
|
||||||
|
"""
|
||||||
|
model_instance = ModelManager().get_model_instance(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider=payload.provider,
|
||||||
|
model_type=payload.model_type,
|
||||||
|
model=payload.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# invoke model
|
||||||
|
response = model_instance.invoke_llm(
|
||||||
|
prompt_messages=payload.prompt_messages,
|
||||||
|
model_parameters=payload.model_parameters,
|
||||||
|
tools=payload.tools,
|
||||||
|
stop=payload.stop,
|
||||||
|
stream=payload.stream or True,
|
||||||
|
user=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, Generator):
|
||||||
|
|
||||||
|
def handle() -> Generator[LLMResultChunk, None, None]:
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.delta.usage:
|
||||||
|
LLMNode.deduct_llm_quota(
|
||||||
|
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return handle()
|
||||||
|
else:
|
||||||
|
if response.usage:
|
||||||
|
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||||
|
return response
|
||||||
@ -1,4 +1,17 @@
|
|||||||
from pydantic import BaseModel
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageRole,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
ToolPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
|
||||||
|
|
||||||
class RequestInvokeTool(BaseModel):
|
class RequestInvokeTool(BaseModel):
|
||||||
@ -6,36 +19,76 @@ class RequestInvokeTool(BaseModel):
|
|||||||
Request to invoke a tool
|
Request to invoke a tool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class RequestInvokeLLM(BaseModel):
|
|
||||||
|
class BaseRequestInvokeModel(BaseModel):
|
||||||
|
provider: str
|
||||||
|
model: str
|
||||||
|
model_type: ModelType
|
||||||
|
|
||||||
|
|
||||||
|
class RequestInvokeLLM(BaseRequestInvokeModel):
|
||||||
"""
|
"""
|
||||||
Request to invoke LLM
|
Request to invoke LLM
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_type: ModelType = ModelType.LLM
|
||||||
|
mode: str
|
||||||
|
model_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
prompt_messages: list[PromptMessage]
|
||||||
|
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
|
||||||
|
stop: Optional[list[str]] = Field(default_factory=list)
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
@field_validator('prompt_messages', mode='before')
|
||||||
|
def convert_prompt_messages(cls, v):
|
||||||
|
if not isinstance(v, list):
|
||||||
|
raise ValueError('prompt_messages must be a list')
|
||||||
|
|
||||||
|
for i in range(len(v)):
|
||||||
|
if v[i]['role'] == PromptMessageRole.USER.value:
|
||||||
|
v[i] = UserPromptMessage(**v[i])
|
||||||
|
elif v[i]['role'] == PromptMessageRole.ASSISTANT.value:
|
||||||
|
v[i] = AssistantPromptMessage(**v[i])
|
||||||
|
elif v[i]['role'] == PromptMessageRole.SYSTEM.value:
|
||||||
|
v[i] = SystemPromptMessage(**v[i])
|
||||||
|
elif v[i]['role'] == PromptMessageRole.TOOL.value:
|
||||||
|
v[i] = ToolPromptMessage(**v[i])
|
||||||
|
else:
|
||||||
|
v[i] = PromptMessage(**v[i])
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class RequestInvokeTextEmbedding(BaseModel):
|
class RequestInvokeTextEmbedding(BaseModel):
|
||||||
"""
|
"""
|
||||||
Request to invoke text embedding
|
Request to invoke text embedding
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class RequestInvokeRerank(BaseModel):
|
class RequestInvokeRerank(BaseModel):
|
||||||
"""
|
"""
|
||||||
Request to invoke rerank
|
Request to invoke rerank
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class RequestInvokeTTS(BaseModel):
|
class RequestInvokeTTS(BaseModel):
|
||||||
"""
|
"""
|
||||||
Request to invoke TTS
|
Request to invoke TTS
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class RequestInvokeSpeech2Text(BaseModel):
|
class RequestInvokeSpeech2Text(BaseModel):
|
||||||
"""
|
"""
|
||||||
Request to invoke speech2text
|
Request to invoke speech2text
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class RequestInvokeModeration(BaseModel):
|
class RequestInvokeModeration(BaseModel):
|
||||||
"""
|
"""
|
||||||
Request to invoke moderation
|
Request to invoke moderation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class RequestInvokeNode(BaseModel):
|
class RequestInvokeNode(BaseModel):
|
||||||
"""
|
"""
|
||||||
Request to invoke node
|
Request to invoke node
|
||||||
|
|||||||
Reference in New Issue
Block a user