answer stream output support

This commit is contained in:
takatost
2024-03-14 20:49:53 +08:00
parent f35ae2355f
commit e6b8b13f2e
10 changed files with 413 additions and 90 deletions

View File

@ -2,7 +2,7 @@ import json
import logging
import time
from collections.abc import Generator
from typing import Optional, Union
from typing import Optional, Union, cast
from pydantic import BaseModel, Extra
@ -13,6 +13,7 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
)
from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueMessageFileEvent,
@ -34,6 +35,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
from core.moderation.output_moderation import ModerationRule, OutputModeration
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
from events.message_event import message_was_created
from extensions.ext_database import db
from models.account import Account
@ -51,15 +54,26 @@ from services.annotation_service import AppAnnotationService
logger = logging.getLogger(__name__)
class StreamGenerateRoute(BaseModel):
"""
StreamGenerateRoute entity
"""
answer_node_id: str
generate_route: list[GenerateRouteChunk]
current_route_position: int = 0
class TaskState(BaseModel):
"""
TaskState entity
"""
class NodeExecutionInfo(BaseModel):
"""
NodeExecutionInfo entity
"""
workflow_node_execution_id: str
node_type: NodeType
start_at: float
class Config:
@ -77,9 +91,11 @@ class TaskState(BaseModel):
total_tokens: int = 0
total_steps: int = 0
running_node_execution_infos: dict[str, NodeExecutionInfo] = {}
ran_node_execution_infos: dict[str, NodeExecutionInfo] = {}
latest_node_execution_info: Optional[NodeExecutionInfo] = None
current_stream_generate_state: Optional[StreamGenerateRoute] = None
class Config:
"""Configuration for this pydantic object."""
@ -122,6 +138,11 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
self._output_moderation_handler = self._init_output_moderation()
self._stream = stream
if stream:
self._stream_generate_routes = self._get_stream_generate_routes()
else:
self._stream_generate_routes = None
def process(self) -> Union[dict, Generator]:
"""
Process generate task pipeline.
@ -290,6 +311,11 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(data)
break
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
workflow_run_response = {
'event': 'workflow_finished',
'task_id': self._application_generate_entity.task_id,
@ -309,7 +335,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
}
yield self._yield_response(workflow_run_response)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
# response moderation
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
@ -390,6 +416,11 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response)
elif isinstance(event, QueueTextChunkEvent):
if not self._is_stream_out_support(
event=event
):
continue
delta_text = event.text
if delta_text is None:
continue
@ -467,20 +498,28 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
latest_node_execution_info = TaskState.NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=event.node_type,
start_at=time.perf_counter()
)
self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._task_state.total_steps += 1
db.session.close()
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
# stream outputs from start
self._generate_stream_outputs_when_node_start()
return workflow_node_execution
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
current_node_execution = self._task_state.running_node_execution_infos[event.node_id]
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
if isinstance(event, QueueNodeSucceededEvent):
@ -508,8 +547,8 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
error=event.error
)
# remove running node execution info
del self._task_state.running_node_execution_infos[event.node_id]
# stream outputs when node finished
self._generate_stream_outputs_when_node_finished()
db.session.close()
@ -517,7 +556,8 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \
-> WorkflowRun:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_run = (db.session.query(WorkflowRun)
.filter(WorkflowRun.id == self._task_state.workflow_run_id).first())
if isinstance(event, QueueStopEvent):
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,
@ -642,7 +682,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
QuotaExceededError: {
'code': 'provider_quota_exceeded',
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.",
"Please go to Settings -> Model Provider to complete your own provider credentials.",
'status': 400
},
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
@ -660,10 +700,10 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
else:
logging.error(e)
data = {
'code': 'internal_server_error',
'code': 'internal_server_error',
'message': 'Internal Server Error, please contact support.',
'status': 500
}
}
return {
'event': 'error',
@ -730,3 +770,218 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
),
queue_manager=self._queue_manager
)
def _get_stream_generate_routes(self) -> dict[str, StreamGenerateRoute]:
"""
Get stream generate routes.
:return:
"""
# find all answer nodes
graph = self._workflow.graph_dict
answer_node_configs = [
node for node in graph['nodes']
if node.get('data', {}).get('type') == NodeType.ANSWER.value
]
# parse stream output node value selectors of answer nodes
stream_generate_routes = {}
for node_config in answer_node_configs:
# get generate route for stream output
answer_node_id = node_config['id']
generate_route = AnswerNode.extract_generate_route_selectors(node_config)
start_node_id = self._get_answer_start_at_node_id(graph, answer_node_id)
if not start_node_id:
continue
stream_generate_routes[start_node_id] = StreamGenerateRoute(
answer_node_id=answer_node_id,
generate_route=generate_route
)
return stream_generate_routes
def _get_answer_start_at_node_id(self, graph: dict, target_node_id: str) \
-> Optional[str]:
"""
Get answer start at node id.
:param graph: graph
:param target_node_id: target node ID
:return:
"""
nodes = graph.get('nodes')
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edge = None
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edge = edge
break
if not ingoing_edge:
return None
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
return None
node_type = source_node.get('data', {}).get('type')
if node_type in [
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER
]:
start_node_id = target_node_id
elif node_type == NodeType.START.value:
start_node_id = source_node_id
else:
start_node_id = self._get_answer_start_at_node_id(graph, source_node_id)
return start_node_id
def _generate_stream_outputs_when_node_start(self) -> None:
"""
Generate stream outputs.
:return:
"""
if not self._task_state.current_stream_generate_state:
return
for route_chunk in self._task_state.current_stream_generate_state.generate_route:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
for token in route_chunk.text:
self._queue_manager.publish(
QueueTextChunkEvent(
text=token
), PublishFrom.TASK_PIPELINE
)
time.sleep(0.01)
self._task_state.current_stream_generate_state.current_route_position += 1
else:
break
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route):
self._task_state.current_stream_generate_state = None
def _generate_stream_outputs_when_node_finished(self) -> None:
"""
Generate stream outputs.
:return:
"""
if not self._task_state.current_stream_generate_state:
return
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
for token in route_chunk.text:
self._queue_manager.publish(
QueueTextChunkEvent(
text=token
), PublishFrom.TASK_PIPELINE
)
time.sleep(0.01)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
route_chunk_node_id = value_selector[0]
# check chunk node id is before current node id or equal to current node id
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
break
latest_node_execution_info = self._task_state.latest_node_execution_info
# get route chunk node execution info
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
if (route_chunk_node_execution_info.node_type == NodeType.LLM
and latest_node_execution_info.node_type == NodeType.LLM):
# only LLM support chunk stream output
self._task_state.current_stream_generate_state.current_route_position += 1
continue
# get route chunk node execution
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id).first()
outputs = route_chunk_node_execution.outputs_dict
# get value from outputs
value = None
for key in value_selector[1:]:
if not value:
value = outputs.get(key)
else:
value = value.get(key)
if value:
text = None
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, object): # TODO FILE
# convert file to markdown
text = f'![]({value.get("url")})'
pass
if text:
for token in text:
self._queue_manager.publish(
QueueTextChunkEvent(
text=token
), PublishFrom.TASK_PIPELINE
)
time.sleep(0.01)
self._task_state.current_stream_generate_state.current_route_position += 1
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route):
self._task_state.current_stream_generate_state = None
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.metadata:
return True
if 'node_id' not in event.metadata:
return True
node_type = event.metadata.get('node_type')
stream_output_value_selector = event.metadata.get('value_selector')
if not stream_output_value_selector:
return False
if not self._task_state.current_stream_generate_state:
return False
route_chunk = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position]
if route_chunk.type != 'var':
return False
if node_type != NodeType.LLM:
# only LLM support chunk stream output
return False
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
return False
return True

View File

@ -20,7 +20,6 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict)
def on_workflow_run_started(self) -> None:
"""
@ -114,34 +113,16 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str) -> None:
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
if node_id in self._streamable_node_ids:
self._queue_manager.publish(
QueueTextChunkEvent(
text=text
), PublishFrom.APPLICATION_MANAGER
)
def _fetch_streamable_node_ids(self, graph: dict) -> list[str]:
"""
Fetch streamable node ids
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
:param graph: workflow graph
:return:
"""
streamable_node_ids = []
end_node_ids = []
for node_config in graph.get('nodes'):
if node_config.get('data', {}).get('type') == NodeType.END.value:
end_node_ids.append(node_config.get('id'))
for edge_config in graph.get('edges'):
if edge_config.get('target') in end_node_ids:
streamable_node_ids.append(edge_config.get('source'))
return streamable_node_ids
self._queue_manager.publish(
QueueTextChunkEvent(
text=text,
metadata={
"node_id": node_id,
**metadata
}
), PublishFrom.APPLICATION_MANAGER
)

View File

@ -3,12 +3,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
QueueAdvancedChatMessageEndEvent,
QueueErrorEvent,
QueueMessage,
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowSucceededEvent,
)
@ -54,8 +53,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
if isinstance(event, QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent):
| QueueAdvancedChatMessageEndEvent):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():

View File

@ -112,7 +112,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str) -> None:
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""