Merge main

This commit is contained in:
Yeuoly
2024-09-14 02:47:01 +08:00
959 changed files with 25695 additions and 24057 deletions

View File

@ -29,14 +29,12 @@ class AnswerNode(BaseNode):
# generate routes
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data)
answer = ''
answer = ""
for part in generate_routes:
if part.type == GenerateRouteChunk.ChunkType.VAR:
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = self.graph_runtime_state.variable_pool.get(
value_selector
)
value = self.graph_runtime_state.variable_pool.get(value_selector)
if value:
answer += value.markdown
@ -44,19 +42,11 @@ class AnswerNode(BaseNode):
part = cast(TextGenerateRouteChunk, part)
answer += part.text
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"answer": answer
}
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer})
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AnswerNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: AnswerNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -73,6 +63,6 @@ class AnswerNode(BaseNode):
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
return variable_mapping

View File

@ -1,4 +1,3 @@
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import (
@ -12,12 +11,12 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerStreamGeneratorRouter:
@classmethod
def init(cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
) -> AnswerStreamGenerateRoute:
def init(
cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
) -> AnswerStreamGenerateRoute:
"""
Get stream generate routes.
:return:
@ -25,7 +24,7 @@ class AnswerStreamGeneratorRouter:
# parse stream output node value selectors of answer nodes
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
for answer_node_id, node_config in node_id_config_mapping.items():
if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value:
if node_config.get("data", {}).get("type") != NodeType.ANSWER.value:
continue
# get generate route for stream output
@ -37,12 +36,11 @@ class AnswerStreamGeneratorRouter:
answer_dependencies = cls._fetch_answers_dependencies(
answer_node_ids=answer_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping
node_id_config_mapping=node_id_config_mapping,
)
return AnswerStreamGenerateRoute(
answer_generate_route=answer_generate_route,
answer_dependencies=answer_dependencies
answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies
)
@classmethod
@ -56,8 +54,7 @@ class AnswerStreamGeneratorRouter:
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector
for variable_selector in variable_selectors
variable_selector.variable: variable_selector.value_selector for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
@ -71,21 +68,17 @@ class AnswerStreamGeneratorRouter:
template = node_data.answer
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
template = template.replace(f"{{{{{var}}}}}", f"Ω{{{{{var}}}}}Ω")
generate_routes: list[GenerateRouteChunk] = []
for part in template.split('Ω'):
for part in template.split("Ω"):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
var_key = part.replace("Ω", "").replace("{{", "").replace("}}", "")
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(
value_selector=value_selector
))
generate_routes.append(VarGenerateRouteChunk(value_selector=value_selector))
else:
generate_routes.append(TextGenerateRouteChunk(
text=part
))
generate_routes.append(TextGenerateRouteChunk(text=part))
return generate_routes
@ -101,15 +94,16 @@ class AnswerStreamGeneratorRouter:
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys
cleaned_part = part.replace("{{", "").replace("}}", "")
return part.startswith("{{") and cleaned_part in variable_keys
@classmethod
def _fetch_answers_dependencies(cls,
answer_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict]
) -> dict[str, list[str]]:
def _fetch_answers_dependencies(
cls,
answer_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict],
) -> dict[str, list[str]]:
"""
Fetch answer dependencies
:param answer_node_ids: answer node ids
@ -127,19 +121,20 @@ class AnswerStreamGeneratorRouter:
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies
answer_dependencies=answer_dependencies,
)
return answer_dependencies
@classmethod
def _recursive_fetch_answer_dependencies(cls,
current_node_id: str,
answer_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]]
) -> None:
def _recursive_fetch_answer_dependencies(
cls,
current_node_id: str,
answer_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]],
) -> None:
"""
Recursive fetch answer dependencies
:param current_node_id: current node id
@ -152,12 +147,12 @@ class AnswerStreamGeneratorRouter:
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
if source_node_type in (
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER,
):
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
if source_node_type in {
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value,
}:
answer_dependencies[answer_node_id].append(source_node_id)
else:
cls._recursive_fetch_answer_dependencies(
@ -165,5 +160,5 @@ class AnswerStreamGeneratorRouter:
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies
answer_dependencies=answer_dependencies,
)

View File

@ -18,7 +18,6 @@ logger = logging.getLogger(__name__)
class AnswerStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
super().__init__(graph, variable_pool)
self.generate_routes = graph.answer_stream_generate_routes
@ -27,9 +26,7 @@ class AnswerStreamProcessor(StreamProcessor):
self.route_position[answer_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
@ -47,9 +44,9 @@ class AnswerStreamProcessor(StreamProcessor):
]
else:
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
] = stream_out_answer_node_ids
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
stream_out_answer_node_ids
)
for _ in stream_out_answer_node_ids:
yield event
@ -77,9 +74,9 @@ class AnswerStreamProcessor(StreamProcessor):
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _generate_stream_outputs_when_node_finished(self,
event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
def _generate_stream_outputs_when_node_finished(
self, event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
@ -87,10 +84,13 @@ class AnswerStreamProcessor(StreamProcessor):
"""
for answer_node_id, position in self.route_position.items():
# all depends on answer node id not in rest node ids
if (event.route_node_state.node_id != answer_node_id
and (answer_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
if event.route_node_state.node_id != answer_node_id and (
answer_node_id not in self.rest_node_ids
or not all(
dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]
)
):
continue
route_position = self.route_position[answer_node_id]
@ -108,6 +108,7 @@ class AnswerStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
from_variable_selector=[answer_node_id, "answer"],
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
@ -115,9 +116,7 @@ class AnswerStreamProcessor(StreamProcessor):
if not value_selector:
break
value = self.variable_pool.get(
value_selector
)
value = self.variable_pool.get(value_selector)
if value is None:
break
@ -158,8 +157,9 @@ class AnswerStreamProcessor(StreamProcessor):
continue
# all depends on answer node id not in rest node ids
if all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
if all(
dep_id not in self.rest_node_ids for dep_id in self.generate_routes.answer_dependencies[answer_node_id]
):
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
continue
@ -213,7 +213,7 @@ class AnswerStreamProcessor(StreamProcessor):
return None
if isinstance(value, dict):
if '__variant' in value and value['__variant'] == FileVar.__name__:
if "__variant" in value and value["__variant"] == FileVar.__name__:
return value
elif isinstance(value, FileVar):
return value.to_dict()

View File

@ -7,16 +7,13 @@ from core.workflow.graph_engine.entities.graph import Graph
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
self.rest_node_ids = graph.node_ids.copy()
@abstractmethod
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
@ -35,9 +32,11 @@ class StreamProcessor(ABC):
reachable_node_ids = []
unreachable_first_node_ids = []
for edge in self.graph.edge_mapping[finished_node_id]:
if (edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify):
if (
edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify
):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:

View File

@ -9,6 +9,7 @@ class AnswerNodeData(BaseNodeData):
"""
Answer Node Data.
"""
answer: str = Field(..., description="answer template string")
@ -28,6 +29,7 @@ class VarGenerateRouteChunk(GenerateRouteChunk):
"""
Var Generate Route Chunk.
"""
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
"""generate route chunk type"""
value_selector: list[str] = Field(..., description="value selector")
@ -37,6 +39,7 @@ class TextGenerateRouteChunk(GenerateRouteChunk):
"""
Text Generate Route Chunk.
"""
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT
"""generate route chunk type"""
text: str = Field(..., description="text")
@ -52,11 +55,10 @@ class AnswerStreamGenerateRoute(BaseModel):
"""
AnswerStreamGenerateRoute entity
"""
answer_dependencies: dict[str, list[str]] = Field(
...,
description="answer dependencies (answer node id -> dependent answer node ids)"
..., description="answer dependencies (answer node id -> dependent answer node ids)"
)
answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field(
...,
description="answer generate route (answer node id -> generate route chunks)"
..., description="answer generate route (answer node id -> generate route chunks)"
)

View File

@ -15,14 +15,16 @@ class BaseNode(ABC):
_node_data_cls: type[BaseNodeData]
_node_type: NodeType
def __init__(self,
id: str,
config: Mapping[str, Any],
graph_init_params: GraphInitParams,
graph: Graph,
graph_runtime_state: GraphRuntimeState,
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None) -> None:
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: GraphInitParams,
graph: Graph,
graph_runtime_state: GraphRuntimeState,
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
) -> None:
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
@ -46,8 +48,7 @@ class BaseNode(ABC):
self.node_data = self._node_data_cls(**config.get("data", {}))
@abstractmethod
def _run(self) \
-> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
def _run(self) -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node
:return:
@ -62,14 +63,14 @@ class BaseNode(ABC):
result = self._run()
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(
run_result=result
)
yield RunCompletedEvent(run_result=result)
else:
yield from result
@classmethod
def extract_variable_selector_to_variable_mapping(cls, graph_config: Mapping[str, Any], config: dict) -> Mapping[str, Sequence[str]]:
def extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], config: dict
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
@ -82,17 +83,12 @@ class BaseNode(ABC):
node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config,
node_id=node_id,
node_data=node_data
graph_config=graph_config, node_id=node_id, node_data=node_data
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: BaseNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: BaseNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

View File

@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union, cast
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
@ -25,11 +25,10 @@ class CodeNode(BaseNode):
"""
code_language = CodeLanguage.PYTHON3
if filters:
code_language = (filters.get("code_language", CodeLanguage.PYTHON3))
code_language = filters.get("code_language", CodeLanguage.PYTHON3)
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] = next(p for p in providers
if p.is_accept_language(code_language))
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
return code_provider.get_default_config()
@ -62,18 +61,10 @@ class CodeNode(BaseNode):
# Transform result
result = self._transform_result(result, node_data.outputs)
except (CodeExecutionException, ValueError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e)
)
except (CodeExecutionError, ValueError) as e:
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
outputs=result
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
def _check_string(self, value: str, variable: str) -> str:
"""
@ -83,16 +74,18 @@ class CodeNode(BaseNode):
:return:
"""
if not isinstance(value, str):
if isinstance(value, type(None)):
if value is None:
return None
else:
raise ValueError(f"Output variable `{variable}` must be a string")
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
raise ValueError(f'The length of output variable `{variable}` must be'
f' less than {dify_config.CODE_MAX_STRING_LENGTH} characters')
return value.replace('\x00', '')
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
raise ValueError(
f"The length of output variable `{variable}` must be"
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
)
return value.replace("\x00", "")
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
"""
@ -102,26 +95,30 @@ class CodeNode(BaseNode):
:return:
"""
if not isinstance(value, int | float):
if isinstance(value, type(None)):
if value is None:
return None
else:
raise ValueError(f"Output variable `{variable}` must be a number")
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
raise ValueError(f'Output variable `{variable}` is out of range,'
f' it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}.')
raise ValueError(
f"Output variable `{variable}` is out of range,"
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
)
if isinstance(value, float):
# raise error if precision is too high
if len(str(value).split('.')[1]) > dify_config.CODE_MAX_PRECISION:
raise ValueError(f'Output variable `{variable}` has too high precision,'
f' it must be less than {dify_config.CODE_MAX_PRECISION} digits.')
if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION:
raise ValueError(
f"Output variable `{variable}` has too high precision,"
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
)
return value
def _transform_result(self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]],
prefix: str = '',
depth: int = 1) -> dict:
def _transform_result(
self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1
) -> dict:
"""
Transform result
:param result: result
@ -139,183 +136,187 @@ class CodeNode(BaseNode):
self._transform_result(
result=output_value,
output_schema=None,
prefix=f'{prefix}.{output_name}' if prefix else output_name,
depth=depth + 1
prefix=f"{prefix}.{output_name}" if prefix else output_name,
depth=depth + 1,
)
elif isinstance(output_value, int | float):
self._check_number(
value=output_value,
variable=f'{prefix}.{output_name}' if prefix else output_name
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
)
elif isinstance(output_value, str):
self._check_string(
value=output_value,
variable=f'{prefix}.{output_name}' if prefix else output_name
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
)
elif isinstance(output_value, list):
first_element = output_value[0] if len(output_value) > 0 else None
if first_element is not None:
if isinstance(first_element, int | float) and all(value is None or isinstance(value, int | float) for value in output_value):
if isinstance(first_element, int | float) and all(
value is None or isinstance(value, int | float) for value in output_value
):
for i, value in enumerate(output_value):
self._check_number(
value=value,
variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]'
variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
)
elif isinstance(first_element, str) and all(value is None or isinstance(value, str) for value in output_value):
elif isinstance(first_element, str) and all(
value is None or isinstance(value, str) for value in output_value
):
for i, value in enumerate(output_value):
self._check_string(
value=value,
variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]'
variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
)
elif isinstance(first_element, dict) and all(value is None or isinstance(value, dict) for value in output_value):
elif isinstance(first_element, dict) and all(
value is None or isinstance(value, dict) for value in output_value
):
for i, value in enumerate(output_value):
if value is not None:
self._transform_result(
result=value,
output_schema=None,
prefix=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]',
depth=depth + 1
prefix=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
depth=depth + 1,
)
else:
raise ValueError(f'Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type.')
elif isinstance(output_value, type(None)):
raise ValueError(
f"Output {prefix}.{output_name} is not a valid array."
f" make sure all elements are of the same type."
)
elif output_value is None:
pass
else:
raise ValueError(f'Output {prefix}.{output_name} is not a valid type.')
raise ValueError(f"Output {prefix}.{output_name} is not a valid type.")
return result
parameters_validated = {}
for output_name, output_config in output_schema.items():
dot = '.' if prefix else ''
dot = "." if prefix else ""
if output_name not in result:
raise ValueError(f'Output {prefix}{dot}{output_name} is missing.')
if output_config.type == 'object':
raise ValueError(f"Output {prefix}{dot}{output_name} is missing.")
if output_config.type == "object":
# check if output is object
if not isinstance(result.get(output_name), dict):
if isinstance(result.get(output_name), type(None)):
transformed_result[output_name] = None
else:
raise ValueError(
f'Output {prefix}{dot}{output_name} is not an object, got {type(result.get(output_name))} instead.'
f"Output {prefix}{dot}{output_name} is not an object,"
f" got {type(result.get(output_name))} instead."
)
else:
transformed_result[output_name] = self._transform_result(
result=result[output_name],
output_schema=output_config.children,
prefix=f'{prefix}.{output_name}',
depth=depth + 1
prefix=f"{prefix}.{output_name}",
depth=depth + 1,
)
elif output_config.type == 'number':
elif output_config.type == "number":
# check if number available
transformed_result[output_name] = self._check_number(
value=result[output_name],
variable=f'{prefix}{dot}{output_name}'
value=result[output_name], variable=f"{prefix}{dot}{output_name}"
)
elif output_config.type == 'string':
elif output_config.type == "string":
# check if string available
transformed_result[output_name] = self._check_string(
value=result[output_name],
variable=f'{prefix}{dot}{output_name}',
variable=f"{prefix}{dot}{output_name}",
)
elif output_config.type == 'array[number]':
elif output_config.type == "array[number]":
# check if array of number available
if not isinstance(result[output_name], list):
if isinstance(result[output_name], type(None)):
transformed_result[output_name] = None
else:
raise ValueError(
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
raise ValueError(
f'The length of output variable `{prefix}{dot}{output_name}` must be'
f' less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements.'
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
)
transformed_result[output_name] = [
self._check_number(
value=value,
variable=f'{prefix}{dot}{output_name}[{i}]'
)
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
for i, value in enumerate(result[output_name])
]
elif output_config.type == 'array[string]':
elif output_config.type == "array[string]":
# check if array of string available
if not isinstance(result[output_name], list):
if isinstance(result[output_name], type(None)):
transformed_result[output_name] = None
else:
raise ValueError(
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
raise ValueError(
f'The length of output variable `{prefix}{dot}{output_name}` must be'
f' less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements.'
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
)
transformed_result[output_name] = [
self._check_string(
value=value,
variable=f'{prefix}{dot}{output_name}[{i}]'
)
self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
for i, value in enumerate(result[output_name])
]
elif output_config.type == 'array[object]':
elif output_config.type == "array[object]":
# check if array of object available
if not isinstance(result[output_name], list):
if isinstance(result[output_name], type(None)):
transformed_result[output_name] = None
else:
raise ValueError(
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
raise ValueError(
f'The length of output variable `{prefix}{dot}{output_name}` must be'
f' less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements.'
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
)
for i, value in enumerate(result[output_name]):
if not isinstance(value, dict):
if isinstance(value, type(None)):
if value is None:
pass
else:
raise ValueError(
f'Output {prefix}{dot}{output_name}[{i}] is not an object, got {type(value)} instead at index {i}.'
f"Output {prefix}{dot}{output_name}[{i}] is not an object,"
f" got {type(value)} instead at index {i}."
)
transformed_result[output_name] = [
None if value is None else self._transform_result(
None
if value is None
else self._transform_result(
result=value,
output_schema=output_config.children,
prefix=f'{prefix}{dot}{output_name}[{i}]',
depth=depth + 1
prefix=f"{prefix}{dot}{output_name}[{i}]",
depth=depth + 1,
)
for i, value in enumerate(result[output_name])
]
else:
raise ValueError(f'Output type {output_config.type} is not supported.')
raise ValueError(f"Output type {output_config.type} is not supported.")
parameters_validated[output_name] = True
# check if all output parameters are validated
if len(parameters_validated) != len(result):
raise ValueError('Not all output parameters are validated.')
raise ValueError("Not all output parameters are validated.")
return transformed_result
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: CodeNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: CodeNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -325,5 +326,6 @@ class CodeNode(BaseNode):
:return:
"""
return {
node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables
}

View File

@ -11,9 +11,10 @@ class CodeNodeData(BaseNodeData):
"""
Code Node Data.
"""
class Output(BaseModel):
type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]']
children: Optional[dict[str, 'Output']] = None
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
children: Optional[dict[str, "Output"]] = None
class Dependency(BaseModel):
name: str
@ -23,4 +24,4 @@ class CodeNodeData(BaseNodeData):
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
code: str
outputs: dict[str, Output]
dependencies: Optional[list[Dependency]] = None
dependencies: Optional[list[Dependency]] = None

View File

@ -25,18 +25,11 @@ class EndNode(BaseNode):
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
outputs[variable_selector.variable] = value
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=outputs,
outputs=outputs
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, outputs=outputs)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: EndNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: EndNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

View File

@ -3,13 +3,13 @@ from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
class EndStreamGeneratorRouter:
@classmethod
def init(cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_parallel_mapping: dict[str, str]
) -> EndStreamParam:
def init(
cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_parallel_mapping: dict[str, str],
) -> EndStreamParam:
"""
Get stream generate routes.
:return:
@ -17,7 +17,7 @@ class EndStreamGeneratorRouter:
# parse stream output node value selector of end nodes
end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
for end_node_id, node_config in node_id_config_mapping.items():
if not node_config.get('data', {}).get('type') == NodeType.END.value:
if node_config.get("data", {}).get("type") != NodeType.END.value:
continue
# skip end node in parallel
@ -33,18 +33,18 @@ class EndStreamGeneratorRouter:
end_dependencies = cls._fetch_ends_dependencies(
end_node_ids=end_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping
node_id_config_mapping=node_id_config_mapping,
)
return EndStreamParam(
end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping,
end_dependencies=end_dependencies
end_dependencies=end_dependencies,
)
@classmethod
def extract_stream_variable_selector_from_node_data(cls,
node_id_config_mapping: dict[str, dict],
node_data: EndNodeData) -> list[list[str]]:
def extract_stream_variable_selector_from_node_data(
cls, node_id_config_mapping: dict[str, dict], node_data: EndNodeData
) -> list[list[str]]:
"""
Extract stream variable selector from node data
:param node_id_config_mapping: node id config mapping
@ -59,21 +59,22 @@ class EndStreamGeneratorRouter:
continue
node_id = variable_selector.value_selector[0]
if node_id != 'sys' and node_id in node_id_config_mapping:
if node_id != "sys" and node_id in node_id_config_mapping:
node = node_id_config_mapping[node_id]
node_type = node.get('data', {}).get('type')
node_type = node.get("data", {}).get("type")
if (
variable_selector.value_selector not in value_selectors
and node_type == NodeType.LLM.value
and variable_selector.value_selector[1] == 'text'
and node_type == NodeType.LLM.value
and variable_selector.value_selector[1] == "text"
):
value_selectors.append(variable_selector.value_selector)
return value_selectors
@classmethod
def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \
-> list[list[str]]:
def _extract_stream_variable_selector(
cls, node_id_config_mapping: dict[str, dict], config: dict
) -> list[list[str]]:
"""
Extract stream variable selector from node config
:param node_id_config_mapping: node id config mapping
@ -84,11 +85,12 @@ class EndStreamGeneratorRouter:
return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data)
@classmethod
def _fetch_ends_dependencies(cls,
end_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict]
) -> dict[str, list[str]]:
def _fetch_ends_dependencies(
cls,
end_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict],
) -> dict[str, list[str]]:
"""
Fetch end dependencies
:param end_node_ids: end node ids
@ -106,20 +108,21 @@ class EndStreamGeneratorRouter:
end_node_id=end_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
end_dependencies=end_dependencies
end_dependencies=end_dependencies,
)
return end_dependencies
@classmethod
def _recursive_fetch_end_dependencies(cls,
current_node_id: str,
end_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]],
# type: ignore[name-defined]
end_dependencies: dict[str, list[str]]
) -> None:
def _recursive_fetch_end_dependencies(
cls,
current_node_id: str,
end_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]],
# type: ignore[name-defined]
end_dependencies: dict[str, list[str]],
) -> None:
"""
Recursive fetch end dependencies
:param current_node_id: current node id
@ -132,11 +135,11 @@ class EndStreamGeneratorRouter:
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
if source_node_type in (
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER,
):
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
if source_node_type in {
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER,
}:
end_dependencies[end_node_id].append(source_node_id)
else:
cls._recursive_fetch_end_dependencies(
@ -144,5 +147,5 @@ class EndStreamGeneratorRouter:
end_node_id=end_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
end_dependencies=end_dependencies
end_dependencies=end_dependencies,
)

View File

@ -15,7 +15,6 @@ logger = logging.getLogger(__name__)
class EndStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
super().__init__(graph, variable_pool)
self.end_stream_param = graph.end_stream_param
@ -26,9 +25,7 @@ class EndStreamProcessor(StreamProcessor):
self.has_outputed = False
self.outputed_node_ids = set()
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
@ -38,7 +35,7 @@ class EndStreamProcessor(StreamProcessor):
elif isinstance(event, NodeRunStreamChunkEvent):
if event.in_iteration_id:
if self.has_outputed and event.node_id not in self.outputed_node_ids:
event.chunk_content = '\n' + event.chunk_content
event.chunk_content = "\n" + event.chunk_content
self.outputed_node_ids.add(event.node_id)
self.has_outputed = True
@ -51,13 +48,13 @@ class EndStreamProcessor(StreamProcessor):
]
else:
stream_out_end_node_ids = self._get_stream_out_end_node_ids(event)
self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
] = stream_out_end_node_ids
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
stream_out_end_node_ids
)
if stream_out_end_node_ids:
if self.has_outputed and event.node_id not in self.outputed_node_ids:
event.chunk_content = '\n' + event.chunk_content
event.chunk_content = "\n" + event.chunk_content
self.outputed_node_ids.add(event.node_id)
self.has_outputed = True
@ -86,9 +83,9 @@ class EndStreamProcessor(StreamProcessor):
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _generate_stream_outputs_when_node_finished(self,
event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
def _generate_stream_outputs_when_node_finished(
self, event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
@ -96,10 +93,12 @@ class EndStreamProcessor(StreamProcessor):
"""
for end_node_id, position in self.route_position.items():
# all depends on end node id not in rest node ids
if (event.route_node_state.node_id != end_node_id
and (end_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.end_stream_param.end_dependencies[end_node_id]))):
if event.route_node_state.node_id != end_node_id and (
end_node_id not in self.rest_node_ids
or not all(
dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]
)
):
continue
route_position = self.route_position[end_node_id]
@ -116,9 +115,7 @@ class EndStreamProcessor(StreamProcessor):
if not value_selector:
continue
value = self.variable_pool.get(
value_selector
)
value = self.variable_pool.get(value_selector)
if value is None:
break
@ -128,7 +125,7 @@ class EndStreamProcessor(StreamProcessor):
if text:
current_node_id = value_selector[0]
if self.has_outputed and current_node_id not in self.outputed_node_ids:
text = '\n' + text
text = "\n" + text
self.outputed_node_ids.add(current_node_id)
self.has_outputed = True
@ -165,8 +162,7 @@ class EndStreamProcessor(StreamProcessor):
continue
# all depends on end node id not in rest node ids
if all(dep_id not in self.rest_node_ids
for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
if all(dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]):
continue
@ -178,7 +174,7 @@ class EndStreamProcessor(StreamProcessor):
break
position += 1
if not value_selector:
continue

View File

@ -8,6 +8,7 @@ class EndNodeData(BaseNodeData):
"""
END Node Data.
"""
outputs: list[VariableSelector]
@ -15,11 +16,10 @@ class EndStreamParam(BaseModel):
"""
EndStreamParam entity
"""
end_dependencies: dict[str, list[str]] = Field(
...,
description="end dependencies (end node id -> dependent node ids)"
..., description="end dependencies (end node id -> dependent node ids)"
)
end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field(
...,
description="end stream variable selector mapping (end node id -> stream variable selectors)"
..., description="end stream variable selector mapping (end node id -> stream variable selectors)"
)

View File

@ -7,32 +7,32 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
class HttpRequestNodeAuthorizationConfig(BaseModel):
type: Literal[None, 'basic', 'bearer', 'custom']
type: Literal[None, "basic", "bearer", "custom"]
api_key: Union[None, str] = None
header: Union[None, str] = None
class HttpRequestNodeAuthorization(BaseModel):
type: Literal['no-auth', 'api-key']
type: Literal["no-auth", "api-key"]
config: Optional[HttpRequestNodeAuthorizationConfig] = None
@field_validator('config', mode='before')
@field_validator("config", mode="before")
@classmethod
def check_config(cls, v: HttpRequestNodeAuthorizationConfig, values: ValidationInfo):
"""
Check config, if type is no-auth, config should be None, otherwise it should be a dict.
"""
if values.data['type'] == 'no-auth':
if values.data["type"] == "no-auth":
return None
else:
if not v or not isinstance(v, dict):
raise ValueError('config should be a dict')
raise ValueError("config should be a dict")
return v
class HttpRequestNodeBody(BaseModel):
type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json']
type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json"]
data: Union[None, str] = None
@ -47,7 +47,7 @@ class HttpRequestNodeData(BaseNodeData):
Code Node Data.
"""
method: Literal['get', 'post', 'put', 'patch', 'delete', 'head']
method: Literal["get", "post", "put", "patch", "delete", "head"]
url: str
authorization: HttpRequestNodeAuthorization
headers: str

View File

@ -6,8 +6,8 @@ from urllib.parse import urlencode
import httpx
import core.helper.ssrf_proxy as ssrf_proxy
from configs import dify_config
from core.helper import ssrf_proxy
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.http_request.entities import (
@ -33,12 +33,12 @@ class HttpExecutorResponse:
check if response is file
"""
content_type = self.get_content_type()
file_content_types = ['image', 'audio', 'video']
file_content_types = ["image", "audio", "video"]
return any(v in content_type for v in file_content_types)
def get_content_type(self) -> str:
return self.headers.get('content-type', '')
return self.headers.get("content-type", "")
def extract_file(self) -> tuple[str, bytes]:
"""
@ -47,28 +47,28 @@ class HttpExecutorResponse:
if self.is_file:
return self.get_content_type(), self.body
return '', b''
return "", b""
@property
def content(self) -> str:
if isinstance(self.response, httpx.Response):
return self.response.text
else:
raise ValueError(f'Invalid response type {type(self.response)}')
raise ValueError(f"Invalid response type {type(self.response)}")
@property
def body(self) -> bytes:
if isinstance(self.response, httpx.Response):
return self.response.content
else:
raise ValueError(f'Invalid response type {type(self.response)}')
raise ValueError(f"Invalid response type {type(self.response)}")
@property
def status_code(self) -> int:
if isinstance(self.response, httpx.Response):
return self.response.status_code
else:
raise ValueError(f'Invalid response type {type(self.response)}')
raise ValueError(f"Invalid response type {type(self.response)}")
@property
def size(self) -> int:
@ -77,11 +77,11 @@ class HttpExecutorResponse:
@property
def readable_size(self) -> str:
if self.size < 1024:
return f'{self.size} bytes'
return f"{self.size} bytes"
elif self.size < 1024 * 1024:
return f'{(self.size / 1024):.2f} KB'
return f"{(self.size / 1024):.2f} KB"
else:
return f'{(self.size / 1024 / 1024):.2f} MB'
return f"{(self.size / 1024 / 1024):.2f} MB"
class HttpExecutor:
@ -120,7 +120,7 @@ class HttpExecutor:
"""
check if body is json
"""
if body and body.type == 'json' and body.data:
if body and body.type == "json" and body.data:
try:
json.loads(body.data)
return True
@ -134,15 +134,15 @@ class HttpExecutor:
"""
Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}`
"""
kv_paris = convert_text.split('\n')
kv_paris = convert_text.split("\n")
result = {}
for kv in kv_paris:
if not kv.strip():
continue
kv = kv.split(':', maxsplit=1)
kv = kv.split(":", maxsplit=1)
if len(kv) == 1:
k, v = kv[0], ''
k, v = kv[0], ""
else:
k, v = kv
result[k.strip()] = v
@ -166,31 +166,31 @@ class HttpExecutor:
# check if it's a valid JSON
is_valid_json = self._is_json_body(node_data.body)
body_data = node_data.body.data or ''
body_data = node_data.body.data or ""
if body_data:
body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json)
content_type_is_set = any(key.lower() == 'content-type' for key in self.headers)
if node_data.body.type == 'json' and not content_type_is_set:
self.headers['Content-Type'] = 'application/json'
elif node_data.body.type == 'x-www-form-urlencoded' and not content_type_is_set:
self.headers['Content-Type'] = 'application/x-www-form-urlencoded'
content_type_is_set = any(key.lower() == "content-type" for key in self.headers)
if node_data.body.type == "json" and not content_type_is_set:
self.headers["Content-Type"] = "application/json"
elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set:
self.headers["Content-Type"] = "application/x-www-form-urlencoded"
if node_data.body.type in ['form-data', 'x-www-form-urlencoded']:
if node_data.body.type in {"form-data", "x-www-form-urlencoded"}:
body = self._to_dict(body_data)
if node_data.body.type == 'form-data':
self.files = {k: ('', v) for k, v in body.items()}
random_str = lambda n: ''.join([chr(randint(97, 122)) for _ in range(n)])
self.boundary = f'----WebKitFormBoundary{random_str(16)}'
if node_data.body.type == "form-data":
self.files = {k: ("", v) for k, v in body.items()}
random_str = lambda n: "".join([chr(randint(97, 122)) for _ in range(n)])
self.boundary = f"----WebKitFormBoundary{random_str(16)}"
self.headers['Content-Type'] = f'multipart/form-data; boundary={self.boundary}'
self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
else:
self.body = urlencode(body)
elif node_data.body.type in ['json', 'raw-text']:
elif node_data.body.type in {"json", "raw-text"}:
self.body = body_data
elif node_data.body.type == 'none':
self.body = ''
elif node_data.body.type == "none":
self.body = ""
self.variable_selectors = (
server_url_variable_selectors
@ -202,23 +202,23 @@ class HttpExecutor:
def _assembling_headers(self) -> dict[str, Any]:
authorization = deepcopy(self.authorization)
headers = deepcopy(self.headers) or {}
if self.authorization.type == 'api-key':
if self.authorization.type == "api-key":
if self.authorization.config is None:
raise ValueError('self.authorization config is required')
raise ValueError("self.authorization config is required")
if authorization.config is None:
raise ValueError('authorization config is required')
raise ValueError("authorization config is required")
if self.authorization.config.api_key is None:
raise ValueError('api_key is required')
raise ValueError("api_key is required")
if not authorization.config.header:
authorization.config.header = 'Authorization'
authorization.config.header = "Authorization"
if self.authorization.config.type == 'bearer':
headers[authorization.config.header] = f'Bearer {authorization.config.api_key}'
elif self.authorization.config.type == 'basic':
headers[authorization.config.header] = f'Basic {authorization.config.api_key}'
elif self.authorization.config.type == 'custom':
if self.authorization.config.type == "bearer":
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
elif self.authorization.config.type == "basic":
headers[authorization.config.header] = f"Basic {authorization.config.api_key}"
elif self.authorization.config.type == "custom":
headers[authorization.config.header] = authorization.config.api_key
return headers
@ -230,10 +230,13 @@ class HttpExecutor:
if isinstance(response, httpx.Response):
executor_response = HttpExecutorResponse(response)
else:
raise ValueError(f'Invalid response type {type(response)}')
raise ValueError(f"Invalid response type {type(response)}")
threshold_size = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE if executor_response.is_file \
threshold_size = (
dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE
if executor_response.is_file
else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
)
if executor_response.size > threshold_size:
raise ValueError(
f'{"File" if executor_response.is_file else "Text"} size is too large,'
@ -248,17 +251,17 @@ class HttpExecutor:
do http request depending on api bundle
"""
kwargs = {
'url': self.server_url,
'headers': headers,
'params': self.params,
'timeout': (self.timeout.connect, self.timeout.read, self.timeout.write),
'follow_redirects': True,
"url": self.server_url,
"headers": headers,
"params": self.params,
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
"follow_redirects": True,
}
if self.method in ('get', 'head', 'post', 'put', 'delete', 'patch'):
if self.method in {"get", "head", "post", "put", "delete", "patch"}:
response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs)
else:
raise ValueError(f'Invalid http method {self.method}')
raise ValueError(f"Invalid http method {self.method}")
return response
def invoke(self) -> HttpExecutorResponse:
@ -280,15 +283,15 @@ class HttpExecutor:
"""
server_url = self.server_url
if self.params:
server_url += f'?{urlencode(self.params)}'
server_url += f"?{urlencode(self.params)}"
raw_request = f'{self.method.upper()} {server_url} HTTP/1.1\n'
raw_request = f"{self.method.upper()} {server_url} HTTP/1.1\n"
headers = self._assembling_headers()
for k, v in headers.items():
# get authorization header
if self.authorization.type == 'api-key':
authorization_header = 'Authorization'
if self.authorization.type == "api-key":
authorization_header = "Authorization"
if self.authorization.config and self.authorization.config.header:
authorization_header = self.authorization.config.header
@ -296,21 +299,21 @@ class HttpExecutor:
raw_request += f'{k}: {"*" * len(v)}\n'
continue
raw_request += f'{k}: {v}\n'
raw_request += f"{k}: {v}\n"
raw_request += '\n'
raw_request += "\n"
# if files, use multipart/form-data with boundary
if self.files:
boundary = self.boundary
raw_request += f'--{boundary}'
raw_request += f"--{boundary}"
for k, v in self.files.items():
raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n'
raw_request += f'{v[1]}\n'
raw_request += f'--{boundary}'
raw_request += '--'
raw_request += f"{v[1]}\n"
raw_request += f"--{boundary}"
raw_request += "--"
else:
raw_request += self.body or ''
raw_request += self.body or ""
return raw_request
@ -328,9 +331,9 @@ class HttpExecutor:
for variable_selector in variable_selectors:
variable = variable_pool.get_any(variable_selector.value_selector)
if variable is None:
raise ValueError(f'Variable {variable_selector.variable} not found')
raise ValueError(f"Variable {variable_selector.variable} not found")
if escape_quotes and isinstance(variable, str):
value = variable.replace('"', '\\"').replace('\n', '\\n')
value = variable.replace('"', '\\"').replace("\n", "\\n")
else:
value = variable
variable_value_mapping[variable_selector.variable] = value

View File

@ -31,18 +31,18 @@ class HttpRequestNode(BaseNode):
@classmethod
def get_default_config(cls, filters: dict | None = None) -> dict:
return {
'type': 'http-request',
'config': {
'method': 'get',
'authorization': {
'type': 'no-auth',
"type": "http-request",
"config": {
"method": "get",
"authorization": {
"type": "no-auth",
},
'body': {'type': 'none'},
'timeout': {
"body": {"type": "none"},
"timeout": {
**HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(),
'max_connect_timeout': dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
'max_read_timeout': dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
'max_write_timeout': dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
"max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
"max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
},
},
}
@ -52,9 +52,8 @@ class HttpRequestNode(BaseNode):
# TODO: Switch to use segment directly
if node_data.authorization.config and node_data.authorization.config.api_key:
node_data.authorization.config.api_key = parser.convert_template(
template=node_data.authorization.config.api_key,
variable_pool=self.graph_runtime_state.variable_pool
).text
template=node_data.authorization.config.api_key, variable_pool=self.graph_runtime_state.variable_pool
).text
# init http executor
http_executor = None
@ -62,7 +61,7 @@ class HttpRequestNode(BaseNode):
http_executor = HttpExecutor(
node_data=node_data,
timeout=self._get_request_timeout(node_data),
variable_pool=self.graph_runtime_state.variable_pool
variable_pool=self.graph_runtime_state.variable_pool,
)
# invoke http executor
@ -71,7 +70,7 @@ class HttpRequestNode(BaseNode):
process_data = {}
if http_executor:
process_data = {
'request': http_executor.to_raw_request(),
"request": http_executor.to_raw_request(),
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -84,13 +83,13 @@ class HttpRequestNode(BaseNode):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'status_code': response.status_code,
'body': response.content if not files else '',
'headers': response.headers,
'files': files,
"status_code": response.status_code,
"body": response.content if not files else "",
"headers": response.headers,
"files": files,
},
process_data={
'request': http_executor.to_raw_request(),
"request": http_executor.to_raw_request(),
},
)
@ -107,10 +106,7 @@ class HttpRequestNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: HttpRequestNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: HttpRequestNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -126,11 +122,11 @@ class HttpRequestNode(BaseNode):
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
return variable_mapping
except Exception as e:
logging.exception(f'Failed to extract variable selector to variable mapping: {e}')
logging.exception(f"Failed to extract variable selector to variable mapping: {e}")
return {}
def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]:
@ -144,7 +140,7 @@ class HttpRequestNode(BaseNode):
# extract filename from url
filename = path.basename(url)
# extract extension if possible
extension = guess_extension(mimetype) or '.bin'
extension = guess_extension(mimetype) or ".bin"
tool_file = ToolFileManager.create_file_by_raw(
user_id=self.user_id,

View File

@ -15,6 +15,7 @@ class IfElseNodeData(BaseNodeData):
"""
Case entity representing a single logical condition group
"""
case_id: str
logical_operator: Literal["and", "or"]
conditions: list[Condition]

View File

@ -20,13 +20,9 @@ class IfElseNode(BaseNode):
node_data = self.node_data
node_data = cast(IfElseNodeData, node_data)
node_inputs: dict[str, list] = {
"conditions": []
}
node_inputs: dict[str, list] = {"conditions": []}
process_datas: dict[str, list] = {
"condition_results": []
}
process_datas: dict[str, list] = {"condition_results": []}
input_conditions = []
final_result = False
@ -37,8 +33,7 @@ class IfElseNode(BaseNode):
if node_data.cases:
for case in node_data.cases:
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions
variable_pool=self.graph_runtime_state.variable_pool, conditions=case.conditions
)
# Apply the logical operator for the current case
@ -60,8 +55,7 @@ class IfElseNode(BaseNode):
else:
# Fallback to old structure if cases are not defined
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=node_data.conditions
variable_pool=self.graph_runtime_state.variable_pool, conditions=node_data.conditions
)
final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result)
@ -69,21 +63,14 @@ class IfElseNode(BaseNode):
selected_case_id = "true" if final_result else "false"
process_datas["condition_results"].append(
{
"group": "default",
"results": group_result,
"final_result": final_result
}
{"group": "default", "results": group_result, "final_result": final_result}
)
node_inputs["conditions"] = input_conditions
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=node_inputs,
process_data=process_datas,
error=str(e)
status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_datas, error=str(e)
)
outputs = {"result": final_result, "selected_case_id": selected_case_id}
@ -92,18 +79,15 @@ class IfElseNode(BaseNode):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_datas,
edge_source_handle=selected_case_id if selected_case_id else "false", # Use case ID or 'default'
outputs=outputs
edge_source_handle=selected_case_id or "false", # Use case ID or 'default'
outputs=outputs,
)
return data
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IfElseNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IfElseNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

View File

@ -7,21 +7,25 @@ class IterationNodeData(BaseIterationNodeData):
"""
Iteration Node Data.
"""
parent_loop_id: Optional[str] = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
parent_loop_id: Optional[str] = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
class IterationStartNodeData(BaseNodeData):
"""
Iteration Start Node Data.
"""
pass
class IterationState(BaseIterationState):
"""
Iteration State.
"""
outputs: list[Any] = None
current_output: Optional[Any] = None
@ -29,6 +33,7 @@ class IterationState(BaseIterationState):
"""
Data.
"""
iterator_length: int
def get_last_output(self) -> Optional[Any]:
@ -38,9 +43,9 @@ class IterationState(BaseIterationState):
if self.outputs:
return self.outputs[-1]
return None
def get_current_output(self) -> Optional[Any]:
"""
Get current output.
"""
return self.current_output
return self.current_output

View File

@ -16,6 +16,7 @@ from core.workflow.graph_engine.entities.event import (
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
@ -33,6 +34,7 @@ class IterationNode(BaseNode):
"""
Iteration Node.
"""
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
@ -45,31 +47,26 @@ class IterationNode(BaseNode):
if not iterator_list_segment:
raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found")
iterator_list_value = iterator_list_segment.to_object()
if not isinstance(iterator_list_value, list):
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
inputs = {
"iterator_selector": iterator_list_value
}
inputs = {"iterator_selector": iterator_list_value}
graph_config = self.graph_config
if not self.node_data.start_node_id:
raise ValueError(f'field start_node_id in iteration {self.node_id} not found')
raise ValueError(f"field start_node_id in iteration {self.node_id} not found")
root_node_id = self.node_data.start_node_id
# init graph
iteration_graph = Graph.init(
graph_config=graph_config,
root_node_id=root_node_id
)
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
if not iteration_graph:
raise ValueError('iteration graph not found')
raise ValueError("iteration graph not found")
leaf_node_ids = iteration_graph.get_leaf_node_ids()
iteration_leaf_node_ids = []
@ -97,26 +94,21 @@ class IterationNode(BaseNode):
Condition(
variable_selector=[self.node_id, "index"],
comparison_operator="<",
value=str(len(iterator_list_value))
value=str(len(iterator_list_value)),
)
]
)
],
),
)
variable_pool = self.graph_runtime_state.variable_pool
# append iteration variable (item, index) to variable pool
variable_pool.add(
[self.node_id, 'index'],
0
)
variable_pool.add(
[self.node_id, 'item'],
iterator_list_value[0]
)
variable_pool.add([self.node_id, "index"], 0)
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine
from core.workflow.graph_engine.graph_engine import GraphEngine
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
@ -130,7 +122,7 @@ class IterationNode(BaseNode):
graph_config=graph_config,
variable_pool=variable_pool,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
)
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
@ -142,10 +134,8 @@ class IterationNode(BaseNode):
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
metadata={
"iterator_length": len(iterator_list_value)
},
predecessor_node_id=self.previous_node_id
metadata={"iterator_length": len(iterator_list_value)},
predecessor_node_id=self.previous_node_id,
)
yield IterationRunNextEvent(
@ -154,7 +144,7 @@ class IterationNode(BaseNode):
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=0,
pre_iteration_output=None
pre_iteration_output=None,
)
outputs: list[Any] = []
@ -165,7 +155,11 @@ class IterationNode(BaseNode):
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START:
if (
isinstance(event, BaseNodeEvent)
and event.node_type == NodeType.ITERATION_START
and not isinstance(event, NodeRunStreamChunkEvent)
):
continue
if isinstance(event, NodeRunSucceededEvent):
@ -176,7 +170,9 @@ class IterationNode(BaseNode):
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any(
[self.node_id, "index"]
)
event.route_node_state.node_run_result.metadata = metadata
yield event
@ -192,21 +188,15 @@ class IterationNode(BaseNode):
variable_pool.remove_node(node_id)
# move to next iteration
current_index = variable_pool.get([self.node_id, 'index'])
current_index = variable_pool.get([self.node_id, "index"])
if current_index is None:
raise ValueError(f'iteration {self.node_id} current index not found')
raise ValueError(f"iteration {self.node_id} current index not found")
next_index = int(current_index.to_object()) + 1
variable_pool.add(
[self.node_id, 'index'],
next_index
)
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add(
[self.node_id, 'item'],
iterator_list_value[next_index]
)
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
@ -214,8 +204,9 @@ class IterationNode(BaseNode):
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
pre_iteration_output=jsonable_encoder(
current_iteration_output) if current_iteration_output else None
pre_iteration_output=jsonable_encoder(current_iteration_output)
if current_iteration_output
else None,
)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
@ -227,13 +218,9 @@ class IterationNode(BaseNode):
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
},
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
@ -255,21 +242,14 @@ class IterationNode(BaseNode):
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
}
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'output': jsonable_encoder(outputs)
}
status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)}
)
)
except Exception as e:
@ -282,16 +262,11 @@ class IterationNode(BaseNode):
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
},
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
@ -301,15 +276,12 @@ class IterationNode(BaseNode):
)
finally:
# remove iteration variable (item, index) from variable pool after iteration run completed
variable_pool.remove([self.node_id, 'index'])
variable_pool.remove([self.node_id, 'item'])
variable_pool.remove([self.node_id, "index"])
variable_pool.remove([self.node_id, "item"])
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IterationNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -319,36 +291,33 @@ class IterationNode(BaseNode):
:return:
"""
variable_mapping = {
f'{node_id}.input_selector': node_data.iterator_selector,
f"{node_id}.input_selector": node_data.iterator_selector,
}
# init graph
iteration_graph = Graph.init(
graph_config=graph_config,
root_node_id=node_data.start_node_id
)
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
if not iteration_graph:
raise ValueError('iteration graph not found')
raise ValueError("iteration graph not found")
for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
if sub_node_config.get('data', {}).get('iteration_id') != node_id:
if sub_node_config.get("data", {}).get("iteration_id") != node_id:
continue
# variable selector to variable mapping
try:
# Get node class
from core.workflow.nodes.node_mapping import node_classes
node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type'))
node_type = NodeType.value_of(sub_node_config.get("data", {}).get("type"))
node_cls = node_classes.get(node_type)
if not node_cls:
continue
node_cls = cast(BaseNode, node_cls)
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config,
config=sub_node_config
graph_config=graph_config, config=sub_node_config
)
sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
except NotImplementedError:
@ -356,7 +325,8 @@ class IterationNode(BaseNode):
# remove iteration variables
sub_node_variable_mapping = {
sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items()
sub_node_id + "." + key: value
for key, value in sub_node_variable_mapping.items()
if value[0] != node_id
}
@ -364,8 +334,7 @@ class IterationNode(BaseNode):
# remove variable out from iteration
variable_mapping = {
key: value for key, value in variable_mapping.items()
if value[0] not in iteration_graph.node_ids
key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids
}
return variable_mapping

View File

@ -11,6 +11,7 @@ class IterationStartNode(BaseNode):
"""
Iteration Start Node.
"""
_node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START
@ -18,16 +19,11 @@ class IterationStartNode(BaseNode):
"""
Run the node.
"""
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IterationNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

View File

@ -9,6 +9,7 @@ class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
provider: str
model: str
@ -17,6 +18,7 @@ class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
@ -26,6 +28,7 @@ class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
@ -33,6 +36,7 @@ class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting
@ -41,17 +45,20 @@ class MultipleRetrievalConfig(BaseModel):
"""
Multiple Retrieval Config.
"""
top_k: int
score_threshold: Optional[float] = None
reranking_mode: str = 'reranking_model'
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: Optional[RerankingModelConfig] = None
weights: Optional[WeightedScoreConfig] = None
class ModelConfig(BaseModel):
"""
Model Config.
Model Config.
"""
provider: str
name: str
mode: str
@ -62,6 +69,7 @@ class SingleRetrievalConfig(BaseModel):
"""
Single Retrieval Config.
"""
model: ModelConfig
@ -69,9 +77,10 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
type: str = 'knowledge-retrieval'
type: str = "knowledge-retrieval"
query_variable_selector: list[str]
dataset_ids: list[str]
retrieval_mode: Literal['single', 'multiple']
retrieval_mode: Literal["single", "multiple"]
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
single_retrieval_config: Optional[SingleRetrievalConfig] = None

View File

@ -24,14 +24,11 @@ from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__)
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
@ -45,62 +42,47 @@ class KnowledgeRetrievalNode(BaseNode):
# extract variables
variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector)
query = variable
variables = {
'query': query
}
variables = {"query": query}
if not query:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error="Query is required."
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
)
# retrieve knowledge
try:
results = self._fetch_dataset_retriever(
node_data=node_data, query=query
)
outputs = {
'result': results
}
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
outputs = {"result": results}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=None,
outputs=outputs
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
)
except Exception as e:
logger.exception("Error when running knowledge retrieval node")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e)
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[
dict[str, Any]]:
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
available_datasets = []
dataset_ids = node_data.dataset_ids
# Subquery: Count the number of available documents for each dataset
subquery = db.session.query(
Document.dataset_id,
func.count(Document.id).label('available_document_count')
).filter(
Document.indexing_status == 'completed',
Document.enabled == True,
Document.archived == False,
Document.dataset_id.in_(dataset_ids)
).group_by(Document.dataset_id).having(
func.count(Document.id) > 0
).subquery()
subquery = (
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
.filter(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.dataset_id.in_(dataset_ids),
)
.group_by(Document.dataset_id)
.having(func.count(Document.id) > 0)
.subquery()
)
results = db.session.query(Dataset).join(
subquery, Dataset.id == subquery.c.dataset_id
).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id.in_(dataset_ids)
).all()
results = (
db.session.query(Dataset)
.join(subquery, Dataset.id == subquery.c.dataset_id)
.filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
.all()
)
for dataset in results:
# pass if dataset is not available
@ -117,16 +99,14 @@ class KnowledgeRetrievalNode(BaseNode):
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model_config.model,
credentials=model_config.credentials
model=model_config.model, credentials=model_config.credentials
)
if model_schema:
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
all_documents = dataset_retrieval.single_retrieve(
available_datasets=available_datasets,
@ -137,111 +117,108 @@ class KnowledgeRetrievalNode(BaseNode):
query=query,
model_config=model_config,
model_instance=model_instance,
planning_strategy=planning_strategy
planning_strategy=planning_strategy,
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
if node_data.multiple_retrieval_config.reranking_mode == 'reranking_model':
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
reranking_model = {
'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model.provider,
'reranking_model_name': node_data.multiple_retrieval_config.reranking_model.model
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
weights = None
elif node_data.multiple_retrieval_config.reranking_mode == 'weighted_score':
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
reranking_model = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
'vector_setting': {
"vector_weight": node_data.multiple_retrieval_config.weights.vector_setting.vector_weight,
"embedding_provider_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_provider_name,
"embedding_model_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_model_name,
"vector_setting": {
"vector_weight": vector_setting.vector_weight,
"embedding_provider_name": vector_setting.embedding_provider_name,
"embedding_model_name": vector_setting.embedding_model_name,
},
'keyword_setting': {
"keyword_setting": {
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
}
},
}
else:
reranking_model = None
weights = None
all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id,
self.user_from.value,
available_datasets, query,
node_data.multiple_retrieval_config.top_k,
node_data.multiple_retrieval_config.score_threshold,
node_data.multiple_retrieval_config.reranking_mode,
reranking_model,
weights,
node_data.multiple_retrieval_config.reranking_enable,
)
all_documents = dataset_retrieval.multiple_retrieve(
self.app_id,
self.tenant_id,
self.user_id,
self.user_from.value,
available_datasets,
query,
node_data.multiple_retrieval_config.top_k,
node_data.multiple_retrieval_config.score_threshold,
node_data.multiple_retrieval_config.reranking_mode,
reranking_model,
weights,
node_data.multiple_retrieval_config.reranking_enable,
)
context_list = []
if all_documents:
document_score_list = {}
page_number_list = {}
for item in all_documents:
if item.metadata.get('score'):
document_score_list[item.metadata['doc_id']] = item.metadata['score']
# both 'page' and 'score' are metadata fields
if item.metadata.get('page'):
page_number_list[item.metadata['doc_id']] = item.metadata['page']
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
index_node_ids = [document.metadata["doc_id"] for document in all_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == 'completed',
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids)
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(segments,
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
float('inf')))
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
dataset = Dataset.query.filter_by(
id=segment.dataset_id
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
document = Document.query.filter(Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
resource_number = 1
if dataset and document:
source = {
'metadata': {
'_source': 'knowledge',
'position': resource_number,
'dataset_id': dataset.id,
'dataset_name': dataset.name,
'document_id': document.id,
'document_name': document.name,
'document_data_source_type': document.data_source_type,
'page': page_number_list.get(segment.index_node_id, None),
'segment_id': segment.id,
'retriever_from': 'workflow',
'score': document_score_list.get(segment.index_node_id, None),
'segment_hit_count': segment.hit_count,
'segment_word_count': segment.word_count,
'segment_position': segment.position,
'segment_index_node_hash': segment.index_node_hash,
"metadata": {
"_source": "knowledge",
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"document_data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": "workflow",
"score": document_score_list.get(segment.index_node_id, None),
"segment_hit_count": segment.hit_count,
"segment_word_count": segment.word_count,
"segment_position": segment.position,
"segment_index_node_hash": segment.index_node_hash,
},
'title': document.name
"title": document.name,
}
if segment.answer:
source['content'] = f'question:{segment.get_sign_content()} \nanswer:{segment.answer}'
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
else:
source['content'] = segment.get_sign_content()
source["content"] = segment.get_sign_content()
context_list.append(source)
resource_number += 1
return context_list
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: KnowledgeRetrievalNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: KnowledgeRetrievalNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -251,11 +228,12 @@ class KnowledgeRetrievalNode(BaseNode):
:return:
"""
variable_mapping = {}
variable_mapping[node_id + '.query'] = node_data.query_variable_selector
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
return variable_mapping
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[
ModelInstance, ModelConfigWithCredentialsEntity]:
def _fetch_model_config(
self, node_data: KnowledgeRetrievalNodeData
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data: node data
@ -266,10 +244,7 @@ class KnowledgeRetrievalNode(BaseNode):
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
provider=provider_name,
model=model_name
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
)
provider_model_bundle = model_instance.provider_model_bundle
@ -280,8 +255,7 @@ class KnowledgeRetrievalNode(BaseNode):
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name,
model_type=ModelType.LLM
model=model_name, model_type=ModelType.LLM
)
if provider_model is None:
@ -297,19 +271,16 @@ class KnowledgeRetrievalNode(BaseNode):
# model config
completion_params = node_data.single_retrieval_config.model.completion_params
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = node_data.single_retrieval_config.model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(
model_name,
model_credentials
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")

View File

@ -11,6 +11,7 @@ class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
@ -21,6 +22,7 @@ class ContextConfig(BaseModel):
"""
Context Config.
"""
enabled: bool
variable_selector: Optional[list[str]] = None
@ -29,37 +31,47 @@ class VisionConfig(BaseModel):
"""
Vision Config.
"""
class Configs(BaseModel):
"""
Configs.
"""
detail: Literal['low', 'high']
detail: Literal["low", "high"]
enabled: bool
configs: Optional[Configs] = None
class PromptConfig(BaseModel):
"""
Prompt Config.
"""
jinja2_variables: Optional[list[VariableSelector]] = None
class LLMNodeChatModelMessage(ChatModelMessage):
"""
LLM Node Chat Model Message.
"""
jinja2_text: Optional[str] = None
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
"""
LLM Node Chat Model Prompt Template.
"""
jinja2_text: Optional[str] = None
class LLMNodeData(BaseNodeData):
"""
LLM Node Data.
"""
model: ModelConfig
prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
prompt_config: Optional[PromptConfig] = None

View File

@ -45,11 +45,11 @@ if TYPE_CHECKING:
from core.file.file_obj import FileVar
class ModelInvokeCompleted(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage
finish_reason: Optional[str] = None
@ -89,7 +89,7 @@ class LLMNode(BaseNode):
files = self._fetch_files(node_data, variable_pool)
if files:
node_inputs['#files#'] = [file.to_dict() for file in files]
node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value
generator = self._fetch_context(node_data, variable_pool)
@ -100,7 +100,7 @@ class LLMNode(BaseNode):
yield event
if context:
node_inputs['#context#'] = context # type: ignore
node_inputs["#context#"] = context # type: ignore
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
@ -111,24 +111,22 @@ class LLMNode(BaseNode):
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data,
query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value])
if node_data.memory else None,
query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
inputs=inputs,
files=files,
context=context,
memory=memory,
model_config=model_config
model_config=model_config,
)
process_data = {
'model_mode': model_config.mode,
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode,
prompt_messages=prompt_messages
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
),
'model_provider': model_config.provider,
'model_name': model_config.model,
"model_provider": model_config.provider,
"model_name": model_config.model,
}
# handle invoke result
@ -136,10 +134,10 @@ class LLMNode(BaseNode):
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
stop=stop,
)
result_text = ''
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
for event in generator:
@ -156,16 +154,12 @@ class LLMNode(BaseNode):
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data
process_data=process_data,
)
)
return
outputs = {
'text': result_text,
'usage': jsonable_encoder(usage),
'finish_reason': finish_reason
}
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
yield RunCompletedEvent(
run_result=NodeRunResult(
@ -176,17 +170,19 @@ class LLMNode(BaseNode):
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
NodeRunMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage
llm_usage=usage,
)
)
def _invoke_llm(self, node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None) \
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
def _invoke_llm(
self,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Invoke large language model
:param node_data_model: node data model
@ -206,9 +202,7 @@ class LLMNode(BaseNode):
)
# handle invoke result
generator = self._handle_invoke_result(
invoke_result=invoke_result
)
generator = self._handle_invoke_result(invoke_result=invoke_result)
usage = LLMUsage.empty_usage()
for event in generator:
@ -219,8 +213,9 @@ class LLMNode(BaseNode):
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
def _handle_invoke_result(
self, invoke_result: LLMResult | Generator
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Handle invoke result
:param invoke_result: invoke result
@ -231,17 +226,14 @@ class LLMNode(BaseNode):
model = None
prompt_messages: list[PromptMessage] = []
full_text = ''
full_text = ""
usage = None
finish_reason = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
yield RunStreamChunkEvent(
chunk_content=text,
from_variable_selector=[self.node_id, 'text']
)
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
if not model:
model = result.model
@ -258,15 +250,11 @@ class LLMNode(BaseNode):
if not usage:
usage = LLMUsage.empty_usage()
yield ModelInvokeCompleted(
text=full_text,
usage=usage,
finish_reason=finish_reason
)
yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason)
def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
def _transform_chat_messages(
self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
"""
Transform chat messages
@ -275,13 +263,13 @@ class LLMNode(BaseNode):
"""
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
if messages.edition_type == 'jinja2' and messages.jinja2_text:
if messages.edition_type == "jinja2" and messages.jinja2_text:
messages.text = messages.jinja2_text
return messages
for message in messages:
if message.edition_type == 'jinja2' and message.jinja2_text:
if message.edition_type == "jinja2" and message.jinja2_text:
message.text = message.jinja2_text
return messages
@ -300,17 +288,15 @@ class LLMNode(BaseNode):
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable = variable_selector.variable
value = variable_pool.get_any(
variable_selector.value_selector
)
value = variable_pool.get_any(variable_selector.value_selector)
def parse_dict(d: dict) -> str:
"""
Parse dict into string
"""
# check if it's a context structure
if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
return d['content']
if "metadata" in d and "_source" in d["metadata"] and "content" in d:
return d["content"]
# else, parse the dict
try:
@ -321,7 +307,7 @@ class LLMNode(BaseNode):
if isinstance(value, str):
value = value
elif isinstance(value, list):
result = ''
result = ""
for item in value:
if isinstance(item, dict):
result += parse_dict(item)
@ -331,7 +317,7 @@ class LLMNode(BaseNode):
result += str(item)
else:
result += str(item)
result += '\n'
result += "\n"
value = result.strip()
elif isinstance(value, dict):
value = parse_dict(value)
@ -366,18 +352,19 @@ class LLMNode(BaseNode):
for variable_selector in variable_selectors:
variable_value = variable_pool.get_any(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f'Variable {variable_selector.variable} not found')
raise ValueError(f"Variable {variable_selector.variable} not found")
inputs[variable_selector.variable] = variable_value
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
.extract_variable_selectors())
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
).extract_variable_selectors()
for variable_selector in query_variable_selectors:
variable_value = variable_pool.get_any(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f'Variable {variable_selector.variable} not found')
raise ValueError(f"Variable {variable_selector.variable} not found")
inputs[variable_selector.variable] = variable_value
@ -393,7 +380,7 @@ class LLMNode(BaseNode):
if not node_data.vision.enabled:
return []
files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value])
files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value])
if not files:
return []
@ -415,29 +402,25 @@ class LLMNode(BaseNode):
context_value = variable_pool.get_any(node_data.context.variable_selector)
if context_value:
if isinstance(context_value, str):
yield RunRetrieverResourceEvent(
retriever_resources=[],
context=context_value
)
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value)
elif isinstance(context_value, list):
context_str = ''
context_str = ""
original_retriever_resource = []
for item in context_value:
if isinstance(item, str):
context_str += item + '\n'
context_str += item + "\n"
else:
if 'content' not in item:
raise ValueError(f'Invalid context structure: {item}')
if "content" not in item:
raise ValueError(f"Invalid context structure: {item}")
context_str += item['content'] + '\n'
context_str += item["content"] + "\n"
retriever_resource = self._convert_to_original_retriever_resource(item)
if retriever_resource:
original_retriever_resource.append(retriever_resource)
yield RunRetrieverResourceEvent(
retriever_resources=original_retriever_resource,
context=context_str.strip()
retriever_resources=original_retriever_resource, context=context_str.strip()
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
@ -446,34 +429,37 @@ class LLMNode(BaseNode):
:param context_dict: context dict
:return:
"""
if ('metadata' in context_dict and '_source' in context_dict['metadata']
and context_dict['metadata']['_source'] == 'knowledge'):
metadata = context_dict.get('metadata', {})
if (
"metadata" in context_dict
and "_source" in context_dict["metadata"]
and context_dict["metadata"]["_source"] == "knowledge"
):
metadata = context_dict.get("metadata", {})
source = {
'position': metadata.get('position'),
'dataset_id': metadata.get('dataset_id'),
'dataset_name': metadata.get('dataset_name'),
'document_id': metadata.get('document_id'),
'document_name': metadata.get('document_name'),
'data_source_type': metadata.get('document_data_source_type'),
'segment_id': metadata.get('segment_id'),
'retriever_from': metadata.get('retriever_from'),
'score': metadata.get('score'),
'hit_count': metadata.get('segment_hit_count'),
'word_count': metadata.get('segment_word_count'),
'segment_position': metadata.get('segment_position'),
'index_node_hash': metadata.get('segment_index_node_hash'),
'content': context_dict.get('content'),
'page': metadata.get('page'),
"position": metadata.get("position"),
"dataset_id": metadata.get("dataset_id"),
"dataset_name": metadata.get("dataset_name"),
"document_id": metadata.get("document_id"),
"document_name": metadata.get("document_name"),
"data_source_type": metadata.get("document_data_source_type"),
"segment_id": metadata.get("segment_id"),
"retriever_from": metadata.get("retriever_from"),
"score": metadata.get("score"),
"hit_count": metadata.get("segment_hit_count"),
"word_count": metadata.get("segment_word_count"),
"segment_position": metadata.get("segment_position"),
"index_node_hash": metadata.get("segment_index_node_hash"),
"content": context_dict.get("content"),
}
return source
return None
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[
ModelInstance, ModelConfigWithCredentialsEntity]:
def _fetch_model_config(
self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data_model: node data model
@ -484,10 +470,7 @@ class LLMNode(BaseNode):
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
provider=provider_name,
model=model_name
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
)
provider_model_bundle = model_instance.provider_model_bundle
@ -498,8 +481,7 @@ class LLMNode(BaseNode):
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name,
model_type=ModelType.LLM
model=model_name, model_type=ModelType.LLM
)
if provider_model is None:
@ -515,19 +497,16 @@ class LLMNode(BaseNode):
# model config
completion_params = node_data_model.completion_params
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = node_data_model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(
model_name,
model_credentials
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
@ -543,9 +522,9 @@ class LLMNode(BaseNode):
stop=stop,
)
def _fetch_memory(self, node_data_memory: Optional[MemoryConfig],
variable_pool: VariablePool,
model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
def _fetch_memory(
self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance
) -> Optional[TokenBufferMemory]:
"""
Fetch memory
:param node_data_memory: node data memory
@ -556,35 +535,35 @@ class LLMNode(BaseNode):
return None
# get conversation id
conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value])
conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value])
if conversation_id is None:
return None
# get conversation
conversation = db.session.query(Conversation).filter(
Conversation.app_id == self.app_id,
Conversation.id == conversation_id
).first()
conversation = (
db.session.query(Conversation)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
.first()
)
if not conversation:
return None
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def _fetch_prompt_messages(self, node_data: LLMNodeData,
query: Optional[str],
query_prompt_template: Optional[str],
inputs: dict[str, str],
files: list["FileVar"],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
def _fetch_prompt_messages(
self,
node_data: LLMNodeData,
query: Optional[str],
query_prompt_template: Optional[str],
inputs: dict[str, str],
files: list["FileVar"],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Fetch prompt messages
:param node_data: node data
@ -601,7 +580,7 @@ class LLMNode(BaseNode):
prompt_messages = prompt_transform.get_prompt(
prompt_template=node_data.prompt_template,
inputs=inputs,
query=query if query else '',
query=query or "",
files=files,
context=context,
memory_config=node_data.memory,
@ -621,8 +600,11 @@ class LLMNode(BaseNode):
if not isinstance(prompt_message.content, str):
prompt_message_content = []
for content_item in prompt_message.content:
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(
content_item, ImagePromptMessageContent):
if (
vision_enabled
and content_item.type == PromptMessageContentType.IMAGE
and isinstance(content_item, ImagePromptMessageContent)
):
# Override vision config if LLM node has vision config
if vision_detail:
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
@ -632,15 +614,18 @@ class LLMNode(BaseNode):
if len(prompt_message_content) > 1:
prompt_message.content = prompt_message_content
elif (len(prompt_message_content) == 1
and prompt_message_content[0].type == PromptMessageContentType.TEXT):
elif (
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
):
prompt_message.content = prompt_message_content[0].data
filtered_prompt_messages.append(prompt_message)
if not filtered_prompt_messages:
raise ValueError("No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding.")
raise ValueError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
return filtered_prompt_messages, stop
@ -678,7 +663,7 @@ class LLMNode(BaseNode):
elif quota_unit == QuotaUnit.CREDITS:
used_quota = 1
if 'gpt-4' in model_instance.model:
if "gpt-4" in model_instance.model:
used_quota = 20
else:
used_quota = 1
@ -689,16 +674,13 @@ class LLMNode(BaseNode):
Provider.provider_name == model_instance.provider,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + used_quota})
Provider.quota_limit > Provider.quota_used,
).update({"quota_used": Provider.quota_used + used_quota})
db.session.commit()
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: LLMNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -712,11 +694,11 @@ class LLMNode(BaseNode):
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
if prompt.edition_type != 'jinja2':
if prompt.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
else:
if prompt_template.edition_type != 'jinja2':
if prompt_template.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
@ -726,39 +708,38 @@ class LLMNode(BaseNode):
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
.extract_variable_selectors())
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
).extract_variable_selectors()
for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if node_data.context.enabled:
variable_mapping['#context#'] = node_data.context.variable_selector
variable_mapping["#context#"] = node_data.context.variable_selector
if node_data.vision.enabled:
variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value]
variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value]
if node_data.memory:
variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value]
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
if node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, list):
for prompt in prompt_template:
if prompt.edition_type == 'jinja2':
if prompt.edition_type == "jinja2":
enable_jinja = True
break
else:
if prompt_template.edition_type == 'jinja2':
if prompt_template.edition_type == "jinja2":
enable_jinja = True
if enable_jinja:
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {
node_id + '.' + key: value for key, value in variable_mapping.items()
}
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
return variable_mapping
@ -775,26 +756,19 @@ class LLMNode(BaseNode):
"prompt_templates": {
"chat_model": {
"prompts": [
{
"role": "system",
"text": "You are a helpful AI assistant.",
"edition_type": "basic"
}
{"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"}
]
},
"completion_model": {
"conversation_histories_role": {
"user_prefix": "Human",
"assistant_prefix": "Assistant"
},
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
"prompt": {
"text": "Here is the chat histories between human and assistant, inside "
"<histories></histories> XML tags.\n\n<histories>\n{{"
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
"edition_type": "basic"
"<histories></histories> XML tags.\n\n<histories>\n{{"
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
"edition_type": "basic",
},
"stop": ["Human:"]
}
"stop": ["Human:"],
},
}
}
},
}

View File

@ -1,4 +1,3 @@
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
@ -7,7 +6,8 @@ class LoopNodeData(BaseIterationNodeData):
Loop Node Data.
"""
class LoopState(BaseIterationState):
"""
Loop State.
"""
"""

View File

@ -10,6 +10,7 @@ class LoopNode(BaseNode):
"""
Loop Node.
"""
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
@ -21,14 +22,16 @@ class LoopNode(BaseNode):
"""
Get conditions.
"""
node_id = node_config.get('id')
node_id = node_config.get("id")
if not node_id:
return []
# TODO waiting for implementation
return [Condition(
variable_selector=[node_id, 'index'],
comparison_operator="",
value_type="value_selector",
value_selector=[]
)]
return [
Condition(
variable_selector=[node_id, "index"],
comparison_operator="",
value_type="value_selector",
value_selector=[],
)
]

View File

@ -8,47 +8,52 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
class ModelConfig(BaseModel):
"""
Model Config.
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
class ParameterConfig(BaseModel):
"""
Parameter Config.
"""
name: str
type: Literal['string', 'number', 'bool', 'select', 'array[string]', 'array[number]', 'array[object]']
type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"]
options: Optional[list[str]] = None
description: str
required: bool
@field_validator('name', mode='before')
@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value) -> str:
if not value:
raise ValueError('Parameter name is required')
if value in ['__reason', '__is_success']:
raise ValueError('Invalid parameter name, __reason and __is_success are reserved')
raise ValueError("Parameter name is required")
if value in {"__reason", "__is_success"}:
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
return value
class ParameterExtractorNodeData(BaseNodeData):
"""
Parameter Extractor Node Data.
"""
model: ModelConfig
query: list[str]
parameters: list[ParameterConfig]
instruction: Optional[str] = None
memory: Optional[MemoryConfig] = None
reasoning_mode: Literal['function_call', 'prompt']
reasoning_mode: Literal["function_call", "prompt"]
@field_validator('reasoning_mode', mode='before')
@field_validator("reasoning_mode", mode="before")
@classmethod
def set_reasoning_mode(cls, v) -> str:
return v or 'function_call'
return v or "function_call"
def get_parameter_json_schema(self) -> dict:
"""
@ -56,32 +61,26 @@ class ParameterExtractorNodeData(BaseNodeData):
:return: parameter json schema
"""
parameters = {
'type': 'object',
'properties': {},
'required': []
}
parameters = {"type": "object", "properties": {}, "required": []}
for parameter in self.parameters:
parameter_schema = {
'description': parameter.description
}
parameter_schema = {"description": parameter.description}
if parameter.type in ['string', 'select']:
parameter_schema['type'] = 'string'
elif parameter.type.startswith('array'):
parameter_schema['type'] = 'array'
if parameter.type in {"string", "select"}:
parameter_schema["type"] = "string"
elif parameter.type.startswith("array"):
parameter_schema["type"] = "array"
nested_type = parameter.type[6:-1]
parameter_schema['items'] = {'type': nested_type}
parameter_schema["items"] = {"type": nested_type}
else:
parameter_schema['type'] = parameter.type
parameter_schema["type"] = parameter.type
if parameter.type == 'select':
parameter_schema['enum'] = parameter.options
if parameter.type == "select":
parameter_schema["enum"] = parameter.options
parameters["properties"][parameter.name] = parameter_schema
parameters['properties'][parameter.name] = parameter_schema
if parameter.required:
parameters['required'].append(parameter.name)
parameters["required"].append(parameter.name)
return parameters
return parameters

View File

@ -45,6 +45,7 @@ class ParameterExtractorNode(LLMNode):
"""
Parameter Extractor Node.
"""
_node_data_cls = ParameterExtractorNodeData
_node_type = NodeType.PARAMETER_EXTRACTOR
@ -57,11 +58,8 @@ class ParameterExtractorNode(LLMNode):
"model": {
"prompt_templates": {
"completion_model": {
"conversation_histories_role": {
"user_prefix": "Human",
"assistant_prefix": "Assistant"
},
"stop": ["Human:"]
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
"stop": ["Human:"],
}
}
}
@ -78,9 +76,9 @@ class ParameterExtractorNode(LLMNode):
query = variable
inputs = {
'query': query,
'parameters': jsonable_encoder(node_data.parameters),
'instruction': jsonable_encoder(node_data.instruction),
"query": query,
"parameters": jsonable_encoder(node_data.parameters),
"instruction": jsonable_encoder(node_data.instruction),
}
model_instance, model_config = self._fetch_model_config(node_data.model)
@ -95,30 +93,29 @@ class ParameterExtractorNode(LLMNode):
# fetch memory
memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance)
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \
and node_data.reasoning_mode == 'function_call':
# use function call
if (
set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}
and node_data.reasoning_mode == "function_call"
):
# use function call
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
node_data, query, self.graph_runtime_state.variable_pool, model_config, memory
)
else:
# use prompt engineering
prompt_messages = self._generate_prompt_engineering_prompt(node_data,
query,
self.graph_runtime_state.variable_pool,
model_config,
memory)
prompt_messages = self._generate_prompt_engineering_prompt(
node_data, query, self.graph_runtime_state.variable_pool, model_config, memory
)
prompt_message_tools = []
process_data = {
'model_mode': model_config.mode,
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode,
prompt_messages=prompt_messages
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
),
'usage': None,
'function': {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
'tool_call': None,
"usage": None,
"function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
"tool_call": None,
}
try:
@ -129,20 +126,17 @@ class ParameterExtractorNode(LLMNode):
tools=prompt_message_tools,
stop=model_config.stop,
)
process_data['usage'] = jsonable_encoder(usage)
process_data['tool_call'] = jsonable_encoder(tool_call)
process_data['llm_text'] = text
process_data["usage"] = jsonable_encoder(usage)
process_data["tool_call"] = jsonable_encoder(tool_call)
process_data["llm_text"] = text
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
outputs={
'__is_success': 0,
'__reason': str(e)
},
outputs={"__is_success": 0, "__reason": str(e)},
error=str(e),
metadata={}
metadata={},
)
error = None
@ -167,24 +161,23 @@ class ParameterExtractorNode(LLMNode):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={
'__is_success': 1 if not error else 0,
'__reason': error,
**result
},
outputs={"__is_success": 1 if not error else 0, "__reason": error, **result},
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
NodeRunMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage
llm_usage=usage,
)
def _invoke_llm(self, node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
stop: list[str]) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
def _invoke_llm(
self,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
stop: list[str],
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
"""
Invoke large language model
:param node_data_model: node data model
@ -217,32 +210,35 @@ class ParameterExtractorNode(LLMNode):
return text, usage, tool_call
def _generate_function_call_prompt(self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
def _generate_function_call_prompt(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
"""
Generate function call prompt.
"""
query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps(
node_data.get_parameter_json_schema()))
query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(
content=query, structure=json.dumps(node_data.get_parameter_json_schema())
)
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory,
rest_token)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
prompt_template = self._get_function_calling_prompt_template(
node_data, query, variable_pool, memory, rest_token
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query='',
query="",
files=[],
context='',
context="",
memory_config=node_data.memory,
memory=None,
model_config=model_config
model_config=model_config,
)
# find last user message
@ -255,124 +251,125 @@ class ParameterExtractorNode(LLMNode):
example_messages = []
for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE:
id = uuid.uuid4().hex
example_messages.extend([
UserPromptMessage(content=example['user']['query']),
AssistantPromptMessage(
content=example['assistant']['text'],
tool_calls=[
AssistantPromptMessage.ToolCall(
id=id,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=example['assistant']['function_call']['name'],
arguments=json.dumps(example['assistant']['function_call']['parameters']
)
))
]
),
ToolPromptMessage(
content='Great! You have called the function with the correct parameters.',
tool_call_id=id
),
AssistantPromptMessage(
content='I have extracted the parameters, let\'s move on.',
)
])
example_messages.extend(
[
UserPromptMessage(content=example["user"]["query"]),
AssistantPromptMessage(
content=example["assistant"]["text"],
tool_calls=[
AssistantPromptMessage.ToolCall(
id=id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=example["assistant"]["function_call"]["name"],
arguments=json.dumps(example["assistant"]["function_call"]["parameters"]),
),
)
],
),
ToolPromptMessage(
content="Great! You have called the function with the correct parameters.", tool_call_id=id
),
AssistantPromptMessage(
content="I have extracted the parameters, let's move on.",
),
]
)
prompt_messages = prompt_messages[:last_user_message_idx] + \
example_messages + prompt_messages[last_user_message_idx:]
prompt_messages = (
prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:]
)
# generate tool
tool = PromptMessageTool(
name=FUNCTION_CALLING_EXTRACTOR_NAME,
description='Extract parameters from the natural language text',
description="Extract parameters from the natural language text",
parameters=node_data.get_parameter_json_schema(),
)
return prompt_messages, [tool]
def _generate_prompt_engineering_prompt(self,
data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
def _generate_prompt_engineering_prompt(
self,
data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
"""
Generate prompt engineering prompt.
"""
model_mode = ModelMode.value_of(data.model.mode)
if model_mode == ModelMode.COMPLETION:
return self._generate_prompt_engineering_completion_prompt(
data, query, variable_pool, model_config, memory
)
return self._generate_prompt_engineering_completion_prompt(data, query, variable_pool, model_config, memory)
elif model_mode == ModelMode.CHAT:
return self._generate_prompt_engineering_chat_prompt(
data, query, variable_pool, model_config, memory
)
return self._generate_prompt_engineering_chat_prompt(data, query, variable_pool, model_config, memory)
else:
raise ValueError(f"Invalid model mode: {model_mode}")
def _generate_prompt_engineering_completion_prompt(self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
def _generate_prompt_engineering_completion_prompt(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
"""
Generate completion prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory,
rest_token)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
prompt_template = self._get_prompt_engineering_prompt_template(
node_data, query, variable_pool, memory, rest_token
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={
'structure': json.dumps(node_data.get_parameter_json_schema())
},
query='',
inputs={"structure": json.dumps(node_data.get_parameter_json_schema())},
query="",
files=[],
context='',
context="",
memory_config=node_data.memory,
memory=memory,
model_config=model_config
model_config=model_config,
)
return prompt_messages
def _generate_prompt_engineering_chat_prompt(self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
def _generate_prompt_engineering_chat_prompt(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
"""
Generate chat prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
prompt_template = self._get_prompt_engineering_prompt_template(
node_data,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
structure=json.dumps(node_data.get_parameter_json_schema()),
text=query
structure=json.dumps(node_data.get_parameter_json_schema()), text=query
),
variable_pool, memory, rest_token
variable_pool,
memory,
rest_token,
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query='',
query="",
files=[],
context='',
context="",
memory_config=node_data.memory,
memory=None,
model_config=model_config
model_config=model_config,
)
# find last user message
@ -384,18 +381,23 @@ class ParameterExtractorNode(LLMNode):
# add example messages before last user message
example_messages = []
for example in CHAT_EXAMPLE:
example_messages.extend([
UserPromptMessage(content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
structure=json.dumps(example['user']['json']),
text=example['user']['query'],
)),
AssistantPromptMessage(
content=json.dumps(example['assistant']['json']),
)
])
example_messages.extend(
[
UserPromptMessage(
content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
structure=json.dumps(example["user"]["json"]),
text=example["user"]["query"],
)
),
AssistantPromptMessage(
content=json.dumps(example["assistant"]["json"]),
),
]
)
prompt_messages = prompt_messages[:last_user_message_idx] + \
example_messages + prompt_messages[last_user_message_idx:]
prompt_messages = (
prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:]
)
return prompt_messages
@ -410,28 +412,28 @@ class ParameterExtractorNode(LLMNode):
if parameter.required and parameter.name not in result:
raise ValueError(f"Parameter {parameter.name} is required")
if parameter.type == 'select' and parameter.options and result.get(parameter.name) not in parameter.options:
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
raise ValueError(f"Invalid `select` value for parameter {parameter.name}")
if parameter.type == 'number' and not isinstance(result.get(parameter.name), int | float):
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
raise ValueError(f"Invalid `number` value for parameter {parameter.name}")
if parameter.type == 'bool' and not isinstance(result.get(parameter.name), bool):
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
raise ValueError(f"Invalid `bool` value for parameter {parameter.name}")
if parameter.type == 'string' and not isinstance(result.get(parameter.name), str):
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
if parameter.type.startswith('array'):
if parameter.type.startswith("array"):
if not isinstance(result.get(parameter.name), list):
raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
nested_type = parameter.type[6:-1]
for item in result.get(parameter.name):
if nested_type == 'number' and not isinstance(item, int | float):
if nested_type == "number" and not isinstance(item, int | float):
raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
if nested_type == 'string' and not isinstance(item, str):
if nested_type == "string" and not isinstance(item, str):
raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
if nested_type == 'object' and not isinstance(item, dict):
if nested_type == "object" and not isinstance(item, dict):
raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
return result
@ -443,12 +445,12 @@ class ParameterExtractorNode(LLMNode):
for parameter in data.parameters:
if parameter.name in result:
# transform value
if parameter.type == 'number':
if parameter.type == "number":
if isinstance(result[parameter.name], int | float):
transformed_result[parameter.name] = result[parameter.name]
elif isinstance(result[parameter.name], str):
try:
if '.' in result[parameter.name]:
if "." in result[parameter.name]:
result[parameter.name] = float(result[parameter.name])
else:
result[parameter.name] = int(result[parameter.name])
@ -465,40 +467,40 @@ class ParameterExtractorNode(LLMNode):
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true')
# elif isinstance(result[parameter.name], int):
# transformed_result[parameter.name] = bool(result[parameter.name])
elif parameter.type in ['string', 'select']:
elif parameter.type in {"string", "select"}:
if isinstance(result[parameter.name], str):
transformed_result[parameter.name] = result[parameter.name]
elif parameter.type.startswith('array'):
elif parameter.type.startswith("array"):
if isinstance(result[parameter.name], list):
nested_type = parameter.type[6:-1]
transformed_result[parameter.name] = []
for item in result[parameter.name]:
if nested_type == 'number':
if nested_type == "number":
if isinstance(item, int | float):
transformed_result[parameter.name].append(item)
elif isinstance(item, str):
try:
if '.' in item:
if "." in item:
transformed_result[parameter.name].append(float(item))
else:
transformed_result[parameter.name].append(int(item))
except ValueError:
pass
elif nested_type == 'string':
elif nested_type == "string":
if isinstance(item, str):
transformed_result[parameter.name].append(item)
elif nested_type == 'object':
elif nested_type == "object":
if isinstance(item, dict):
transformed_result[parameter.name].append(item)
if parameter.name not in transformed_result:
if parameter.type == 'number':
if parameter.type == "number":
transformed_result[parameter.name] = 0
elif parameter.type == 'bool':
elif parameter.type == "bool":
transformed_result[parameter.name] = False
elif parameter.type in ['string', 'select']:
transformed_result[parameter.name] = ''
elif parameter.type.startswith('array'):
elif parameter.type in {"string", "select"}:
transformed_result[parameter.name] = ""
elif parameter.type.startswith("array"):
transformed_result[parameter.name] = []
return transformed_result
@ -514,24 +516,24 @@ class ParameterExtractorNode(LLMNode):
"""
stack = []
for i, c in enumerate(text):
if c == '{' or c == '[':
if c in {"{", "["}:
stack.append(c)
elif c == '}' or c == ']':
elif c in {"}", "]"}:
# check if stack is empty
if not stack:
return text[:i]
# check if the last element in stack is matching
if (c == '}' and stack[-1] == '{') or (c == ']' and stack[-1] == '['):
if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["):
stack.pop()
if not stack:
return text[:i + 1]
return text[: i + 1]
else:
return text[:i]
return None
# extract json from the text
for idx in range(len(result)):
if result[idx] == '{' or result[idx] == '[':
if result[idx] == "{" or result[idx] == "[":
json_str = extract_json(result[idx:])
if json_str:
try:
@ -554,12 +556,12 @@ class ParameterExtractorNode(LLMNode):
"""
result = {}
for parameter in data.parameters:
if parameter.type == 'number':
if parameter.type == "number":
result[parameter.name] = 0
elif parameter.type == 'bool':
elif parameter.type == "bool":
result[parameter.name] = False
elif parameter.type in ['string', 'select']:
result[parameter.name] = ''
elif parameter.type in {"string", "select"}:
result[parameter.name] = ""
return result
@ -575,71 +577,76 @@ class ParameterExtractorNode(LLMNode):
return variable_template_parser.format(inputs)
def _get_function_calling_prompt_template(self, node_data: ParameterExtractorNodeData, query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000) \
-> list[ChatModelMessage]:
def _get_function_calling_prompt_template(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode.value_of(node_data.model.mode)
input_text = query
memory_str = ''
instruction = self._render_instruction(node_data.instruction or '', variable_pool)
memory_str = ""
instruction = self._render_instruction(node_data.instruction or "", variable_pool)
if memory:
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size)
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction)
)
user_prompt_message = ChatModelMessage(
role=PromptMessageRole.USER,
text=input_text
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]
else:
raise ValueError(f"Model mode {model_mode} not support.")
def _get_prompt_engineering_prompt_template(self, node_data: ParameterExtractorNodeData, query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000) \
-> list[ChatModelMessage]:
def _get_prompt_engineering_prompt_template(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode.value_of(node_data.model.mode)
input_text = query
memory_str = ''
instruction = self._render_instruction(node_data.instruction or '', variable_pool)
memory_str = ""
instruction = self._render_instruction(node_data.instruction or "", variable_pool)
if memory:
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size)
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction)
)
user_prompt_message = ChatModelMessage(
role=PromptMessageRole.USER,
text=input_text
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]
elif model_mode == ModelMode.COMPLETION:
return CompletionModelPromptTemplate(
text=COMPLETION_GENERATE_JSON_PROMPT.format(histories=memory_str,
text=input_text,
instruction=instruction)
.replace('{γγγ', '')
.replace('}γγγ', '')
text=COMPLETION_GENERATE_JSON_PROMPT.format(
histories=memory_str, text=input_text, instruction=instruction
)
.replace("{γγγ", "")
.replace("}γγγ", "")
)
else:
raise ValueError(f"Model mode {model_mode} not support.")
def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str]) -> int:
def _calculate_rest_token(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str],
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
model_instance, model_config = self._fetch_model_config(node_data.model)
@ -659,12 +666,12 @@ class ParameterExtractorNode(LLMNode):
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query='',
query="",
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config
model_config=model_config,
)
rest_tokens = 2000
@ -673,26 +680,28 @@ class ParameterExtractorNode(LLMNode):
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
curr_message_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
prompt_messages
) + 1000 # add 1000 to ensure tool call messages
curr_message_tokens = (
model_type_instance.get_num_tokens(model_config.model, model_config.credentials, prompt_messages) + 1000
) # add 1000 to ensure tool call messages
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[
ModelInstance, ModelConfigWithCredentialsEntity]:
def _fetch_model_config(
self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config.
"""
@ -703,10 +712,7 @@ class ParameterExtractorNode(LLMNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ParameterExtractorNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: ParameterExtractorNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -715,17 +721,13 @@ class ParameterExtractorNode(LLMNode):
:param node_data: node data
:return:
"""
variable_mapping = {
'query': node_data.query
}
variable_mapping = {"query": node_data.query}
if node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
for selector in variable_template_parser.extract_variable_selectors():
variable_mapping[selector.variable] = selector.value_selector
variable_mapping = {
node_id + '.' + key: value for key, value in variable_mapping.items()
}
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
return variable_mapping

View File

@ -1,4 +1,4 @@
FUNCTION_CALLING_EXTRACTOR_NAME = 'extract_parameters'
FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters"
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy.
### Task
@ -23,7 +23,7 @@ Steps:
To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples.
### Final Output
Produce well-formatted function calls in json without XML tags, as shown in the example.
"""
""" # noqa: E501
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside <context></context> XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside <structure></structure> XML tags.
<context>
@ -33,63 +33,52 @@ FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information fr
<structure>
\x7bstructure\x7d
</structure>
"""
""" # noqa: E501
FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [{
'user': {
'query': 'What is the weather today in SF?',
'function': {
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
'parameters': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The location to get the weather information',
'required': True
FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [
{
"user": {
"query": "What is the weather today in SF?",
"function": {
"name": FUNCTION_CALLING_EXTRACTOR_NAME,
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The location to get the weather information",
"required": True,
},
},
"required": ["location"],
},
'required': ['location']
}
}
},
},
"assistant": {
"text": "I need always call the function with the correct parameters."
" in this case, I need to call the function with the location parameter.",
"function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"location": "San Francisco"}},
},
},
'assistant': {
'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the location parameter.',
'function_call' : {
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
'parameters': {
'location': 'San Francisco'
}
}
}
}, {
'user': {
'query': 'I want to eat some apple pie.',
'function': {
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
'parameters': {
'type': 'object',
'properties': {
'food': {
'type': 'string',
'description': 'The food to eat',
'required': True
}
{
"user": {
"query": "I want to eat some apple pie.",
"function": {
"name": FUNCTION_CALLING_EXTRACTOR_NAME,
"parameters": {
"type": "object",
"properties": {"food": {"type": "string", "description": "The food to eat", "required": True}},
"required": ["food"],
},
'required': ['food']
}
}
},
},
"assistant": {
"text": "I need always call the function with the correct parameters."
" in this case, I need to call the function with the food parameter.",
"function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"food": "apple pie"}},
},
},
'assistant': {
'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the food parameter.',
'function_call' : {
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
'parameters': {
'food': 'apple pie'
}
}
}
}]
]
COMPLETION_GENERATE_JSON_PROMPT = """### Instructions:
Some extra information are provided below, I should always follow the instructions as possible as I can.
@ -130,7 +119,7 @@ Inside <text></text> XML tags, there is a text that I should extract parameters
### Answer
I should always output a valid JSON object. Output nothing other than the JSON object.
```JSON
"""
""" # noqa: E501
CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object.
The structure of the JSON object you can found in the instructions.
@ -161,46 +150,33 @@ Inside <text></text> XML tags, there is a text that you should convert to a JSON
</text>
"""
CHAT_EXAMPLE = [{
'user': {
'query': 'What is the weather today in SF?',
'json': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The location to get the weather information',
'required': True
}
CHAT_EXAMPLE = [
{
"user": {
"query": "What is the weather today in SF?",
"json": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The location to get the weather information",
"required": True,
}
},
"required": ["location"],
},
'required': ['location']
}
},
"assistant": {"text": "I need to output a valid JSON object.", "json": {"location": "San Francisco"}},
},
'assistant': {
'text': 'I need to output a valid JSON object.',
'json': {
'location': 'San Francisco'
}
}
}, {
'user': {
'query': 'I want to eat some apple pie.',
'json': {
'type': 'object',
'properties': {
'food': {
'type': 'string',
'description': 'The food to eat',
'required': True
}
{
"user": {
"query": "I want to eat some apple pie.",
"json": {
"type": "object",
"properties": {"food": {"type": "string", "description": "The food to eat", "required": True}},
"required": ["food"],
},
'required': ['food']
}
},
"assistant": {"text": "I need to output a valid JSON object.", "json": {"result": "apple pie"}},
},
'assistant': {
'text': 'I need to output a valid JSON object.',
'json': {
'result': 'apple pie'
}
}
}]
]

View File

@ -8,8 +8,9 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
class ModelConfig(BaseModel):
"""
Model Config.
Model Config.
"""
provider: str
name: str
mode: str
@ -20,6 +21,7 @@ class ClassConfig(BaseModel):
"""
Class Config.
"""
id: str
name: str
@ -28,8 +30,9 @@ class QuestionClassifierNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
query_variable_selector: list[str]
type: str = 'question-classifier'
type: str = "question-classifier"
model: ModelConfig
classes: list[ClassConfig]
instruction: Optional[str] = None

View File

@ -45,34 +45,25 @@ class QuestionClassifierNode(LLMNode):
# extract variables
variable = variable_pool.get(node_data.query_variable_selector)
query = variable.value if variable else None
variables = {
'query': query
}
variables = {"query": query}
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
# fetch instruction
instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else ''
instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else ""
node_data.instruction = instruction
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt(
node_data=node_data,
context='',
query=query,
memory=memory,
model_config=model_config
node_data=node_data, context="", query=query, memory=memory, model_config=model_config
)
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop
)
result_text = ''
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
for event in generator:
@ -87,8 +78,8 @@ class QuestionClassifierNode(LLMNode):
try:
result_text_json = parse_and_check_json_markdown(result_text, [])
# result_text_json = json.loads(result_text.strip('```JSON\n'))
if 'category_name' in result_text_json and 'category_id' in result_text_json:
category_id_result = result_text_json['category_id']
if "category_name" in result_text_json and "category_id" in result_text_json:
category_id_result = result_text_json["category_id"]
classes = node_data.classes
classes_map = {class_.id: class_.name for class_ in classes}
category_ids = [_class.id for _class in classes]
@ -100,17 +91,14 @@ class QuestionClassifierNode(LLMNode):
logging.error(f"Failed to parse result text: {result_text}")
try:
process_data = {
'model_mode': model_config.mode,
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode,
prompt_messages=prompt_messages
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
),
'usage': jsonable_encoder(usage),
'finish_reason': finish_reason
}
outputs = {
'class_name': category_name
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
}
outputs = {"class_name": category_name}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -121,9 +109,9 @@ class QuestionClassifierNode(LLMNode):
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
NodeRunMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage
llm_usage=usage,
)
except ValueError as e:
@ -134,17 +122,14 @@ class QuestionClassifierNode(LLMNode):
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
NodeRunMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage
llm_usage=usage,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: QuestionClassifierNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: QuestionClassifierNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -153,7 +138,7 @@ class QuestionClassifierNode(LLMNode):
:param node_data: node data
:return:
"""
variable_mapping = {'query': node_data.query_variable_selector}
variable_mapping = {"query": node_data.query_variable_selector}
variable_selectors = []
if node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
@ -161,10 +146,8 @@ class QuestionClassifierNode(LLMNode):
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {
node_id + '.' + key: value for key, value in variable_mapping.items()
}
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
return variable_mapping
@classmethod
@ -174,19 +157,16 @@ class QuestionClassifierNode(LLMNode):
:param filters: filter by node config parameters.
:return:
"""
return {
"type": "question-classifier",
"config": {
"instructions": ""
}
}
return {"type": "question-classifier", "config": {"instructions": ""}}
def _fetch_prompt(self, node_data: QuestionClassifierNodeData,
query: str,
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
def _fetch_prompt(
self,
node_data: QuestionClassifierNodeData,
query: str,
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Fetch prompt
:param node_data: node data
@ -202,118 +182,122 @@ class QuestionClassifierNode(LLMNode):
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query='',
query="",
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config
model_config=model_config,
)
stop = model_config.stop
return prompt_messages, stop
def _calculate_rest_token(self, node_data: QuestionClassifierNodeData, query: str,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str]) -> int:
def _calculate_rest_token(
self,
node_data: QuestionClassifierNodeData,
query: str,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str],
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query='',
query="",
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config
model_config=model_config,
)
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(
prompt_messages
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000) \
-> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
def _get_prompt_template(
self,
node_data: QuestionClassifierNodeData,
query: str,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
) -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
model_mode = ModelMode.value_of(node_data.model.mode)
classes = node_data.classes
categories = []
for class_ in classes:
category = {
'category_id': class_.id,
'category_name': class_.name
}
category = {"category_id": class_.id, "category_name": class_.name}
categories.append(category)
instruction = node_data.instruction if node_data.instruction else ''
instruction = node_data.instruction or ""
input_text = query
memory_str = ''
memory_str = ""
if memory:
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size)
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
prompt_messages = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = ChatModelMessage(
role=PromptMessageRole.USER,
text=QUESTION_CLASSIFIER_USER_PROMPT_1
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1
)
prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = ChatModelMessage(
role=PromptMessageRole.ASSISTANT,
text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = ChatModelMessage(
role=PromptMessageRole.USER,
text=QUESTION_CLASSIFIER_USER_PROMPT_2
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = ChatModelMessage(
role=PromptMessageRole.ASSISTANT,
text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2
)
prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = ChatModelMessage(
role=PromptMessageRole.USER,
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text,
categories=json.dumps(categories, ensure_ascii=False),
classification_instructions=instruction)
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(
input_text=input_text,
categories=json.dumps(categories, ensure_ascii=False),
classification_instructions=instruction,
),
)
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
return CompletionModelPromptTemplate(
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str,
input_text=input_text,
categories=json.dumps(categories),
classification_instructions=instruction,
ensure_ascii=False)
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(
histories=memory_str,
input_text=input_text,
categories=json.dumps(categories),
classification_instructions=instruction,
ensure_ascii=False,
)
)
else:
@ -329,14 +313,12 @@ class QuestionClassifierNode(LLMNode):
variable = variable_pool.get(variable_selector.value_selector)
variable_value = variable.value if variable else None
if variable_value is None:
raise ValueError(f'Variable {variable_selector.variable} not found')
raise ValueError(f"Variable {variable_selector.variable} not found")
inputs[variable_selector.variable] = variable_value
prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
instruction = prompt_template.format(
prompt_inputs
)
instruction = prompt_template.format(prompt_inputs)
return instruction

View File

@ -1,5 +1,3 @@
QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
### Job Description',
You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories.
@ -14,13 +12,13 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
<histories>
{histories}
</histories>
"""
""" # noqa: E501
QUESTION_CLASSIFIER_USER_PROMPT_1 = """
{ "input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],
"categories": [{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"},{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"},{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"},{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}],
"classification_instructions": ["classify the text based on the feedback provided by customer"]}
"""
""" # noqa: E501
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """
```json
@ -34,7 +32,7 @@ QUESTION_CLASSIFIER_USER_PROMPT_2 = """
{"input_text": ["bad service, slow to bring the food"],
"categories": [{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"},{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"},{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}],
"classification_instructions": []}
"""
""" # noqa: E501
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """
```json
@ -75,4 +73,4 @@ Here is the chat histories between human and assistant, inside <histories></hist
### User Input
{{"input_text" : ["{input_text}"], "categories" : {categories},"classification_instruction" : ["{classification_instructions}"]}}
### Assistant Output
"""
""" # noqa: E501

View File

@ -10,4 +10,5 @@ class StartNodeData(BaseNodeData):
"""
Start Node Data
"""
variables: Sequence[VariableEntity] = Field(default_factory=list)

View File

@ -1,4 +1,3 @@
from collections.abc import Mapping, Sequence
from typing import Any
@ -22,20 +21,13 @@ class StartNode(BaseNode):
system_inputs = self.graph_runtime_state.variable_pool.system_variables
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
outputs=node_inputs
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: StartNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: StartNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

View File

@ -1,5 +1,3 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
@ -8,5 +6,6 @@ class TemplateTransformNodeData(BaseNodeData):
"""
Code Node Data.
"""
variables: list[VariableSelector]
template: str
template: str

View File

@ -2,13 +2,13 @@ import os
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from models.workflow import WorkflowNodeExecutionStatus
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000'))
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
class TemplateTransformNode(BaseNode):
@ -24,15 +24,7 @@ class TemplateTransformNode(BaseNode):
"""
return {
"type": "template-transform",
"config": {
"variables": [
{
"variable": "arg1",
"value_selector": []
}
],
"template": "{{ arg1 }}"
}
"config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"},
}
def _run(self) -> NodeRunResult:
@ -51,38 +43,25 @@ class TemplateTransformNode(BaseNode):
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=node_data.template,
inputs=variables
)
except CodeExecutionException as e:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e)
language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables
)
except CodeExecutionError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters"
error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters",
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
outputs={
'output': result['result']
}
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: TemplateTransformNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -92,5 +71,6 @@ class TemplateTransformNode(BaseNode):
:return:
"""
return {
node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables
}

View File

@ -10,45 +10,46 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
class ToolEntity(BaseModel):
provider_id: str
provider_type: ToolProviderType
provider_name: str # redundancy
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
tool_label: str # redundancy
tool_configurations: dict[str, Any]
@field_validator('tool_configurations', mode='before')
@field_validator("tool_configurations", mode="before")
@classmethod
def validate_tool_configurations(cls, value, values: ValidationInfo):
if not isinstance(value, dict):
raise ValueError('tool_configurations must be a dictionary')
for key in values.data.get('tool_configurations', {}).keys():
value = values.data.get('tool_configurations', {}).get(key)
raise ValueError("tool_configurations must be a dictionary")
for key in values.data.get("tool_configurations", {}):
value = values.data.get("tool_configurations", {}).get(key)
if not isinstance(value, str | int | float | bool):
raise ValueError(f'{key} must be a string')
raise ValueError(f"{key} must be a string")
return value
class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal['mixed', 'variable', 'constant']
type: Literal["mixed", "variable", "constant"]
@field_validator('type', mode='before')
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get('value')
if typ == 'mixed' and not isinstance(value, str):
raise ValueError('value must be a string')
elif typ == 'variable':
value = validation_info.data.get("value")
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError('value must be a list')
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError('value must be a list of strings')
elif typ == 'constant' and not isinstance(value, str | int | float | bool):
raise ValueError('value must be a string, int, float, or bool')
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool):
raise ValueError("value must be a string, int, float, or bool")
return typ
"""

View File

@ -35,10 +35,7 @@ class ToolNode(BaseNode):
node_data = cast(ToolNodeData, self.node_data)
# fetch tool icon
tool_info = {
'provider_type': node_data.provider_type.value,
'provider_id': node_data.provider_id
}
tool_info = {"provider_type": node_data.provider_type.value, "provider_id": node_data.provider_id}
# get tool runtime
try:
@ -50,10 +47,8 @@ class ToolNode(BaseNode):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to get tool runtime: {str(e)}'
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to get tool runtime: {str(e)}",
)
)
return
@ -61,15 +56,13 @@ class ToolNode(BaseNode):
# get parameters
tool_parameters = tool_runtime.get_runtime_parameters() or []
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data
tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data,
for_log=True
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data,
for_log=True,
)
try:
@ -86,10 +79,8 @@ class ToolNode(BaseNode):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to invoke tool: {str(e)}',
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool: {str(e)}",
)
)
return
@ -126,12 +117,10 @@ class ToolNode(BaseNode):
result[parameter_name] = None
continue
if parameter.type == ToolParameter.ToolParameterType.FILE:
result[parameter_name] = [
v.to_dict() for v in self._fetch_files(variable_pool)
]
result[parameter_name] = [v.to_dict() for v in self._fetch_files(variable_pool)]
else:
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == 'variable':
if tool_input.type == "variable":
parameter_value_segment = variable_pool.get(tool_input.value)
if not parameter_value_segment:
raise Exception("input variable dose not exists")
@ -147,14 +136,16 @@ class ToolNode(BaseNode):
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
variable = variable_pool.get(['sys', SystemVariableKey.FILES.value])
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _transform_message(self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any]) -> Generator[RunEvent, None, None]:
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
) -> Generator[RunEvent, None, None]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
@ -169,66 +160,65 @@ class ToolNode(BaseNode):
files: list[FileVar] = []
text = ""
json: list[dict] = []
variables: dict[str, Any] = {}
for message in message_stream:
if message.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
message.type == ToolInvokeMessage.MessageType.IMAGE:
if message.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
url = message.message.text
ext = path.splitext(url)[1]
mimetype = message.meta.get('mime_type', 'image/jpeg')
filename = message.save_as or url.split('/')[-1]
transfer_method = message.meta.get('transfer_method', FileTransferMethod.TOOL_FILE)
mimetype = message.meta.get("mime_type", "image/jpeg")
filename = message.save_as or url.split("/")[-1]
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
# get tool file id
tool_file_id = url.split('/')[-1].split('.')[0]
files.append(FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
url=url,
related_id=tool_file_id,
filename=filename,
extension=ext,
mime_type=mimetype,
))
tool_file_id = url.split("/")[-1].split(".")[0]
files.append(
FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
url=url,
related_id=tool_file_id,
filename=filename,
extension=ext,
mime_type=mimetype,
)
)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split('/')[-1].split('.')[0]
files.append(FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
filename=message.save_as,
extension=path.splitext(message.save_as)[1],
mime_type=message.meta.get('mime_type', 'application/octet-stream'),
))
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
files.append(
FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
filename=message.save_as,
extension=path.splitext(message.save_as)[1],
mime_type=message.meta.get("mime_type", "application/octet-stream"),
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text + '\n'
text += message.message.text + "\n"
yield RunStreamChunkEvent(
chunk_content=message.message.text,
from_variable_selector=[self.node_id, 'text']
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message, ToolInvokeMessage.JsonMessage)
json.append(message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f'Link: {message.message.text}\n'
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(
chunk_content=stream_text,
from_variable_selector=[self.node_id, 'text']
)
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
@ -241,8 +231,7 @@ class ToolNode(BaseNode):
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value,
from_variable_selector=[self.node_id, variable_name]
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
)
else:
variables[variable_name] = variable_value
@ -250,25 +239,15 @@ class ToolNode(BaseNode):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'text': text,
'files': files,
'json': json,
**variables
},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
inputs=parameters_for_log
outputs={"text": text, "files": files, "json": json, **variables},
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
inputs=parameters_for_log,
)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ToolNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: ToolNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -280,18 +259,16 @@ class ToolNode(BaseNode):
result = {}
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == 'mixed':
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == 'variable':
elif input.type == "variable":
result[parameter_name] = input.value
elif input.type == 'constant':
elif input.type == "constant":
pass
result = {
node_id + '.' + key: value for key, value in result.items()
}
result = {node_id + "." + key: value for key, value in result.items()}
return result

View File

@ -1,5 +1,3 @@
from typing import Literal, Optional
from pydantic import BaseModel
@ -11,23 +9,27 @@ class AdvancedSettings(BaseModel):
"""
Advanced setting.
"""
group_enabled: bool
class Group(BaseModel):
"""
Group.
"""
output_type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]']
output_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
variables: list[list[str]]
group_name: str
groups: list[Group]
class VariableAssignerNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
type: str = 'variable-assigner'
type: str = "variable-assigner"
output_type: str
variables: list[list[str]]
advanced_settings: Optional[AdvancedSettings] = None

View File

@ -21,13 +21,9 @@ class VariableAggregatorNode(BaseNode):
for selector in node_data.variables:
variable = self.graph_runtime_state.variable_pool.get_any(selector)
if variable is not None:
outputs = {
"output": variable
}
outputs = {"output": variable}
inputs = {
'.'.join(selector[1:]): variable
}
inputs = {".".join(selector[1:]): variable}
break
else:
for group in node_data.advanced_settings.groups:
@ -35,24 +31,15 @@ class VariableAggregatorNode(BaseNode):
variable = self.graph_runtime_state.variable_pool.get_any(selector)
if variable is not None:
outputs[group.group_name] = {
'output': variable
}
inputs['.'.join(selector[1:])] = variable
outputs[group.group_name] = {"output": variable}
inputs[".".join(selector[1:])] = variable
break
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
inputs=inputs
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: VariableAssignerNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

View File

@ -2,7 +2,7 @@ from .node import VariableAssignerNode
from .node_data import VariableAssignerData, WriteMode
__all__ = [
'VariableAssignerNode',
'VariableAssignerData',
'WriteMode',
"VariableAssignerNode",
"VariableAssignerData",
"WriteMode",
]

View File

@ -24,43 +24,43 @@ class VariableAssignerNode(BaseNode):
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError('assigned variable not found')
raise VariableAssignerNodeError("assigned variable not found")
match data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_variable = original_variable.model_copy(update={'value': income_value.value})
raise VariableAssignerNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
raise VariableAssignerNodeError("input value not found")
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={'value': updated_value})
updated_variable = original_variable.model_copy(update={"value": updated_value})
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
raise VariableAssignerNodeError(f"unsupported write mode: {data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline.
# Update conversation variable.
conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id'])
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise VariableAssignerNodeError('conversation_id not found')
raise VariableAssignerNodeError("conversation_id not found")
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
'value': income_value.to_object(),
"value": income_value.to_object(),
},
)
@ -72,7 +72,7 @@ def update_conversation_variable(conversation_id: str, variable: Variable):
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableAssignerNodeError('conversation variable not found in the database')
raise VariableAssignerNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
@ -84,8 +84,8 @@ def get_zero_value(t: SegmentType):
case SegmentType.OBJECT:
return factory.build_segment({})
case SegmentType.STRING:
return factory.build_segment('')
return factory.build_segment("")
case SegmentType.NUMBER:
return factory.build_segment(0)
case _:
raise VariableAssignerNodeError(f'unsupported variable type: {t}')
raise VariableAssignerNodeError(f"unsupported variable type: {t}")

View File

@ -6,14 +6,14 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
class WriteMode(str, Enum):
OVER_WRITE = 'over-write'
APPEND = 'append'
CLEAR = 'clear'
OVER_WRITE = "over-write"
APPEND = "append"
CLEAR = "clear"
class VariableAssignerData(BaseNodeData):
title: str = 'Variable Assigner'
desc: Optional[str] = 'Assign a value to a variable'
title: str = "Variable Assigner"
desc: Optional[str] = "Assign a value to a variable"
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]