This commit is contained in:
takatost
2024-07-17 01:02:40 +08:00
parent 775e52db4d
commit 4ef3d4e65c
37 changed files with 241 additions and 267 deletions

View File

@ -21,10 +21,9 @@ class AnswerNode(BaseNode):
_node_data_cls = AnswerNodeData
node_type = NodeType.ANSWER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
@ -38,7 +37,7 @@ class AnswerNode(BaseNode):
if part.type == "var":
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = variable_pool.get_variable_value(
value = self.graph_runtime_state.variable_pool.get_variable_value(
variable_selector=value_selector
)

View File

@ -9,7 +9,7 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.iterable_node import IterableNodeMixin
from core.workflow.nodes.iterable_node_mixin import IterableNodeMixin
class BaseNode(ABC):
@ -104,21 +104,19 @@ class BaseNode(ABC):
class BaseIterationNode(BaseNode, IterableNodeMixin):
@abstractmethod
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
def _run(self) -> BaseIterationState:
"""
Run node
:param variable_pool: variable pool
:return:
"""
raise NotImplementedError
def run(self, variable_pool: VariablePool) -> BaseIterationState:
def run(self) -> BaseIterationState:
"""
Run node entry
:param variable_pool: variable pool
:return:
"""
return self._run(variable_pool=variable_pool)
return self._run(variable_pool=self.graph_runtime_state.variable_pool)
def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
"""

View File

@ -42,14 +42,13 @@ class CodeNode(BaseNode):
return code_provider.get_default_config()
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run code
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data: CodeNodeData = cast(self._node_data_cls, node_data)
node_data = cast(CodeNodeData, node_data)
# Get code language
code_language = node_data.code_language
@ -59,7 +58,7 @@ class CodeNode(BaseNode):
variables = {}
for variable_selector in node_data.variables:
variable = variable_selector.variable
value = variable_pool.get_variable_value(
value = self.graph_runtime_state.variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)

View File

@ -2,7 +2,6 @@ from typing import cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.entities import EndNodeData
from models.workflow import WorkflowNodeExecutionStatus
@ -12,19 +11,18 @@ class EndNode(BaseNode):
_node_data_cls = EndNodeData
node_type = NodeType.END
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
node_data = cast(EndNodeData, node_data)
output_variables = node_data.outputs
outputs = {}
for variable_selector in output_variables:
value = variable_pool.get_variable_value(
value = self.graph_runtime_state.variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)

View File

@ -49,14 +49,16 @@ class HttpRequestNode(BaseNode):
},
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data)
# init http executor
http_executor = None
try:
http_executor = HttpExecutor(
node_data=node_data, timeout=self._get_request_timeout(node_data), variable_pool=variable_pool
node_data=node_data,
timeout=self._get_request_timeout(node_data),
variable_pool=self.graph_runtime_state.variable_pool
)
# invoke http executor

View File

@ -4,6 +4,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.utils.condition.processor import ConditionProcessor
from models.workflow import WorkflowNodeExecutionStatus
@ -30,11 +31,16 @@ class IfElseNode(BaseNode):
input_conditions = []
final_result = False
selected_case_id = None
condition_processor = ConditionProcessor()
try:
# Check if the new cases structure is used
if node_data.cases:
for case in node_data.cases:
input_conditions, group_result = self.process_conditions(self.graph_runtime_state.variable_pool, case.conditions)
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions
)
# Apply the logical operator for the current case
final_result = all(group_result) if case.logical_operator == "and" else any(group_result)
@ -53,7 +59,10 @@ class IfElseNode(BaseNode):
else:
# Fallback to old structure if cases are not defined
input_conditions, group_result = self.process_conditions(variable_pool, node_data.conditions)
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=node_data.conditions
)
final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result)

View File

@ -17,7 +17,7 @@ class IterationNode(BaseIterationNode):
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
def _run(self) -> BaseIterationState:
"""
Run the node.
"""
@ -32,7 +32,7 @@ class IterationNode(BaseIterationNode):
iterator_length=len(iterator) if iterator is not None else 0
))
self._set_current_iteration_variable(variable_pool, state)
self._set_current_iteration_variable(self.graph_runtime_state.variable_pool, state)
return state
def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str:

View File

@ -14,7 +14,6 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrival_methods import RetrievalMethod
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from extensions.ext_database import db
@ -37,11 +36,11 @@ class KnowledgeRetrievalNode(BaseNode):
_node_data_cls = KnowledgeRetrievalNodeData
node_type = NodeType.KNOWLEDGE_RETRIEVAL
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
def _run(self) -> NodeRunResult:
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
# extract variables
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
query = self.graph_runtime_state.variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
variables = {
'query': query
}

View File

@ -27,7 +27,6 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import NodeRunRetrieverResourceEvent
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.entities import (
@ -85,9 +84,7 @@ class LLMNode(BaseNode):
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
context = event.context
yield NodeRunRetrieverResourceEvent(
retriever_resources=event.retriever_resources
)
yield event
if context:
node_inputs['#context#'] = context # type: ignore
@ -170,7 +167,7 @@ class LLMNode(BaseNode):
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None) \
-> Generator["RunStreamChunkEvent | ModelInvokeCompleted", None, None]:
-> Generator[RunEvent, None, None]:
"""
Invoke large language model
:param node_data_model: node data model
@ -204,7 +201,7 @@ class LLMNode(BaseNode):
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
-> Generator["RunStreamChunkEvent | ModelInvokeCompleted", None, None]:
-> Generator[RunEvent, None, None]:
"""
Handle invoke result
:param invoke_result: invoke result

View File

@ -14,8 +14,8 @@ class LoopNode(BaseIterationNode):
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
def _run(self, variable_pool: VariablePool) -> LoopState:
return super()._run(variable_pool)
def _run(self) -> LoopState:
return super()._run()
def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str:
"""

View File

@ -66,12 +66,12 @@ class ParameterExtractorNode(LLMNode):
}
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self.node_data)
query = variable_pool.get_variable_value(node_data.query)
query = self.graph_runtime_state.variable_pool.get_variable_value(node_data.query)
if not query:
raise ValueError("Input variable content not found or is empty")
@ -91,17 +91,20 @@ class ParameterExtractorNode(LLMNode):
raise ValueError("Model schema not found")
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance)
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \
and node_data.reasoning_mode == 'function_call':
# use function call
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
node_data, query, variable_pool, model_config, memory
node_data, query, self.graph_runtime_state.variable_pool, model_config, memory
)
else:
# use prompt engineering
prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config,
prompt_messages = self._generate_prompt_engineering_prompt(node_data,
query,
self.graph_runtime_state.variable_pool,
model_config,
memory)
prompt_message_tools = []

View File

@ -1,7 +1,6 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus
@ -11,17 +10,16 @@ class StartNode(BaseNode):
_node_data_cls = StartNodeData
node_type = NodeType.START
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
# Get cleaned inputs
cleaned_inputs = variable_pool.user_inputs
cleaned_inputs = self.graph_runtime_state.variable_pool.user_inputs
for var in variable_pool.system_variables:
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
for var in self.graph_runtime_state.variable_pool.system_variables:
cleaned_inputs['sys.' + var.value] = self.graph_runtime_state.variable_pool.system_variables[var]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,

View File

@ -3,13 +3,13 @@ from typing import Optional, cast
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from models.workflow import WorkflowNodeExecutionStatus
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000'))
class TemplateTransformNode(BaseNode):
_node_data_cls = TemplateTransformNodeData
_node_type = NodeType.TEMPLATE_TRANSFORM
@ -34,7 +34,7 @@ class TemplateTransformNode(BaseNode):
}
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run node
"""
@ -45,7 +45,7 @@ class TemplateTransformNode(BaseNode):
variables = {}
for variable_selector in node_data.variables:
variable = variable_selector.variable
value = variable_pool.get_variable_value(
value = self.graph_runtime_state.variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
@ -63,7 +63,7 @@ class TemplateTransformNode(BaseNode):
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e)
)
if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
@ -78,9 +78,10 @@ class TemplateTransformNode(BaseNode):
'output': result['result']
}
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[
str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
@ -88,4 +89,4 @@ class TemplateTransformNode(BaseNode):
"""
return {
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
}
}

View File

@ -1,8 +0,0 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
class TestNodeData(BaseNodeData):
"""
Test Node Data.
"""
pass

View File

@ -1,33 +0,0 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.test.entities import TestNodeData
from models.workflow import WorkflowNodeExecutionStatus
class TestNode(BaseNode):
_node_data_cls = TestNodeData
node_type = NodeType.ANSWER
def _run(self) -> NodeRunResult:
"""
Run node
:return:
"""
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"content": "abc"
},
edge_source_handle="1"
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {}

View File

@ -23,7 +23,7 @@ class ToolNode(BaseNode):
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run the tool node
"""
@ -52,7 +52,7 @@ class ToolNode(BaseNode):
)
# get parameters
parameters = self._generate_parameters(variable_pool, node_data, tool_runtime)
parameters = self._generate_parameters(self.graph_runtime_state.variable_pool, node_data, tool_runtime)
try:
messages = ToolEngine.workflow_invoke(
@ -136,7 +136,8 @@ class ToolNode(BaseNode):
return files
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) \
-> tuple[str, list[FileVar], list[dict]]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""

View File

@ -2,7 +2,6 @@ from typing import cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
from models.workflow import WorkflowNodeExecutionStatus
@ -12,7 +11,7 @@ class VariableAggregatorNode(BaseNode):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
node_data = cast(VariableAssignerNodeData, self.node_data)
# Get variables
outputs = {}
@ -20,7 +19,7 @@ class VariableAggregatorNode(BaseNode):
if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled:
for variable in node_data.variables:
value = variable_pool.get_variable_value(variable)
value = self.graph_runtime_state.variable_pool.get_variable_value(variable)
if value is not None:
outputs = {
@ -34,7 +33,7 @@ class VariableAggregatorNode(BaseNode):
else:
for group in node_data.advanced_settings.groups:
for variable in group.variables:
value = variable_pool.get_variable_value(variable)
value = self.graph_runtime_state.variable_pool.get_variable_value(variable)
if value is not None:
outputs[group.group_name] = {