mirror of
https://github.com/langgenius/dify.git
synced 2026-04-21 19:27:40 +08:00
optimize
This commit is contained in:
@ -21,10 +21,9 @@ class AnswerNode(BaseNode):
|
||||
_node_data_cls = AnswerNodeData
|
||||
node_type = NodeType.ANSWER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
@ -38,7 +37,7 @@ class AnswerNode(BaseNode):
|
||||
if part.type == "var":
|
||||
part = cast(VarGenerateRouteChunk, part)
|
||||
value_selector = part.value_selector
|
||||
value = variable_pool.get_variable_value(
|
||||
value = self.graph_runtime_state.variable_pool.get_variable_value(
|
||||
variable_selector=value_selector
|
||||
)
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
||||
from core.workflow.nodes.iterable_node import IterableNodeMixin
|
||||
from core.workflow.nodes.iterable_node_mixin import IterableNodeMixin
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
@ -104,21 +104,19 @@ class BaseNode(ABC):
|
||||
|
||||
class BaseIterationNode(BaseNode, IterableNodeMixin):
|
||||
@abstractmethod
|
||||
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
|
||||
def _run(self) -> BaseIterationState:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, variable_pool: VariablePool) -> BaseIterationState:
|
||||
def run(self) -> BaseIterationState:
|
||||
"""
|
||||
Run node entry
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
return self._run(variable_pool=variable_pool)
|
||||
return self._run(variable_pool=self.graph_runtime_state.variable_pool)
|
||||
|
||||
def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
|
||||
"""
|
||||
|
||||
@ -42,14 +42,13 @@ class CodeNode(BaseNode):
|
||||
|
||||
return code_provider.get_default_config()
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run code
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data: CodeNodeData = cast(self._node_data_cls, node_data)
|
||||
node_data = cast(CodeNodeData, node_data)
|
||||
|
||||
# Get code language
|
||||
code_language = node_data.code_language
|
||||
@ -59,7 +58,7 @@ class CodeNode(BaseNode):
|
||||
variables = {}
|
||||
for variable_selector in node_data.variables:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_variable_value(
|
||||
value = self.graph_runtime_state.variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@ from typing import cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
@ -12,19 +11,18 @@ class EndNode(BaseNode):
|
||||
_node_data_cls = EndNodeData
|
||||
node_type = NodeType.END
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
node_data = cast(EndNodeData, node_data)
|
||||
output_variables = node_data.outputs
|
||||
|
||||
outputs = {}
|
||||
for variable_selector in output_variables:
|
||||
value = variable_pool.get_variable_value(
|
||||
value = self.graph_runtime_state.variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
|
||||
|
||||
@ -49,14 +49,16 @@ class HttpRequestNode(BaseNode):
|
||||
},
|
||||
}
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data)
|
||||
|
||||
# init http executor
|
||||
http_executor = None
|
||||
try:
|
||||
http_executor = HttpExecutor(
|
||||
node_data=node_data, timeout=self._get_request_timeout(node_data), variable_pool=variable_pool
|
||||
node_data=node_data,
|
||||
timeout=self._get_request_timeout(node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool
|
||||
)
|
||||
|
||||
# invoke http executor
|
||||
|
||||
@ -4,6 +4,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@ -30,11 +31,16 @@ class IfElseNode(BaseNode):
|
||||
input_conditions = []
|
||||
final_result = False
|
||||
selected_case_id = None
|
||||
condition_processor = ConditionProcessor()
|
||||
try:
|
||||
# Check if the new cases structure is used
|
||||
if node_data.cases:
|
||||
for case in node_data.cases:
|
||||
input_conditions, group_result = self.process_conditions(self.graph_runtime_state.variable_pool, case.conditions)
|
||||
input_conditions, group_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=case.conditions
|
||||
)
|
||||
|
||||
# Apply the logical operator for the current case
|
||||
final_result = all(group_result) if case.logical_operator == "and" else any(group_result)
|
||||
|
||||
@ -53,7 +59,10 @@ class IfElseNode(BaseNode):
|
||||
|
||||
else:
|
||||
# Fallback to old structure if cases are not defined
|
||||
input_conditions, group_result = self.process_conditions(variable_pool, node_data.conditions)
|
||||
input_conditions, group_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=node_data.conditions
|
||||
)
|
||||
|
||||
final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result)
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ class IterationNode(BaseIterationNode):
|
||||
_node_data_cls = IterationNodeData
|
||||
_node_type = NodeType.ITERATION
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
|
||||
def _run(self) -> BaseIterationState:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
@ -32,7 +32,7 @@ class IterationNode(BaseIterationNode):
|
||||
iterator_length=len(iterator) if iterator is not None else 0
|
||||
))
|
||||
|
||||
self._set_current_iteration_variable(variable_pool, state)
|
||||
self._set_current_iteration_variable(self.graph_runtime_state.variable_pool, state)
|
||||
return state
|
||||
|
||||
def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str:
|
||||
|
||||
@ -14,7 +14,6 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from extensions.ext_database import db
|
||||
@ -37,11 +36,11 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
_node_data_cls = KnowledgeRetrievalNodeData
|
||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
|
||||
|
||||
# extract variables
|
||||
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
|
||||
query = self.graph_runtime_state.variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
|
||||
variables = {
|
||||
'query': query
|
||||
}
|
||||
|
||||
@ -27,7 +27,6 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import NodeRunRetrieverResourceEvent
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
@ -85,9 +84,7 @@ class LLMNode(BaseNode):
|
||||
for event in generator:
|
||||
if isinstance(event, RunRetrieverResourceEvent):
|
||||
context = event.context
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
retriever_resources=event.retriever_resources
|
||||
)
|
||||
yield event
|
||||
|
||||
if context:
|
||||
node_inputs['#context#'] = context # type: ignore
|
||||
@ -170,7 +167,7 @@ class LLMNode(BaseNode):
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None) \
|
||||
-> Generator["RunStreamChunkEvent | ModelInvokeCompleted", None, None]:
|
||||
-> Generator[RunEvent, None, None]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data_model: node data model
|
||||
@ -204,7 +201,7 @@ class LLMNode(BaseNode):
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
|
||||
-> Generator["RunStreamChunkEvent | ModelInvokeCompleted", None, None]:
|
||||
-> Generator[RunEvent, None, None]:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
|
||||
@ -14,8 +14,8 @@ class LoopNode(BaseIterationNode):
|
||||
_node_data_cls = LoopNodeData
|
||||
_node_type = NodeType.LOOP
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> LoopState:
|
||||
return super()._run(variable_pool)
|
||||
def _run(self) -> LoopState:
|
||||
return super()._run()
|
||||
|
||||
def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str:
|
||||
"""
|
||||
|
||||
@ -66,12 +66,12 @@ class ParameterExtractorNode(LLMNode):
|
||||
}
|
||||
}
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
node_data = cast(ParameterExtractorNodeData, self.node_data)
|
||||
query = variable_pool.get_variable_value(node_data.query)
|
||||
query = self.graph_runtime_state.variable_pool.get_variable_value(node_data.query)
|
||||
if not query:
|
||||
raise ValueError("Input variable content not found or is empty")
|
||||
|
||||
@ -91,17 +91,20 @@ class ParameterExtractorNode(LLMNode):
|
||||
raise ValueError("Model schema not found")
|
||||
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
||||
memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance)
|
||||
|
||||
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \
|
||||
and node_data.reasoning_mode == 'function_call':
|
||||
# use function call
|
||||
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
|
||||
node_data, query, variable_pool, model_config, memory
|
||||
node_data, query, self.graph_runtime_state.variable_pool, model_config, memory
|
||||
)
|
||||
else:
|
||||
# use prompt engineering
|
||||
prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config,
|
||||
prompt_messages = self._generate_prompt_engineering_prompt(node_data,
|
||||
query,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
model_config,
|
||||
memory)
|
||||
prompt_message_tools = []
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
@ -11,17 +10,16 @@ class StartNode(BaseNode):
|
||||
_node_data_cls = StartNodeData
|
||||
node_type = NodeType.START
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
# Get cleaned inputs
|
||||
cleaned_inputs = variable_pool.user_inputs
|
||||
cleaned_inputs = self.graph_runtime_state.variable_pool.user_inputs
|
||||
|
||||
for var in variable_pool.system_variables:
|
||||
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
|
||||
for var in self.graph_runtime_state.variable_pool.system_variables:
|
||||
cleaned_inputs['sys.' + var.value] = self.graph_runtime_state.variable_pool.system_variables[var]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
||||
@ -3,13 +3,13 @@ from typing import Optional, cast
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000'))
|
||||
|
||||
|
||||
class TemplateTransformNode(BaseNode):
|
||||
_node_data_cls = TemplateTransformNodeData
|
||||
_node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
@ -34,7 +34,7 @@ class TemplateTransformNode(BaseNode):
|
||||
}
|
||||
}
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
"""
|
||||
@ -45,7 +45,7 @@ class TemplateTransformNode(BaseNode):
|
||||
variables = {}
|
||||
for variable_selector in node_data.variables:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_variable_value(
|
||||
value = self.graph_runtime_state.variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
|
||||
@ -63,7 +63,7 @@ class TemplateTransformNode(BaseNode):
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
|
||||
if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
@ -78,9 +78,10 @@ class TemplateTransformNode(BaseNode):
|
||||
'output': result['result']
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[
|
||||
str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
@ -88,4 +89,4 @@ class TemplateTransformNode(BaseNode):
|
||||
"""
|
||||
return {
|
||||
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,8 +0,0 @@
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class TestNodeData(BaseNodeData):
|
||||
"""
|
||||
Test Node Data.
|
||||
"""
|
||||
pass
|
||||
@ -1,33 +0,0 @@
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.test.entities import TestNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class TestNode(BaseNode):
|
||||
_node_data_cls = TestNodeData
|
||||
node_type = NodeType.ANSWER
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"content": "abc"
|
||||
},
|
||||
edge_source_handle="1"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
@ -23,7 +23,7 @@ class ToolNode(BaseNode):
|
||||
_node_data_cls = ToolNodeData
|
||||
_node_type = NodeType.TOOL
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the tool node
|
||||
"""
|
||||
@ -52,7 +52,7 @@ class ToolNode(BaseNode):
|
||||
)
|
||||
|
||||
# get parameters
|
||||
parameters = self._generate_parameters(variable_pool, node_data, tool_runtime)
|
||||
parameters = self._generate_parameters(self.graph_runtime_state.variable_pool, node_data, tool_runtime)
|
||||
|
||||
try:
|
||||
messages = ToolEngine.workflow_invoke(
|
||||
@ -136,7 +136,8 @@ class ToolNode(BaseNode):
|
||||
|
||||
return files
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) \
|
||||
-> tuple[str, list[FileVar], list[dict]]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
|
||||
@ -2,7 +2,6 @@ from typing import cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
@ -12,7 +11,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
_node_type = NodeType.VARIABLE_AGGREGATOR
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_data = cast(VariableAssignerNodeData, self.node_data)
|
||||
# Get variables
|
||||
outputs = {}
|
||||
@ -20,7 +19,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
|
||||
if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled:
|
||||
for variable in node_data.variables:
|
||||
value = variable_pool.get_variable_value(variable)
|
||||
value = self.graph_runtime_state.variable_pool.get_variable_value(variable)
|
||||
|
||||
if value is not None:
|
||||
outputs = {
|
||||
@ -34,7 +33,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
else:
|
||||
for group in node_data.advanced_settings.groups:
|
||||
for variable in group.variables:
|
||||
value = variable_pool.get_variable_value(variable)
|
||||
value = self.graph_runtime_state.variable_pool.get_variable_value(variable)
|
||||
|
||||
if value is not None:
|
||||
outputs[group.group_name] = {
|
||||
|
||||
Reference in New Issue
Block a user