mirror of
https://github.com/langgenius/dify.git
synced 2026-03-22 14:57:58 +08:00
add end stream output test
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, SystemVariable, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseNodeEvent,
|
||||
@ -16,12 +16,267 @@ from core.workflow.graph_engine.entities.event import (
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from models.workflow import WorkflowType
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
@patch('extensions.ext_database.db.session.remove')
|
||||
@patch('extensions.ext_database.db.session.close')
|
||||
def test_run_parallel(mock_close, mock_remove):
|
||||
def test_run_parallel_in_workflow(mock_close, mock_remove, mocker):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "1",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"source": "llm1",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"source": "llm1",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "4",
|
||||
"source": "llm2",
|
||||
"target": "end1",
|
||||
},
|
||||
{
|
||||
"id": "5",
|
||||
"source": "llm3",
|
||||
"target": "end2",
|
||||
}
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"type": "start",
|
||||
"title": "start",
|
||||
"variables": [{
|
||||
"label": "query",
|
||||
"max_length": 48,
|
||||
"options": [],
|
||||
"required": True,
|
||||
"type": "text-input",
|
||||
"variable": "query"
|
||||
}]
|
||||
},
|
||||
"id": "start"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
"title": "llm1",
|
||||
"context": {
|
||||
"enabled": False,
|
||||
"variable_selector": []
|
||||
},
|
||||
"model": {
|
||||
"completion_params": {
|
||||
"temperature": 0.7
|
||||
},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai"
|
||||
},
|
||||
"prompt_template": [{
|
||||
"role": "system",
|
||||
"text": "say hi"
|
||||
}, {
|
||||
"role": "user",
|
||||
"text": "{{#start.query#}}"
|
||||
}],
|
||||
"vision": {
|
||||
"configs": {
|
||||
"detail": "high"
|
||||
},
|
||||
"enabled": False
|
||||
}
|
||||
},
|
||||
"id": "llm1"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
"title": "llm2",
|
||||
"context": {
|
||||
"enabled": False,
|
||||
"variable_selector": []
|
||||
},
|
||||
"model": {
|
||||
"completion_params": {
|
||||
"temperature": 0.7
|
||||
},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai"
|
||||
},
|
||||
"prompt_template": [{
|
||||
"role": "system",
|
||||
"text": "say bye"
|
||||
}, {
|
||||
"role": "user",
|
||||
"text": "{{#start.query#}}"
|
||||
}],
|
||||
"vision": {
|
||||
"configs": {
|
||||
"detail": "high"
|
||||
},
|
||||
"enabled": False
|
||||
}
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
"title": "llm3",
|
||||
"context": {
|
||||
"enabled": False,
|
||||
"variable_selector": []
|
||||
},
|
||||
"model": {
|
||||
"completion_params": {
|
||||
"temperature": 0.7
|
||||
},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai"
|
||||
},
|
||||
"prompt_template": [{
|
||||
"role": "system",
|
||||
"text": "say good morning"
|
||||
}, {
|
||||
"role": "user",
|
||||
"text": "{{#start.query#}}"
|
||||
}],
|
||||
"vision": {
|
||||
"configs": {
|
||||
"detail": "high"
|
||||
},
|
||||
"enabled": False
|
||||
}
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "end",
|
||||
"title": "end1",
|
||||
"outputs": [{
|
||||
"value_selector": ["llm2", "text"],
|
||||
"variable": "result2"
|
||||
}, {
|
||||
"value_selector": ["start", "query"],
|
||||
"variable": "query"
|
||||
}],
|
||||
},
|
||||
"id": "end1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "end",
|
||||
"title": "end2",
|
||||
"outputs": [{
|
||||
"value_selector": ["llm1", "text"],
|
||||
"variable": "result1"
|
||||
}, {
|
||||
"value_selector": ["llm3", "text"],
|
||||
"variable": "result3"
|
||||
}],
|
||||
},
|
||||
"id": "end2",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={
|
||||
"query": "hi"
|
||||
})
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="333",
|
||||
user_id="444",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200
|
||||
)
|
||||
|
||||
def llm_generator(self):
|
||||
contents = [
|
||||
'hi',
|
||||
'bye',
|
||||
'good morning'
|
||||
]
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=contents[int(self.node_id[-1]) - 1],
|
||||
from_variable_selector=[self.node_id, 'text']
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={},
|
||||
process_data={},
|
||||
outputs={},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: 1,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: 1,
|
||||
NodeRunMetadataKey.CURRENCY: 'USD'
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
print("")
|
||||
|
||||
with patch.object(LLMNode, '_run', new=llm_generator):
|
||||
items = []
|
||||
generator = graph_engine.run()
|
||||
for item in generator:
|
||||
print(type(item), item)
|
||||
items.append(item)
|
||||
if isinstance(item, NodeRunSucceededEvent):
|
||||
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
|
||||
|
||||
assert not isinstance(item, NodeRunFailedEvent)
|
||||
assert not isinstance(item, GraphRunFailedEvent)
|
||||
|
||||
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [
|
||||
'llm2', 'llm3', 'end1', 'end2'
|
||||
]:
|
||||
assert item.parallel_id is not None
|
||||
|
||||
assert len(items) == 17
|
||||
assert isinstance(items[0], GraphRunStartedEvent)
|
||||
assert isinstance(items[1], NodeRunStartedEvent)
|
||||
assert items[1].route_node_state.node_id == 'start'
|
||||
assert isinstance(items[2], NodeRunSucceededEvent)
|
||||
assert items[2].route_node_state.node_id == 'start'
|
||||
|
||||
|
||||
@patch('extensions.ext_database.db.session.remove')
|
||||
@patch('extensions.ext_database.db.session.close')
|
||||
def test_run_parallel_in_chatflow(mock_close, mock_remove):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
@ -291,7 +546,7 @@ def test_run_branch(mock_close, mock_remove):
|
||||
items = []
|
||||
generator = graph_engine.run()
|
||||
for item in generator:
|
||||
print(type(item), item)
|
||||
# print(type(item), item)
|
||||
items.append(item)
|
||||
|
||||
assert len(items) == 10
|
||||
|
||||
Reference in New Issue
Block a user