refactor pipeline and remove node run run_args

This commit is contained in:
takatost
2024-03-09 19:05:48 +08:00
parent 4c5822fb6e
commit 2f57d090a1
11 changed files with 201 additions and 114 deletions

View File

@ -41,6 +41,19 @@ class TaskState(BaseModel):
"""
TaskState entity
"""
class NodeExecutionInfo(BaseModel):
"""
NodeExecutionInfo entity
"""
workflow_node_execution: WorkflowNodeExecution
start_at: float
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
answer: str = ""
metadata: dict = {}
@ -49,8 +62,8 @@ class TaskState(BaseModel):
total_tokens: int = 0
total_steps: int = 0
current_node_execution: Optional[WorkflowNodeExecution] = None
current_node_execution_start_at: Optional[float] = None
running_node_execution_infos: dict[str, NodeExecutionInfo] = {}
latest_node_execution_info: Optional[NodeExecutionInfo] = None
class Config:
"""Configuration for this pydantic object."""
@ -179,7 +192,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response)
elif isinstance(event, QueueNodeStartedEvent):
self._on_node_start(event)
workflow_node_execution = self._task_state.current_node_execution
workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
response = {
'event': 'node_started',
@ -198,7 +211,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
self._on_node_finished(event)
workflow_node_execution = self._task_state.current_node_execution
workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
response = {
'event': 'node_finished',
@ -339,15 +352,22 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
predecessor_node_id=event.predecessor_node_id
)
self._task_state.current_node_execution = workflow_node_execution
self._task_state.current_node_execution_start_at = time.perf_counter()
latest_node_execution_info = TaskState.NodeExecutionInfo(
workflow_node_execution=workflow_node_execution,
start_at=time.perf_counter()
)
self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._task_state.total_steps += 1
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None:
current_node_execution = self._task_state.running_node_execution_infos[event.node_id]
if isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_node_execution_success(
workflow_node_execution=self._task_state.current_node_execution,
start_at=self._task_state.current_node_execution_start_at,
workflow_node_execution=current_node_execution.workflow_node_execution,
start_at=current_node_execution.start_at,
inputs=event.inputs,
process_data=event.process_data,
outputs=event.outputs,
@ -359,12 +379,14 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
else:
workflow_node_execution = self._workflow_node_execution_failed(
workflow_node_execution=self._task_state.current_node_execution,
start_at=self._task_state.current_node_execution_start_at,
workflow_node_execution=current_node_execution.workflow_node_execution,
start_at=current_node_execution.start_at,
error=event.error
)
self._task_state.current_node_execution = workflow_node_execution
# remove running node execution info
del self._task_state.running_node_execution_infos[event.node_id]
self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None:
if isinstance(event, QueueStopEvent):
@ -391,8 +413,8 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
outputs=self._task_state.current_node_execution.outputs
if self._task_state.current_node_execution else None
outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs
if self._task_state.latest_node_execution_info else None
)
self._task_state.workflow_run = workflow_run