feat: backwards invoke llm

This commit is contained in:
Yeuoly
2024-07-29 22:08:14 +08:00
parent d52476c1c9
commit 31e8b134d1
4 changed files with 119 additions and 7 deletions

View File

@ -1,10 +1,13 @@
import time import time
from collections.abc import Generator
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.inner_api import api from controllers.inner_api import api
from controllers.inner_api.plugin.wraps import get_tenant, plugin_data from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
from controllers.inner_api.wraps import plugin_inner_api_only from controllers.inner_api.wraps import plugin_inner_api_only
from core.plugin.backwards_invocation.model import PluginBackwardsInvocation
from core.plugin.entities.request import ( from core.plugin.entities.request import (
RequestInvokeLLM, RequestInvokeLLM,
RequestInvokeModeration, RequestInvokeModeration,
@ -17,7 +20,6 @@ from core.plugin.entities.request import (
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
from libs.helper import compact_generate_response from libs.helper import compact_generate_response
from models.account import Tenant from models.account import Tenant
from services.plugin.plugin_invoke_service import PluginInvokeService
class PluginInvokeLLMApi(Resource): class PluginInvokeLLMApi(Resource):
@ -26,7 +28,15 @@ class PluginInvokeLLMApi(Resource):
@get_tenant @get_tenant
@plugin_data(payload_type=RequestInvokeLLM) @plugin_data(payload_type=RequestInvokeLLM)
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM):
pass def generator():
response = PluginBackwardsInvocation.invoke_llm(user_id, tenant_model, payload)
if isinstance(response, Generator):
for chunk in response:
yield chunk.model_dump_json().encode() + b'\n\n'
else:
yield response.model_dump_json().encode() + b'\n\n'
return compact_generate_response(generator())
class PluginInvokeTextEmbeddingApi(Resource): class PluginInvokeTextEmbeddingApi(Resource):

View File

@ -7,7 +7,7 @@ from core.entities.provider_configuration import ProviderConfiguration, Provider
from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.errors.error import ProviderTokenNotInitError from core.errors.error import ProviderTokenNotInitError
from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.entities.rerank_entities import RerankResult
@ -103,7 +103,7 @@ class ModelInstance:
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
-> Union[LLMResult, Generator]: -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
""" """
Invoke large language model Invoke large language model

View File

@ -0,0 +1,49 @@
from collections.abc import Generator
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.plugin.entities.request import RequestInvokeLLM
from core.workflow.nodes.llm.llm_node import LLMNode
from models.account import Tenant
class PluginBackwardsInvocation:
@classmethod
def invoke_llm(
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
) -> Generator[LLMResultChunk, None, None] | LLMResult:
"""
invoke llm
"""
model_instance = ModelManager().get_model_instance(
tenant_id=tenant.id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_llm(
prompt_messages=payload.prompt_messages,
model_parameters=payload.model_parameters,
tools=payload.tools,
stop=payload.stop,
stream=payload.stream or True,
user=user_id,
)
if isinstance(response, Generator):
def handle() -> Generator[LLMResultChunk, None, None]:
for chunk in response:
if chunk.delta.usage:
LLMNode.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
yield chunk
return handle()
else:
if response.usage:
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
return response

View File

@ -1,4 +1,17 @@
from pydantic import BaseModel from typing import Any, Optional
from pydantic import BaseModel, Field, field_validator
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageRole,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelType
class RequestInvokeTool(BaseModel): class RequestInvokeTool(BaseModel):
@ -6,37 +19,77 @@ class RequestInvokeTool(BaseModel):
Request to invoke a tool Request to invoke a tool
""" """
class RequestInvokeLLM(BaseModel):
class BaseRequestInvokeModel(BaseModel):
provider: str
model: str
model_type: ModelType
class RequestInvokeLLM(BaseRequestInvokeModel):
""" """
Request to invoke LLM Request to invoke LLM
""" """
model_type: ModelType = ModelType.LLM
mode: str
model_parameters: dict[str, Any] = Field(default_factory=dict)
prompt_messages: list[PromptMessage]
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
stop: Optional[list[str]] = Field(default_factory=list)
stream: Optional[bool] = False
@field_validator('prompt_messages', mode='before')
def convert_prompt_messages(cls, v):
if not isinstance(v, list):
raise ValueError('prompt_messages must be a list')
for i in range(len(v)):
if v[i]['role'] == PromptMessageRole.USER.value:
v[i] = UserPromptMessage(**v[i])
elif v[i]['role'] == PromptMessageRole.ASSISTANT.value:
v[i] = AssistantPromptMessage(**v[i])
elif v[i]['role'] == PromptMessageRole.SYSTEM.value:
v[i] = SystemPromptMessage(**v[i])
elif v[i]['role'] == PromptMessageRole.TOOL.value:
v[i] = ToolPromptMessage(**v[i])
else:
v[i] = PromptMessage(**v[i])
return v
class RequestInvokeTextEmbedding(BaseModel): class RequestInvokeTextEmbedding(BaseModel):
""" """
Request to invoke text embedding Request to invoke text embedding
""" """
class RequestInvokeRerank(BaseModel): class RequestInvokeRerank(BaseModel):
""" """
Request to invoke rerank Request to invoke rerank
""" """
class RequestInvokeTTS(BaseModel): class RequestInvokeTTS(BaseModel):
""" """
Request to invoke TTS Request to invoke TTS
""" """
class RequestInvokeSpeech2Text(BaseModel): class RequestInvokeSpeech2Text(BaseModel):
""" """
Request to invoke speech2text Request to invoke speech2text
""" """
class RequestInvokeModeration(BaseModel): class RequestInvokeModeration(BaseModel):
""" """
Request to invoke moderation Request to invoke moderation
""" """
class RequestInvokeNode(BaseModel): class RequestInvokeNode(BaseModel):
""" """
Request to invoke node Request to invoke node
""" """