refactor apps

This commit is contained in:
takatost
2024-03-02 02:40:18 +08:00
parent 15c7e0ec2f
commit 3f5d1a79c6
111 changed files with 1979 additions and 1819 deletions

View File

@ -1,101 +0,0 @@
import logging
from typing import Optional
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
class AgentLLMCallback(Callback):
def __init__(self, agent_callback: AgentLoopGatherCallbackHandler) -> None:
self.agent_callback = agent_callback
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Before invoke callback
:param llm_instance: LLM instance
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.agent_callback.on_llm_before_invoke(
prompt_messages=prompt_messages
)
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None):
"""
On new chunk callback
:param llm_instance: LLM instance
:param chunk: chunk
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
pass
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
After invoke callback
:param llm_instance: LLM instance
:param result: result
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.agent_callback.on_llm_after_invoke(
result=result
)
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Invoke error callback
:param llm_instance: LLM instance
:param ex: exception
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.agent_callback.on_llm_error(
error=ex
)

View File

@ -5,19 +5,17 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import Generation, LLMResult
from langchain.schema.language_model import BaseLanguageModel
from core.entities.application_entities import ModelConfigEntity
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_manager import ModelInstance
from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.rag.retrieval.agent.fake_llm import FakeLLM
class LLMChain(LCLLMChain):
model_config: ModelConfigEntity
model_config: EasyUIBasedModelConfigEntity
"""The language model instance to use."""
llm: BaseLanguageModel = FakeLLM(response="")
parameters: dict[str, Any] = {}
agent_llm_callback: Optional[AgentLLMCallback] = None
def generate(
self,
@ -38,7 +36,6 @@ class LLMChain(LCLLMChain):
prompt_messages=prompt_messages,
stream=False,
stop=stop,
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else None,
model_parameters=self.parameters
)

View File

@ -10,7 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage
from langchain.tools import BaseTool
from pydantic import root_validator
from core.entities.application_entities import ModelConfigEntity
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool
@ -21,7 +21,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
model_config: ModelConfigEntity
model_config: EasyUIBasedModelConfigEntity
class Config:
"""Configuration for this pydantic object."""
@ -156,7 +156,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
@classmethod
def from_llm_and_tools(
cls,
model_config: ModelConfigEntity,
model_config: EasyUIBasedModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,

View File

@ -12,7 +12,7 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
from core.entities.application_entities import ModelConfigEntity
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
from core.rag.retrieval.agent.llm_chain import LLMChain
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
@ -206,7 +206,7 @@ Thought: {agent_scratchpad}
@classmethod
def from_llm_and_tools(
cls,
model_config: ModelConfigEntity,
model_config: EasyUIBasedModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,

View File

@ -7,13 +7,12 @@ from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool
from pydantic import BaseModel, Extra
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import prompt_messages_to_lc_messages
from core.helper import moderation
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.errors.invoke import InvokeError
from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
@ -23,15 +22,14 @@ from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetr
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
model_config: ModelConfigEntity
model_config: EasyUIBasedModelConfigEntity
tools: list[BaseTool]
summary_model_config: Optional[ModelConfigEntity] = None
summary_model_config: Optional[EasyUIBasedModelConfigEntity] = None
memory: Optional[TokenBufferMemory] = None
callbacks: Callbacks = None
max_iterations: int = 6
max_execution_time: Optional[float] = None
early_stopping_method: str = "generate"
agent_llm_callback: Optional[AgentLLMCallback] = None
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
class Config:

View File

@ -2,9 +2,10 @@ from typing import Optional, cast
from langchain.tools import BaseTool
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom, EasyUIBasedModelConfigEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@ -17,7 +18,7 @@ from models.dataset import Dataset
class DatasetRetrieval:
def retrieve(self, tenant_id: str,
model_config: ModelConfigEntity,
model_config: EasyUIBasedModelConfigEntity,
config: DatasetEntity,
query: str,
invoke_from: InvokeFrom,