refactor workflow generate pipeline

This commit is contained in:
takatost
2024-03-06 22:10:49 +08:00
parent 5963e7d1c5
commit 6372183471
31 changed files with 1175 additions and 445 deletions

View File

@ -5,7 +5,8 @@ from typing import Literal, Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit
from core.app.app_queue_manager import PublishFrom
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@ -121,7 +122,9 @@ class CotAgentRunner(BaseAgentRunner):
)
if iteration_step > 1:
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt messages
prompt_messages = self._organize_cot_prompt_messages(
@ -163,7 +166,9 @@ class CotAgentRunner(BaseAgentRunner):
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
for chunk in react_chunks:
if isinstance(chunk, dict):
@ -225,7 +230,9 @@ class CotAgentRunner(BaseAgentRunner):
llm_usage=usage_dict['usage'])
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
if not scratchpad.action:
# failed to extract action, return final answer directly
@ -255,7 +262,9 @@ class CotAgentRunner(BaseAgentRunner):
observation=answer,
answer=answer,
messages_ids=[])
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
else:
# invoke tool
error_response = None
@ -282,7 +291,9 @@ class CotAgentRunner(BaseAgentRunner):
self.variables_pool.set_file(tool_name=tool_call_name,
value=message_file.id,
name=save_as)
self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file.id
), PublishFrom.APPLICATION_MANAGER)
message_file_ids = [message_file.id for message_file, _ in message_files]
except ToolProviderCredentialValidationError as e:
@ -318,7 +329,9 @@ class CotAgentRunner(BaseAgentRunner):
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
)
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt tool message
for prompt_tool in prompt_messages_tools:
@ -352,7 +365,7 @@ class CotAgentRunner(BaseAgentRunner):
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish_message_end(LLMResult(
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
@ -360,7 +373,7 @@ class CotAgentRunner(BaseAgentRunner):
),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
system_fingerprint=''
), PublishFrom.APPLICATION_MANAGER)
)), PublishFrom.APPLICATION_MANAGER)
def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \
-> Generator[Union[str, dict], None, None]: