feat(workflow): integrate workflow entry with advanced chat app

This commit is contained in:
takatost
2024-08-13 16:21:10 +08:00
parent 8d27ec364f
commit 8401a11109
38 changed files with 976 additions and 2328 deletions

View File

@ -1,9 +1,9 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMUsage
from models.workflow import WorkflowNodeExecutionStatus
@ -83,10 +83,11 @@ class NodeRunResult(BaseModel):
"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[Mapping[str, Any]] = None # process data
outputs: Optional[Mapping[str, Any]] = None # node outputs
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata
inputs: Optional[dict[str, Any]] = None # node inputs
process_data: Optional[dict[str, Any]] = None # process data
outputs: Optional[dict[str, Any]] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
llm_usage: Optional[LLMUsage] = None # llm usage
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches

View File

@ -1,4 +1,4 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel, Field
@ -26,7 +26,8 @@ class GraphRunStartedEvent(BaseGraphEvent):
class GraphRunSucceededEvent(BaseGraphEvent):
pass
outputs: Optional[dict[str, Any]] = None
"""outputs"""
class GraphRunFailedEvent(BaseGraphEvent):
@ -39,6 +40,7 @@ class GraphRunFailedEvent(BaseGraphEvent):
class BaseNodeEvent(GraphEngineEvent):
id: str = Field(..., description="node execution id")
node_id: str = Field(..., description="node id")
node_type: NodeType = Field(..., description="node type")
node_data: BaseNodeData = Field(..., description="node data")
@ -47,7 +49,8 @@ class BaseNodeEvent(GraphEngineEvent):
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class NodeRunStartedEvent(BaseNodeEvent):
@ -82,7 +85,8 @@ class NodeRunFailedEvent(BaseNodeEvent):
class BaseParallelBranchEvent(GraphEngineEvent):
parallel_id: str = Field(..., description="parallel id")
parallel_start_node_id: str = Field(..., description="parallel start node id")
in_iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
@ -103,6 +107,7 @@ class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
class BaseIterationEvent(GraphEngineEvent):
iteration_id: str = Field(..., description="iteration node execution id")
iteration_node_id: str = Field(..., description="iteration node id")
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
iteration_node_data: BaseNodeData = Field(..., description="node data")
@ -113,8 +118,9 @@ class BaseIterationEvent(GraphEngineEvent):
class IterationRunStartedEvent(BaseIterationEvent):
inputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
predecessor_node_id: Optional[str] = None
@ -124,16 +130,18 @@ class IterationRunNextEvent(BaseIterationEvent):
class IterationRunSucceededEvent(BaseIterationEvent):
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
steps: int = 0
class IterationRunFailedEvent(BaseIterationEvent):
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
steps: int = 0
error: str = Field(..., description="failed reason")

View File

@ -24,6 +24,8 @@ class GraphParallel(BaseModel):
start_from_node_id: str = Field(..., description="start from node id")
parent_parallel_id: Optional[str] = None
"""parent parallel id"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id"""
end_to_node_id: Optional[str] = None
"""end to node id"""
@ -101,7 +103,7 @@ class Graph(BaseModel):
# parse run condition
run_condition = None
if edge_config.get('sourceHandle'):
if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source':
run_condition = RunCondition(
type='branch_identify',
branch_identify=edge_config.get('sourceHandle')
@ -176,7 +178,8 @@ class Graph(BaseModel):
# init end stream param
end_stream_param = EndStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping
reverse_edge_mapping=reverse_edge_mapping,
node_parallel_mapping=node_parallel_mapping
)
# init graph
@ -287,9 +290,17 @@ class Graph(BaseModel):
if all(node_id in node_parallel_mapping for node_id in parallel_node_ids):
parent_parallel_id = node_parallel_mapping[parallel_node_ids[0]]
if not parent_parallel_id:
raise Exception(f"Parent parallel id not found for node ids {parallel_node_ids}")
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
raise Exception(f"Parent parallel {parent_parallel_id} not found")
parallel = GraphParallel(
start_from_node_id=start_node_id,
parent_parallel_id=parent_parallel_id
parent_parallel_id=parent_parallel.id,
parent_parallel_start_node_id=parent_parallel.start_from_node_id
)
parallel_mapping[parallel.id] = parallel

View File

@ -1,15 +1,24 @@
from typing import Any
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
class GraphRuntimeState(BaseModel):
variable_pool: VariablePool = Field(..., description="variable pool")
"""variable pool"""
start_at: float = Field(..., description="start time")
"""start time"""
total_tokens: int = 0
"""total tokens"""
llm_usage: LLMUsage = LLMUsage.empty_usage()
"""llm usage info"""
outputs: dict[str, Any] = {}
"""outputs"""
node_run_steps: int = 0
"""node run steps"""

View File

@ -10,7 +10,11 @@ from uritemplate.variable import VariableValue
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, UserFrom
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeType,
UserFrom,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
@ -108,13 +112,29 @@ class GraphEngine:
if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.')
return
elif isinstance(item, NodeRunSucceededEvent):
if item.node_type == NodeType.END:
self.graph_runtime_state.outputs = (item.route_node_state.node_run_result.outputs
if item.route_node_state.node_run_result
and item.route_node_state.node_run_result.outputs
else {})
elif item.node_type == NodeType.ANSWER:
if "answer" not in self.graph_runtime_state.outputs:
self.graph_runtime_state.outputs["answer"] = ""
self.graph_runtime_state.outputs["answer"] += "\n" + (item.route_node_state.node_run_result.outputs.get("answer", "")
if item.route_node_state.node_run_result
and item.route_node_state.node_run_result.outputs
else "")
self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs["answer"].strip()
except Exception as e:
logger.exception(f"Graph run failed: {str(e)}")
yield GraphRunFailedEvent(error=str(e))
return
# trigger graph run success event
yield GraphRunSucceededEvent()
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
except GraphRunFailedError as e:
yield GraphRunFailedEvent(error=e.error)
return
@ -163,6 +183,7 @@ class GraphEngine:
# init workflow run state
node_instance = node_cls( # type: ignore
id=route_node_state.id,
config=node_config,
graph_init_params=self.init_params,
graph=self.graph,
@ -192,6 +213,7 @@ class GraphEngine:
route_node_state.failed_reason = str(e)
yield NodeRunFailedEvent(
error=str(e),
id=node_instance.id,
node_id=next_node_id,
node_type=node_type,
node_data=node_instance.node_data,
@ -291,7 +313,7 @@ class GraphEngine:
continue
elif isinstance(event, ParallelBranchRunFailedEvent):
raise GraphRunFailedError(event.reason)
raise GraphRunFailedError(event.error)
except queue.Empty:
continue
@ -360,6 +382,7 @@ class GraphEngine:
"""
# trigger node run start event
yield NodeRunStartedEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
@ -383,7 +406,8 @@ class GraphEngine:
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason,
error=route_node_state.failed_reason or 'Unknown error.',
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
@ -398,6 +422,10 @@ class GraphEngine:
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)
if run_result.llm_usage:
# use the latest usage
self.graph_runtime_state.llm_usage += run_result.llm_usage
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
@ -409,6 +437,7 @@ class GraphEngine:
)
yield NodeRunSucceededEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
@ -420,6 +449,7 @@ class GraphEngine:
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
@ -431,6 +461,7 @@ class GraphEngine:
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
@ -450,6 +481,7 @@ class GraphEngine:
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,

View File

@ -16,11 +16,13 @@ class BaseNode(ABC):
_node_type: NodeType
def __init__(self,
id: str,
config: Mapping[str, Any],
graph_init_params: GraphInitParams,
graph: Graph,
graph_runtime_state: GraphRuntimeState,
previous_node_id: Optional[str] = None) -> None:
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
self.workflow_type = graph_init_params.workflow_type

View File

@ -7,7 +7,8 @@ class EndStreamGeneratorRouter:
@classmethod
def init(cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_parallel_mapping: dict[str, str]
) -> EndStreamParam:
"""
Get stream generate routes.
@ -19,6 +20,10 @@ class EndStreamGeneratorRouter:
if not node_config.get('data', {}).get('type') == NodeType.END.value:
continue
# skip end node in parallel
if end_node_id in node_parallel_mapping:
continue
# get generate route for stream output
stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config)
end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors

View File

@ -1,5 +1,6 @@
import logging
from collections.abc import Generator
from datetime import datetime, timezone
from typing import Any, cast
from configs import dify_config
@ -123,10 +124,14 @@ class IterationNode(BaseNode):
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
)
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
yield IterationRunStartedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
metadata={
"iterator_length": 1
@ -135,6 +140,7 @@ class IterationNode(BaseNode):
)
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
@ -186,6 +192,7 @@ class IterationNode(BaseNode):
)
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
@ -197,9 +204,11 @@ class IterationNode(BaseNode):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
@ -222,9 +231,11 @@ class IterationNode(BaseNode):
yield event
yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
@ -247,9 +258,11 @@ class IterationNode(BaseNode):
# iteration run failed
logger.exception("Iteration run failed")
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)

View File

@ -1,3 +1,4 @@
import logging
from typing import Any, cast
from sqlalchemy import func
@ -20,6 +21,8 @@ from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__)
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
@ -67,7 +70,7 @@ class KnowledgeRetrievalNode(BaseNode):
)
except Exception as e:
logger.exception("Error when running knowledge retrieval node")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,

View File

@ -168,7 +168,8 @@ class LLMNode(BaseNode):
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
},
llm_usage=usage
)
)

View File

@ -175,7 +175,8 @@ class ParameterExtractorNode(LLMNode):
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
},
llm_usage=usage
)
def _invoke_llm(self, node_data_model: ModelConfig,

View File

@ -119,7 +119,8 @@ class QuestionClassifierNode(LLMNode):
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
},
llm_usage=usage
)
except ValueError as e:
@ -131,7 +132,8 @@ class QuestionClassifierNode(LLMNode):
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
},
llm_usage=usage
)
@classmethod

View File

@ -117,10 +117,11 @@ class WorkflowEntry:
graph_runtime_state=graph_engine.graph_runtime_state,
event=event
)
yield event
yield event
except GenerateTaskStoppedException:
pass
except Exception as e:
logger.exception("Unknown Error when workflow entry running")
if callbacks:
for callback in callbacks:
callback.on_event(
@ -205,7 +206,7 @@ class WorkflowEntry:
node_instance=node_instance
)
# run node TODO
# run node
node_run_result = node_instance.run(
variable_pool=variable_pool
)
@ -223,7 +224,7 @@ class WorkflowEntry:
return node_instance, node_run_result
@classmethod
def handle_special_values(cls, value: Optional[dict]) -> Optional[dict]:
def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]:
"""
Handle special values
:param value: value
@ -232,7 +233,7 @@ class WorkflowEntry:
if not value:
return None
new_value = value.copy()
new_value = dict(value) if value else {}
if isinstance(new_value, dict):
for key, val in new_value.items():
if isinstance(val, FileVar):