mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
refactor apps
This commit is contained in:
@ -1,101 +0,0 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentLLMCallback(Callback):
|
||||
|
||||
def __init__(self, agent_callback: AgentLoopGatherCallbackHandler) -> None:
|
||||
self.agent_callback = agent_callback
|
||||
|
||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Before invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.agent_callback.on_llm_before_invoke(
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None):
|
||||
"""
|
||||
On new chunk callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param chunk: chunk
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
After invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param result: result
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.agent_callback.on_llm_after_invoke(
|
||||
result=result
|
||||
)
|
||||
|
||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Invoke error callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param ex: exception
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.agent_callback.on_llm_error(
|
||||
error=ex
|
||||
)
|
||||
@ -5,19 +5,17 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.rag.retrieval.agent.fake_llm import FakeLLM
|
||||
|
||||
|
||||
class LLMChain(LCLLMChain):
|
||||
model_config: ModelConfigEntity
|
||||
model_config: EasyUIBasedModelConfigEntity
|
||||
"""The language model instance to use."""
|
||||
llm: BaseLanguageModel = FakeLLM(response="")
|
||||
parameters: dict[str, Any] = {}
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -38,7 +36,6 @@ class LLMChain(LCLLMChain):
|
||||
prompt_messages=prompt_messages,
|
||||
stream=False,
|
||||
stop=stop,
|
||||
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else None,
|
||||
model_parameters=self.parameters
|
||||
)
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
@ -21,7 +21,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
"""
|
||||
An Multi Dataset Retrieve Agent driven by Router.
|
||||
"""
|
||||
model_config: ModelConfigEntity
|
||||
model_config: EasyUIBasedModelConfigEntity
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -156,7 +156,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_config: ModelConfigEntity,
|
||||
model_config: EasyUIBasedModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||
|
||||
@ -12,7 +12,7 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
|
||||
from core.rag.retrieval.agent.llm_chain import LLMChain
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
@ -206,7 +206,7 @@ Thought: {agent_scratchpad}
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_config: ModelConfigEntity,
|
||||
model_config: EasyUIBasedModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
|
||||
@ -7,13 +7,12 @@ from langchain.callbacks.manager import Callbacks
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import prompt_messages_to_lc_messages
|
||||
from core.helper import moderation
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||
from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
||||
@ -23,15 +22,14 @@ from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetr
|
||||
|
||||
class AgentConfiguration(BaseModel):
|
||||
strategy: PlanningStrategy
|
||||
model_config: ModelConfigEntity
|
||||
model_config: EasyUIBasedModelConfigEntity
|
||||
tools: list[BaseTool]
|
||||
summary_model_config: Optional[ModelConfigEntity] = None
|
||||
summary_model_config: Optional[EasyUIBasedModelConfigEntity] = None
|
||||
memory: Optional[TokenBufferMemory] = None
|
||||
callbacks: Callbacks = None
|
||||
max_iterations: int = 6
|
||||
max_execution_time: Optional[float] = None
|
||||
early_stopping_method: str = "generate"
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
||||
|
||||
class Config:
|
||||
|
||||
@ -2,9 +2,10 @@ from typing import Optional, cast
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, EasyUIBasedModelConfigEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
@ -17,7 +18,7 @@ from models.dataset import Dataset
|
||||
|
||||
class DatasetRetrieval:
|
||||
def retrieve(self, tenant_id: str,
|
||||
model_config: ModelConfigEntity,
|
||||
model_config: EasyUIBasedModelConfigEntity,
|
||||
config: DatasetEntity,
|
||||
query: str,
|
||||
invoke_from: InvokeFrom,
|
||||
|
||||
Reference in New Issue
Block a user