feat: Iteration node support parallel mode (#9493)

This commit is contained in:
Novice
2024-11-05 10:32:49 +08:00
committed by Joel
parent 623b27583b
commit baaa3ae02c
33 changed files with 1284 additions and 192 deletions

View File

@ -12,6 +12,7 @@ from core.app.entities.queue_entities import (
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@ -35,6 +36,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
@ -251,6 +253,12 @@ class WorkflowCycleManage:
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.execution_metadata = json.dumps(
{
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
}
)
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
session.add(workflow_node_execution)
@ -305,7 +313,9 @@ class WorkflowCycleManage:
return workflow_node_execution
def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
def _handle_workflow_node_execution_failed(
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
@ -318,16 +328,19 @@ class WorkflowCycleManage:
outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time,
WorkflowNodeExecution.execution_metadata: execution_metadata,
}
)
@ -342,6 +355,7 @@ class WorkflowCycleManage:
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
@ -448,6 +462,7 @@ class WorkflowCycleManage:
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
parallel_run_id=event.parallel_mode_run_id,
),
)
@ -464,7 +479,7 @@ class WorkflowCycleManage:
def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
@ -608,6 +623,7 @@ class WorkflowCycleManage:
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
),
)
@ -633,7 +649,9 @@ class WorkflowCycleManage:
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
status=WorkflowNodeExecutionStatus.SUCCEEDED,
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,