This commit is contained in:
takatost
2024-07-23 00:10:23 +08:00
parent a603e01f5e
commit 2c695ded79
9 changed files with 140 additions and 127 deletions

View File

@ -2,7 +2,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from typing_extensions import deprecated
from core.app.segments import ArrayVariable, ObjectVariable, Variable, factory
@ -21,7 +21,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
_variable_dictionary: dict[str, dict[int, Variable]] = Field(
variable_dictionary: dict[str, dict[int, Variable]] = Field(
description='Variables mapping',
default=defaultdict(dict)
)
@ -36,10 +36,12 @@ class VariablePool(BaseModel):
)
environment_variables: Sequence[Variable] = Field(
description="Environment variables."
description="Environment variables.",
default_factory=list
)
def __post_init__(self):
@model_validator(mode="after")
def val_model_after(self):
"""
Append system variables
:return:
@ -52,6 +54,8 @@ class VariablePool(BaseModel):
for var in self.environment_variables or []:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
return self
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""
Adds a variable to the variable pool.
@ -78,7 +82,7 @@ class VariablePool(BaseModel):
v = value
hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]][hash_key] = v
self.variable_dictionary[selector[0]][hash_key] = v
def get(self, selector: Sequence[str], /) -> Variable | None:
"""
@ -96,7 +100,7 @@ class VariablePool(BaseModel):
if len(selector) < 2:
raise ValueError('Invalid selector')
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
value = self.variable_dictionary[selector[0]].get(hash_key)
return value
@ -117,7 +121,7 @@ class VariablePool(BaseModel):
if len(selector) < 2:
raise ValueError('Invalid selector')
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
value = self.variable_dictionary[selector[0]].get(hash_key)
if value is None:
return value
@ -140,7 +144,7 @@ class VariablePool(BaseModel):
if not selector:
return
if len(selector) == 1:
self._variable_dictionary[selector[0]] = {}
self.variable_dictionary[selector[0]] = {}
return
hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]].pop(hash_key, None)
self.variable_dictionary[selector[0]].pop(hash_key, None)

View File

@ -285,16 +285,18 @@ class Graph(BaseModel):
)
# collect all branches node ids
end_to_node_id: Optional[str] = None
for branch_node_id, node_ids in in_branch_node_ids.items():
for node_id in node_ids:
node_parallel_mapping[node_id] = parallel.id
if not end_to_node_id and edge_mapping.get(node_id):
node_edges = edge_mapping[node_id]
target_node_id = node_edges[0].target_node_id
if node_parallel_mapping.get(target_node_id) == parent_parallel_id:
end_to_node_id = target_node_id
end_to_node_id: Optional[str] = None
for node_id in node_parallel_mapping:
if not end_to_node_id and edge_mapping.get(node_id):
node_edges = edge_mapping[node_id]
target_node_id = node_edges[0].target_node_id
if node_parallel_mapping.get(target_node_id) == parent_parallel_id:
end_to_node_id = target_node_id
break
if end_to_node_id:
parallel.end_to_node_id = end_to_node_id

View File

@ -244,8 +244,8 @@ class GraphEngine:
next_node_id = final_node_id
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
break
# if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
# break
def _run_parallel_node(self,
flask_app: Flask,
@ -402,10 +402,9 @@ class GraphEngine:
:param variable_value: variable value
:return:
"""
self.graph_runtime_state.variable_pool.append_variable(
node_id=node_id,
variable_key_list=variable_key_list,
value=variable_value # type: ignore[arg-type]
self.graph_runtime_state.variable_pool.add(
[node_id] + variable_key_list,
variable_value
)
# if variable_value is a dict, then recursively append variables

View File

@ -35,7 +35,7 @@ class AnswerNode(BaseNode):
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = self.graph_runtime_state.variable_pool.get(
variable_selector=value_selector
value_selector
)
if value:

View File

@ -1,4 +1,3 @@
import json
import logging
from collections.abc import Generator
from typing import Optional, cast
@ -98,7 +97,7 @@ class AnswerStreamProcessor:
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping[node_id]:
for edge in self.graph.edge_mapping.get(node_id, []):
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
@ -108,7 +107,7 @@ class AnswerStreamProcessor:
remove target node ids until merge
"""
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping[node_id]:
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
@ -124,8 +123,10 @@ class AnswerStreamProcessor:
"""
for answer_node_id, position in self.route_position.items():
# all depends on answer node id not in rest node ids
if not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
if (event.route_node_state.node_id != answer_node_id
and (answer_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
continue
route_position = self.route_position[answer_node_id]
@ -145,53 +146,14 @@ class AnswerStreamProcessor:
if not value_selector:
break
value = self.variable_pool.get_variable_value(
variable_selector=value_selector
value = self.variable_pool.get(
value_selector
)
if value is None:
break
text = ''
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, FileVar):
# convert file to markdown
text = value.to_markdown()
elif isinstance(value, dict):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
if file_vars:
file_var = file_vars[0]
try:
file_var_obj = FileVar(**file_var)
# convert file to markdown
text = file_var_obj.to_markdown()
except Exception as e:
logger.error(f'Error creating file var: {e}')
if not text:
# other types
text = json.dumps(value, ensure_ascii=False)
elif isinstance(value, list):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
for file_var in file_vars:
try:
file_var_obj = FileVar(**file_var)
except Exception as e:
logger.error(f'Error creating file var: {e}')
continue
# convert file to markdown
text = file_var_obj.to_markdown() + ' '
text = text.strip()
if not text and value:
# other types
text = json.dumps(value, ensure_ascii=False)
text = value.markdown
if text:
yield NodeRunStreamChunkEvent(

View File

@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import Any, Optional
from core.file.file_obj import FileVar
@ -7,15 +8,15 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
class ConditionProcessor:
def process_conditions(self, variable_pool: VariablePool, conditions: list[Condition]):
def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]):
input_conditions = []
group_result = []
index = 0
for condition in conditions:
index += 1
actual_value = variable_pool.get_variable_value(
variable_selector=condition.variable_selector
actual_value = variable_pool.get_any(
condition.variable_selector
)
expected_value = None
@ -24,8 +25,8 @@ class ConditionProcessor:
variable_selectors = variable_template_parser.extract_variable_selectors()
if variable_selectors:
for variable_selector in variable_selectors:
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
value = variable_pool.get_any(
variable_selector.value_selector
)
expected_value = variable_template_parser.format({variable_selector.variable: value})
@ -63,37 +64,37 @@ class ConditionProcessor:
:return: bool
"""
if comparison_operator == "contains":
return self._assert_contains(actual_value, expected_value) # type: ignore
return self._assert_contains(actual_value, expected_value)
elif comparison_operator == "not contains":
return self._assert_not_contains(actual_value, expected_value) # type: ignore
return self._assert_not_contains(actual_value, expected_value)
elif comparison_operator == "start with":
return self._assert_start_with(actual_value, expected_value) # type: ignore
return self._assert_start_with(actual_value, expected_value)
elif comparison_operator == "end with":
return self._assert_end_with(actual_value, expected_value) # type: ignore
return self._assert_end_with(actual_value, expected_value)
elif comparison_operator == "is":
return self._assert_is(actual_value, expected_value) # type: ignore
return self._assert_is(actual_value, expected_value)
elif comparison_operator == "is not":
return self._assert_is_not(actual_value, expected_value) # type: ignore
return self._assert_is_not(actual_value, expected_value)
elif comparison_operator == "empty":
return self._assert_empty(actual_value) # type: ignore
return self._assert_empty(actual_value)
elif comparison_operator == "not empty":
return self._assert_not_empty(actual_value) # type: ignore
return self._assert_not_empty(actual_value)
elif comparison_operator == "=":
return self._assert_equal(actual_value, expected_value) # type: ignore
return self._assert_equal(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_not_equal(actual_value, expected_value) # type: ignore
return self._assert_not_equal(actual_value, expected_value)
elif comparison_operator == ">":
return self._assert_greater_than(actual_value, expected_value) # type: ignore
return self._assert_greater_than(actual_value, expected_value)
elif comparison_operator == "<":
return self._assert_less_than(actual_value, expected_value) # type: ignore
return self._assert_less_than(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_greater_than_or_equal(actual_value, expected_value) # type: ignore
return self._assert_greater_than_or_equal(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_less_than_or_equal(actual_value, expected_value) # type: ignore
return self._assert_less_than_or_equal(actual_value, expected_value)
elif comparison_operator == "null":
return self._assert_null(actual_value) # type: ignore
return self._assert_null(actual_value)
elif comparison_operator == "not null":
return self._assert_not_null(actual_value) # type: ignore
return self._assert_not_null(actual_value)
else:
raise ValueError(f"Invalid comparison operator: {comparison_operator}")