This commit is contained in:
takatost
2024-03-07 20:50:02 +08:00
parent f4f7cfd45a
commit 90bcb241cc
3 changed files with 54 additions and 5 deletions

View File

@ -47,6 +47,7 @@ class TaskState(BaseModel):
answer: str = ""
metadata: dict = {}
usage: LLMUsage
workflow_run_id: Optional[str] = None
class AdvancedChatAppGenerateTaskPipeline:
@ -110,6 +111,8 @@ class AdvancedChatAppGenerateTaskPipeline:
}
self._task_state.answer = annotation.content
elif isinstance(event, QueueWorkflowStartedEvent):
self._task_state.workflow_run_id = event.workflow_run_id
elif isinstance(event, QueueNodeFinishedEvent):
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value:
@ -171,6 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline:
break
elif isinstance(event, QueueWorkflowStartedEvent):
workflow_run = self._get_workflow_run(event.workflow_run_id)
self._task_state.workflow_run_id = workflow_run.id
response = {
'event': 'workflow_started',
'task_id': self._application_generate_entity.task_id,
@ -234,7 +238,7 @@ class AdvancedChatAppGenerateTaskPipeline:
if isinstance(event, QueueWorkflowFinishedEvent):
workflow_run = self._get_workflow_run(event.workflow_run_id)
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs
outputs = workflow_run.outputs_dict
self._task_state.answer = outputs.get('text', '')
else:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
@ -389,7 +393,13 @@ class AdvancedChatAppGenerateTaskPipeline:
:param workflow_run_id: workflow run id
:return:
"""
return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
if workflow_run:
# Because the workflow_run will be modified in the sub-thread,
# and the first query in the main thread will cache the entity,
# you need to expire the entity after the query
db.session.expire(workflow_run)
return workflow_run
def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution:
"""
@ -397,7 +407,14 @@ class AdvancedChatAppGenerateTaskPipeline:
:param workflow_node_execution_id: workflow node execution id
:return:
"""
return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()
workflow_node_execution = (db.session.query(WorkflowNodeExecution)
.filter(WorkflowNodeExecution.id == workflow_node_execution_id).first())
if workflow_node_execution:
# Because the workflow_node_execution will be modified in the sub-thread,
# and the first query in the main thread will cache the entity,
# you need to expire the entity after the query
db.session.expire(workflow_node_execution)
return workflow_node_execution
def _save_message(self) -> None:
"""
@ -408,6 +425,7 @@ class AdvancedChatAppGenerateTaskPipeline:
self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.workflow_run_id = self._task_state.workflow_run_id
if self._task_state.metadata and self._task_state.metadata.get('usage'):
usage = LLMUsage(**self._task_state.metadata['usage'])