mirror of
https://github.com/langgenius/dify.git
synced 2026-02-28 21:46:27 +08:00
optimize
This commit is contained in:
@ -21,10 +21,12 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
|
||||
# process condition
|
||||
condition_processor = ConditionProcessor()
|
||||
compare_result, _ = condition_processor.process(
|
||||
input_conditions, group_result = condition_processor.process_conditions(
|
||||
variable_pool=graph_runtime_state.variable_pool,
|
||||
logical_operator="and",
|
||||
conditions=self.condition.conditions
|
||||
)
|
||||
|
||||
# Apply the logical operator for the current case
|
||||
compare_result = all(group_result)
|
||||
|
||||
return compare_result
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Optional
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@ -40,7 +41,7 @@ class GraphRunFailedEvent(BaseGraphEvent):
|
||||
|
||||
|
||||
class BaseNodeEvent(GraphEngineEvent):
|
||||
node_id: str = Field(..., description="node id")
|
||||
route_node_state: RouteNodeState = Field(..., description="route node state")
|
||||
parallel_id: Optional[str] = Field(None, description="parallel id if node is in parallel")
|
||||
# iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
|
||||
|
||||
@ -60,21 +61,11 @@ class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
|
||||
|
||||
class NodeRunSucceededEvent(BaseNodeEvent):
|
||||
run_result: NodeRunResult = Field(..., description="run result")
|
||||
pass
|
||||
|
||||
|
||||
class NodeRunFailedEvent(BaseNodeEvent):
|
||||
run_result: NodeRunResult = Field(
|
||||
default=NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED),
|
||||
description="run result"
|
||||
)
|
||||
reason: str = Field("", description="failed reason")
|
||||
|
||||
@model_validator(mode='before')
|
||||
def init_reason(cls, values: dict) -> dict:
|
||||
if not values.get("reason"):
|
||||
values["reason"] = values.get("run_result").error or "Unknown error"
|
||||
return values
|
||||
pass
|
||||
|
||||
|
||||
###########################################
|
||||
|
||||
@ -16,7 +16,7 @@ class RouteNodeState(BaseModel):
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
|
||||
state_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""node state id"""
|
||||
|
||||
node_id: str
|
||||
@ -45,11 +45,15 @@ class RouteNodeState(BaseModel):
|
||||
|
||||
|
||||
class RuntimeRouteState(BaseModel):
|
||||
routes: dict[str, list[str]] = Field(default_factory=dict)
|
||||
"""graph state routes (source_node_state_id: target_node_state_id)"""
|
||||
routes: dict[str, list[str]] = Field(
|
||||
default_factory=dict,
|
||||
description="graph state routes (source_node_state_id: target_node_state_id)"
|
||||
)
|
||||
|
||||
node_state_mapping: dict[str, RouteNodeState] = Field(default_factory=dict)
|
||||
"""node state mapping (route_node_state_id: route_node_state)"""
|
||||
node_state_mapping: dict[str, RouteNodeState] = Field(
|
||||
default_factory=dict,
|
||||
description="node state mapping (route_node_state_id: route_node_state)"
|
||||
)
|
||||
|
||||
def create_node_state(self, node_id: str) -> RouteNodeState:
|
||||
"""
|
||||
@ -58,7 +62,7 @@ class RuntimeRouteState(BaseModel):
|
||||
:param node_id: node id
|
||||
"""
|
||||
state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None))
|
||||
self.node_state_mapping[state.state_id] = state
|
||||
self.node_state_mapping[state.id] = state
|
||||
return state
|
||||
|
||||
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
|
||||
|
||||
@ -10,7 +10,7 @@ from flask import Flask, current_app
|
||||
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.node_entities import NodeType, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
@ -29,9 +29,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes import node_classes
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.test.test_node import TestNode
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
@ -86,7 +84,7 @@ class GraphEngine:
|
||||
for item in generator:
|
||||
yield item
|
||||
if isinstance(item, NodeRunFailedEvent):
|
||||
yield GraphRunFailedEvent(reason=item.reason)
|
||||
yield GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.')
|
||||
return
|
||||
|
||||
# trigger graph run success event
|
||||
@ -100,7 +98,7 @@ class GraphEngine:
|
||||
|
||||
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
|
||||
next_node_id = start_node_id
|
||||
previous_node_id = None
|
||||
previous_route_node_state: Optional[RouteNodeState] = None
|
||||
while True:
|
||||
# max steps reached
|
||||
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
|
||||
@ -108,23 +106,42 @@ class GraphEngine:
|
||||
|
||||
# or max execution time reached
|
||||
if self._is_timed_out(
|
||||
start_at=self.graph_runtime_state.start_at,
|
||||
max_execution_time=self.max_execution_time
|
||||
start_at=self.graph_runtime_state.start_at,
|
||||
max_execution_time=self.max_execution_time
|
||||
):
|
||||
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
|
||||
|
||||
# init route node state
|
||||
route_node_state = self.graph_runtime_state.create_node_state(
|
||||
node_id=next_node_id
|
||||
)
|
||||
|
||||
try:
|
||||
# run node
|
||||
yield from self._run_node(
|
||||
node_id=next_node_id,
|
||||
previous_node_id=previous_node_id,
|
||||
route_node_state=route_node_state,
|
||||
previous_node_id=previous_route_node_state.node_id if previous_route_node_state else None,
|
||||
parallel_id=in_parallel_id
|
||||
)
|
||||
|
||||
self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state
|
||||
|
||||
# append route
|
||||
if previous_route_node_state:
|
||||
if previous_route_node_state.id not in self.graph_runtime_state.node_run_state.routes:
|
||||
self.graph_runtime_state.node_run_state.routes[previous_route_node_state.id] = []
|
||||
|
||||
self.graph_runtime_state.node_run_state.routes[previous_route_node_state.id].append(
|
||||
route_node_state.id
|
||||
)
|
||||
except Exception as e:
|
||||
yield NodeRunFailedEvent(node_id=next_node_id, reason=str(e))
|
||||
yield NodeRunFailedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=in_parallel_id
|
||||
)
|
||||
raise e
|
||||
|
||||
previous_node_id = next_node_id
|
||||
previous_route_node_state = route_node_state
|
||||
|
||||
# get next node ids
|
||||
edge_mappings = self.graph.edge_mapping.get(next_node_id)
|
||||
@ -227,24 +244,32 @@ class GraphEngine:
|
||||
in_parallel_id=parallel_id
|
||||
)
|
||||
|
||||
if generator:
|
||||
for item in generator:
|
||||
q.put(item)
|
||||
except Exception:
|
||||
for item in generator:
|
||||
q.put(item)
|
||||
if isinstance(item, NodeRunFailedEvent):
|
||||
q.put(GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.'))
|
||||
return
|
||||
|
||||
# trigger graph run success event
|
||||
q.put(GraphRunSucceededEvent())
|
||||
except (GraphRunFailedError, NodeRunFailedError) as e:
|
||||
q.put(GraphRunFailedEvent(reason=e.error))
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating in parallel")
|
||||
q.put(GraphRunFailedEvent(reason=str(e)))
|
||||
finally:
|
||||
q.put(None)
|
||||
db.session.remove()
|
||||
|
||||
def _run_node(self,
|
||||
node_id: str,
|
||||
route_node_state: RouteNodeState,
|
||||
previous_node_id: Optional[str] = None,
|
||||
parallel_id: Optional[str] = None
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
"""
|
||||
# get node config
|
||||
node_id = route_node_state.node_id
|
||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||
if not node_config:
|
||||
raise GraphRunFailedError(f'Node {node_id} config not found.')
|
||||
@ -256,16 +281,7 @@ class GraphEngine:
|
||||
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
|
||||
|
||||
# init workflow run state
|
||||
# node_instance = node_cls( # type: ignore
|
||||
# config=node_config,
|
||||
# graph_init_params=self.init_params,
|
||||
# graph=self.graph,
|
||||
# graph_runtime_state=self.graph_runtime_state,
|
||||
# previous_node_id=previous_node_id
|
||||
# )
|
||||
|
||||
# init workflow run state
|
||||
node_instance = TestNode(
|
||||
node_instance = node_cls( # type: ignore
|
||||
config=node_config,
|
||||
graph_init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
@ -274,7 +290,10 @@ class GraphEngine:
|
||||
)
|
||||
|
||||
# trigger node run start event
|
||||
yield NodeRunStartedEvent(node_id=node_id, parallel_id=parallel_id)
|
||||
yield NodeRunStartedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
@ -283,60 +302,50 @@ class GraphEngine:
|
||||
self.graph_runtime_state.node_run_steps += 1
|
||||
|
||||
try:
|
||||
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
run_result = None
|
||||
for item in generator:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
route_node_state.status = RouteNodeState.Status.SUCCESS \
|
||||
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED \
|
||||
else RouteNodeState.Status.FAILED
|
||||
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
route_node_state.node_run_result = run_result
|
||||
route_node_state.failed_reason = run_result.error \
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED else None
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
yield NodeRunFailedEvent(
|
||||
node_id=node_id,
|
||||
parallel_id=parallel_id,
|
||||
run_result=run_result,
|
||||
reason=run_result.error
|
||||
route_node_state=route_node_state
|
||||
)
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
yield NodeRunSucceededEvent(
|
||||
node_id=node_id,
|
||||
parallel_id=parallel_id,
|
||||
run_result=run_result
|
||||
route_node_state=route_node_state
|
||||
)
|
||||
|
||||
self.graph_runtime_state.node_run_state.node_state_mapping[node_id] = RouteNodeState(
|
||||
node_id=node_id,
|
||||
start_at=start_at,
|
||||
status=RouteNodeState.Status.SUCCESS if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
else RouteNodeState.Status.FAILED,
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
node_run_result=run_result,
|
||||
failed_reason=run_result.error
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED else None
|
||||
)
|
||||
|
||||
# todo append self.graph_runtime_state.node_run_state.routes
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
node_id=node_id,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
node_id=node_id,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context
|
||||
)
|
||||
|
||||
# todo record state
|
||||
except GenerateTaskStoppedException as e:
|
||||
# trigger node run failed event
|
||||
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
|
||||
yield NodeRunFailedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}")
|
||||
|
||||
Reference in New Issue
Block a user