mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
Merge main
This commit is contained in:
@ -29,14 +29,12 @@ class AnswerNode(BaseNode):
|
||||
# generate routes
|
||||
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
answer = ''
|
||||
answer = ""
|
||||
for part in generate_routes:
|
||||
if part.type == GenerateRouteChunk.ChunkType.VAR:
|
||||
part = cast(VarGenerateRouteChunk, part)
|
||||
value_selector = part.value_selector
|
||||
value = self.graph_runtime_state.variable_pool.get(
|
||||
value_selector
|
||||
)
|
||||
value = self.graph_runtime_state.variable_pool.get(value_selector)
|
||||
|
||||
if value:
|
||||
answer += value.markdown
|
||||
@ -44,19 +42,11 @@ class AnswerNode(BaseNode):
|
||||
part = cast(TextGenerateRouteChunk, part)
|
||||
answer += part.text
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"answer": answer
|
||||
}
|
||||
)
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer})
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AnswerNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: AnswerNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@ -73,6 +63,6 @@ class AnswerNode(BaseNode):
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector
|
||||
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
@ -12,12 +11,12 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerStreamGeneratorRouter:
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
|
||||
) -> AnswerStreamGenerateRoute:
|
||||
def init(
|
||||
cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
) -> AnswerStreamGenerateRoute:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
@ -25,7 +24,7 @@ class AnswerStreamGeneratorRouter:
|
||||
# parse stream output node value selectors of answer nodes
|
||||
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
|
||||
for answer_node_id, node_config in node_id_config_mapping.items():
|
||||
if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value:
|
||||
if node_config.get("data", {}).get("type") != NodeType.ANSWER.value:
|
||||
continue
|
||||
|
||||
# get generate route for stream output
|
||||
@ -37,12 +36,11 @@ class AnswerStreamGeneratorRouter:
|
||||
answer_dependencies = cls._fetch_answers_dependencies(
|
||||
answer_node_ids=answer_node_ids,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_id_config_mapping=node_id_config_mapping
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
)
|
||||
|
||||
return AnswerStreamGenerateRoute(
|
||||
answer_generate_route=answer_generate_route,
|
||||
answer_dependencies=answer_dependencies
|
||||
answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -56,8 +54,7 @@ class AnswerStreamGeneratorRouter:
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
value_selector_mapping = {
|
||||
variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in variable_selectors
|
||||
variable_selector.variable: variable_selector.value_selector for variable_selector in variable_selectors
|
||||
}
|
||||
|
||||
variable_keys = list(value_selector_mapping.keys())
|
||||
@ -71,21 +68,17 @@ class AnswerStreamGeneratorRouter:
|
||||
|
||||
template = node_data.answer
|
||||
for var in variable_keys:
|
||||
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
||||
template = template.replace(f"{{{{{var}}}}}", f"Ω{{{{{var}}}}}Ω")
|
||||
|
||||
generate_routes: list[GenerateRouteChunk] = []
|
||||
for part in template.split('Ω'):
|
||||
for part in template.split("Ω"):
|
||||
if part:
|
||||
if cls._is_variable(part, variable_keys):
|
||||
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
|
||||
var_key = part.replace("Ω", "").replace("{{", "").replace("}}", "")
|
||||
value_selector = value_selector_mapping[var_key]
|
||||
generate_routes.append(VarGenerateRouteChunk(
|
||||
value_selector=value_selector
|
||||
))
|
||||
generate_routes.append(VarGenerateRouteChunk(value_selector=value_selector))
|
||||
else:
|
||||
generate_routes.append(TextGenerateRouteChunk(
|
||||
text=part
|
||||
))
|
||||
generate_routes.append(TextGenerateRouteChunk(text=part))
|
||||
|
||||
return generate_routes
|
||||
|
||||
@ -101,15 +94,16 @@ class AnswerStreamGeneratorRouter:
|
||||
|
||||
@classmethod
|
||||
def _is_variable(cls, part, variable_keys):
|
||||
cleaned_part = part.replace('{{', '').replace('}}', '')
|
||||
return part.startswith('{{') and cleaned_part in variable_keys
|
||||
cleaned_part = part.replace("{{", "").replace("}}", "")
|
||||
return part.startswith("{{") and cleaned_part in variable_keys
|
||||
|
||||
@classmethod
|
||||
def _fetch_answers_dependencies(cls,
|
||||
answer_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict]
|
||||
) -> dict[str, list[str]]:
|
||||
def _fetch_answers_dependencies(
|
||||
cls,
|
||||
answer_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch answer dependencies
|
||||
:param answer_node_ids: answer node ids
|
||||
@ -127,19 +121,20 @@ class AnswerStreamGeneratorRouter:
|
||||
answer_node_id=answer_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
answer_dependencies=answer_dependencies
|
||||
answer_dependencies=answer_dependencies,
|
||||
)
|
||||
|
||||
return answer_dependencies
|
||||
|
||||
@classmethod
|
||||
def _recursive_fetch_answer_dependencies(cls,
|
||||
current_node_id: str,
|
||||
answer_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
answer_dependencies: dict[str, list[str]]
|
||||
) -> None:
|
||||
def _recursive_fetch_answer_dependencies(
|
||||
cls,
|
||||
current_node_id: str,
|
||||
answer_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
answer_dependencies: dict[str, list[str]],
|
||||
) -> None:
|
||||
"""
|
||||
Recursive fetch answer dependencies
|
||||
:param current_node_id: current node id
|
||||
@ -152,12 +147,12 @@ class AnswerStreamGeneratorRouter:
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
|
||||
if source_node_type in (
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
):
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
if source_node_type in {
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value,
|
||||
}:
|
||||
answer_dependencies[answer_node_id].append(source_node_id)
|
||||
else:
|
||||
cls._recursive_fetch_answer_dependencies(
|
||||
@ -165,5 +160,5 @@ class AnswerStreamGeneratorRouter:
|
||||
answer_node_id=answer_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
answer_dependencies=answer_dependencies
|
||||
answer_dependencies=answer_dependencies,
|
||||
)
|
||||
|
||||
@ -18,7 +18,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnswerStreamProcessor(StreamProcessor):
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
super().__init__(graph, variable_pool)
|
||||
self.generate_routes = graph.answer_stream_generate_routes
|
||||
@ -27,9 +26,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
self.route_position[answer_node_id] = 0
|
||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||
|
||||
def process(self,
|
||||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
|
||||
@ -47,9 +44,9 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
]
|
||||
else:
|
||||
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
|
||||
self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
] = stream_out_answer_node_ids
|
||||
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
|
||||
stream_out_answer_node_ids
|
||||
)
|
||||
|
||||
for _ in stream_out_answer_node_ids:
|
||||
yield event
|
||||
@ -77,9 +74,9 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
def _generate_stream_outputs_when_node_finished(
|
||||
self, event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:param event: node run succeeded event
|
||||
@ -87,10 +84,13 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
"""
|
||||
for answer_node_id, position in self.route_position.items():
|
||||
# all depends on answer node id not in rest node ids
|
||||
if (event.route_node_state.node_id != answer_node_id
|
||||
and (answer_node_id not in self.rest_node_ids
|
||||
or not all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
|
||||
if event.route_node_state.node_id != answer_node_id and (
|
||||
answer_node_id not in self.rest_node_ids
|
||||
or not all(
|
||||
dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
route_position = self.route_position[answer_node_id]
|
||||
@ -108,6 +108,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
from_variable_selector=[answer_node_id, "answer"],
|
||||
)
|
||||
else:
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
@ -115,9 +116,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
if not value_selector:
|
||||
break
|
||||
|
||||
value = self.variable_pool.get(
|
||||
value_selector
|
||||
)
|
||||
value = self.variable_pool.get(value_selector)
|
||||
|
||||
if value is None:
|
||||
break
|
||||
@ -158,8 +157,9 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
continue
|
||||
|
||||
# all depends on answer node id not in rest node ids
|
||||
if all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
|
||||
if all(
|
||||
dep_id not in self.rest_node_ids for dep_id in self.generate_routes.answer_dependencies[answer_node_id]
|
||||
):
|
||||
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
|
||||
continue
|
||||
|
||||
@ -213,7 +213,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
if '__variant' in value and value['__variant'] == FileVar.__name__:
|
||||
if "__variant" in value and value["__variant"] == FileVar.__name__:
|
||||
return value
|
||||
elif isinstance(value, FileVar):
|
||||
return value.to_dict()
|
||||
|
||||
@ -7,16 +7,13 @@ from core.workflow.graph_engine.entities.graph import Graph
|
||||
|
||||
|
||||
class StreamProcessor(ABC):
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
self.graph = graph
|
||||
self.variable_pool = variable_pool
|
||||
self.rest_node_ids = graph.node_ids.copy()
|
||||
|
||||
@abstractmethod
|
||||
def process(self,
|
||||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
||||
@ -35,9 +32,11 @@ class StreamProcessor(ABC):
|
||||
reachable_node_ids = []
|
||||
unreachable_first_node_ids = []
|
||||
for edge in self.graph.edge_mapping[finished_node_id]:
|
||||
if (edge.run_condition
|
||||
and edge.run_condition.branch_identify
|
||||
and run_result.edge_source_handle == edge.run_condition.branch_identify):
|
||||
if (
|
||||
edge.run_condition
|
||||
and edge.run_condition.branch_identify
|
||||
and run_result.edge_source_handle == edge.run_condition.branch_identify
|
||||
):
|
||||
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
continue
|
||||
else:
|
||||
|
||||
@ -9,6 +9,7 @@ class AnswerNodeData(BaseNodeData):
|
||||
"""
|
||||
Answer Node Data.
|
||||
"""
|
||||
|
||||
answer: str = Field(..., description="answer template string")
|
||||
|
||||
|
||||
@ -28,6 +29,7 @@ class VarGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Var Generate Route Chunk.
|
||||
"""
|
||||
|
||||
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
|
||||
"""generate route chunk type"""
|
||||
value_selector: list[str] = Field(..., description="value selector")
|
||||
@ -37,6 +39,7 @@ class TextGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Text Generate Route Chunk.
|
||||
"""
|
||||
|
||||
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT
|
||||
"""generate route chunk type"""
|
||||
text: str = Field(..., description="text")
|
||||
@ -52,11 +55,10 @@ class AnswerStreamGenerateRoute(BaseModel):
|
||||
"""
|
||||
AnswerStreamGenerateRoute entity
|
||||
"""
|
||||
|
||||
answer_dependencies: dict[str, list[str]] = Field(
|
||||
...,
|
||||
description="answer dependencies (answer node id -> dependent answer node ids)"
|
||||
..., description="answer dependencies (answer node id -> dependent answer node ids)"
|
||||
)
|
||||
answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field(
|
||||
...,
|
||||
description="answer generate route (answer node id -> generate route chunks)"
|
||||
..., description="answer generate route (answer node id -> generate route chunks)"
|
||||
)
|
||||
|
||||
@ -15,14 +15,16 @@ class BaseNode(ABC):
|
||||
_node_data_cls: type[BaseNodeData]
|
||||
_node_type: NodeType
|
||||
|
||||
def __init__(self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
self.id = id
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
self.app_id = graph_init_params.app_id
|
||||
@ -46,8 +48,7 @@ class BaseNode(ABC):
|
||||
self.node_data = self._node_data_cls(**config.get("data", {}))
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) \
|
||||
-> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
|
||||
def _run(self) -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
@ -62,14 +63,14 @@ class BaseNode(ABC):
|
||||
result = self._run()
|
||||
|
||||
if isinstance(result, NodeRunResult):
|
||||
yield RunCompletedEvent(
|
||||
run_result=result
|
||||
)
|
||||
yield RunCompletedEvent(run_result=result)
|
||||
else:
|
||||
yield from result
|
||||
|
||||
@classmethod
|
||||
def extract_variable_selector_to_variable_mapping(cls, graph_config: Mapping[str, Any], config: dict) -> Mapping[str, Sequence[str]]:
|
||||
def extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], config: dict
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
@ -82,17 +83,12 @@ class BaseNode(ABC):
|
||||
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
return cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config,
|
||||
node_id=node_id,
|
||||
node_data=node_data
|
||||
graph_config=graph_config, node_id=node_id, node_data=node_data
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: BaseNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: BaseNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
|
||||
@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
@ -25,11 +25,10 @@ class CodeNode(BaseNode):
|
||||
"""
|
||||
code_language = CodeLanguage.PYTHON3
|
||||
if filters:
|
||||
code_language = (filters.get("code_language", CodeLanguage.PYTHON3))
|
||||
code_language = filters.get("code_language", CodeLanguage.PYTHON3)
|
||||
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] = next(p for p in providers
|
||||
if p.is_accept_language(code_language))
|
||||
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
|
||||
|
||||
return code_provider.get_default_config()
|
||||
|
||||
@ -62,18 +61,10 @@ class CodeNode(BaseNode):
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result, node_data.outputs)
|
||||
except (CodeExecutionException, ValueError) as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e)
|
||||
)
|
||||
except (CodeExecutionError, ValueError) as e:
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
outputs=result
|
||||
)
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
||||
def _check_string(self, value: str, variable: str) -> str:
|
||||
"""
|
||||
@ -83,16 +74,18 @@ class CodeNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
if isinstance(value, type(None)):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Output variable `{variable}` must be a string")
|
||||
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
raise ValueError(f'The length of output variable `{variable}` must be'
|
||||
f' less than {dify_config.CODE_MAX_STRING_LENGTH} characters')
|
||||
|
||||
return value.replace('\x00', '')
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
raise ValueError(
|
||||
f"The length of output variable `{variable}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
|
||||
)
|
||||
|
||||
return value.replace("\x00", "")
|
||||
|
||||
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
|
||||
"""
|
||||
@ -102,26 +95,30 @@ class CodeNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
if not isinstance(value, int | float):
|
||||
if isinstance(value, type(None)):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Output variable `{variable}` must be a number")
|
||||
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
raise ValueError(f'Output variable `{variable}` is out of range,'
|
||||
f' it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}.')
|
||||
raise ValueError(
|
||||
f"Output variable `{variable}` is out of range,"
|
||||
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
|
||||
)
|
||||
|
||||
if isinstance(value, float):
|
||||
# raise error if precision is too high
|
||||
if len(str(value).split('.')[1]) > dify_config.CODE_MAX_PRECISION:
|
||||
raise ValueError(f'Output variable `{variable}` has too high precision,'
|
||||
f' it must be less than {dify_config.CODE_MAX_PRECISION} digits.')
|
||||
if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION:
|
||||
raise ValueError(
|
||||
f"Output variable `{variable}` has too high precision,"
|
||||
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
def _transform_result(self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]],
|
||||
prefix: str = '',
|
||||
depth: int = 1) -> dict:
|
||||
def _transform_result(
|
||||
self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1
|
||||
) -> dict:
|
||||
"""
|
||||
Transform result
|
||||
:param result: result
|
||||
@ -139,183 +136,187 @@ class CodeNode(BaseNode):
|
||||
self._transform_result(
|
||||
result=output_value,
|
||||
output_schema=None,
|
||||
prefix=f'{prefix}.{output_name}' if prefix else output_name,
|
||||
depth=depth + 1
|
||||
prefix=f"{prefix}.{output_name}" if prefix else output_name,
|
||||
depth=depth + 1,
|
||||
)
|
||||
elif isinstance(output_value, int | float):
|
||||
self._check_number(
|
||||
value=output_value,
|
||||
variable=f'{prefix}.{output_name}' if prefix else output_name
|
||||
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
|
||||
)
|
||||
elif isinstance(output_value, str):
|
||||
self._check_string(
|
||||
value=output_value,
|
||||
variable=f'{prefix}.{output_name}' if prefix else output_name
|
||||
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
|
||||
)
|
||||
elif isinstance(output_value, list):
|
||||
first_element = output_value[0] if len(output_value) > 0 else None
|
||||
if first_element is not None:
|
||||
if isinstance(first_element, int | float) and all(value is None or isinstance(value, int | float) for value in output_value):
|
||||
if isinstance(first_element, int | float) and all(
|
||||
value is None or isinstance(value, int | float) for value in output_value
|
||||
):
|
||||
for i, value in enumerate(output_value):
|
||||
self._check_number(
|
||||
value=value,
|
||||
variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]'
|
||||
variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
|
||||
)
|
||||
elif isinstance(first_element, str) and all(value is None or isinstance(value, str) for value in output_value):
|
||||
elif isinstance(first_element, str) and all(
|
||||
value is None or isinstance(value, str) for value in output_value
|
||||
):
|
||||
for i, value in enumerate(output_value):
|
||||
self._check_string(
|
||||
value=value,
|
||||
variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]'
|
||||
variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
|
||||
)
|
||||
elif isinstance(first_element, dict) and all(value is None or isinstance(value, dict) for value in output_value):
|
||||
elif isinstance(first_element, dict) and all(
|
||||
value is None or isinstance(value, dict) for value in output_value
|
||||
):
|
||||
for i, value in enumerate(output_value):
|
||||
if value is not None:
|
||||
self._transform_result(
|
||||
result=value,
|
||||
output_schema=None,
|
||||
prefix=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]',
|
||||
depth=depth + 1
|
||||
prefix=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
|
||||
depth=depth + 1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type.')
|
||||
elif isinstance(output_value, type(None)):
|
||||
raise ValueError(
|
||||
f"Output {prefix}.{output_name} is not a valid array."
|
||||
f" make sure all elements are of the same type."
|
||||
)
|
||||
elif output_value is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f'Output {prefix}.{output_name} is not a valid type.')
|
||||
|
||||
raise ValueError(f"Output {prefix}.{output_name} is not a valid type.")
|
||||
|
||||
return result
|
||||
|
||||
parameters_validated = {}
|
||||
for output_name, output_config in output_schema.items():
|
||||
dot = '.' if prefix else ''
|
||||
dot = "." if prefix else ""
|
||||
if output_name not in result:
|
||||
raise ValueError(f'Output {prefix}{dot}{output_name} is missing.')
|
||||
|
||||
if output_config.type == 'object':
|
||||
raise ValueError(f"Output {prefix}{dot}{output_name} is missing.")
|
||||
|
||||
if output_config.type == "object":
|
||||
# check if output is object
|
||||
if not isinstance(result.get(output_name), dict):
|
||||
if isinstance(result.get(output_name), type(None)):
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Output {prefix}{dot}{output_name} is not an object, got {type(result.get(output_name))} instead.'
|
||||
f"Output {prefix}{dot}{output_name} is not an object,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
transformed_result[output_name] = self._transform_result(
|
||||
result=result[output_name],
|
||||
output_schema=output_config.children,
|
||||
prefix=f'{prefix}.{output_name}',
|
||||
depth=depth + 1
|
||||
prefix=f"{prefix}.{output_name}",
|
||||
depth=depth + 1,
|
||||
)
|
||||
elif output_config.type == 'number':
|
||||
elif output_config.type == "number":
|
||||
# check if number available
|
||||
transformed_result[output_name] = self._check_number(
|
||||
value=result[output_name],
|
||||
variable=f'{prefix}{dot}{output_name}'
|
||||
value=result[output_name], variable=f"{prefix}{dot}{output_name}"
|
||||
)
|
||||
elif output_config.type == 'string':
|
||||
elif output_config.type == "string":
|
||||
# check if string available
|
||||
transformed_result[output_name] = self._check_string(
|
||||
value=result[output_name],
|
||||
variable=f'{prefix}{dot}{output_name}',
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == 'array[number]':
|
||||
elif output_config.type == "array[number]":
|
||||
# check if array of number available
|
||||
if not isinstance(result[output_name], list):
|
||||
if isinstance(result[output_name], type(None)):
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
f'The length of output variable `{prefix}{dot}{output_name}` must be'
|
||||
f' less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements.'
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
self._check_number(
|
||||
value=value,
|
||||
variable=f'{prefix}{dot}{output_name}[{i}]'
|
||||
)
|
||||
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == 'array[string]':
|
||||
elif output_config.type == "array[string]":
|
||||
# check if array of string available
|
||||
if not isinstance(result[output_name], list):
|
||||
if isinstance(result[output_name], type(None)):
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
f'The length of output variable `{prefix}{dot}{output_name}` must be'
|
||||
f' less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements.'
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
self._check_string(
|
||||
value=value,
|
||||
variable=f'{prefix}{dot}{output_name}[{i}]'
|
||||
)
|
||||
self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == 'array[object]':
|
||||
elif output_config.type == "array[object]":
|
||||
# check if array of object available
|
||||
if not isinstance(result[output_name], list):
|
||||
if isinstance(result[output_name], type(None)):
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
f'The length of output variable `{prefix}{dot}{output_name}` must be'
|
||||
f' less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements.'
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
||||
|
||||
for i, value in enumerate(result[output_name]):
|
||||
if not isinstance(value, dict):
|
||||
if isinstance(value, type(None)):
|
||||
if value is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Output {prefix}{dot}{output_name}[{i}] is not an object, got {type(value)} instead at index {i}.'
|
||||
f"Output {prefix}{dot}{output_name}[{i}] is not an object,"
|
||||
f" got {type(value)} instead at index {i}."
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
None if value is None else self._transform_result(
|
||||
None
|
||||
if value is None
|
||||
else self._transform_result(
|
||||
result=value,
|
||||
output_schema=output_config.children,
|
||||
prefix=f'{prefix}{dot}{output_name}[{i}]',
|
||||
depth=depth + 1
|
||||
prefix=f"{prefix}{dot}{output_name}[{i}]",
|
||||
depth=depth + 1,
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
else:
|
||||
raise ValueError(f'Output type {output_config.type} is not supported.')
|
||||
|
||||
raise ValueError(f"Output type {output_config.type} is not supported.")
|
||||
|
||||
parameters_validated[output_name] = True
|
||||
|
||||
# check if all output parameters are validated
|
||||
if len(parameters_validated) != len(result):
|
||||
raise ValueError('Not all output parameters are validated.')
|
||||
raise ValueError("Not all output parameters are validated.")
|
||||
|
||||
return transformed_result
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: CodeNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: CodeNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@ -325,5 +326,6 @@ class CodeNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in node_data.variables
|
||||
}
|
||||
|
||||
@ -11,9 +11,10 @@ class CodeNodeData(BaseNodeData):
|
||||
"""
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]']
|
||||
children: Optional[dict[str, 'Output']] = None
|
||||
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
children: Optional[dict[str, "Output"]] = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
name: str
|
||||
@ -23,4 +24,4 @@ class CodeNodeData(BaseNodeData):
|
||||
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
|
||||
code: str
|
||||
outputs: dict[str, Output]
|
||||
dependencies: Optional[list[Dependency]] = None
|
||||
dependencies: Optional[list[Dependency]] = None
|
||||
|
||||
@ -25,18 +25,11 @@ class EndNode(BaseNode):
|
||||
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
||||
outputs[variable_selector.variable] = value
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=outputs,
|
||||
outputs=outputs
|
||||
)
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, outputs=outputs)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: EndNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: EndNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
|
||||
@ -3,13 +3,13 @@ from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
|
||||
|
||||
|
||||
class EndStreamGeneratorRouter:
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_parallel_mapping: dict[str, str]
|
||||
) -> EndStreamParam:
|
||||
def init(
|
||||
cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_parallel_mapping: dict[str, str],
|
||||
) -> EndStreamParam:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
@ -17,7 +17,7 @@ class EndStreamGeneratorRouter:
|
||||
# parse stream output node value selector of end nodes
|
||||
end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
|
||||
for end_node_id, node_config in node_id_config_mapping.items():
|
||||
if not node_config.get('data', {}).get('type') == NodeType.END.value:
|
||||
if node_config.get("data", {}).get("type") != NodeType.END.value:
|
||||
continue
|
||||
|
||||
# skip end node in parallel
|
||||
@ -33,18 +33,18 @@ class EndStreamGeneratorRouter:
|
||||
end_dependencies = cls._fetch_ends_dependencies(
|
||||
end_node_ids=end_node_ids,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_id_config_mapping=node_id_config_mapping
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
)
|
||||
|
||||
return EndStreamParam(
|
||||
end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping,
|
||||
end_dependencies=end_dependencies
|
||||
end_dependencies=end_dependencies,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_stream_variable_selector_from_node_data(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
node_data: EndNodeData) -> list[list[str]]:
|
||||
def extract_stream_variable_selector_from_node_data(
|
||||
cls, node_id_config_mapping: dict[str, dict], node_data: EndNodeData
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Extract stream variable selector from node data
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
@ -59,21 +59,22 @@ class EndStreamGeneratorRouter:
|
||||
continue
|
||||
|
||||
node_id = variable_selector.value_selector[0]
|
||||
if node_id != 'sys' and node_id in node_id_config_mapping:
|
||||
if node_id != "sys" and node_id in node_id_config_mapping:
|
||||
node = node_id_config_mapping[node_id]
|
||||
node_type = node.get('data', {}).get('type')
|
||||
node_type = node.get("data", {}).get("type")
|
||||
if (
|
||||
variable_selector.value_selector not in value_selectors
|
||||
and node_type == NodeType.LLM.value
|
||||
and variable_selector.value_selector[1] == 'text'
|
||||
and node_type == NodeType.LLM.value
|
||||
and variable_selector.value_selector[1] == "text"
|
||||
):
|
||||
value_selectors.append(variable_selector.value_selector)
|
||||
|
||||
return value_selectors
|
||||
|
||||
@classmethod
|
||||
def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \
|
||||
-> list[list[str]]:
|
||||
def _extract_stream_variable_selector(
|
||||
cls, node_id_config_mapping: dict[str, dict], config: dict
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Extract stream variable selector from node config
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
@ -84,11 +85,12 @@ class EndStreamGeneratorRouter:
|
||||
return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data)
|
||||
|
||||
@classmethod
|
||||
def _fetch_ends_dependencies(cls,
|
||||
end_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict]
|
||||
) -> dict[str, list[str]]:
|
||||
def _fetch_ends_dependencies(
|
||||
cls,
|
||||
end_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch end dependencies
|
||||
:param end_node_ids: end node ids
|
||||
@ -106,20 +108,21 @@ class EndStreamGeneratorRouter:
|
||||
end_node_id=end_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
end_dependencies=end_dependencies
|
||||
end_dependencies=end_dependencies,
|
||||
)
|
||||
|
||||
return end_dependencies
|
||||
|
||||
@classmethod
|
||||
def _recursive_fetch_end_dependencies(cls,
|
||||
current_node_id: str,
|
||||
end_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]],
|
||||
# type: ignore[name-defined]
|
||||
end_dependencies: dict[str, list[str]]
|
||||
) -> None:
|
||||
def _recursive_fetch_end_dependencies(
|
||||
cls,
|
||||
current_node_id: str,
|
||||
end_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]],
|
||||
# type: ignore[name-defined]
|
||||
end_dependencies: dict[str, list[str]],
|
||||
) -> None:
|
||||
"""
|
||||
Recursive fetch end dependencies
|
||||
:param current_node_id: current node id
|
||||
@ -132,11 +135,11 @@ class EndStreamGeneratorRouter:
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
|
||||
if source_node_type in (
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
):
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
if source_node_type in {
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
}:
|
||||
end_dependencies[end_node_id].append(source_node_id)
|
||||
else:
|
||||
cls._recursive_fetch_end_dependencies(
|
||||
@ -144,5 +147,5 @@ class EndStreamGeneratorRouter:
|
||||
end_node_id=end_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
end_dependencies=end_dependencies
|
||||
end_dependencies=end_dependencies,
|
||||
)
|
||||
|
||||
@ -15,7 +15,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndStreamProcessor(StreamProcessor):
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
super().__init__(graph, variable_pool)
|
||||
self.end_stream_param = graph.end_stream_param
|
||||
@ -26,9 +25,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
self.has_outputed = False
|
||||
self.outputed_node_ids = set()
|
||||
|
||||
def process(self,
|
||||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
|
||||
@ -38,7 +35,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.in_iteration_id:
|
||||
if self.has_outputed and event.node_id not in self.outputed_node_ids:
|
||||
event.chunk_content = '\n' + event.chunk_content
|
||||
event.chunk_content = "\n" + event.chunk_content
|
||||
|
||||
self.outputed_node_ids.add(event.node_id)
|
||||
self.has_outputed = True
|
||||
@ -51,13 +48,13 @@ class EndStreamProcessor(StreamProcessor):
|
||||
]
|
||||
else:
|
||||
stream_out_end_node_ids = self._get_stream_out_end_node_ids(event)
|
||||
self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
] = stream_out_end_node_ids
|
||||
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
|
||||
stream_out_end_node_ids
|
||||
)
|
||||
|
||||
if stream_out_end_node_ids:
|
||||
if self.has_outputed and event.node_id not in self.outputed_node_ids:
|
||||
event.chunk_content = '\n' + event.chunk_content
|
||||
event.chunk_content = "\n" + event.chunk_content
|
||||
|
||||
self.outputed_node_ids.add(event.node_id)
|
||||
self.has_outputed = True
|
||||
@ -86,9 +83,9 @@ class EndStreamProcessor(StreamProcessor):
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
def _generate_stream_outputs_when_node_finished(
|
||||
self, event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:param event: node run succeeded event
|
||||
@ -96,10 +93,12 @@ class EndStreamProcessor(StreamProcessor):
|
||||
"""
|
||||
for end_node_id, position in self.route_position.items():
|
||||
# all depends on end node id not in rest node ids
|
||||
if (event.route_node_state.node_id != end_node_id
|
||||
and (end_node_id not in self.rest_node_ids
|
||||
or not all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.end_stream_param.end_dependencies[end_node_id]))):
|
||||
if event.route_node_state.node_id != end_node_id and (
|
||||
end_node_id not in self.rest_node_ids
|
||||
or not all(
|
||||
dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
route_position = self.route_position[end_node_id]
|
||||
@ -116,9 +115,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
if not value_selector:
|
||||
continue
|
||||
|
||||
value = self.variable_pool.get(
|
||||
value_selector
|
||||
)
|
||||
value = self.variable_pool.get(value_selector)
|
||||
|
||||
if value is None:
|
||||
break
|
||||
@ -128,7 +125,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
if text:
|
||||
current_node_id = value_selector[0]
|
||||
if self.has_outputed and current_node_id not in self.outputed_node_ids:
|
||||
text = '\n' + text
|
||||
text = "\n" + text
|
||||
|
||||
self.outputed_node_ids.add(current_node_id)
|
||||
self.has_outputed = True
|
||||
@ -165,8 +162,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
continue
|
||||
|
||||
# all depends on end node id not in rest node ids
|
||||
if all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
|
||||
if all(dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
|
||||
if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]):
|
||||
continue
|
||||
|
||||
@ -178,7 +174,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
break
|
||||
|
||||
position += 1
|
||||
|
||||
|
||||
if not value_selector:
|
||||
continue
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ class EndNodeData(BaseNodeData):
|
||||
"""
|
||||
END Node Data.
|
||||
"""
|
||||
|
||||
outputs: list[VariableSelector]
|
||||
|
||||
|
||||
@ -15,11 +16,10 @@ class EndStreamParam(BaseModel):
|
||||
"""
|
||||
EndStreamParam entity
|
||||
"""
|
||||
|
||||
end_dependencies: dict[str, list[str]] = Field(
|
||||
...,
|
||||
description="end dependencies (end node id -> dependent node ids)"
|
||||
..., description="end dependencies (end node id -> dependent node ids)"
|
||||
)
|
||||
end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field(
|
||||
...,
|
||||
description="end stream variable selector mapping (end node id -> stream variable selectors)"
|
||||
..., description="end stream variable selector mapping (end node id -> stream variable selectors)"
|
||||
)
|
||||
|
||||
@ -7,32 +7,32 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class HttpRequestNodeAuthorizationConfig(BaseModel):
|
||||
type: Literal[None, 'basic', 'bearer', 'custom']
|
||||
type: Literal[None, "basic", "bearer", "custom"]
|
||||
api_key: Union[None, str] = None
|
||||
header: Union[None, str] = None
|
||||
|
||||
|
||||
class HttpRequestNodeAuthorization(BaseModel):
|
||||
type: Literal['no-auth', 'api-key']
|
||||
type: Literal["no-auth", "api-key"]
|
||||
config: Optional[HttpRequestNodeAuthorizationConfig] = None
|
||||
|
||||
@field_validator('config', mode='before')
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def check_config(cls, v: HttpRequestNodeAuthorizationConfig, values: ValidationInfo):
|
||||
"""
|
||||
Check config, if type is no-auth, config should be None, otherwise it should be a dict.
|
||||
"""
|
||||
if values.data['type'] == 'no-auth':
|
||||
if values.data["type"] == "no-auth":
|
||||
return None
|
||||
else:
|
||||
if not v or not isinstance(v, dict):
|
||||
raise ValueError('config should be a dict')
|
||||
raise ValueError("config should be a dict")
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class HttpRequestNodeBody(BaseModel):
|
||||
type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json']
|
||||
type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json"]
|
||||
data: Union[None, str] = None
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
method: Literal['get', 'post', 'put', 'patch', 'delete', 'head']
|
||||
method: Literal["get", "post", "put", "patch", "delete", "head"]
|
||||
url: str
|
||||
authorization: HttpRequestNodeAuthorization
|
||||
headers: str
|
||||
|
||||
@ -6,8 +6,8 @@ from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
||||
import core.helper.ssrf_proxy as ssrf_proxy
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.http_request.entities import (
|
||||
@ -33,12 +33,12 @@ class HttpExecutorResponse:
|
||||
check if response is file
|
||||
"""
|
||||
content_type = self.get_content_type()
|
||||
file_content_types = ['image', 'audio', 'video']
|
||||
file_content_types = ["image", "audio", "video"]
|
||||
|
||||
return any(v in content_type for v in file_content_types)
|
||||
|
||||
def get_content_type(self) -> str:
|
||||
return self.headers.get('content-type', '')
|
||||
return self.headers.get("content-type", "")
|
||||
|
||||
def extract_file(self) -> tuple[str, bytes]:
|
||||
"""
|
||||
@ -47,28 +47,28 @@ class HttpExecutorResponse:
|
||||
if self.is_file:
|
||||
return self.get_content_type(), self.body
|
||||
|
||||
return '', b''
|
||||
return "", b""
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
if isinstance(self.response, httpx.Response):
|
||||
return self.response.text
|
||||
else:
|
||||
raise ValueError(f'Invalid response type {type(self.response)}')
|
||||
raise ValueError(f"Invalid response type {type(self.response)}")
|
||||
|
||||
@property
|
||||
def body(self) -> bytes:
|
||||
if isinstance(self.response, httpx.Response):
|
||||
return self.response.content
|
||||
else:
|
||||
raise ValueError(f'Invalid response type {type(self.response)}')
|
||||
raise ValueError(f"Invalid response type {type(self.response)}")
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
if isinstance(self.response, httpx.Response):
|
||||
return self.response.status_code
|
||||
else:
|
||||
raise ValueError(f'Invalid response type {type(self.response)}')
|
||||
raise ValueError(f"Invalid response type {type(self.response)}")
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
@ -77,11 +77,11 @@ class HttpExecutorResponse:
|
||||
@property
|
||||
def readable_size(self) -> str:
|
||||
if self.size < 1024:
|
||||
return f'{self.size} bytes'
|
||||
return f"{self.size} bytes"
|
||||
elif self.size < 1024 * 1024:
|
||||
return f'{(self.size / 1024):.2f} KB'
|
||||
return f"{(self.size / 1024):.2f} KB"
|
||||
else:
|
||||
return f'{(self.size / 1024 / 1024):.2f} MB'
|
||||
return f"{(self.size / 1024 / 1024):.2f} MB"
|
||||
|
||||
|
||||
class HttpExecutor:
|
||||
@ -120,7 +120,7 @@ class HttpExecutor:
|
||||
"""
|
||||
check if body is json
|
||||
"""
|
||||
if body and body.type == 'json' and body.data:
|
||||
if body and body.type == "json" and body.data:
|
||||
try:
|
||||
json.loads(body.data)
|
||||
return True
|
||||
@ -134,15 +134,15 @@ class HttpExecutor:
|
||||
"""
|
||||
Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}`
|
||||
"""
|
||||
kv_paris = convert_text.split('\n')
|
||||
kv_paris = convert_text.split("\n")
|
||||
result = {}
|
||||
for kv in kv_paris:
|
||||
if not kv.strip():
|
||||
continue
|
||||
|
||||
kv = kv.split(':', maxsplit=1)
|
||||
kv = kv.split(":", maxsplit=1)
|
||||
if len(kv) == 1:
|
||||
k, v = kv[0], ''
|
||||
k, v = kv[0], ""
|
||||
else:
|
||||
k, v = kv
|
||||
result[k.strip()] = v
|
||||
@ -166,31 +166,31 @@ class HttpExecutor:
|
||||
# check if it's a valid JSON
|
||||
is_valid_json = self._is_json_body(node_data.body)
|
||||
|
||||
body_data = node_data.body.data or ''
|
||||
body_data = node_data.body.data or ""
|
||||
if body_data:
|
||||
body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json)
|
||||
|
||||
content_type_is_set = any(key.lower() == 'content-type' for key in self.headers)
|
||||
if node_data.body.type == 'json' and not content_type_is_set:
|
||||
self.headers['Content-Type'] = 'application/json'
|
||||
elif node_data.body.type == 'x-www-form-urlencoded' and not content_type_is_set:
|
||||
self.headers['Content-Type'] = 'application/x-www-form-urlencoded'
|
||||
content_type_is_set = any(key.lower() == "content-type" for key in self.headers)
|
||||
if node_data.body.type == "json" and not content_type_is_set:
|
||||
self.headers["Content-Type"] = "application/json"
|
||||
elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set:
|
||||
self.headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
|
||||
if node_data.body.type in ['form-data', 'x-www-form-urlencoded']:
|
||||
if node_data.body.type in {"form-data", "x-www-form-urlencoded"}:
|
||||
body = self._to_dict(body_data)
|
||||
|
||||
if node_data.body.type == 'form-data':
|
||||
self.files = {k: ('', v) for k, v in body.items()}
|
||||
random_str = lambda n: ''.join([chr(randint(97, 122)) for _ in range(n)])
|
||||
self.boundary = f'----WebKitFormBoundary{random_str(16)}'
|
||||
if node_data.body.type == "form-data":
|
||||
self.files = {k: ("", v) for k, v in body.items()}
|
||||
random_str = lambda n: "".join([chr(randint(97, 122)) for _ in range(n)])
|
||||
self.boundary = f"----WebKitFormBoundary{random_str(16)}"
|
||||
|
||||
self.headers['Content-Type'] = f'multipart/form-data; boundary={self.boundary}'
|
||||
self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
|
||||
else:
|
||||
self.body = urlencode(body)
|
||||
elif node_data.body.type in ['json', 'raw-text']:
|
||||
elif node_data.body.type in {"json", "raw-text"}:
|
||||
self.body = body_data
|
||||
elif node_data.body.type == 'none':
|
||||
self.body = ''
|
||||
elif node_data.body.type == "none":
|
||||
self.body = ""
|
||||
|
||||
self.variable_selectors = (
|
||||
server_url_variable_selectors
|
||||
@ -202,23 +202,23 @@ class HttpExecutor:
|
||||
def _assembling_headers(self) -> dict[str, Any]:
|
||||
authorization = deepcopy(self.authorization)
|
||||
headers = deepcopy(self.headers) or {}
|
||||
if self.authorization.type == 'api-key':
|
||||
if self.authorization.type == "api-key":
|
||||
if self.authorization.config is None:
|
||||
raise ValueError('self.authorization config is required')
|
||||
raise ValueError("self.authorization config is required")
|
||||
if authorization.config is None:
|
||||
raise ValueError('authorization config is required')
|
||||
raise ValueError("authorization config is required")
|
||||
|
||||
if self.authorization.config.api_key is None:
|
||||
raise ValueError('api_key is required')
|
||||
raise ValueError("api_key is required")
|
||||
|
||||
if not authorization.config.header:
|
||||
authorization.config.header = 'Authorization'
|
||||
authorization.config.header = "Authorization"
|
||||
|
||||
if self.authorization.config.type == 'bearer':
|
||||
headers[authorization.config.header] = f'Bearer {authorization.config.api_key}'
|
||||
elif self.authorization.config.type == 'basic':
|
||||
headers[authorization.config.header] = f'Basic {authorization.config.api_key}'
|
||||
elif self.authorization.config.type == 'custom':
|
||||
if self.authorization.config.type == "bearer":
|
||||
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
|
||||
elif self.authorization.config.type == "basic":
|
||||
headers[authorization.config.header] = f"Basic {authorization.config.api_key}"
|
||||
elif self.authorization.config.type == "custom":
|
||||
headers[authorization.config.header] = authorization.config.api_key
|
||||
|
||||
return headers
|
||||
@ -230,10 +230,13 @@ class HttpExecutor:
|
||||
if isinstance(response, httpx.Response):
|
||||
executor_response = HttpExecutorResponse(response)
|
||||
else:
|
||||
raise ValueError(f'Invalid response type {type(response)}')
|
||||
raise ValueError(f"Invalid response type {type(response)}")
|
||||
|
||||
threshold_size = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE if executor_response.is_file \
|
||||
threshold_size = (
|
||||
dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE
|
||||
if executor_response.is_file
|
||||
else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
|
||||
)
|
||||
if executor_response.size > threshold_size:
|
||||
raise ValueError(
|
||||
f'{"File" if executor_response.is_file else "Text"} size is too large,'
|
||||
@ -248,17 +251,17 @@ class HttpExecutor:
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
kwargs = {
|
||||
'url': self.server_url,
|
||||
'headers': headers,
|
||||
'params': self.params,
|
||||
'timeout': (self.timeout.connect, self.timeout.read, self.timeout.write),
|
||||
'follow_redirects': True,
|
||||
"url": self.server_url,
|
||||
"headers": headers,
|
||||
"params": self.params,
|
||||
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
|
||||
"follow_redirects": True,
|
||||
}
|
||||
|
||||
if self.method in ('get', 'head', 'post', 'put', 'delete', 'patch'):
|
||||
if self.method in {"get", "head", "post", "put", "delete", "patch"}:
|
||||
response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs)
|
||||
else:
|
||||
raise ValueError(f'Invalid http method {self.method}')
|
||||
raise ValueError(f"Invalid http method {self.method}")
|
||||
return response
|
||||
|
||||
def invoke(self) -> HttpExecutorResponse:
|
||||
@ -280,15 +283,15 @@ class HttpExecutor:
|
||||
"""
|
||||
server_url = self.server_url
|
||||
if self.params:
|
||||
server_url += f'?{urlencode(self.params)}'
|
||||
server_url += f"?{urlencode(self.params)}"
|
||||
|
||||
raw_request = f'{self.method.upper()} {server_url} HTTP/1.1\n'
|
||||
raw_request = f"{self.method.upper()} {server_url} HTTP/1.1\n"
|
||||
|
||||
headers = self._assembling_headers()
|
||||
for k, v in headers.items():
|
||||
# get authorization header
|
||||
if self.authorization.type == 'api-key':
|
||||
authorization_header = 'Authorization'
|
||||
if self.authorization.type == "api-key":
|
||||
authorization_header = "Authorization"
|
||||
if self.authorization.config and self.authorization.config.header:
|
||||
authorization_header = self.authorization.config.header
|
||||
|
||||
@ -296,21 +299,21 @@ class HttpExecutor:
|
||||
raw_request += f'{k}: {"*" * len(v)}\n'
|
||||
continue
|
||||
|
||||
raw_request += f'{k}: {v}\n'
|
||||
raw_request += f"{k}: {v}\n"
|
||||
|
||||
raw_request += '\n'
|
||||
raw_request += "\n"
|
||||
|
||||
# if files, use multipart/form-data with boundary
|
||||
if self.files:
|
||||
boundary = self.boundary
|
||||
raw_request += f'--{boundary}'
|
||||
raw_request += f"--{boundary}"
|
||||
for k, v in self.files.items():
|
||||
raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n'
|
||||
raw_request += f'{v[1]}\n'
|
||||
raw_request += f'--{boundary}'
|
||||
raw_request += '--'
|
||||
raw_request += f"{v[1]}\n"
|
||||
raw_request += f"--{boundary}"
|
||||
raw_request += "--"
|
||||
else:
|
||||
raw_request += self.body or ''
|
||||
raw_request += self.body or ""
|
||||
|
||||
return raw_request
|
||||
|
||||
@ -328,9 +331,9 @@ class HttpExecutor:
|
||||
for variable_selector in variable_selectors:
|
||||
variable = variable_pool.get_any(variable_selector.value_selector)
|
||||
if variable is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
if escape_quotes and isinstance(variable, str):
|
||||
value = variable.replace('"', '\\"').replace('\n', '\\n')
|
||||
value = variable.replace('"', '\\"').replace("\n", "\\n")
|
||||
else:
|
||||
value = variable
|
||||
variable_value_mapping[variable_selector.variable] = value
|
||||
|
||||
@ -31,18 +31,18 @@ class HttpRequestNode(BaseNode):
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: dict | None = None) -> dict:
|
||||
return {
|
||||
'type': 'http-request',
|
||||
'config': {
|
||||
'method': 'get',
|
||||
'authorization': {
|
||||
'type': 'no-auth',
|
||||
"type": "http-request",
|
||||
"config": {
|
||||
"method": "get",
|
||||
"authorization": {
|
||||
"type": "no-auth",
|
||||
},
|
||||
'body': {'type': 'none'},
|
||||
'timeout': {
|
||||
"body": {"type": "none"},
|
||||
"timeout": {
|
||||
**HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(),
|
||||
'max_connect_timeout': dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
|
||||
'max_read_timeout': dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
|
||||
'max_write_timeout': dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
||||
"max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
|
||||
"max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
|
||||
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -52,9 +52,8 @@ class HttpRequestNode(BaseNode):
|
||||
# TODO: Switch to use segment directly
|
||||
if node_data.authorization.config and node_data.authorization.config.api_key:
|
||||
node_data.authorization.config.api_key = parser.convert_template(
|
||||
template=node_data.authorization.config.api_key,
|
||||
variable_pool=self.graph_runtime_state.variable_pool
|
||||
).text
|
||||
template=node_data.authorization.config.api_key, variable_pool=self.graph_runtime_state.variable_pool
|
||||
).text
|
||||
|
||||
# init http executor
|
||||
http_executor = None
|
||||
@ -62,7 +61,7 @@ class HttpRequestNode(BaseNode):
|
||||
http_executor = HttpExecutor(
|
||||
node_data=node_data,
|
||||
timeout=self._get_request_timeout(node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
)
|
||||
|
||||
# invoke http executor
|
||||
@ -71,7 +70,7 @@ class HttpRequestNode(BaseNode):
|
||||
process_data = {}
|
||||
if http_executor:
|
||||
process_data = {
|
||||
'request': http_executor.to_raw_request(),
|
||||
"request": http_executor.to_raw_request(),
|
||||
}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -84,13 +83,13 @@ class HttpRequestNode(BaseNode):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
'status_code': response.status_code,
|
||||
'body': response.content if not files else '',
|
||||
'headers': response.headers,
|
||||
'files': files,
|
||||
"status_code": response.status_code,
|
||||
"body": response.content if not files else "",
|
||||
"headers": response.headers,
|
||||
"files": files,
|
||||
},
|
||||
process_data={
|
||||
'request': http_executor.to_raw_request(),
|
||||
"request": http_executor.to_raw_request(),
|
||||
},
|
||||
)
|
||||
|
||||
@ -107,10 +106,7 @@ class HttpRequestNode(BaseNode):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: HttpRequestNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: HttpRequestNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@ -126,11 +122,11 @@ class HttpRequestNode(BaseNode):
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector
|
||||
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
except Exception as e:
|
||||
logging.exception(f'Failed to extract variable selector to variable mapping: {e}')
|
||||
logging.exception(f"Failed to extract variable selector to variable mapping: {e}")
|
||||
return {}
|
||||
|
||||
def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]:
|
||||
@ -144,7 +140,7 @@ class HttpRequestNode(BaseNode):
|
||||
# extract filename from url
|
||||
filename = path.basename(url)
|
||||
# extract extension if possible
|
||||
extension = guess_extension(mimetype) or '.bin'
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
|
||||
tool_file = ToolFileManager.create_file_by_raw(
|
||||
user_id=self.user_id,
|
||||
|
||||
@ -15,6 +15,7 @@ class IfElseNodeData(BaseNodeData):
|
||||
"""
|
||||
Case entity representing a single logical condition group
|
||||
"""
|
||||
|
||||
case_id: str
|
||||
logical_operator: Literal["and", "or"]
|
||||
conditions: list[Condition]
|
||||
|
||||
@ -20,13 +20,9 @@ class IfElseNode(BaseNode):
|
||||
node_data = self.node_data
|
||||
node_data = cast(IfElseNodeData, node_data)
|
||||
|
||||
node_inputs: dict[str, list] = {
|
||||
"conditions": []
|
||||
}
|
||||
node_inputs: dict[str, list] = {"conditions": []}
|
||||
|
||||
process_datas: dict[str, list] = {
|
||||
"condition_results": []
|
||||
}
|
||||
process_datas: dict[str, list] = {"condition_results": []}
|
||||
|
||||
input_conditions = []
|
||||
final_result = False
|
||||
@ -37,8 +33,7 @@ class IfElseNode(BaseNode):
|
||||
if node_data.cases:
|
||||
for case in node_data.cases:
|
||||
input_conditions, group_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=case.conditions
|
||||
variable_pool=self.graph_runtime_state.variable_pool, conditions=case.conditions
|
||||
)
|
||||
|
||||
# Apply the logical operator for the current case
|
||||
@ -60,8 +55,7 @@ class IfElseNode(BaseNode):
|
||||
else:
|
||||
# Fallback to old structure if cases are not defined
|
||||
input_conditions, group_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=node_data.conditions
|
||||
variable_pool=self.graph_runtime_state.variable_pool, conditions=node_data.conditions
|
||||
)
|
||||
|
||||
final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result)
|
||||
@ -69,21 +63,14 @@ class IfElseNode(BaseNode):
|
||||
selected_case_id = "true" if final_result else "false"
|
||||
|
||||
process_datas["condition_results"].append(
|
||||
{
|
||||
"group": "default",
|
||||
"results": group_result,
|
||||
"final_result": final_result
|
||||
}
|
||||
{"group": "default", "results": group_result, "final_result": final_result}
|
||||
)
|
||||
|
||||
node_inputs["conditions"] = input_conditions
|
||||
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_datas,
|
||||
error=str(e)
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_datas, error=str(e)
|
||||
)
|
||||
|
||||
outputs = {"result": final_result, "selected_case_id": selected_case_id}
|
||||
@ -92,18 +79,15 @@ class IfElseNode(BaseNode):
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_datas,
|
||||
edge_source_handle=selected_case_id if selected_case_id else "false", # Use case ID or 'default'
|
||||
outputs=outputs
|
||||
edge_source_handle=selected_case_id or "false", # Use case ID or 'default'
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IfElseNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IfElseNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
|
||||
@ -7,21 +7,25 @@ class IterationNodeData(BaseIterationNodeData):
|
||||
"""
|
||||
Iteration Node Data.
|
||||
"""
|
||||
parent_loop_id: Optional[str] = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
|
||||
parent_loop_id: Optional[str] = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
|
||||
|
||||
class IterationStartNodeData(BaseNodeData):
|
||||
"""
|
||||
Iteration Start Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IterationState(BaseIterationState):
|
||||
"""
|
||||
Iteration State.
|
||||
"""
|
||||
|
||||
outputs: list[Any] = None
|
||||
current_output: Optional[Any] = None
|
||||
|
||||
@ -29,6 +33,7 @@ class IterationState(BaseIterationState):
|
||||
"""
|
||||
Data.
|
||||
"""
|
||||
|
||||
iterator_length: int
|
||||
|
||||
def get_last_output(self) -> Optional[Any]:
|
||||
@ -38,9 +43,9 @@ class IterationState(BaseIterationState):
|
||||
if self.outputs:
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
|
||||
def get_current_output(self) -> Optional[Any]:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
return self.current_output
|
||||
return self.current_output
|
||||
|
||||
@ -16,6 +16,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
@ -33,6 +34,7 @@ class IterationNode(BaseNode):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = IterationNodeData
|
||||
_node_type = NodeType.ITERATION
|
||||
|
||||
@ -45,31 +47,26 @@ class IterationNode(BaseNode):
|
||||
|
||||
if not iterator_list_segment:
|
||||
raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found")
|
||||
|
||||
|
||||
iterator_list_value = iterator_list_segment.to_object()
|
||||
|
||||
if not isinstance(iterator_list_value, list):
|
||||
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
||||
|
||||
inputs = {
|
||||
"iterator_selector": iterator_list_value
|
||||
}
|
||||
inputs = {"iterator_selector": iterator_list_value}
|
||||
|
||||
graph_config = self.graph_config
|
||||
|
||||
|
||||
if not self.node_data.start_node_id:
|
||||
raise ValueError(f'field start_node_id in iteration {self.node_id} not found')
|
||||
raise ValueError(f"field start_node_id in iteration {self.node_id} not found")
|
||||
|
||||
root_node_id = self.node_data.start_node_id
|
||||
|
||||
# init graph
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=root_node_id
|
||||
)
|
||||
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
|
||||
|
||||
if not iteration_graph:
|
||||
raise ValueError('iteration graph not found')
|
||||
raise ValueError("iteration graph not found")
|
||||
|
||||
leaf_node_ids = iteration_graph.get_leaf_node_ids()
|
||||
iteration_leaf_node_ids = []
|
||||
@ -97,26 +94,21 @@ class IterationNode(BaseNode):
|
||||
Condition(
|
||||
variable_selector=[self.node_id, "index"],
|
||||
comparison_operator="<",
|
||||
value=str(len(iterator_list_value))
|
||||
value=str(len(iterator_list_value)),
|
||||
)
|
||||
]
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# append iteration variable (item, index) to variable pool
|
||||
variable_pool.add(
|
||||
[self.node_id, 'index'],
|
||||
0
|
||||
)
|
||||
variable_pool.add(
|
||||
[self.node_id, 'item'],
|
||||
iterator_list_value[0]
|
||||
)
|
||||
variable_pool.add([self.node_id, "index"], 0)
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
|
||||
|
||||
# init graph engine
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
@ -130,7 +122,7 @@ class IterationNode(BaseNode):
|
||||
graph_config=graph_config,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
)
|
||||
|
||||
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
@ -142,10 +134,8 @@ class IterationNode(BaseNode):
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
metadata={
|
||||
"iterator_length": len(iterator_list_value)
|
||||
},
|
||||
predecessor_node_id=self.previous_node_id
|
||||
metadata={"iterator_length": len(iterator_list_value)},
|
||||
predecessor_node_id=self.previous_node_id,
|
||||
)
|
||||
|
||||
yield IterationRunNextEvent(
|
||||
@ -154,7 +144,7 @@ class IterationNode(BaseNode):
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=0,
|
||||
pre_iteration_output=None
|
||||
pre_iteration_output=None,
|
||||
)
|
||||
|
||||
outputs: list[Any] = []
|
||||
@ -165,7 +155,11 @@ class IterationNode(BaseNode):
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||
event.in_iteration_id = self.node_id
|
||||
|
||||
if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START:
|
||||
if (
|
||||
isinstance(event, BaseNodeEvent)
|
||||
and event.node_type == NodeType.ITERATION_START
|
||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||
):
|
||||
continue
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
@ -176,7 +170,9 @@ class IterationNode(BaseNode):
|
||||
|
||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
|
||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any(
|
||||
[self.node_id, "index"]
|
||||
)
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
|
||||
yield event
|
||||
@ -192,21 +188,15 @@ class IterationNode(BaseNode):
|
||||
variable_pool.remove_node(node_id)
|
||||
|
||||
# move to next iteration
|
||||
current_index = variable_pool.get([self.node_id, 'index'])
|
||||
current_index = variable_pool.get([self.node_id, "index"])
|
||||
if current_index is None:
|
||||
raise ValueError(f'iteration {self.node_id} current index not found')
|
||||
raise ValueError(f"iteration {self.node_id} current index not found")
|
||||
|
||||
next_index = int(current_index.to_object()) + 1
|
||||
variable_pool.add(
|
||||
[self.node_id, 'index'],
|
||||
next_index
|
||||
)
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add(
|
||||
[self.node_id, 'item'],
|
||||
iterator_list_value[next_index]
|
||||
)
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
@ -214,8 +204,9 @@ class IterationNode(BaseNode):
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=next_index,
|
||||
pre_iteration_output=jsonable_encoder(
|
||||
current_iteration_output) if current_iteration_output else None
|
||||
pre_iteration_output=jsonable_encoder(current_iteration_output)
|
||||
if current_iteration_output
|
||||
else None,
|
||||
)
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
@ -227,13 +218,9 @@ class IterationNode(BaseNode):
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={
|
||||
"output": jsonable_encoder(outputs)
|
||||
},
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
@ -255,21 +242,14 @@ class IterationNode(BaseNode):
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={
|
||||
"output": jsonable_encoder(outputs)
|
||||
},
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||
}
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
'output': jsonable_encoder(outputs)
|
||||
}
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)}
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@ -282,16 +262,11 @@ class IterationNode(BaseNode):
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={
|
||||
"output": jsonable_encoder(outputs)
|
||||
},
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
@ -301,15 +276,12 @@ class IterationNode(BaseNode):
|
||||
)
|
||||
finally:
|
||||
# remove iteration variable (item, index) from variable pool after iteration run completed
|
||||
variable_pool.remove([self.node_id, 'index'])
|
||||
variable_pool.remove([self.node_id, 'item'])
|
||||
|
||||
variable_pool.remove([self.node_id, "index"])
|
||||
variable_pool.remove([self.node_id, "item"])
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IterationNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@ -319,36 +291,33 @@ class IterationNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
variable_mapping = {
|
||||
f'{node_id}.input_selector': node_data.iterator_selector,
|
||||
f"{node_id}.input_selector": node_data.iterator_selector,
|
||||
}
|
||||
|
||||
# init graph
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=node_data.start_node_id
|
||||
)
|
||||
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
|
||||
|
||||
if not iteration_graph:
|
||||
raise ValueError('iteration graph not found')
|
||||
|
||||
raise ValueError("iteration graph not found")
|
||||
|
||||
for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
|
||||
if sub_node_config.get('data', {}).get('iteration_id') != node_id:
|
||||
if sub_node_config.get("data", {}).get("iteration_id") != node_id:
|
||||
continue
|
||||
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type'))
|
||||
|
||||
node_type = NodeType.value_of(sub_node_config.get("data", {}).get("type"))
|
||||
node_cls = node_classes.get(node_type)
|
||||
if not node_cls:
|
||||
continue
|
||||
|
||||
node_cls = cast(BaseNode, node_cls)
|
||||
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config,
|
||||
config=sub_node_config
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
@ -356,7 +325,8 @@ class IterationNode(BaseNode):
|
||||
|
||||
# remove iteration variables
|
||||
sub_node_variable_mapping = {
|
||||
sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items()
|
||||
sub_node_id + "." + key: value
|
||||
for key, value in sub_node_variable_mapping.items()
|
||||
if value[0] != node_id
|
||||
}
|
||||
|
||||
@ -364,8 +334,7 @@ class IterationNode(BaseNode):
|
||||
|
||||
# remove variable out from iteration
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items()
|
||||
if value[0] not in iteration_graph.node_ids
|
||||
key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids
|
||||
}
|
||||
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@ -11,6 +11,7 @@ class IterationStartNode(BaseNode):
|
||||
"""
|
||||
Iteration Start Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = IterationStartNodeData
|
||||
_node_type = NodeType.ITERATION_START
|
||||
|
||||
@ -18,16 +19,11 @@ class IterationStartNode(BaseNode):
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
)
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IterationNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
|
||||
@ -9,6 +9,7 @@ class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
Reranking Model Config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
@ -17,6 +18,7 @@ class VectorSetting(BaseModel):
|
||||
"""
|
||||
Vector Setting.
|
||||
"""
|
||||
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
@ -26,6 +28,7 @@ class KeywordSetting(BaseModel):
|
||||
"""
|
||||
Keyword Setting.
|
||||
"""
|
||||
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
@ -33,6 +36,7 @@ class WeightedScoreConfig(BaseModel):
|
||||
"""
|
||||
Weighted score Config.
|
||||
"""
|
||||
|
||||
vector_setting: VectorSetting
|
||||
keyword_setting: KeywordSetting
|
||||
|
||||
@ -41,17 +45,20 @@ class MultipleRetrievalConfig(BaseModel):
|
||||
"""
|
||||
Multiple Retrieval Config.
|
||||
"""
|
||||
|
||||
top_k: int
|
||||
score_threshold: Optional[float] = None
|
||||
reranking_mode: str = 'reranking_model'
|
||||
reranking_mode: str = "reranking_model"
|
||||
reranking_enable: bool = True
|
||||
reranking_model: Optional[RerankingModelConfig] = None
|
||||
weights: Optional[WeightedScoreConfig] = None
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
Model Config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
@ -62,6 +69,7 @@ class SingleRetrievalConfig(BaseModel):
|
||||
"""
|
||||
Single Retrieval Config.
|
||||
"""
|
||||
|
||||
model: ModelConfig
|
||||
|
||||
|
||||
@ -69,9 +77,10 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
"""
|
||||
type: str = 'knowledge-retrieval'
|
||||
|
||||
type: str = "knowledge-retrieval"
|
||||
query_variable_selector: list[str]
|
||||
dataset_ids: list[str]
|
||||
retrieval_mode: Literal['single', 'multiple']
|
||||
retrieval_mode: Literal["single", "multiple"]
|
||||
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
|
||||
single_retrieval_config: Optional[SingleRetrievalConfig] = None
|
||||
|
||||
@ -24,14 +24,11 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
@ -45,62 +42,47 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
# extract variables
|
||||
variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector)
|
||||
query = variable
|
||||
variables = {
|
||||
'query': query
|
||||
}
|
||||
variables = {"query": query}
|
||||
if not query:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error="Query is required."
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
|
||||
)
|
||||
# retrieve knowledge
|
||||
try:
|
||||
results = self._fetch_dataset_retriever(
|
||||
node_data=node_data, query=query
|
||||
)
|
||||
outputs = {
|
||||
'result': results
|
||||
}
|
||||
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
|
||||
outputs = {"result": results}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
process_data=None,
|
||||
outputs=outputs
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error when running knowledge retrieval node")
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e)
|
||||
)
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
|
||||
|
||||
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[
|
||||
dict[str, Any]]:
|
||||
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
|
||||
available_datasets = []
|
||||
dataset_ids = node_data.dataset_ids
|
||||
|
||||
# Subquery: Count the number of available documents for each dataset
|
||||
subquery = db.session.query(
|
||||
Document.dataset_id,
|
||||
func.count(Document.id).label('available_document_count')
|
||||
).filter(
|
||||
Document.indexing_status == 'completed',
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.dataset_id.in_(dataset_ids)
|
||||
).group_by(Document.dataset_id).having(
|
||||
func.count(Document.id) > 0
|
||||
).subquery()
|
||||
subquery = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
|
||||
.filter(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
)
|
||||
.group_by(Document.dataset_id)
|
||||
.having(func.count(Document.id) > 0)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
results = db.session.query(Dataset).join(
|
||||
subquery, Dataset.id == subquery.c.dataset_id
|
||||
).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id.in_(dataset_ids)
|
||||
).all()
|
||||
results = (
|
||||
db.session.query(Dataset)
|
||||
.join(subquery, Dataset.id == subquery.c.dataset_id)
|
||||
.filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
for dataset in results:
|
||||
# pass if dataset is not available
|
||||
@ -117,16 +99,14 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model,
|
||||
credentials=model_config.credentials
|
||||
model=model_config.model, credentials=model_config.credentials
|
||||
)
|
||||
|
||||
if model_schema:
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features \
|
||||
or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
all_documents = dataset_retrieval.single_retrieve(
|
||||
available_datasets=available_datasets,
|
||||
@ -137,111 +117,108 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
query=query,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
planning_strategy=planning_strategy
|
||||
planning_strategy=planning_strategy,
|
||||
)
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
||||
if node_data.multiple_retrieval_config.reranking_mode == 'reranking_model':
|
||||
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
|
||||
reranking_model = {
|
||||
'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
'reranking_model_name': node_data.multiple_retrieval_config.reranking_model.model
|
||||
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
|
||||
}
|
||||
weights = None
|
||||
elif node_data.multiple_retrieval_config.reranking_mode == 'weighted_score':
|
||||
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
|
||||
reranking_model = None
|
||||
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
|
||||
weights = {
|
||||
'vector_setting': {
|
||||
"vector_weight": node_data.multiple_retrieval_config.weights.vector_setting.vector_weight,
|
||||
"embedding_provider_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_provider_name,
|
||||
"embedding_model_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_model_name,
|
||||
"vector_setting": {
|
||||
"vector_weight": vector_setting.vector_weight,
|
||||
"embedding_provider_name": vector_setting.embedding_provider_name,
|
||||
"embedding_model_name": vector_setting.embedding_model_name,
|
||||
},
|
||||
'keyword_setting': {
|
||||
"keyword_setting": {
|
||||
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
|
||||
}
|
||||
},
|
||||
}
|
||||
else:
|
||||
reranking_model = None
|
||||
weights = None
|
||||
all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id,
|
||||
self.user_from.value,
|
||||
available_datasets, query,
|
||||
node_data.multiple_retrieval_config.top_k,
|
||||
node_data.multiple_retrieval_config.score_threshold,
|
||||
node_data.multiple_retrieval_config.reranking_mode,
|
||||
reranking_model,
|
||||
weights,
|
||||
node_data.multiple_retrieval_config.reranking_enable,
|
||||
)
|
||||
all_documents = dataset_retrieval.multiple_retrieve(
|
||||
self.app_id,
|
||||
self.tenant_id,
|
||||
self.user_id,
|
||||
self.user_from.value,
|
||||
available_datasets,
|
||||
query,
|
||||
node_data.multiple_retrieval_config.top_k,
|
||||
node_data.multiple_retrieval_config.score_threshold,
|
||||
node_data.multiple_retrieval_config.reranking_mode,
|
||||
reranking_model,
|
||||
weights,
|
||||
node_data.multiple_retrieval_config.reranking_enable,
|
||||
)
|
||||
|
||||
context_list = []
|
||||
if all_documents:
|
||||
document_score_list = {}
|
||||
page_number_list = {}
|
||||
for item in all_documents:
|
||||
if item.metadata.get('score'):
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
# both 'page' and 'score' are metadata fields
|
||||
if item.metadata.get('page'):
|
||||
page_number_list[item.metadata['doc_id']] = item.metadata['page']
|
||||
if item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
|
||||
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
|
||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id.in_(dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids)
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
).all()
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(segments,
|
||||
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
|
||||
float('inf')))
|
||||
sorted_segments = sorted(
|
||||
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
|
||||
)
|
||||
|
||||
for segment in sorted_segments:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=segment.dataset_id
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = Document.query.filter(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
document = Document.query.filter(Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
|
||||
resource_number = 1
|
||||
if dataset and document:
|
||||
source = {
|
||||
'metadata': {
|
||||
'_source': 'knowledge',
|
||||
'position': resource_number,
|
||||
'dataset_id': dataset.id,
|
||||
'dataset_name': dataset.name,
|
||||
'document_id': document.id,
|
||||
'document_name': document.name,
|
||||
'document_data_source_type': document.data_source_type,
|
||||
'page': page_number_list.get(segment.index_node_id, None),
|
||||
'segment_id': segment.id,
|
||||
'retriever_from': 'workflow',
|
||||
'score': document_score_list.get(segment.index_node_id, None),
|
||||
'segment_hit_count': segment.hit_count,
|
||||
'segment_word_count': segment.word_count,
|
||||
'segment_position': segment.position,
|
||||
'segment_index_node_hash': segment.index_node_hash,
|
||||
"metadata": {
|
||||
"_source": "knowledge",
|
||||
"position": resource_number,
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"document_data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": "workflow",
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
"segment_hit_count": segment.hit_count,
|
||||
"segment_word_count": segment.word_count,
|
||||
"segment_position": segment.position,
|
||||
"segment_index_node_hash": segment.index_node_hash,
|
||||
},
|
||||
'title': document.name
|
||||
"title": document.name,
|
||||
}
|
||||
if segment.answer:
|
||||
source['content'] = f'question:{segment.get_sign_content()} \nanswer:{segment.answer}'
|
||||
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source['content'] = segment.get_sign_content()
|
||||
source["content"] = segment.get_sign_content()
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
return context_list
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: KnowledgeRetrievalNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@ -251,11 +228,12 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
variable_mapping = {}
|
||||
variable_mapping[node_id + '.query'] = node_data.query_variable_selector
|
||||
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
|
||||
return variable_mapping
|
||||
|
||||
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[
|
||||
ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
def _fetch_model_config(
|
||||
self, node_data: KnowledgeRetrievalNodeData
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
:param node_data: node data
|
||||
@ -266,10 +244,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider_name,
|
||||
model=model_name
|
||||
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
@ -280,8 +255,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_name,
|
||||
model_type=ModelType.LLM
|
||||
model=model_name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
@ -297,19 +271,16 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
# model config
|
||||
completion_params = node_data.single_retrieval_config.model.completion_params
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = node_data.single_retrieval_config.model.mode
|
||||
if not model_mode:
|
||||
raise ValueError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model_name,
|
||||
model_credentials
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
@ -11,6 +11,7 @@ class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
@ -21,6 +22,7 @@ class ContextConfig(BaseModel):
|
||||
"""
|
||||
Context Config.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
variable_selector: Optional[list[str]] = None
|
||||
|
||||
@ -29,37 +31,47 @@ class VisionConfig(BaseModel):
|
||||
"""
|
||||
Vision Config.
|
||||
"""
|
||||
|
||||
class Configs(BaseModel):
|
||||
"""
|
||||
Configs.
|
||||
"""
|
||||
detail: Literal['low', 'high']
|
||||
|
||||
detail: Literal["low", "high"]
|
||||
|
||||
enabled: bool
|
||||
configs: Optional[Configs] = None
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""
|
||||
Prompt Config.
|
||||
"""
|
||||
|
||||
jinja2_variables: Optional[list[VariableSelector]] = None
|
||||
|
||||
|
||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||
"""
|
||||
LLM Node Chat Model Message.
|
||||
"""
|
||||
|
||||
jinja2_text: Optional[str] = None
|
||||
|
||||
|
||||
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
"""
|
||||
LLM Node Chat Model Prompt Template.
|
||||
"""
|
||||
|
||||
jinja2_text: Optional[str] = None
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
"""
|
||||
LLM Node Data.
|
||||
"""
|
||||
|
||||
model: ModelConfig
|
||||
prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
|
||||
prompt_config: Optional[PromptConfig] = None
|
||||
|
||||
@ -45,11 +45,11 @@ if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
|
||||
|
||||
class ModelInvokeCompleted(BaseModel):
|
||||
"""
|
||||
Model invoke completed
|
||||
"""
|
||||
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
finish_reason: Optional[str] = None
|
||||
@ -89,7 +89,7 @@ class LLMNode(BaseNode):
|
||||
files = self._fetch_files(node_data, variable_pool)
|
||||
|
||||
if files:
|
||||
node_inputs['#files#'] = [file.to_dict() for file in files]
|
||||
node_inputs["#files#"] = [file.to_dict() for file in files]
|
||||
|
||||
# fetch context value
|
||||
generator = self._fetch_context(node_data, variable_pool)
|
||||
@ -100,7 +100,7 @@ class LLMNode(BaseNode):
|
||||
yield event
|
||||
|
||||
if context:
|
||||
node_inputs['#context#'] = context # type: ignore
|
||||
node_inputs["#context#"] = context # type: ignore
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
@ -111,24 +111,22 @@ class LLMNode(BaseNode):
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
node_data=node_data,
|
||||
query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value])
|
||||
if node_data.memory else None,
|
||||
query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None,
|
||||
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
process_data = {
|
||||
'model_mode': model_config.mode,
|
||||
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode,
|
||||
prompt_messages=prompt_messages
|
||||
"model_mode": model_config.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
'model_provider': model_config.provider,
|
||||
'model_name': model_config.model,
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
|
||||
# handle invoke result
|
||||
@ -136,10 +134,10 @@ class LLMNode(BaseNode):
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
result_text = ''
|
||||
result_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
for event in generator:
|
||||
@ -156,16 +154,12 @@ class LLMNode(BaseNode):
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
process_data=process_data
|
||||
process_data=process_data,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
outputs = {
|
||||
'text': result_text,
|
||||
'usage': jsonable_encoder(usage),
|
||||
'finish_reason': finish_reason
|
||||
}
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
@ -176,17 +170,19 @@ class LLMNode(BaseNode):
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
def _invoke_llm(self, node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None) \
|
||||
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
def _invoke_llm(
|
||||
self,
|
||||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data_model: node data model
|
||||
@ -206,9 +202,7 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = self._handle_invoke_result(
|
||||
invoke_result=invoke_result
|
||||
)
|
||||
generator = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
usage = LLMUsage.empty_usage()
|
||||
for event in generator:
|
||||
@ -219,8 +213,9 @@ class LLMNode(BaseNode):
|
||||
# deduct quota
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
|
||||
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
def _handle_invoke_result(
|
||||
self, invoke_result: LLMResult | Generator
|
||||
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
@ -231,17 +226,14 @@ class LLMNode(BaseNode):
|
||||
|
||||
model = None
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
full_text = ''
|
||||
full_text = ""
|
||||
usage = None
|
||||
finish_reason = None
|
||||
for result in invoke_result:
|
||||
text = result.delta.message.content
|
||||
full_text += text
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=text,
|
||||
from_variable_selector=[self.node_id, 'text']
|
||||
)
|
||||
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
@ -258,15 +250,11 @@ class LLMNode(BaseNode):
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
yield ModelInvokeCompleted(
|
||||
text=full_text,
|
||||
usage=usage,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason)
|
||||
|
||||
def _transform_chat_messages(self,
|
||||
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
def _transform_chat_messages(
|
||||
self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
"""
|
||||
Transform chat messages
|
||||
|
||||
@ -275,13 +263,13 @@ class LLMNode(BaseNode):
|
||||
"""
|
||||
|
||||
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
|
||||
if messages.edition_type == 'jinja2' and messages.jinja2_text:
|
||||
if messages.edition_type == "jinja2" and messages.jinja2_text:
|
||||
messages.text = messages.jinja2_text
|
||||
|
||||
return messages
|
||||
|
||||
for message in messages:
|
||||
if message.edition_type == 'jinja2' and message.jinja2_text:
|
||||
if message.edition_type == "jinja2" and message.jinja2_text:
|
||||
message.text = message.jinja2_text
|
||||
|
||||
return messages
|
||||
@ -300,17 +288,15 @@ class LLMNode(BaseNode):
|
||||
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_any(
|
||||
variable_selector.value_selector
|
||||
)
|
||||
value = variable_pool.get_any(variable_selector.value_selector)
|
||||
|
||||
def parse_dict(d: dict) -> str:
|
||||
"""
|
||||
Parse dict into string
|
||||
"""
|
||||
# check if it's a context structure
|
||||
if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
|
||||
return d['content']
|
||||
if "metadata" in d and "_source" in d["metadata"] and "content" in d:
|
||||
return d["content"]
|
||||
|
||||
# else, parse the dict
|
||||
try:
|
||||
@ -321,7 +307,7 @@ class LLMNode(BaseNode):
|
||||
if isinstance(value, str):
|
||||
value = value
|
||||
elif isinstance(value, list):
|
||||
result = ''
|
||||
result = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
result += parse_dict(item)
|
||||
@ -331,7 +317,7 @@ class LLMNode(BaseNode):
|
||||
result += str(item)
|
||||
else:
|
||||
result += str(item)
|
||||
result += '\n'
|
||||
result += "\n"
|
||||
value = result.strip()
|
||||
elif isinstance(value, dict):
|
||||
value = parse_dict(value)
|
||||
@ -366,18 +352,19 @@ class LLMNode(BaseNode):
|
||||
for variable_selector in variable_selectors:
|
||||
variable_value = variable_pool.get_any(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
|
||||
inputs[variable_selector.variable] = variable_value
|
||||
|
||||
memory = node_data.memory
|
||||
if memory and memory.query_prompt_template:
|
||||
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
|
||||
.extract_variable_selectors())
|
||||
query_variable_selectors = VariableTemplateParser(
|
||||
template=memory.query_prompt_template
|
||||
).extract_variable_selectors()
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_value = variable_pool.get_any(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
|
||||
inputs[variable_selector.variable] = variable_value
|
||||
|
||||
@ -393,7 +380,7 @@ class LLMNode(BaseNode):
|
||||
if not node_data.vision.enabled:
|
||||
return []
|
||||
|
||||
files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value])
|
||||
files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value])
|
||||
if not files:
|
||||
return []
|
||||
|
||||
@ -415,29 +402,25 @@ class LLMNode(BaseNode):
|
||||
context_value = variable_pool.get_any(node_data.context.variable_selector)
|
||||
if context_value:
|
||||
if isinstance(context_value, str):
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=[],
|
||||
context=context_value
|
||||
)
|
||||
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value)
|
||||
elif isinstance(context_value, list):
|
||||
context_str = ''
|
||||
context_str = ""
|
||||
original_retriever_resource = []
|
||||
for item in context_value:
|
||||
if isinstance(item, str):
|
||||
context_str += item + '\n'
|
||||
context_str += item + "\n"
|
||||
else:
|
||||
if 'content' not in item:
|
||||
raise ValueError(f'Invalid context structure: {item}')
|
||||
if "content" not in item:
|
||||
raise ValueError(f"Invalid context structure: {item}")
|
||||
|
||||
context_str += item['content'] + '\n'
|
||||
context_str += item["content"] + "\n"
|
||||
|
||||
retriever_resource = self._convert_to_original_retriever_resource(item)
|
||||
if retriever_resource:
|
||||
original_retriever_resource.append(retriever_resource)
|
||||
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=original_retriever_resource,
|
||||
context=context_str.strip()
|
||||
retriever_resources=original_retriever_resource, context=context_str.strip()
|
||||
)
|
||||
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
|
||||
@ -446,34 +429,37 @@ class LLMNode(BaseNode):
|
||||
:param context_dict: context dict
|
||||
:return:
|
||||
"""
|
||||
if ('metadata' in context_dict and '_source' in context_dict['metadata']
|
||||
and context_dict['metadata']['_source'] == 'knowledge'):
|
||||
metadata = context_dict.get('metadata', {})
|
||||
if (
|
||||
"metadata" in context_dict
|
||||
and "_source" in context_dict["metadata"]
|
||||
and context_dict["metadata"]["_source"] == "knowledge"
|
||||
):
|
||||
metadata = context_dict.get("metadata", {})
|
||||
|
||||
source = {
|
||||
'position': metadata.get('position'),
|
||||
'dataset_id': metadata.get('dataset_id'),
|
||||
'dataset_name': metadata.get('dataset_name'),
|
||||
'document_id': metadata.get('document_id'),
|
||||
'document_name': metadata.get('document_name'),
|
||||
'data_source_type': metadata.get('document_data_source_type'),
|
||||
'segment_id': metadata.get('segment_id'),
|
||||
'retriever_from': metadata.get('retriever_from'),
|
||||
'score': metadata.get('score'),
|
||||
'hit_count': metadata.get('segment_hit_count'),
|
||||
'word_count': metadata.get('segment_word_count'),
|
||||
'segment_position': metadata.get('segment_position'),
|
||||
'index_node_hash': metadata.get('segment_index_node_hash'),
|
||||
'content': context_dict.get('content'),
|
||||
'page': metadata.get('page'),
|
||||
"position": metadata.get("position"),
|
||||
"dataset_id": metadata.get("dataset_id"),
|
||||
"dataset_name": metadata.get("dataset_name"),
|
||||
"document_id": metadata.get("document_id"),
|
||||
"document_name": metadata.get("document_name"),
|
||||
"data_source_type": metadata.get("document_data_source_type"),
|
||||
"segment_id": metadata.get("segment_id"),
|
||||
"retriever_from": metadata.get("retriever_from"),
|
||||
"score": metadata.get("score"),
|
||||
"hit_count": metadata.get("segment_hit_count"),
|
||||
"word_count": metadata.get("segment_word_count"),
|
||||
"segment_position": metadata.get("segment_position"),
|
||||
"index_node_hash": metadata.get("segment_index_node_hash"),
|
||||
"content": context_dict.get("content"),
|
||||
}
|
||||
|
||||
return source
|
||||
|
||||
return None
|
||||
|
||||
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[
|
||||
ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
def _fetch_model_config(
|
||||
self, node_data_model: ModelConfig
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
:param node_data_model: node data model
|
||||
@ -484,10 +470,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider_name,
|
||||
model=model_name
|
||||
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
@ -498,8 +481,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_name,
|
||||
model_type=ModelType.LLM
|
||||
model=model_name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
@ -515,19 +497,16 @@ class LLMNode(BaseNode):
|
||||
# model config
|
||||
completion_params = node_data_model.completion_params
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = node_data_model.mode
|
||||
if not model_mode:
|
||||
raise ValueError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model_name,
|
||||
model_credentials
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
@ -543,9 +522,9 @@ class LLMNode(BaseNode):
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _fetch_memory(self, node_data_memory: Optional[MemoryConfig],
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
|
||||
def _fetch_memory(
|
||||
self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance
|
||||
) -> Optional[TokenBufferMemory]:
|
||||
"""
|
||||
Fetch memory
|
||||
:param node_data_memory: node data memory
|
||||
@ -556,35 +535,35 @@ class LLMNode(BaseNode):
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value])
|
||||
conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value])
|
||||
if conversation_id is None:
|
||||
return None
|
||||
|
||||
# get conversation
|
||||
conversation = db.session.query(Conversation).filter(
|
||||
Conversation.app_id == self.app_id,
|
||||
Conversation.id == conversation_id
|
||||
).first()
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
return memory
|
||||
|
||||
def _fetch_prompt_messages(self, node_data: LLMNodeData,
|
||||
query: Optional[str],
|
||||
query_prompt_template: Optional[str],
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
def _fetch_prompt_messages(
|
||||
self,
|
||||
node_data: LLMNodeData,
|
||||
query: Optional[str],
|
||||
query_prompt_template: Optional[str],
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Fetch prompt messages
|
||||
:param node_data: node data
|
||||
@ -601,7 +580,7 @@ class LLMNode(BaseNode):
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=node_data.prompt_template,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
query=query or "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
@ -621,8 +600,11 @@ class LLMNode(BaseNode):
|
||||
if not isinstance(prompt_message.content, str):
|
||||
prompt_message_content = []
|
||||
for content_item in prompt_message.content:
|
||||
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(
|
||||
content_item, ImagePromptMessageContent):
|
||||
if (
|
||||
vision_enabled
|
||||
and content_item.type == PromptMessageContentType.IMAGE
|
||||
and isinstance(content_item, ImagePromptMessageContent)
|
||||
):
|
||||
# Override vision config if LLM node has vision config
|
||||
if vision_detail:
|
||||
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
|
||||
@ -632,15 +614,18 @@ class LLMNode(BaseNode):
|
||||
|
||||
if len(prompt_message_content) > 1:
|
||||
prompt_message.content = prompt_message_content
|
||||
elif (len(prompt_message_content) == 1
|
||||
and prompt_message_content[0].type == PromptMessageContentType.TEXT):
|
||||
elif (
|
||||
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
|
||||
):
|
||||
prompt_message.content = prompt_message_content[0].data
|
||||
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
if not filtered_prompt_messages:
|
||||
raise ValueError("No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding.")
|
||||
raise ValueError(
|
||||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
@ -678,7 +663,7 @@ class LLMNode(BaseNode):
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = 1
|
||||
|
||||
if 'gpt-4' in model_instance.model:
|
||||
if "gpt-4" in model_instance.model:
|
||||
used_quota = 20
|
||||
else:
|
||||
used_quota = 1
|
||||
@ -689,16 +674,13 @@ class LLMNode(BaseNode):
|
||||
Provider.provider_name == model_instance.provider,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used
|
||||
).update({'quota_used': Provider.quota_used + used_quota})
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
).update({"quota_used": Provider.quota_used + used_quota})
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: LLMNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@ -712,11 +694,11 @@ class LLMNode(BaseNode):
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
if prompt.edition_type != 'jinja2':
|
||||
if prompt.edition_type != "jinja2":
|
||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
else:
|
||||
if prompt_template.edition_type != 'jinja2':
|
||||
if prompt_template.edition_type != "jinja2":
|
||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
@ -726,39 +708,38 @@ class LLMNode(BaseNode):
|
||||
|
||||
memory = node_data.memory
|
||||
if memory and memory.query_prompt_template:
|
||||
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
|
||||
.extract_variable_selectors())
|
||||
query_variable_selectors = VariableTemplateParser(
|
||||
template=memory.query_prompt_template
|
||||
).extract_variable_selectors()
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
if node_data.context.enabled:
|
||||
variable_mapping['#context#'] = node_data.context.variable_selector
|
||||
variable_mapping["#context#"] = node_data.context.variable_selector
|
||||
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value]
|
||||
variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value]
|
||||
|
||||
if node_data.memory:
|
||||
variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value]
|
||||
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
|
||||
|
||||
if node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
if prompt.edition_type == 'jinja2':
|
||||
if prompt.edition_type == "jinja2":
|
||||
enable_jinja = True
|
||||
break
|
||||
else:
|
||||
if prompt_template.edition_type == 'jinja2':
|
||||
if prompt_template.edition_type == "jinja2":
|
||||
enable_jinja = True
|
||||
|
||||
if enable_jinja:
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
variable_mapping = {
|
||||
node_id + '.' + key: value for key, value in variable_mapping.items()
|
||||
}
|
||||
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@ -775,26 +756,19 @@ class LLMNode(BaseNode):
|
||||
"prompt_templates": {
|
||||
"chat_model": {
|
||||
"prompts": [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "You are a helpful AI assistant.",
|
||||
"edition_type": "basic"
|
||||
}
|
||||
{"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"}
|
||||
]
|
||||
},
|
||||
"completion_model": {
|
||||
"conversation_histories_role": {
|
||||
"user_prefix": "Human",
|
||||
"assistant_prefix": "Assistant"
|
||||
},
|
||||
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
|
||||
"prompt": {
|
||||
"text": "Here is the chat histories between human and assistant, inside "
|
||||
"<histories></histories> XML tags.\n\n<histories>\n{{"
|
||||
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
|
||||
"edition_type": "basic"
|
||||
"<histories></histories> XML tags.\n\n<histories>\n{{"
|
||||
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
|
||||
"edition_type": "basic",
|
||||
},
|
||||
"stop": ["Human:"]
|
||||
}
|
||||
"stop": ["Human:"],
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
|
||||
|
||||
|
||||
@ -7,7 +6,8 @@ class LoopNodeData(BaseIterationNodeData):
|
||||
Loop Node Data.
|
||||
"""
|
||||
|
||||
|
||||
class LoopState(BaseIterationState):
|
||||
"""
|
||||
Loop State.
|
||||
"""
|
||||
"""
|
||||
|
||||
@ -10,6 +10,7 @@ class LoopNode(BaseNode):
|
||||
"""
|
||||
Loop Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = LoopNodeData
|
||||
_node_type = NodeType.LOOP
|
||||
|
||||
@ -21,14 +22,16 @@ class LoopNode(BaseNode):
|
||||
"""
|
||||
Get conditions.
|
||||
"""
|
||||
node_id = node_config.get('id')
|
||||
node_id = node_config.get("id")
|
||||
if not node_id:
|
||||
return []
|
||||
|
||||
# TODO waiting for implementation
|
||||
return [Condition(
|
||||
variable_selector=[node_id, 'index'],
|
||||
comparison_operator="≤",
|
||||
value_type="value_selector",
|
||||
value_selector=[]
|
||||
)]
|
||||
return [
|
||||
Condition(
|
||||
variable_selector=[node_id, "index"],
|
||||
comparison_operator="≤",
|
||||
value_type="value_selector",
|
||||
value_selector=[],
|
||||
)
|
||||
]
|
||||
|
||||
@ -8,47 +8,52 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
Model Config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: dict[str, Any] = {}
|
||||
|
||||
|
||||
class ParameterConfig(BaseModel):
|
||||
"""
|
||||
Parameter Config.
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: Literal['string', 'number', 'bool', 'select', 'array[string]', 'array[number]', 'array[object]']
|
||||
type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"]
|
||||
options: Optional[list[str]] = None
|
||||
description: str
|
||||
required: bool
|
||||
|
||||
@field_validator('name', mode='before')
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def validate_name(cls, value) -> str:
|
||||
if not value:
|
||||
raise ValueError('Parameter name is required')
|
||||
if value in ['__reason', '__is_success']:
|
||||
raise ValueError('Invalid parameter name, __reason and __is_success are reserved')
|
||||
raise ValueError("Parameter name is required")
|
||||
if value in {"__reason", "__is_success"}:
|
||||
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
|
||||
return value
|
||||
|
||||
|
||||
class ParameterExtractorNodeData(BaseNodeData):
|
||||
"""
|
||||
Parameter Extractor Node Data.
|
||||
"""
|
||||
|
||||
model: ModelConfig
|
||||
query: list[str]
|
||||
parameters: list[ParameterConfig]
|
||||
instruction: Optional[str] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
reasoning_mode: Literal['function_call', 'prompt']
|
||||
reasoning_mode: Literal["function_call", "prompt"]
|
||||
|
||||
@field_validator('reasoning_mode', mode='before')
|
||||
@field_validator("reasoning_mode", mode="before")
|
||||
@classmethod
|
||||
def set_reasoning_mode(cls, v) -> str:
|
||||
return v or 'function_call'
|
||||
return v or "function_call"
|
||||
|
||||
def get_parameter_json_schema(self) -> dict:
|
||||
"""
|
||||
@ -56,32 +61,26 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
|
||||
:return: parameter json schema
|
||||
"""
|
||||
parameters = {
|
||||
'type': 'object',
|
||||
'properties': {},
|
||||
'required': []
|
||||
}
|
||||
parameters = {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
for parameter in self.parameters:
|
||||
parameter_schema = {
|
||||
'description': parameter.description
|
||||
}
|
||||
parameter_schema = {"description": parameter.description}
|
||||
|
||||
if parameter.type in ['string', 'select']:
|
||||
parameter_schema['type'] = 'string'
|
||||
elif parameter.type.startswith('array'):
|
||||
parameter_schema['type'] = 'array'
|
||||
if parameter.type in {"string", "select"}:
|
||||
parameter_schema["type"] = "string"
|
||||
elif parameter.type.startswith("array"):
|
||||
parameter_schema["type"] = "array"
|
||||
nested_type = parameter.type[6:-1]
|
||||
parameter_schema['items'] = {'type': nested_type}
|
||||
parameter_schema["items"] = {"type": nested_type}
|
||||
else:
|
||||
parameter_schema['type'] = parameter.type
|
||||
parameter_schema["type"] = parameter.type
|
||||
|
||||
if parameter.type == 'select':
|
||||
parameter_schema['enum'] = parameter.options
|
||||
if parameter.type == "select":
|
||||
parameter_schema["enum"] = parameter.options
|
||||
|
||||
parameters["properties"][parameter.name] = parameter_schema
|
||||
|
||||
parameters['properties'][parameter.name] = parameter_schema
|
||||
|
||||
if parameter.required:
|
||||
parameters['required'].append(parameter.name)
|
||||
parameters["required"].append(parameter.name)
|
||||
|
||||
return parameters
|
||||
return parameters
|
||||
|
||||
@ -45,6 +45,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
"""
|
||||
Parameter Extractor Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = ParameterExtractorNodeData
|
||||
_node_type = NodeType.PARAMETER_EXTRACTOR
|
||||
|
||||
@ -57,11 +58,8 @@ class ParameterExtractorNode(LLMNode):
|
||||
"model": {
|
||||
"prompt_templates": {
|
||||
"completion_model": {
|
||||
"conversation_histories_role": {
|
||||
"user_prefix": "Human",
|
||||
"assistant_prefix": "Assistant"
|
||||
},
|
||||
"stop": ["Human:"]
|
||||
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
|
||||
"stop": ["Human:"],
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -78,9 +76,9 @@ class ParameterExtractorNode(LLMNode):
|
||||
query = variable
|
||||
|
||||
inputs = {
|
||||
'query': query,
|
||||
'parameters': jsonable_encoder(node_data.parameters),
|
||||
'instruction': jsonable_encoder(node_data.instruction),
|
||||
"query": query,
|
||||
"parameters": jsonable_encoder(node_data.parameters),
|
||||
"instruction": jsonable_encoder(node_data.instruction),
|
||||
}
|
||||
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
@ -95,30 +93,29 @@ class ParameterExtractorNode(LLMNode):
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance)
|
||||
|
||||
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \
|
||||
and node_data.reasoning_mode == 'function_call':
|
||||
# use function call
|
||||
if (
|
||||
set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}
|
||||
and node_data.reasoning_mode == "function_call"
|
||||
):
|
||||
# use function call
|
||||
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
|
||||
node_data, query, self.graph_runtime_state.variable_pool, model_config, memory
|
||||
)
|
||||
else:
|
||||
# use prompt engineering
|
||||
prompt_messages = self._generate_prompt_engineering_prompt(node_data,
|
||||
query,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
model_config,
|
||||
memory)
|
||||
prompt_messages = self._generate_prompt_engineering_prompt(
|
||||
node_data, query, self.graph_runtime_state.variable_pool, model_config, memory
|
||||
)
|
||||
prompt_message_tools = []
|
||||
|
||||
process_data = {
|
||||
'model_mode': model_config.mode,
|
||||
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode,
|
||||
prompt_messages=prompt_messages
|
||||
"model_mode": model_config.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
'usage': None,
|
||||
'function': {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
|
||||
'tool_call': None,
|
||||
"usage": None,
|
||||
"function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
|
||||
"tool_call": None,
|
||||
}
|
||||
|
||||
try:
|
||||
@ -129,20 +126,17 @@ class ParameterExtractorNode(LLMNode):
|
||||
tools=prompt_message_tools,
|
||||
stop=model_config.stop,
|
||||
)
|
||||
process_data['usage'] = jsonable_encoder(usage)
|
||||
process_data['tool_call'] = jsonable_encoder(tool_call)
|
||||
process_data['llm_text'] = text
|
||||
process_data["usage"] = jsonable_encoder(usage)
|
||||
process_data["tool_call"] = jsonable_encoder(tool_call)
|
||||
process_data["llm_text"] = text
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={
|
||||
'__is_success': 0,
|
||||
'__reason': str(e)
|
||||
},
|
||||
outputs={"__is_success": 0, "__reason": str(e)},
|
||||
error=str(e),
|
||||
metadata={}
|
||||
metadata={},
|
||||
)
|
||||
|
||||
error = None
|
||||
@ -167,24 +161,23 @@ class ParameterExtractorNode(LLMNode):
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={
|
||||
'__is_success': 1 if not error else 0,
|
||||
'__reason': error,
|
||||
**result
|
||||
},
|
||||
outputs={"__is_success": 1 if not error else 0, "__reason": error, **result},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
def _invoke_llm(self, node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str]) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
|
||||
def _invoke_llm(
|
||||
self,
|
||||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data_model: node data model
|
||||
@ -217,32 +210,35 @@ class ParameterExtractorNode(LLMNode):
|
||||
|
||||
return text, usage, tool_call
|
||||
|
||||
def _generate_function_call_prompt(self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
|
||||
def _generate_function_call_prompt(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
|
||||
"""
|
||||
Generate function call prompt.
|
||||
"""
|
||||
query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps(
|
||||
node_data.get_parameter_json_schema()))
|
||||
query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(
|
||||
content=query, structure=json.dumps(node_data.get_parameter_json_schema())
|
||||
)
|
||||
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
|
||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory,
|
||||
rest_token)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
|
||||
prompt_template = self._get_function_calling_prompt_template(
|
||||
node_data, query, variable_pool, memory, rest_token
|
||||
)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query='',
|
||||
query="",
|
||||
files=[],
|
||||
context='',
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# find last user message
|
||||
@ -255,124 +251,125 @@ class ParameterExtractorNode(LLMNode):
|
||||
example_messages = []
|
||||
for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE:
|
||||
id = uuid.uuid4().hex
|
||||
example_messages.extend([
|
||||
UserPromptMessage(content=example['user']['query']),
|
||||
AssistantPromptMessage(
|
||||
content=example['assistant']['text'],
|
||||
tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=id,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=example['assistant']['function_call']['name'],
|
||||
arguments=json.dumps(example['assistant']['function_call']['parameters']
|
||||
)
|
||||
))
|
||||
]
|
||||
),
|
||||
ToolPromptMessage(
|
||||
content='Great! You have called the function with the correct parameters.',
|
||||
tool_call_id=id
|
||||
),
|
||||
AssistantPromptMessage(
|
||||
content='I have extracted the parameters, let\'s move on.',
|
||||
)
|
||||
])
|
||||
example_messages.extend(
|
||||
[
|
||||
UserPromptMessage(content=example["user"]["query"]),
|
||||
AssistantPromptMessage(
|
||||
content=example["assistant"]["text"],
|
||||
tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=example["assistant"]["function_call"]["name"],
|
||||
arguments=json.dumps(example["assistant"]["function_call"]["parameters"]),
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
ToolPromptMessage(
|
||||
content="Great! You have called the function with the correct parameters.", tool_call_id=id
|
||||
),
|
||||
AssistantPromptMessage(
|
||||
content="I have extracted the parameters, let's move on.",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
prompt_messages = prompt_messages[:last_user_message_idx] + \
|
||||
example_messages + prompt_messages[last_user_message_idx:]
|
||||
prompt_messages = (
|
||||
prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:]
|
||||
)
|
||||
|
||||
# generate tool
|
||||
tool = PromptMessageTool(
|
||||
name=FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
description='Extract parameters from the natural language text',
|
||||
description="Extract parameters from the natural language text",
|
||||
parameters=node_data.get_parameter_json_schema(),
|
||||
)
|
||||
|
||||
return prompt_messages, [tool]
|
||||
|
||||
def _generate_prompt_engineering_prompt(self,
|
||||
data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> list[PromptMessage]:
|
||||
def _generate_prompt_engineering_prompt(
|
||||
self,
|
||||
data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate prompt engineering prompt.
|
||||
"""
|
||||
model_mode = ModelMode.value_of(data.model.mode)
|
||||
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
return self._generate_prompt_engineering_completion_prompt(
|
||||
data, query, variable_pool, model_config, memory
|
||||
)
|
||||
return self._generate_prompt_engineering_completion_prompt(data, query, variable_pool, model_config, memory)
|
||||
elif model_mode == ModelMode.CHAT:
|
||||
return self._generate_prompt_engineering_chat_prompt(
|
||||
data, query, variable_pool, model_config, memory
|
||||
)
|
||||
return self._generate_prompt_engineering_chat_prompt(data, query, variable_pool, model_config, memory)
|
||||
else:
|
||||
raise ValueError(f"Invalid model mode: {model_mode}")
|
||||
|
||||
def _generate_prompt_engineering_completion_prompt(self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> list[PromptMessage]:
|
||||
def _generate_prompt_engineering_completion_prompt(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate completion prompt.
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory,
|
||||
rest_token)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data, query, variable_pool, memory, rest_token
|
||||
)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={
|
||||
'structure': json.dumps(node_data.get_parameter_json_schema())
|
||||
},
|
||||
query='',
|
||||
inputs={"structure": json.dumps(node_data.get_parameter_json_schema())},
|
||||
query="",
|
||||
files=[],
|
||||
context='',
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _generate_prompt_engineering_chat_prompt(self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> list[PromptMessage]:
|
||||
def _generate_prompt_engineering_chat_prompt(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate chat prompt.
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data,
|
||||
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
|
||||
structure=json.dumps(node_data.get_parameter_json_schema()),
|
||||
text=query
|
||||
structure=json.dumps(node_data.get_parameter_json_schema()), text=query
|
||||
),
|
||||
variable_pool, memory, rest_token
|
||||
variable_pool,
|
||||
memory,
|
||||
rest_token,
|
||||
)
|
||||
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query='',
|
||||
query="",
|
||||
files=[],
|
||||
context='',
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# find last user message
|
||||
@ -384,18 +381,23 @@ class ParameterExtractorNode(LLMNode):
|
||||
# add example messages before last user message
|
||||
example_messages = []
|
||||
for example in CHAT_EXAMPLE:
|
||||
example_messages.extend([
|
||||
UserPromptMessage(content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
|
||||
structure=json.dumps(example['user']['json']),
|
||||
text=example['user']['query'],
|
||||
)),
|
||||
AssistantPromptMessage(
|
||||
content=json.dumps(example['assistant']['json']),
|
||||
)
|
||||
])
|
||||
example_messages.extend(
|
||||
[
|
||||
UserPromptMessage(
|
||||
content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
|
||||
structure=json.dumps(example["user"]["json"]),
|
||||
text=example["user"]["query"],
|
||||
)
|
||||
),
|
||||
AssistantPromptMessage(
|
||||
content=json.dumps(example["assistant"]["json"]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
prompt_messages = prompt_messages[:last_user_message_idx] + \
|
||||
example_messages + prompt_messages[last_user_message_idx:]
|
||||
prompt_messages = (
|
||||
prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
@ -410,28 +412,28 @@ class ParameterExtractorNode(LLMNode):
|
||||
if parameter.required and parameter.name not in result:
|
||||
raise ValueError(f"Parameter {parameter.name} is required")
|
||||
|
||||
if parameter.type == 'select' and parameter.options and result.get(parameter.name) not in parameter.options:
|
||||
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
|
||||
raise ValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == 'number' and not isinstance(result.get(parameter.name), int | float):
|
||||
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
|
||||
raise ValueError(f"Invalid `number` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == 'bool' and not isinstance(result.get(parameter.name), bool):
|
||||
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
|
||||
raise ValueError(f"Invalid `bool` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == 'string' and not isinstance(result.get(parameter.name), str):
|
||||
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
|
||||
raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type.startswith('array'):
|
||||
if parameter.type.startswith("array"):
|
||||
if not isinstance(result.get(parameter.name), list):
|
||||
raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
|
||||
nested_type = parameter.type[6:-1]
|
||||
for item in result.get(parameter.name):
|
||||
if nested_type == 'number' and not isinstance(item, int | float):
|
||||
if nested_type == "number" and not isinstance(item, int | float):
|
||||
raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
|
||||
if nested_type == 'string' and not isinstance(item, str):
|
||||
if nested_type == "string" and not isinstance(item, str):
|
||||
raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
|
||||
if nested_type == 'object' and not isinstance(item, dict):
|
||||
if nested_type == "object" and not isinstance(item, dict):
|
||||
raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
|
||||
return result
|
||||
|
||||
@ -443,12 +445,12 @@ class ParameterExtractorNode(LLMNode):
|
||||
for parameter in data.parameters:
|
||||
if parameter.name in result:
|
||||
# transform value
|
||||
if parameter.type == 'number':
|
||||
if parameter.type == "number":
|
||||
if isinstance(result[parameter.name], int | float):
|
||||
transformed_result[parameter.name] = result[parameter.name]
|
||||
elif isinstance(result[parameter.name], str):
|
||||
try:
|
||||
if '.' in result[parameter.name]:
|
||||
if "." in result[parameter.name]:
|
||||
result[parameter.name] = float(result[parameter.name])
|
||||
else:
|
||||
result[parameter.name] = int(result[parameter.name])
|
||||
@ -465,40 +467,40 @@ class ParameterExtractorNode(LLMNode):
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true')
|
||||
# elif isinstance(result[parameter.name], int):
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
elif parameter.type in ['string', 'select']:
|
||||
elif parameter.type in {"string", "select"}:
|
||||
if isinstance(result[parameter.name], str):
|
||||
transformed_result[parameter.name] = result[parameter.name]
|
||||
elif parameter.type.startswith('array'):
|
||||
elif parameter.type.startswith("array"):
|
||||
if isinstance(result[parameter.name], list):
|
||||
nested_type = parameter.type[6:-1]
|
||||
transformed_result[parameter.name] = []
|
||||
for item in result[parameter.name]:
|
||||
if nested_type == 'number':
|
||||
if nested_type == "number":
|
||||
if isinstance(item, int | float):
|
||||
transformed_result[parameter.name].append(item)
|
||||
elif isinstance(item, str):
|
||||
try:
|
||||
if '.' in item:
|
||||
if "." in item:
|
||||
transformed_result[parameter.name].append(float(item))
|
||||
else:
|
||||
transformed_result[parameter.name].append(int(item))
|
||||
except ValueError:
|
||||
pass
|
||||
elif nested_type == 'string':
|
||||
elif nested_type == "string":
|
||||
if isinstance(item, str):
|
||||
transformed_result[parameter.name].append(item)
|
||||
elif nested_type == 'object':
|
||||
elif nested_type == "object":
|
||||
if isinstance(item, dict):
|
||||
transformed_result[parameter.name].append(item)
|
||||
|
||||
if parameter.name not in transformed_result:
|
||||
if parameter.type == 'number':
|
||||
if parameter.type == "number":
|
||||
transformed_result[parameter.name] = 0
|
||||
elif parameter.type == 'bool':
|
||||
elif parameter.type == "bool":
|
||||
transformed_result[parameter.name] = False
|
||||
elif parameter.type in ['string', 'select']:
|
||||
transformed_result[parameter.name] = ''
|
||||
elif parameter.type.startswith('array'):
|
||||
elif parameter.type in {"string", "select"}:
|
||||
transformed_result[parameter.name] = ""
|
||||
elif parameter.type.startswith("array"):
|
||||
transformed_result[parameter.name] = []
|
||||
|
||||
return transformed_result
|
||||
@ -514,24 +516,24 @@ class ParameterExtractorNode(LLMNode):
|
||||
"""
|
||||
stack = []
|
||||
for i, c in enumerate(text):
|
||||
if c == '{' or c == '[':
|
||||
if c in {"{", "["}:
|
||||
stack.append(c)
|
||||
elif c == '}' or c == ']':
|
||||
elif c in {"}", "]"}:
|
||||
# check if stack is empty
|
||||
if not stack:
|
||||
return text[:i]
|
||||
# check if the last element in stack is matching
|
||||
if (c == '}' and stack[-1] == '{') or (c == ']' and stack[-1] == '['):
|
||||
if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["):
|
||||
stack.pop()
|
||||
if not stack:
|
||||
return text[:i + 1]
|
||||
return text[: i + 1]
|
||||
else:
|
||||
return text[:i]
|
||||
return None
|
||||
|
||||
# extract json from the text
|
||||
for idx in range(len(result)):
|
||||
if result[idx] == '{' or result[idx] == '[':
|
||||
if result[idx] == "{" or result[idx] == "[":
|
||||
json_str = extract_json(result[idx:])
|
||||
if json_str:
|
||||
try:
|
||||
@ -554,12 +556,12 @@ class ParameterExtractorNode(LLMNode):
|
||||
"""
|
||||
result = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.type == 'number':
|
||||
if parameter.type == "number":
|
||||
result[parameter.name] = 0
|
||||
elif parameter.type == 'bool':
|
||||
elif parameter.type == "bool":
|
||||
result[parameter.name] = False
|
||||
elif parameter.type in ['string', 'select']:
|
||||
result[parameter.name] = ''
|
||||
elif parameter.type in {"string", "select"}:
|
||||
result[parameter.name] = ""
|
||||
|
||||
return result
|
||||
|
||||
@ -575,71 +577,76 @@ class ParameterExtractorNode(LLMNode):
|
||||
|
||||
return variable_template_parser.format(inputs)
|
||||
|
||||
def _get_function_calling_prompt_template(self, node_data: ParameterExtractorNodeData, query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000) \
|
||||
-> list[ChatModelMessage]:
|
||||
def _get_function_calling_prompt_template(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
input_text = query
|
||||
memory_str = ''
|
||||
instruction = self._render_instruction(node_data.instruction or '', variable_pool)
|
||||
memory_str = ""
|
||||
instruction = self._render_instruction(node_data.instruction or "", variable_pool)
|
||||
|
||||
if memory:
|
||||
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
|
||||
message_limit=node_data.memory.window.size)
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction)
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=input_text
|
||||
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
else:
|
||||
raise ValueError(f"Model mode {model_mode} not support.")
|
||||
|
||||
def _get_prompt_engineering_prompt_template(self, node_data: ParameterExtractorNodeData, query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000) \
|
||||
-> list[ChatModelMessage]:
|
||||
|
||||
def _get_prompt_engineering_prompt_template(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
input_text = query
|
||||
memory_str = ''
|
||||
instruction = self._render_instruction(node_data.instruction or '', variable_pool)
|
||||
memory_str = ""
|
||||
instruction = self._render_instruction(node_data.instruction or "", variable_pool)
|
||||
|
||||
if memory:
|
||||
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
|
||||
message_limit=node_data.memory.window.size)
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction)
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=input_text
|
||||
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
return CompletionModelPromptTemplate(
|
||||
text=COMPLETION_GENERATE_JSON_PROMPT.format(histories=memory_str,
|
||||
text=input_text,
|
||||
instruction=instruction)
|
||||
.replace('{γγγ', '')
|
||||
.replace('}γγγ', '')
|
||||
text=COMPLETION_GENERATE_JSON_PROMPT.format(
|
||||
histories=memory_str, text=input_text, instruction=instruction
|
||||
)
|
||||
.replace("{γγγ", "")
|
||||
.replace("}γγγ", "")
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Model mode {model_mode} not support.")
|
||||
|
||||
def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: Optional[str]) -> int:
|
||||
def _calculate_rest_token(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: Optional[str],
|
||||
) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
@ -659,12 +666,12 @@ class ParameterExtractorNode(LLMNode):
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query='',
|
||||
query="",
|
||||
files=[],
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
@ -673,26 +680,28 @@ class ParameterExtractorNode(LLMNode):
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
prompt_messages
|
||||
) + 1000 # add 1000 to ensure tool call messages
|
||||
curr_message_tokens = (
|
||||
model_type_instance.get_num_tokens(model_config.model, model_config.credentials, prompt_messages) + 1000
|
||||
) # add 1000 to ensure tool call messages
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[
|
||||
ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
def _fetch_model_config(
|
||||
self, node_data_model: ModelConfig
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config.
|
||||
"""
|
||||
@ -703,10 +712,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ParameterExtractorNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: ParameterExtractorNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@ -715,17 +721,13 @@ class ParameterExtractorNode(LLMNode):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
variable_mapping = {
|
||||
'query': node_data.query
|
||||
}
|
||||
variable_mapping = {"query": node_data.query}
|
||||
|
||||
if node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
|
||||
for selector in variable_template_parser.extract_variable_selectors():
|
||||
variable_mapping[selector.variable] = selector.value_selector
|
||||
|
||||
variable_mapping = {
|
||||
node_id + '.' + key: value for key, value in variable_mapping.items()
|
||||
}
|
||||
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
FUNCTION_CALLING_EXTRACTOR_NAME = 'extract_parameters'
|
||||
FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters"
|
||||
|
||||
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy.
|
||||
### Task
|
||||
@ -23,7 +23,7 @@ Steps:
|
||||
To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples.
|
||||
### Final Output
|
||||
Produce well-formatted function calls in json without XML tags, as shown in the example.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside <context></context> XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside <structure></structure> XML tags.
|
||||
<context>
|
||||
@ -33,63 +33,52 @@ FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information fr
|
||||
<structure>
|
||||
\x7bstructure\x7d
|
||||
</structure>
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [{
|
||||
'user': {
|
||||
'query': 'What is the weather today in SF?',
|
||||
'function': {
|
||||
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'location': {
|
||||
'type': 'string',
|
||||
'description': 'The location to get the weather information',
|
||||
'required': True
|
||||
FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [
|
||||
{
|
||||
"user": {
|
||||
"query": "What is the weather today in SF?",
|
||||
"function": {
|
||||
"name": FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The location to get the weather information",
|
||||
"required": True,
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
'required': ['location']
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
"assistant": {
|
||||
"text": "I need always call the function with the correct parameters."
|
||||
" in this case, I need to call the function with the location parameter.",
|
||||
"function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"location": "San Francisco"}},
|
||||
},
|
||||
},
|
||||
'assistant': {
|
||||
'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the location parameter.',
|
||||
'function_call' : {
|
||||
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
'parameters': {
|
||||
'location': 'San Francisco'
|
||||
}
|
||||
}
|
||||
}
|
||||
}, {
|
||||
'user': {
|
||||
'query': 'I want to eat some apple pie.',
|
||||
'function': {
|
||||
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'food': {
|
||||
'type': 'string',
|
||||
'description': 'The food to eat',
|
||||
'required': True
|
||||
}
|
||||
{
|
||||
"user": {
|
||||
"query": "I want to eat some apple pie.",
|
||||
"function": {
|
||||
"name": FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"food": {"type": "string", "description": "The food to eat", "required": True}},
|
||||
"required": ["food"],
|
||||
},
|
||||
'required': ['food']
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
"assistant": {
|
||||
"text": "I need always call the function with the correct parameters."
|
||||
" in this case, I need to call the function with the food parameter.",
|
||||
"function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"food": "apple pie"}},
|
||||
},
|
||||
},
|
||||
'assistant': {
|
||||
'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the food parameter.',
|
||||
'function_call' : {
|
||||
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
'parameters': {
|
||||
'food': 'apple pie'
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
]
|
||||
|
||||
COMPLETION_GENERATE_JSON_PROMPT = """### Instructions:
|
||||
Some extra information are provided below, I should always follow the instructions as possible as I can.
|
||||
@ -130,7 +119,7 @@ Inside <text></text> XML tags, there is a text that I should extract parameters
|
||||
### Answer
|
||||
I should always output a valid JSON object. Output nothing other than the JSON object.
|
||||
```JSON
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object.
|
||||
The structure of the JSON object you can found in the instructions.
|
||||
@ -161,46 +150,33 @@ Inside <text></text> XML tags, there is a text that you should convert to a JSON
|
||||
</text>
|
||||
"""
|
||||
|
||||
CHAT_EXAMPLE = [{
|
||||
'user': {
|
||||
'query': 'What is the weather today in SF?',
|
||||
'json': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'location': {
|
||||
'type': 'string',
|
||||
'description': 'The location to get the weather information',
|
||||
'required': True
|
||||
}
|
||||
CHAT_EXAMPLE = [
|
||||
{
|
||||
"user": {
|
||||
"query": "What is the weather today in SF?",
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The location to get the weather information",
|
||||
"required": True,
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
'required': ['location']
|
||||
}
|
||||
},
|
||||
"assistant": {"text": "I need to output a valid JSON object.", "json": {"location": "San Francisco"}},
|
||||
},
|
||||
'assistant': {
|
||||
'text': 'I need to output a valid JSON object.',
|
||||
'json': {
|
||||
'location': 'San Francisco'
|
||||
}
|
||||
}
|
||||
}, {
|
||||
'user': {
|
||||
'query': 'I want to eat some apple pie.',
|
||||
'json': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'food': {
|
||||
'type': 'string',
|
||||
'description': 'The food to eat',
|
||||
'required': True
|
||||
}
|
||||
{
|
||||
"user": {
|
||||
"query": "I want to eat some apple pie.",
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": {"food": {"type": "string", "description": "The food to eat", "required": True}},
|
||||
"required": ["food"],
|
||||
},
|
||||
'required': ['food']
|
||||
}
|
||||
},
|
||||
"assistant": {"text": "I need to output a valid JSON object.", "json": {"result": "apple pie"}},
|
||||
},
|
||||
'assistant': {
|
||||
'text': 'I need to output a valid JSON object.',
|
||||
'json': {
|
||||
'result': 'apple pie'
|
||||
}
|
||||
}
|
||||
}]
|
||||
]
|
||||
|
||||
@ -8,8 +8,9 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
Model Config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
@ -20,6 +21,7 @@ class ClassConfig(BaseModel):
|
||||
"""
|
||||
Class Config.
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
@ -28,8 +30,9 @@ class QuestionClassifierNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
"""
|
||||
|
||||
query_variable_selector: list[str]
|
||||
type: str = 'question-classifier'
|
||||
type: str = "question-classifier"
|
||||
model: ModelConfig
|
||||
classes: list[ClassConfig]
|
||||
instruction: Optional[str] = None
|
||||
|
||||
@ -45,34 +45,25 @@ class QuestionClassifierNode(LLMNode):
|
||||
# extract variables
|
||||
variable = variable_pool.get(node_data.query_variable_selector)
|
||||
query = variable.value if variable else None
|
||||
variables = {
|
||||
'query': query
|
||||
}
|
||||
variables = {"query": query}
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
||||
# fetch instruction
|
||||
instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else ''
|
||||
instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else ""
|
||||
node_data.instruction = instruction
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._fetch_prompt(
|
||||
node_data=node_data,
|
||||
context='',
|
||||
query=query,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
node_data=node_data, context="", query=query, memory=memory, model_config=model_config
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop
|
||||
node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop
|
||||
)
|
||||
|
||||
result_text = ''
|
||||
result_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
for event in generator:
|
||||
@ -87,8 +78,8 @@ class QuestionClassifierNode(LLMNode):
|
||||
try:
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
# result_text_json = json.loads(result_text.strip('```JSON\n'))
|
||||
if 'category_name' in result_text_json and 'category_id' in result_text_json:
|
||||
category_id_result = result_text_json['category_id']
|
||||
if "category_name" in result_text_json and "category_id" in result_text_json:
|
||||
category_id_result = result_text_json["category_id"]
|
||||
classes = node_data.classes
|
||||
classes_map = {class_.id: class_.name for class_ in classes}
|
||||
category_ids = [_class.id for _class in classes]
|
||||
@ -100,17 +91,14 @@ class QuestionClassifierNode(LLMNode):
|
||||
logging.error(f"Failed to parse result text: {result_text}")
|
||||
try:
|
||||
process_data = {
|
||||
'model_mode': model_config.mode,
|
||||
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode,
|
||||
prompt_messages=prompt_messages
|
||||
"model_mode": model_config.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
'usage': jsonable_encoder(usage),
|
||||
'finish_reason': finish_reason
|
||||
}
|
||||
outputs = {
|
||||
'class_name': category_name
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
outputs = {"class_name": category_name}
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -121,9 +109,9 @@ class QuestionClassifierNode(LLMNode):
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
@ -134,17 +122,14 @@ class QuestionClassifierNode(LLMNode):
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: QuestionClassifierNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: QuestionClassifierNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@ -153,7 +138,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
variable_mapping = {'query': node_data.query_variable_selector}
|
||||
variable_mapping = {"query": node_data.query_variable_selector}
|
||||
variable_selectors = []
|
||||
if node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
|
||||
@ -161,10 +146,8 @@ class QuestionClassifierNode(LLMNode):
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
variable_mapping = {
|
||||
node_id + '.' + key: value for key, value in variable_mapping.items()
|
||||
}
|
||||
|
||||
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
@ -174,19 +157,16 @@ class QuestionClassifierNode(LLMNode):
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"type": "question-classifier",
|
||||
"config": {
|
||||
"instructions": ""
|
||||
}
|
||||
}
|
||||
return {"type": "question-classifier", "config": {"instructions": ""}}
|
||||
|
||||
def _fetch_prompt(self, node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
def _fetch_prompt(
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Fetch prompt
|
||||
:param node_data: node data
|
||||
@ -202,118 +182,122 @@ class QuestionClassifierNode(LLMNode):
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query='',
|
||||
query="",
|
||||
files=[],
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
return prompt_messages, stop
|
||||
|
||||
def _calculate_rest_token(self, node_data: QuestionClassifierNodeData, query: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: Optional[str]) -> int:
|
||||
def _calculate_rest_token(
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: Optional[str],
|
||||
) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query='',
|
||||
query="",
|
||||
files=[],
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
model=model_config.model
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(
|
||||
prompt_messages
|
||||
)
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000) \
|
||||
-> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
|
||||
def _get_prompt_template(
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000,
|
||||
) -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
classes = node_data.classes
|
||||
categories = []
|
||||
for class_ in classes:
|
||||
category = {
|
||||
'category_id': class_.id,
|
||||
'category_name': class_.name
|
||||
}
|
||||
category = {"category_id": class_.id, "category_name": class_.name}
|
||||
categories.append(category)
|
||||
instruction = node_data.instruction if node_data.instruction else ''
|
||||
instruction = node_data.instruction or ""
|
||||
input_text = query
|
||||
memory_str = ''
|
||||
memory_str = ""
|
||||
if memory:
|
||||
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
|
||||
message_limit=node_data.memory.window.size)
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
prompt_messages = []
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
|
||||
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
|
||||
)
|
||||
prompt_messages.append(system_prompt_messages)
|
||||
user_prompt_message_1 = ChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=QUESTION_CLASSIFIER_USER_PROMPT_1
|
||||
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_1)
|
||||
assistant_prompt_message_1 = ChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT,
|
||||
text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1
|
||||
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_1)
|
||||
user_prompt_message_2 = ChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=QUESTION_CLASSIFIER_USER_PROMPT_2
|
||||
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_2)
|
||||
assistant_prompt_message_2 = ChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT,
|
||||
text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2
|
||||
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_2)
|
||||
user_prompt_message_3 = ChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text,
|
||||
categories=json.dumps(categories, ensure_ascii=False),
|
||||
classification_instructions=instruction)
|
||||
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(
|
||||
input_text=input_text,
|
||||
categories=json.dumps(categories, ensure_ascii=False),
|
||||
classification_instructions=instruction,
|
||||
),
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_3)
|
||||
return prompt_messages
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
return CompletionModelPromptTemplate(
|
||||
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str,
|
||||
input_text=input_text,
|
||||
categories=json.dumps(categories),
|
||||
classification_instructions=instruction,
|
||||
ensure_ascii=False)
|
||||
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(
|
||||
histories=memory_str,
|
||||
input_text=input_text,
|
||||
categories=json.dumps(categories),
|
||||
classification_instructions=instruction,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
@ -329,14 +313,12 @@ class QuestionClassifierNode(LLMNode):
|
||||
variable = variable_pool.get(variable_selector.value_selector)
|
||||
variable_value = variable.value if variable else None
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
|
||||
inputs[variable_selector.variable] = variable_value
|
||||
|
||||
prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
instruction = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
instruction = prompt_template.format(prompt_inputs)
|
||||
return instruction
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
|
||||
### Job Description',
|
||||
You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories.
|
||||
@ -14,13 +12,13 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
|
||||
<histories>
|
||||
{histories}
|
||||
</histories>
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_1 = """
|
||||
{ "input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],
|
||||
"categories": [{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"},{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"},{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"},{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}],
|
||||
"classification_instructions": ["classify the text based on the feedback provided by customer"]}
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """
|
||||
```json
|
||||
@ -34,7 +32,7 @@ QUESTION_CLASSIFIER_USER_PROMPT_2 = """
|
||||
{"input_text": ["bad service, slow to bring the food"],
|
||||
"categories": [{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"},{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"},{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}],
|
||||
"classification_instructions": []}
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """
|
||||
```json
|
||||
@ -75,4 +73,4 @@ Here is the chat histories between human and assistant, inside <histories></hist
|
||||
### User Input
|
||||
{{"input_text" : ["{input_text}"], "categories" : {categories},"classification_instruction" : ["{classification_instructions}"]}}
|
||||
### Assistant Output
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
@ -10,4 +10,5 @@ class StartNodeData(BaseNodeData):
|
||||
"""
|
||||
Start Node Data
|
||||
"""
|
||||
|
||||
variables: Sequence[VariableEntity] = Field(default_factory=list)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
@ -22,20 +21,13 @@ class StartNode(BaseNode):
|
||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables
|
||||
|
||||
for var in system_inputs:
|
||||
node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]
|
||||
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
outputs=node_inputs
|
||||
)
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: StartNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: StartNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
@ -8,5 +6,6 @@ class TemplateTransformNodeData(BaseNodeData):
|
||||
"""
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
variables: list[VariableSelector]
|
||||
template: str
|
||||
template: str
|
||||
|
||||
@ -2,13 +2,13 @@ import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000'))
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
|
||||
|
||||
|
||||
class TemplateTransformNode(BaseNode):
|
||||
@ -24,15 +24,7 @@ class TemplateTransformNode(BaseNode):
|
||||
"""
|
||||
return {
|
||||
"type": "template-transform",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"variable": "arg1",
|
||||
"value_selector": []
|
||||
}
|
||||
],
|
||||
"template": "{{ arg1 }}"
|
||||
}
|
||||
"config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"},
|
||||
}
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
@ -51,38 +43,25 @@ class TemplateTransformNode(BaseNode):
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2,
|
||||
code=node_data.template,
|
||||
inputs=variables
|
||||
)
|
||||
except CodeExecutionException as e:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e)
|
||||
language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
||||
if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters"
|
||||
error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters",
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
outputs={
|
||||
'output': result['result']
|
||||
}
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: TemplateTransformNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@ -92,5 +71,6 @@ class TemplateTransformNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in node_data.variables
|
||||
}
|
||||
|
||||
@ -10,45 +10,46 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: ToolProviderType
|
||||
provider_name: str # redundancy
|
||||
provider_name: str # redundancy
|
||||
tool_name: str
|
||||
tool_label: str # redundancy
|
||||
tool_label: str # redundancy
|
||||
tool_configurations: dict[str, Any]
|
||||
|
||||
@field_validator('tool_configurations', mode='before')
|
||||
@field_validator("tool_configurations", mode="before")
|
||||
@classmethod
|
||||
def validate_tool_configurations(cls, value, values: ValidationInfo):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError('tool_configurations must be a dictionary')
|
||||
|
||||
for key in values.data.get('tool_configurations', {}).keys():
|
||||
value = values.data.get('tool_configurations', {}).get(key)
|
||||
raise ValueError("tool_configurations must be a dictionary")
|
||||
|
||||
for key in values.data.get("tool_configurations", {}):
|
||||
value = values.data.get("tool_configurations", {}).get(key)
|
||||
if not isinstance(value, str | int | float | bool):
|
||||
raise ValueError(f'{key} must be a string')
|
||||
|
||||
raise ValueError(f"{key} must be a string")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal['mixed', 'variable', 'constant']
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
|
||||
@field_validator('type', mode='before')
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = validation_info.data.get('value')
|
||||
if typ == 'mixed' and not isinstance(value, str):
|
||||
raise ValueError('value must be a string')
|
||||
elif typ == 'variable':
|
||||
value = validation_info.data.get("value")
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError('value must be a list')
|
||||
raise ValueError("value must be a list")
|
||||
for val in value:
|
||||
if not isinstance(val, str):
|
||||
raise ValueError('value must be a list of strings')
|
||||
elif typ == 'constant' and not isinstance(value, str | int | float | bool):
|
||||
raise ValueError('value must be a string, int, float, or bool')
|
||||
raise ValueError("value must be a list of strings")
|
||||
elif typ == "constant" and not isinstance(value, str | int | float | bool):
|
||||
raise ValueError("value must be a string, int, float, or bool")
|
||||
return typ
|
||||
|
||||
"""
|
||||
|
||||
@ -35,10 +35,7 @@ class ToolNode(BaseNode):
|
||||
node_data = cast(ToolNodeData, self.node_data)
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {
|
||||
'provider_type': node_data.provider_type.value,
|
||||
'provider_id': node_data.provider_id
|
||||
}
|
||||
tool_info = {"provider_type": node_data.provider_type.value, "provider_id": node_data.provider_id}
|
||||
|
||||
# get tool runtime
|
||||
try:
|
||||
@ -50,10 +47,8 @@ class ToolNode(BaseNode):
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
error=f'Failed to get tool runtime: {str(e)}'
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to get tool runtime: {str(e)}",
|
||||
)
|
||||
)
|
||||
return
|
||||
@ -61,15 +56,13 @@ class ToolNode(BaseNode):
|
||||
# get parameters
|
||||
tool_parameters = tool_runtime.get_runtime_parameters() or []
|
||||
parameters = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=node_data
|
||||
tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=node_data,
|
||||
for_log=True
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=node_data,
|
||||
for_log=True,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -86,10 +79,8 @@ class ToolNode(BaseNode):
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
error=f'Failed to invoke tool: {str(e)}',
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to invoke tool: {str(e)}",
|
||||
)
|
||||
)
|
||||
return
|
||||
@ -126,12 +117,10 @@ class ToolNode(BaseNode):
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
result[parameter_name] = [
|
||||
v.to_dict() for v in self._fetch_files(variable_pool)
|
||||
]
|
||||
result[parameter_name] = [v.to_dict() for v in self._fetch_files(variable_pool)]
|
||||
else:
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == 'variable':
|
||||
if tool_input.type == "variable":
|
||||
parameter_value_segment = variable_pool.get(tool_input.value)
|
||||
if not parameter_value_segment:
|
||||
raise Exception("input variable dose not exists")
|
||||
@ -147,14 +136,16 @@ class ToolNode(BaseNode):
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||
variable = variable_pool.get(['sys', SystemVariableKey.FILES.value])
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _transform_message(self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any]) -> Generator[RunEvent, None, None]:
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
) -> Generator[RunEvent, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
@ -169,66 +160,65 @@ class ToolNode(BaseNode):
|
||||
files: list[FileVar] = []
|
||||
text = ""
|
||||
json: list[dict] = []
|
||||
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
message.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
if message.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
url = message.message.text
|
||||
ext = path.splitext(url)[1]
|
||||
mimetype = message.meta.get('mime_type', 'image/jpeg')
|
||||
filename = message.save_as or url.split('/')[-1]
|
||||
transfer_method = message.meta.get('transfer_method', FileTransferMethod.TOOL_FILE)
|
||||
mimetype = message.meta.get("mime_type", "image/jpeg")
|
||||
filename = message.save_as or url.split("/")[-1]
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
|
||||
# get tool file id
|
||||
tool_file_id = url.split('/')[-1].split('.')[0]
|
||||
files.append(FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
url=url,
|
||||
related_id=tool_file_id,
|
||||
filename=filename,
|
||||
extension=ext,
|
||||
mime_type=mimetype,
|
||||
))
|
||||
tool_file_id = url.split("/")[-1].split(".")[0]
|
||||
files.append(
|
||||
FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
url=url,
|
||||
related_id=tool_file_id,
|
||||
filename=filename,
|
||||
extension=ext,
|
||||
mime_type=mimetype,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split('/')[-1].split('.')[0]
|
||||
files.append(FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file_id,
|
||||
filename=message.save_as,
|
||||
extension=path.splitext(message.save_as)[1],
|
||||
mime_type=message.meta.get('mime_type', 'application/octet-stream'),
|
||||
))
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
files.append(
|
||||
FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file_id,
|
||||
filename=message.save_as,
|
||||
extension=path.splitext(message.save_as)[1],
|
||||
mime_type=message.meta.get("mime_type", "application/octet-stream"),
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text + '\n'
|
||||
text += message.message.text + "\n"
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=message.message.text,
|
||||
from_variable_selector=[self.node_id, 'text']
|
||||
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message, ToolInvokeMessage.JsonMessage)
|
||||
json.append(message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f'Link: {message.message.text}\n'
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=stream_text,
|
||||
from_variable_selector=[self.node_id, 'text']
|
||||
)
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
@ -241,8 +231,7 @@ class ToolNode(BaseNode):
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=variable_value,
|
||||
from_variable_selector=[self.node_id, variable_name]
|
||||
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
@ -250,25 +239,15 @@ class ToolNode(BaseNode):
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
'text': text,
|
||||
'files': files,
|
||||
'json': json,
|
||||
**variables
|
||||
},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
inputs=parameters_for_log
|
||||
outputs={"text": text, "files": files, "json": json, **variables},
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ToolNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: ToolNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@ -280,18 +259,16 @@ class ToolNode(BaseNode):
|
||||
result = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
if input.type == 'mixed':
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == 'variable':
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
elif input.type == 'constant':
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
|
||||
result = {
|
||||
node_id + '.' + key: value for key, value in result.items()
|
||||
}
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -11,23 +9,27 @@ class AdvancedSettings(BaseModel):
|
||||
"""
|
||||
Advanced setting.
|
||||
"""
|
||||
|
||||
group_enabled: bool
|
||||
|
||||
class Group(BaseModel):
|
||||
"""
|
||||
Group.
|
||||
"""
|
||||
output_type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]']
|
||||
|
||||
output_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
variables: list[list[str]]
|
||||
group_name: str
|
||||
|
||||
groups: list[Group]
|
||||
|
||||
|
||||
class VariableAssignerNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
"""
|
||||
type: str = 'variable-assigner'
|
||||
|
||||
type: str = "variable-assigner"
|
||||
output_type: str
|
||||
variables: list[list[str]]
|
||||
advanced_settings: Optional[AdvancedSettings] = None
|
||||
|
||||
@ -21,13 +21,9 @@ class VariableAggregatorNode(BaseNode):
|
||||
for selector in node_data.variables:
|
||||
variable = self.graph_runtime_state.variable_pool.get_any(selector)
|
||||
if variable is not None:
|
||||
outputs = {
|
||||
"output": variable
|
||||
}
|
||||
outputs = {"output": variable}
|
||||
|
||||
inputs = {
|
||||
'.'.join(selector[1:]): variable
|
||||
}
|
||||
inputs = {".".join(selector[1:]): variable}
|
||||
break
|
||||
else:
|
||||
for group in node_data.advanced_settings.groups:
|
||||
@ -35,24 +31,15 @@ class VariableAggregatorNode(BaseNode):
|
||||
variable = self.graph_runtime_state.variable_pool.get_any(selector)
|
||||
|
||||
if variable is not None:
|
||||
outputs[group.group_name] = {
|
||||
'output': variable
|
||||
}
|
||||
inputs['.'.join(selector[1:])] = variable
|
||||
outputs[group.group_name] = {"output": variable}
|
||||
inputs[".".join(selector[1:])] = variable
|
||||
break
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
inputs=inputs
|
||||
)
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: VariableAssignerNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
|
||||
@ -2,7 +2,7 @@ from .node import VariableAssignerNode
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
__all__ = [
|
||||
'VariableAssignerNode',
|
||||
'VariableAssignerData',
|
||||
'WriteMode',
|
||||
"VariableAssignerNode",
|
||||
"VariableAssignerData",
|
||||
"WriteMode",
|
||||
]
|
||||
|
||||
@ -24,43 +24,43 @@ class VariableAssignerNode(BaseNode):
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableAssignerNodeError('assigned variable not found')
|
||||
raise VariableAssignerNodeError("assigned variable not found")
|
||||
|
||||
match data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError('input value not found')
|
||||
updated_variable = original_variable.model_copy(update={'value': income_value.value})
|
||||
raise VariableAssignerNodeError("input value not found")
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError('input value not found')
|
||||
raise VariableAssignerNodeError("input value not found")
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
updated_variable = original_variable.model_copy(update={'value': updated_value})
|
||||
updated_variable = original_variable.model_copy(update={"value": updated_value})
|
||||
|
||||
case WriteMode.CLEAR:
|
||||
income_value = get_zero_value(original_variable.value_type)
|
||||
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
|
||||
raise VariableAssignerNodeError(f"unsupported write mode: {data.write_mode}")
|
||||
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable)
|
||||
|
||||
# TODO: Move database operation to the pipeline.
|
||||
# Update conversation variable.
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id'])
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
raise VariableAssignerNodeError('conversation_id not found')
|
||||
raise VariableAssignerNodeError("conversation_id not found")
|
||||
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={
|
||||
'value': income_value.to_object(),
|
||||
"value": income_value.to_object(),
|
||||
},
|
||||
)
|
||||
|
||||
@ -72,7 +72,7 @@ def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableAssignerNodeError('conversation variable not found in the database')
|
||||
raise VariableAssignerNodeError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
@ -84,8 +84,8 @@ def get_zero_value(t: SegmentType):
|
||||
case SegmentType.OBJECT:
|
||||
return factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
return factory.build_segment('')
|
||||
return factory.build_segment("")
|
||||
case SegmentType.NUMBER:
|
||||
return factory.build_segment(0)
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f'unsupported variable type: {t}')
|
||||
raise VariableAssignerNodeError(f"unsupported variable type: {t}")
|
||||
|
||||
@ -6,14 +6,14 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class WriteMode(str, Enum):
|
||||
OVER_WRITE = 'over-write'
|
||||
APPEND = 'append'
|
||||
CLEAR = 'clear'
|
||||
OVER_WRITE = "over-write"
|
||||
APPEND = "append"
|
||||
CLEAR = "clear"
|
||||
|
||||
|
||||
class VariableAssignerData(BaseNodeData):
|
||||
title: str = 'Variable Assigner'
|
||||
desc: Optional[str] = 'Assign a value to a variable'
|
||||
title: str = "Variable Assigner"
|
||||
desc: Optional[str] = "Assign a value to a variable"
|
||||
assigned_variable_selector: Sequence[str]
|
||||
write_mode: WriteMode
|
||||
input_variable_selector: Sequence[str]
|
||||
|
||||
Reference in New Issue
Block a user