This commit is contained in:
takatost
2024-07-22 19:57:32 +08:00
372 changed files with 9779 additions and 1678 deletions

View File

@ -6,7 +6,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
class BaseWorkflowCallback(ABC):
class WorkflowCallback(ABC):
@abstractmethod
def on_workflow_run_started(self) -> None:
"""
@ -78,7 +78,7 @@ class BaseWorkflowCallback(ABC):
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
inputs: Optional[dict] = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
@ -82,9 +83,9 @@ class NodeRunResult(BaseModel):
"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[dict] = None # node inputs
inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[dict] = None # process data
outputs: Optional[dict] = None # node outputs
outputs: Optional[Mapping[str, Any]] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches

View File

@ -1,118 +1,146 @@
from enum import Enum
from typing import Any, Optional, Union
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field
from typing_extensions import deprecated
from core.app.segments import ArrayVariable, ObjectVariable, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable
VariableValue = Union[str, int, float, dict, list, FileVar]
class ValueType(Enum):
"""
Value Type Enum
"""
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILE = "array[file]"
FILE = "file"
SYSTEM_VARIABLE_NODE_ID = 'sys'
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
class VariablePool(BaseModel):
variables_mapping: dict[str, dict[int, VariableValue]] = Field(
# Variable dictionary is a dictionary for looking up variables by their selector.
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
_variable_dictionary: dict[str, dict[int, Variable]] = Field(
description='Variables mapping',
default={},
default=defaultdict(dict)
)
user_inputs: dict = Field(
# TODO: This user inputs is not used for pool.
user_inputs: Mapping[str, Any] = Field(
description='User inputs',
)
system_variables: dict[SystemVariable, Any] = Field(
system_variables: Mapping[SystemVariable, Any] = Field(
description='System variables',
)
@model_validator(mode='before')
def append_system_variables(cls, v: dict) -> dict:
environment_variables: Sequence[Variable] = Field(
description="Environment variables."
)
def __post_init__(self):
"""
Append system variables
:param v: params
:return:
"""
v['variables_mapping'] = {
'sys': {}
}
system_variables = v['system_variables']
for system_variable, value in system_variables.items():
variable_key_list_hash = hash((system_variable.value,))
v['variables_mapping']['sys'][variable_key_list_hash] = value
return v
# Add system variables to the variable pool
for key, value in self.system_variables.items():
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:
# Add environment variables to the variable pool
for var in self.environment_variables or []:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""
Append variable
:param node_id: node id
:param variable_key_list: variable key list, like: ['result', 'text']
:param value: value
:return:
Adds a variable to the variable pool.
Args:
selector (Sequence[str]): The selector for the variable.
value (VariableValue): The value of the variable.
Raises:
ValueError: If the selector is invalid.
Returns:
None
"""
if node_id not in self.variables_mapping:
self.variables_mapping[node_id] = {}
if len(selector) < 2:
raise ValueError('Invalid selector')
variable_key_list_hash = hash(tuple(variable_key_list))
if value is None:
return
self.variables_mapping[node_id][variable_key_list_hash] = value
if not isinstance(value, Variable):
v = factory.build_anonymous_variable(value)
else:
v = value
def get_variable_value(self, variable_selector: list[str],
target_value_type: Optional[ValueType] = None) -> Optional[VariableValue]:
hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]][hash_key] = v
def get(self, selector: Sequence[str], /) -> Variable | None:
"""
Get variable
:param variable_selector: include node_id and variables
:param target_value_type: target value type
:return:
Retrieves the value from the variable pool based on the given selector.
Args:
selector (Sequence[str]): The selector used to identify the variable.
Returns:
Any: The value associated with the given selector.
Raises:
ValueError: If the selector is invalid.
"""
if len(variable_selector) < 2:
raise ValueError('Invalid value selector')
node_id = variable_selector[0]
if node_id not in self.variables_mapping:
return None
# fetch variable keys, pop node_id
variable_key_list = variable_selector[1:]
variable_key_list_hash = hash(tuple(variable_key_list))
value = self.variables_mapping[node_id].get(variable_key_list_hash)
if target_value_type:
if target_value_type == ValueType.STRING:
return str(value)
elif target_value_type == ValueType.NUMBER:
return int(value)
elif target_value_type == ValueType.OBJECT:
if not isinstance(value, dict):
raise ValueError('Invalid value type: object')
elif target_value_type in [ValueType.ARRAY_STRING,
ValueType.ARRAY_NUMBER,
ValueType.ARRAY_OBJECT,
ValueType.ARRAY_FILE]:
if not isinstance(value, list):
raise ValueError(f'Invalid value type: {target_value_type.value}')
if len(selector) < 2:
raise ValueError('Invalid selector')
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
return value
def clear_node_variables(self, node_id: str) -> None:
@deprecated('This method is deprecated, use `get` instead.')
def get_any(self, selector: Sequence[str], /) -> Any | None:
"""
Clear node variables
:param node_id: node id
:return:
Retrieves the value from the variable pool based on the given selector.
Args:
selector (Sequence[str]): The selector used to identify the variable.
Returns:
Any: The value associated with the given selector.
Raises:
ValueError: If the selector is invalid.
"""
if node_id in self.variables_mapping:
self.variables_mapping.pop(node_id)
if len(selector) < 2:
raise ValueError('Invalid selector')
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
if value is None:
return value
if isinstance(value, ArrayVariable):
return [element.value for element in value.value]
if isinstance(value, ObjectVariable):
return {k: v.value for k, v in value.value.items()}
return value.value if value else None
def remove(self, selector: Sequence[str], /):
"""
Remove variables from the variable pool based on the given selector.
Args:
selector (Sequence[str]): A sequence of strings representing the selector.
Returns:
None
"""
if not selector:
return
if len(selector) == 1:
self._variable_dictionary[selector[0]] = {}
return
hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]].pop(hash_key, None)

View File

@ -1,7 +1,5 @@
import json
from typing import cast
from core.file.file_obj import FileVar
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
@ -18,7 +16,7 @@ from models.workflow import WorkflowNodeExecutionStatus
class AnswerNode(BaseNode):
_node_data_cls = AnswerNodeData
node_type = NodeType.ANSWER
_node_type: NodeType = NodeType.ANSWER
def _run(self) -> NodeRunResult:
"""
@ -36,31 +34,12 @@ class AnswerNode(BaseNode):
if part.type == GenerateRouteChunk.ChunkType.VAR:
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = self.graph_runtime_state.variable_pool.get_variable_value(
value = self.graph_runtime_state.variable_pool.get(
variable_selector=value_selector
)
text = ''
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, dict):
# other types
text = json.dumps(value, ensure_ascii=False)
elif isinstance(value, FileVar):
# convert file to markdown
text = value.to_markdown()
elif isinstance(value, list):
for item in value:
if isinstance(item, FileVar):
text += item.to_markdown() + ' '
text = text.strip()
if not text and value:
# other types
text = json.dumps(value, ensure_ascii=False)
answer += text
if value:
answer += value.markdown
else:
part = cast(TextGenerateRouteChunk, part)
answer += part.text

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Optional
from collections.abc import Generator, Mapping
from typing import Any, Optional
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
@ -17,7 +17,7 @@ class BaseNode(ABC):
_node_type: NodeType
def __init__(self,
config: dict,
config: Mapping[str, Any],
graph_init_params: GraphInitParams,
graph: Graph,
graph_runtime_state: GraphRuntimeState,

View File

@ -57,11 +57,8 @@ class CodeNode(BaseNode):
variables = {}
for variable_selector in node_data.variables:
variable = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
variables[variable] = value
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
variables[variable] = value.value if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(

View File

@ -9,7 +9,7 @@ from models.workflow import WorkflowNodeExecutionStatus
class EndNode(BaseNode):
_node_data_cls = EndNodeData
node_type = NodeType.END
_node_type = NodeType.END
def _run(self) -> NodeRunResult:
"""
@ -22,11 +22,8 @@ class EndNode(BaseNode):
outputs = {}
for variable_selector in output_variables:
value = self.graph_runtime_state.variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
outputs[variable_selector.variable] = value
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
outputs[variable_selector.variable] = value.value if value else None
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -43,7 +40,7 @@ class EndNode(BaseNode):
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
node_data = cast(cls._node_data_cls, node_data)
node_data = cast(EndNodeData, node_data)
return cls.extract_generate_nodes_from_node_data(graph, node_data)
@ -55,7 +52,7 @@ class EndNode(BaseNode):
:param node_data: node data object
:return:
"""
nodes = graph.get('nodes')
nodes = graph.get('nodes', [])
node_mapping = {node.get('id'): node for node in nodes}
variable_selectors = node_data.outputs

View File

@ -58,4 +58,3 @@ class HttpRequestNodeData(BaseNodeData):
params: str
body: Optional[HttpRequestNodeBody] = None
timeout: Optional[HttpRequestNodeTimeout] = None
mask_authorization_header: Optional[bool] = True

View File

@ -9,7 +9,7 @@ import httpx
import core.helper.ssrf_proxy as ssrf_proxy
from configs import dify_config
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeBody,
@ -212,13 +212,11 @@ class HttpExecutor:
raise ValueError('self.authorization config is required')
if authorization.config is None:
raise ValueError('authorization config is required')
if authorization.config.type != 'bearer' and authorization.config.header is None:
raise ValueError('authorization config header is required')
if self.authorization.config.api_key is None:
raise ValueError('api_key is required')
if not self.authorization.config.header:
if not authorization.config.header:
authorization.config.header = 'Authorization'
if self.authorization.config.type == 'bearer':
@ -283,7 +281,7 @@ class HttpExecutor:
# validate response
return self._validate_and_parse_response(response)
def to_raw_request(self, mask_authorization_header: Optional[bool] = True) -> str:
def to_raw_request(self) -> str:
"""
convert to raw request
"""
@ -295,16 +293,15 @@ class HttpExecutor:
headers = self._assembling_headers()
for k, v in headers.items():
if mask_authorization_header:
# get authorization header
if self.authorization.type == 'api-key':
authorization_header = 'Authorization'
if self.authorization.config and self.authorization.config.header:
authorization_header = self.authorization.config.header
# get authorization header
if self.authorization.type == 'api-key':
authorization_header = 'Authorization'
if self.authorization.config and self.authorization.config.header:
authorization_header = self.authorization.config.header
if k.lower() == authorization_header.lower():
raw_request += f'{k}: {"*" * len(v)}\n'
continue
if k.lower() == authorization_header.lower():
raw_request += f'{k}: {"*" * len(v)}\n'
continue
raw_request += f'{k}: {v}\n'
@ -336,16 +333,13 @@ class HttpExecutor:
if variable_pool:
variable_value_mapping = {}
for variable_selector in variable_selectors:
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector, target_value_type=ValueType.STRING
)
if value is None:
variable = variable_pool.get(variable_selector.value_selector)
if variable is None:
raise ValueError(f'Variable {variable_selector.variable} not found')
if escape_quotes and isinstance(value, str):
value = value.replace('"', '\\"')
if escape_quotes and isinstance(variable.value, str):
value = variable.value.replace('"', '\\"')
else:
value = variable.value
variable_value_mapping[variable_selector.variable] = value
return variable_template_parser.format(variable_value_mapping), variable_selectors

View File

@ -3,6 +3,7 @@ from mimetypes import guess_extension
from os import path
from typing import cast
from core.app.segments import parser
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.base_node_data_entities import BaseNodeData
@ -50,6 +51,9 @@ class HttpRequestNode(BaseNode):
def _run(self) -> NodeRunResult:
node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data)
# TODO: Switch to use segment directly
if node_data.authorization.config and node_data.authorization.config.api_key:
node_data.authorization.config.api_key = parser.convert_template(template=node_data.authorization.config.api_key, variable_pool=variable_pool).text
# init http executor
http_executor = None
@ -66,9 +70,7 @@ class HttpRequestNode(BaseNode):
process_data = {}
if http_executor:
process_data = {
'request': http_executor.to_raw_request(
mask_authorization_header=node_data.mask_authorization_header
),
'request': http_executor.to_raw_request(),
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -87,9 +89,7 @@ class HttpRequestNode(BaseNode):
'files': files,
},
process_data={
'request': http_executor.to_raw_request(
mask_authorization_header=node_data.mask_authorization_header,
),
'request': http_executor.to_raw_request(),
},
)

View File

@ -10,7 +10,7 @@ from models.workflow import WorkflowNodeExecutionStatus
class IfElseNode(BaseNode):
_node_data_cls = IfElseNodeData
node_type = NodeType.IF_ELSE
_node_type = NodeType.IF_ELSE
def _run(self) -> NodeRunResult:
"""

View File

@ -21,7 +21,8 @@ class IterationNode(BaseIterationNode):
"""
Run the node.
"""
iterator = variable_pool.get_variable_value(cast(IterationNodeData, self.node_data).iterator_selector)
self.node_data = cast(IterationNodeData, self.node_data)
iterator = variable_pool.get_any(self.node_data.iterator_selector)
if not isinstance(iterator, list):
raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.")
@ -64,15 +65,15 @@ class IterationNode(BaseIterationNode):
"""
node_data = cast(IterationNodeData, self.node_data)
variable_pool.append_variable(self.node_id, ['index'], state.index)
variable_pool.add((self.node_id, 'index'), state.index)
# get the iterator value
iterator = variable_pool.get_variable_value(node_data.iterator_selector)
iterator = variable_pool.get_any(node_data.iterator_selector)
if iterator is None or not isinstance(iterator, list):
return
if state.index < len(iterator):
variable_pool.append_variable(self.node_id, ['item'], iterator[state.index])
variable_pool.add((self.node_id, 'item'), iterator[state.index])
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
"""
@ -88,7 +89,7 @@ class IterationNode(BaseIterationNode):
:return: True if iteration limit is reached, False otherwise
"""
node_data = cast(IterationNodeData, self.node_data)
iterator = variable_pool.get_variable_value(node_data.iterator_selector)
iterator = variable_pool.get_any(node_data.iterator_selector)
if iterator is None or not isinstance(iterator, list):
return True
@ -101,9 +102,9 @@ class IterationNode(BaseIterationNode):
:param variable_pool: variable pool
"""
output_selector = cast(IterationNodeData, self.node_data).output_selector
output = variable_pool.get_variable_value(output_selector)
output = variable_pool.get_any(output_selector)
# clear the output for this iteration
variable_pool.append_variable(self.node_id, output_selector[1:], None)
variable_pool.remove([self.node_id] + output_selector[1:])
state.current_output = output
if output is not None:
state.outputs.append(output)

View File

@ -21,7 +21,7 @@ from models.dataset import Dataset, Document, DocumentSegment
from models.workflow import WorkflowNodeExecutionStatus
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
@ -40,7 +40,8 @@ class KnowledgeRetrievalNode(BaseNode):
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
# extract variables
query = self.graph_runtime_state.variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector)
query = variable.value if variable else None
variables = {
'query': query
}

View File

@ -44,7 +44,7 @@ from models.workflow import WorkflowNodeExecutionStatus
class LLMNode(BaseNode):
_node_data_cls = LLMNodeData
node_type = NodeType.LLM
_node_type = NodeType.LLM
def _run(self) -> Generator[RunEvent, None, None]:
"""
@ -98,7 +98,7 @@ class LLMNode(BaseNode):
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data,
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value]) # type: ignore
query=variable_pool.get_any(['sys', SystemVariable.QUERY.value])
if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
inputs=inputs,
@ -276,8 +276,8 @@ class LLMNode(BaseNode):
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable = variable_selector.variable
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
value = variable_pool.get_any(
variable_selector.value_selector
)
def parse_dict(d: dict) -> str:
@ -340,7 +340,7 @@ class LLMNode(BaseNode):
variable_selectors = variable_template_parser.extract_variable_selectors()
for variable_selector in variable_selectors:
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
variable_value = variable_pool.get_any(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f'Variable {variable_selector.variable} not found')
@ -351,7 +351,7 @@ class LLMNode(BaseNode):
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
.extract_variable_selectors())
for variable_selector in query_variable_selectors:
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
variable_value = variable_pool.get_any(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f'Variable {variable_selector.variable} not found')
@ -369,7 +369,7 @@ class LLMNode(BaseNode):
if not node_data.vision.enabled:
return []
files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value])
files = variable_pool.get_any(['sys', SystemVariable.FILES.value])
if not files:
return []
@ -388,7 +388,7 @@ class LLMNode(BaseNode):
if not node_data.context.variable_selector:
return
context_value = variable_pool.get_variable_value(node_data.context.variable_selector)
context_value = variable_pool.get_any(node_data.context.variable_selector)
if context_value:
if isinstance(context_value, str):
yield RunRetrieverResourceEvent(
@ -530,7 +530,7 @@ class LLMNode(BaseNode):
return None
# get conversation id
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION_ID.value])
conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value])
if conversation_id is None:
return None

View File

@ -71,9 +71,10 @@ class ParameterExtractorNode(LLMNode):
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self.node_data)
query = self.graph_runtime_state.variable_pool.get_variable_value(node_data.query)
if not query:
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
if not variable:
raise ValueError("Input variable content not found or is empty")
query = variable.value
inputs = {
'query': query,
@ -567,7 +568,8 @@ class ParameterExtractorNode(LLMNode):
variable_template_parser = VariableTemplateParser(instruction)
inputs = {}
for selector in variable_template_parser.extract_variable_selectors():
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
variable = variable_pool.get(selector.value_selector)
inputs[selector.variable] = variable.value if variable else None
return variable_template_parser.format(inputs)

View File

@ -43,7 +43,8 @@ class QuestionClassifierNode(LLMNode):
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
variable = variable_pool.get(node_data.query_variable_selector)
query = variable.value if variable else None
variables = {
'query': query
}
@ -305,7 +306,8 @@ class QuestionClassifierNode(LLMNode):
variable_template_parser = VariableTemplateParser(template=instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
variable = variable_pool.get(variable_selector.value_selector)
variable_value = variable.value if variable else None
if variable_value is None:
raise ValueError(f'Variable {variable_selector.variable} not found')

View File

@ -8,7 +8,7 @@ from models.workflow import WorkflowNodeExecutionStatus
class StartNode(BaseNode):
_node_data_cls = StartNodeData
node_type = NodeType.START
_node_type = NodeType.START
def _run(self) -> NodeRunResult:
"""
@ -16,7 +16,7 @@ class StartNode(BaseNode):
:return:
"""
# Get cleaned inputs
cleaned_inputs = self.graph_runtime_state.variable_pool.user_inputs
cleaned_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
for var in self.graph_runtime_state.variable_pool.system_variables:
cleaned_inputs['sys.' + var.value] = self.graph_runtime_state.variable_pool.system_variables[var]

View File

@ -44,12 +44,9 @@ class TemplateTransformNode(BaseNode):
# Get variables
variables = {}
for variable_selector in node_data.variables:
variable = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
variables[variable] = value
variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
variables[variable_name] = value
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(

View File

@ -29,6 +29,7 @@ class ToolEntity(BaseModel):
class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal['mixed', 'variable', 'constant']

View File

@ -1,10 +1,11 @@
from collections.abc import Mapping, Sequence
from os import path
from typing import Optional, cast
from typing import Any, cast
from core.app.segments import parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool.tool import Tool
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
@ -20,6 +21,7 @@ class ToolNode(BaseNode):
"""
Tool Node
"""
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
@ -50,23 +52,24 @@ class ToolNode(BaseNode):
},
error=f'Failed to get tool runtime: {str(e)}'
)
# get parameters
parameters = self._generate_parameters(self.graph_runtime_state.variable_pool, node_data, tool_runtime)
tool_parameters = tool_runtime.get_runtime_parameters() or []
parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data)
parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True)
try:
messages = ToolEngine.workflow_invoke(
tool=tool_runtime,
tool_parameters=parameters,
user_id=self.user_id,
workflow_id=self.workflow_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters,
inputs=parameters_for_log,
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
@ -86,21 +89,34 @@ class ToolNode(BaseNode):
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
inputs=parameters
inputs=parameters_for_log
)
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData, tool_runtime: Tool) -> dict:
def _generate_parameters(
self,
*,
tool_parameters: Sequence[ToolParameter],
variable_pool: VariablePool,
node_data: ToolNodeData,
for_log: bool = False,
) -> Mapping[str, Any]:
"""
Generate parameters
"""
tool_parameters = tool_runtime.get_all_runtime_parameters()
Generate parameters based on the given tool parameters, variable pool, and node data.
def fetch_parameter(name: str) -> Optional[ToolParameter]:
return next((parameter for parameter in tool_parameters if parameter.name == name), None)
Args:
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (ToolNodeData): The data associated with the tool node.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
result = {}
for parameter_name in node_data.tool_parameters:
parameter = fetch_parameter(parameter_name)
parameter = tool_parameters_dictionary.get(parameter_name)
if not parameter:
continue
if parameter.type == ToolParameter.ToolParameterType.FILE:
@ -108,35 +124,21 @@ class ToolNode(BaseNode):
v.to_dict() for v in self._fetch_files(variable_pool)
]
else:
input = node_data.tool_parameters[parameter_name]
if input.type == 'mixed':
result[parameter_name] = self._format_variable_template(input.value, variable_pool)
elif input.type == 'variable':
result[parameter_name] = variable_pool.get_variable_value(input.value)
elif input.type == 'constant':
result[parameter_name] = input.value
tool_input = node_data.tool_parameters[parameter_name]
segment_group = parser.convert_template(
template=str(tool_input.value),
variable_pool=variable_pool,
)
result[parameter_name] = segment_group.log if for_log else segment_group.text
return result
def _format_variable_template(self, template: str, variable_pool: VariablePool) -> str:
"""
Format variable template
"""
inputs = {}
template_parser = VariableTemplateParser(template)
for selector in template_parser.extract_variable_selectors():
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
return template_parser.format(inputs)
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value])
if not files:
return []
return files
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) \
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
# FIXME: ensure this is a ArrayVariable contains FileVariable.
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
return [file_var.value for file_var in variable.value] if variable else []
def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\
-> tuple[str, list[FileVar], list[dict]]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]

View File

@ -18,28 +18,27 @@ class VariableAggregatorNode(BaseNode):
inputs = {}
if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled:
for variable in node_data.variables:
value = self.graph_runtime_state.variable_pool.get_variable_value(variable)
if value is not None:
for selector in node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
outputs = {
"output": value
"output": variable.value
}
inputs = {
'.'.join(variable[1:]): value
'.'.join(selector[1:]): variable.value
}
break
else:
for group in node_data.advanced_settings.groups:
for variable in group.variables:
value = self.graph_runtime_state.variable_pool.get_variable_value(variable)
for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if value is not None:
if variable is not None:
outputs[group.group_name] = {
'output': value
'output': variable.value
}
inputs['.'.join(variable[1:])] = value
inputs['.'.join(selector[1:])] = variable.value
break
return NodeRunResult(

View File

@ -1,12 +1,48 @@
import re
from collections.abc import Mapping
from typing import Any
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
REGEX = re.compile(r'\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}')
def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str:
"""
This is an alternative to the VariableTemplateParser class,
offering the same functionality but with better readability and ease of use.
"""
variable_keys = [match[0] for match in re.findall(REGEX, template)]
variable_keys = list(set(variable_keys))
# This key_selector is a tuple of (key, selector) where selector is a list of keys
# e.g. ('#node_id.query.name#', ['node_id', 'query', 'name'])
key_selectors = filter(
lambda t: len(t[1]) >= 2,
((key, selector.replace('#', '').split('.')) for key, selector in zip(variable_keys, variable_keys)),
)
inputs = {key: variable_pool.get_any(selector) for key, selector in key_selectors}
def replacer(match):
key = match.group(1)
# return original matched string if key not found
value = inputs.get(key, match.group(0))
if value is None:
value = ''
value = str(value)
# remove template variables if required
return re.sub(REGEX, r'{\1}', value)
result = re.sub(REGEX, replacer, template)
result = re.sub(r'<\|.*?\|>', '', result)
return result
class VariableTemplateParser:
"""
!NOTE: Consider to use the new `segments` module instead of this class.
A class for parsing and manipulating template variables in a string.
Rules:
@ -70,14 +106,11 @@ class VariableTemplateParser:
if len(split_result) < 2:
continue
variable_selectors.append(VariableSelector(
variable=variable_key,
value_selector=split_result
))
variable_selectors.append(VariableSelector(variable=variable_key, value_selector=split_result))
return variable_selectors
def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
def format(self, inputs: Mapping[str, Any]) -> str:
"""
Formats the template string by replacing the template variables with their corresponding values.
@ -88,17 +121,19 @@ class VariableTemplateParser:
Returns:
The formatted string with template variables replaced by their values.
"""
def replacer(match):
key = match.group(1)
value = inputs.get(key, match.group(0)) # return original matched string if key not found
if value is None:
value = ''
# convert the value to string
if isinstance(value, list | dict | bool | int | float):
value = str(value)
# remove template variables if required
if remove_template_variables:
return VariableTemplateParser.remove_template_variables(value)
return value
return VariableTemplateParser.remove_template_variables(value)
prompt = re.sub(REGEX, replacer, self.template)
return re.sub(r'<\|.*?\|>', '', prompt)

View File

@ -1,10 +1,9 @@
import logging
import time
from collections.abc import Generator
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from flask import current_app
from configs import dify_config
from core.app.app_config.entities import FileExtraConfig
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
@ -13,7 +12,6 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine
@ -37,11 +35,11 @@ class WorkflowEntry:
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
callbacks: list[BaseWorkflowCallback],
user_inputs: dict,
system_inputs: dict[SystemVariable, Any],
user_inputs: Mapping[str, Any],
system_inputs: Mapping[SystemVariable, Any],
callbacks: Sequence[BaseWorkflowCallback],
call_depth: int = 0,
variable_pool: Optional[VariablePool] = None) -> Generator:
variable_pool: Optional[VariablePool] = None) -> None:
"""
:param workflow: Workflow instance
:param user_id: user id
@ -71,9 +69,14 @@ class WorkflowEntry:
if not variable_pool:
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=user_inputs
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
)
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
if call_depth > workflow_call_max_depth:
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
# init graph
graph = Graph.init(
graph_config=graph_config
@ -124,11 +127,11 @@ class WorkflowEntry:
return rst
def _run_workflow(self, graph_config: dict,
workflow_runtime_state: WorkflowRuntimeState,
callbacks: list[BaseWorkflowCallback],
start_node: Optional[str] = None,
end_node: Optional[str] = None) -> None:
def _run_workflow(self, workflow: Workflow,
workflow_run_state: WorkflowRunState,
callbacks: Sequence[BaseWorkflowCallback],
start_at: Optional[str] = None,
end_at: Optional[str] = None) -> None:
"""
Run workflow
:param graph_config: workflow graph config
@ -149,12 +152,11 @@ class WorkflowEntry:
error='Start node not found in workflow graph.'
)
predecessor_node: Optional[BaseNode] = None
current_iteration_node: Optional[BaseIterationNode] = None
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
max_execution_steps = cast(int, max_execution_steps)
max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
max_execution_time = cast(int, max_execution_time)
predecessor_node: BaseNode | None = None
current_iteration_node: BaseIterationNode | None = None
has_entry_node = False
max_execution_steps = dify_config.WORKFLOW_MAX_EXECUTION_STEPS
max_execution_time = dify_config.WORKFLOW_MAX_EXECUTION_TIME
while True:
# get next nodes
next_nodes = self._get_next_overall_nodes(
@ -212,7 +214,7 @@ class WorkflowEntry:
# move to next iteration
next_node_id = next_iteration
# get next id
next_nodes = [self._get_node(workflow_run_state, graph, next_node_id, callbacks)]
next_nodes = [self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks)]
if not next_nodes:
break
@ -423,7 +425,8 @@ class WorkflowEntry:
# init variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={}
user_inputs={},
environment_variables=workflow.environment_variables,
)
# variable selector to variable mapping
@ -458,11 +461,11 @@ class WorkflowEntry:
return node_instance, node_run_result
def single_step_run_iteration_workflow_node(self, workflow: Workflow,
node_id: str,
user_id: str,
user_inputs: dict,
callbacks: list[BaseWorkflowCallback] = None,
) -> None:
node_id: str,
user_id: str,
user_inputs: dict,
callbacks: Sequence[BaseWorkflowCallback],
) -> None:
"""
Single iteration run workflow node
"""
@ -488,7 +491,8 @@ class WorkflowEntry:
# init variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={}
user_inputs={},
environment_variables=workflow.environment_variables,
)
# variable selector to variable mapping
@ -604,7 +608,7 @@ class WorkflowEntry:
for callback in callbacks:
callback.on_workflow_run_started()
def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None:
def _workflow_run_success(self, callbacks: Sequence[WorkflowCallback]) -> None:
"""
Workflow run success
:param callbacks: workflow callbacks
@ -616,7 +620,7 @@ class WorkflowEntry:
callback.on_workflow_run_succeeded()
def _workflow_run_failed(self, error: str,
callbacks: list[BaseWorkflowCallback] = None) -> None:
callbacks: Sequence[WorkflowCallback]) -> None:
"""
Workflow run failed
:param error: error message
@ -629,11 +633,11 @@ class WorkflowEntry:
error=error
)
def _workflow_iteration_started(self, graph: dict,
def _workflow_iteration_started(self, *, graph: Mapping[str, Any],
current_iteration_node: BaseIterationNode,
workflow_run_state: WorkflowRunState,
predecessor_node_id: Optional[str] = None,
callbacks: list[BaseWorkflowCallback] = None) -> None:
callbacks: Sequence[WorkflowCallback]) -> None:
"""
Workflow iteration started
:param current_iteration_node: current iteration node
@ -666,10 +670,10 @@ class WorkflowEntry:
# add steps
workflow_run_state.workflow_node_steps += 1
def _workflow_iteration_next(self, graph: dict,
def _workflow_iteration_next(self, *, graph: Mapping[str, Any],
current_iteration_node: BaseIterationNode,
workflow_run_state: WorkflowRunState,
callbacks: list[BaseWorkflowCallback] = None) -> None:
callbacks: Sequence[BaseWorkflowCallback]) -> None:
"""
Workflow iteration next
:param workflow_run_state: workflow run state
@ -696,11 +700,11 @@ class WorkflowEntry:
nodes = [node for node in nodes if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id]
for node in nodes:
workflow_run_state.variable_pool.clear_node_variables(node_id=node.get('id'))
workflow_run_state.variable_pool.remove((node.get('id'),))
def _workflow_iteration_completed(self, current_iteration_node: BaseIterationNode,
workflow_run_state: WorkflowRunState,
callbacks: list[BaseWorkflowCallback] = None) -> None:
def _workflow_iteration_completed(self, *, current_iteration_node: BaseIterationNode,
workflow_run_state: WorkflowRunState,
callbacks: Sequence[BaseWorkflowCallback]) -> None:
if callbacks:
if isinstance(workflow_run_state.current_iteration_state, IterationState):
for callback in callbacks:
@ -713,12 +717,12 @@ class WorkflowEntry:
}
)
def _get_next_overall_nodes(self, workflow_run_state: WorkflowRunState,
graph: dict,
callbacks: list[BaseWorkflowCallback],
predecessor_node: Optional[BaseNode] = None,
node_start_at: Optional[str] = None,
node_end_at: Optional[str] = None) -> list[BaseNode]:
def _get_next_overall_node(self, *, workflow_run_state: WorkflowRunState,
graph: Mapping[str, Any],
callbacks: list[BaseWorkflowCallback],
predecessor_node: Optional[BaseNode] = None,
node_start_at: Optional[str] = None,
node_end_at: Optional[str] = None) -> Optional[BaseNode]:
"""
Get next nodes
multiple target nodes in the future.
@ -804,26 +808,26 @@ class WorkflowEntry:
if not target_node_cls:
continue
target_node = target_node_cls(
tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id,
workflow_id=workflow_run_state.workflow_id,
user_id=workflow_run_state.user_id,
user_from=workflow_run_state.user_from,
invoke_from=workflow_run_state.invoke_from,
config=target_node_config,
callbacks=callbacks,
workflow_call_depth=workflow_run_state.workflow_call_depth
)
target_node = target_node_cls(
tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id,
workflow_id=workflow_run_state.workflow_id,
user_id=workflow_run_state.user_id,
user_from=workflow_run_state.user_from,
invoke_from=workflow_run_state.invoke_from,
config=target_node_config,
callbacks=callbacks,
workflow_call_depth=workflow_run_state.workflow_call_depth
)
target_nodes.append(target_node)
target_nodes.append(target_node)
return target_nodes
return target_nodes
def _get_node(self, workflow_run_state: WorkflowRunState,
graph: dict,
graph: Mapping[str, Any],
node_id: str,
callbacks: list[BaseWorkflowCallback]) -> Optional[BaseNode]:
callbacks: Sequence[WorkflowCallback]):
"""
Get node from graph by node id
"""
@ -834,7 +838,7 @@ class WorkflowEntry:
for node_config in nodes:
if node_config.get('id') == node_id:
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
node_cls = node_classes[node_type]
return node_cls(
tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id,
@ -847,8 +851,6 @@ class WorkflowEntry:
workflow_call_depth=workflow_run_state.workflow_call_depth
)
return None
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
Check timeout
@ -867,10 +869,10 @@ class WorkflowEntry:
if node_and_result.node_id == node_id
])
def _run_workflow_node(self, workflow_run_state: WorkflowRunState,
def _run_workflow_node(self, *, workflow_run_state: WorkflowRunState,
node: BaseNode,
predecessor_node: Optional[BaseNode] = None,
callbacks: list[BaseWorkflowCallback] = None) -> None:
callbacks: Sequence[WorkflowCallback]) -> None:
if callbacks:
for callback in callbacks:
callback.on_workflow_node_execute_started(
@ -973,10 +975,8 @@ class WorkflowEntry:
:param variable_value: variable value
:return:
"""
variable_pool.append_variable(
node_id=node_id,
variable_key_list=variable_key_list,
value=variable_value
variable_pool.add(
[node_id] + variable_key_list, variable_value
)
# if variable_value is a dict, then recursively append variables
@ -1025,7 +1025,7 @@ class WorkflowEntry:
tenant_id: str,
node_instance: BaseNode):
for variable_key, variable_selector in variable_mapping.items():
if variable_key not in user_inputs:
if variable_key not in user_inputs and not variable_pool.get(variable_selector):
raise ValueError(f'Variable key {variable_key} not found in user inputs.')
# fetch variable node id from variable selector
@ -1035,7 +1035,7 @@ class WorkflowEntry:
# get value
value = user_inputs.get(variable_key)
# temp fix for image type
# FIXME: temp fix for image type
if node_instance.node_type == NodeType.LLM:
new_value = []
if isinstance(value, list):
@ -1062,11 +1062,7 @@ class WorkflowEntry:
value = new_value
# append variable and value to variable pool
variable_pool.append_variable(
node_id=variable_node_id,
variable_key_list=variable_key_list,
value=value
)
variable_pool.add([variable_node_id]+variable_key_list, value)
class WorkflowRunFailedError(Exception):