mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
fix bugs
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
@ -38,9 +38,21 @@ class VariablePool(BaseModel):
|
||||
description='System variables',
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
for system_variable, value in self.system_variables.items():
|
||||
self.append_variable('sys', [system_variable.value], value)
|
||||
@model_validator(mode='before')
|
||||
def append_system_variables(cls, v: dict) -> dict:
|
||||
"""
|
||||
Append system variables
|
||||
:param v: params
|
||||
:return:
|
||||
"""
|
||||
v['variables_mapping'] = {
|
||||
'sys': {}
|
||||
}
|
||||
system_variables = v['system_variables']
|
||||
for system_variable, value in system_variables.items():
|
||||
variable_key_list_hash = hash((system_variable.value,))
|
||||
v['variables_mapping']['sys'][variable_key_list_hash] = value
|
||||
return v
|
||||
|
||||
def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:
|
||||
"""
|
||||
|
||||
@ -4,6 +4,7 @@ from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class RunConditionHandler(ABC):
|
||||
@ -18,13 +19,13 @@ class RunConditionHandler(ABC):
|
||||
@abstractmethod
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
source_node_id: str,
|
||||
previous_route_node_state: RouteNodeState,
|
||||
target_node_id: str) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param source_node_id: source node id
|
||||
:param previous_route_node_state: previous route node state
|
||||
:param target_node_id: target node id
|
||||
:return: bool
|
||||
"""
|
||||
|
||||
@ -1,29 +1,26 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
source_node_id: str,
|
||||
previous_route_node_state: RouteNodeState,
|
||||
target_node_id: str) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param source_node_id: source node id
|
||||
:param previous_route_node_state: previous route node state
|
||||
:param target_node_id: target node id
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.branch_identify:
|
||||
raise Exception("Branch identify is required")
|
||||
|
||||
node_route_state = graph_runtime_state.node_run_state.node_state_mapping.get(source_node_id)
|
||||
if not node_route_state:
|
||||
return False
|
||||
|
||||
run_result = node_route_state.node_run_result
|
||||
run_result = previous_route_node_state.node_run_result
|
||||
if not run_result:
|
||||
return False
|
||||
|
||||
|
||||
@ -1,18 +1,19 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
source_node_id: str,
|
||||
previous_route_node_state: RouteNodeState,
|
||||
target_node_id: str) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param source_node_id: source node id
|
||||
:param previous_route_node_state: previous route node state
|
||||
:param target_node_id: target node id
|
||||
:return: bool
|
||||
"""
|
||||
|
||||
@ -167,7 +167,7 @@ class GraphEngine:
|
||||
run_condition=edge.run_condition,
|
||||
).check(
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
source_node_id=edge.source_node_id,
|
||||
previous_route_node_state=previous_route_node_state,
|
||||
target_node_id=edge.target_node_id,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user