Merge branch 'main' into fix/chore-fix

This commit is contained in:
Yeuoly
2024-11-15 15:43:32 +08:00
114 changed files with 4193 additions and 220 deletions

View File

@ -24,6 +24,7 @@ class NodeRunMetadataKey(str, Enum):
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
class NodeRunResult(BaseModel):

View File

@ -95,13 +95,16 @@ class VariablePool(BaseModel):
if len(selector) < 2:
raise ValueError("Invalid selector")
if isinstance(value, Variable):
variable = value
if isinstance(value, Segment):
v = value
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
else:
v = variable_factory.build_segment(value)
segment = variable_factory.build_segment(value)
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]][hash_key] = v
self.variable_dictionary[selector[0]][hash_key] = variable
def get(self, selector: Sequence[str], /) -> Segment | None:
"""

View File

@ -148,6 +148,7 @@ class IterationRunStartedEvent(BaseIterationEvent):
class IterationRunNextEvent(BaseIterationEvent):
index: int = Field(..., description="index")
pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output")
duration: Optional[float] = Field(None, description="duration")
class IterationRunSucceededEvent(BaseIterationEvent):
@ -156,6 +157,7 @@ class IterationRunSucceededEvent(BaseIterationEvent):
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
steps: int = 0
iteration_duration_map: Optional[dict[str, float]] = None
class IterationRunFailedEvent(BaseIterationEvent):

View File

@ -143,14 +143,14 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
def _extract_text_from_plain_text(file_content: bytes) -> str:
try:
return file_content.decode("utf-8")
return file_content.decode("utf-8", "ignore")
except UnicodeDecodeError as e:
raise TextExtractionError("Failed to decode plain text file") from e
def _extract_text_from_json(file_content: bytes) -> str:
try:
json_data = json.loads(file_content.decode("utf-8"))
json_data = json.loads(file_content.decode("utf-8", "ignore"))
return json.dumps(json_data, indent=2, ensure_ascii=False)
except (UnicodeDecodeError, json.JSONDecodeError) as e:
raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e
@ -159,7 +159,7 @@ def _extract_text_from_json(file_content: bytes) -> str:
def _extract_text_from_yaml(file_content: bytes) -> str:
"""Extract the content from yaml file"""
try:
yaml_data = yaml.safe_load_all(file_content.decode("utf-8"))
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore"))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, yaml.YAMLError) as e:
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
@ -217,7 +217,7 @@ def _extract_text_from_file(file: File):
def _extract_text_from_csv(file_content: bytes) -> str:
try:
csv_file = io.StringIO(file_content.decode("utf-8"))
csv_file = io.StringIO(file_content.decode("utf-8", "ignore"))
csv_reader = csv.reader(csv_file)
rows = list(csv_reader)

View File

@ -156,6 +156,7 @@ class IterationNode(BaseNode[IterationNodeData]):
index=0,
pre_iteration_output=None,
)
iter_run_map: dict[str, float] = {}
outputs: list[Any] = [None] * len(iterator_list_value)
try:
if self.node_data.is_parallel:
@ -175,6 +176,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_graph,
index,
item,
iter_run_map,
)
future.add_done_callback(thread_pool.task_done_callback)
futures.append(future)
@ -213,6 +215,7 @@ class IterationNode(BaseNode[IterationNodeData]):
start_at,
graph_engine,
iteration_graph,
iter_run_map,
)
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None]
@ -230,7 +233,9 @@ class IterationNode(BaseNode[IterationNodeData]):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)}
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": jsonable_encoder(outputs)},
metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map},
)
)
except IterationNodeError as e:
@ -356,15 +361,19 @@ class IterationNode(BaseNode[IterationNodeData]):
start_at: datetime,
graph_engine: "GraphEngine",
iteration_graph: Graph,
iter_run_map: dict[str, float],
parallel_mode_run_id: Optional[str] = None,
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration
"""
iter_start_at = datetime.now(timezone.utc).replace(tzinfo=None)
try:
rst = graph_engine.run()
# get current iteration index
current_index = variable_pool.get([self.node_id, "index"]).value
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
next_index = int(current_index) + 1
if current_index is None:
@ -431,6 +440,8 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
@ -439,6 +450,7 @@ class IterationNode(BaseNode[IterationNodeData]):
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
@ -449,6 +461,8 @@ class IterationNode(BaseNode[IterationNodeData]):
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
@ -457,6 +471,7 @@ class IterationNode(BaseNode[IterationNodeData]):
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
@ -474,7 +489,10 @@ class IterationNode(BaseNode[IterationNodeData]):
)
yield metadata_event
current_iteration_output = variable_pool.get(self.node_data.output_selector).value
current_output_segment = variable_pool.get(self.node_data.output_selector)
if current_output_segment is None:
raise IterationNodeError("iteration output selector not found")
current_iteration_output = current_output_segment.value
outputs[current_index] = current_iteration_output
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
@ -485,6 +503,8 @@ class IterationNode(BaseNode[IterationNodeData]):
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
@ -493,6 +513,7 @@ class IterationNode(BaseNode[IterationNodeData]):
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
duration=duration,
)
except IterationNodeError as e:
@ -528,6 +549,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_graph: Graph,
index: int,
item: Any,
iter_run_map: dict[str, float],
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration in parallel mode
@ -546,6 +568,7 @@ class IterationNode(BaseNode[IterationNodeData]):
start_at=start_at,
graph_engine=graph_engine_copy,
iteration_graph=iteration_graph,
iter_run_map=iter_run_map,
parallel_mode_run_id=parallel_mode_run_id,
):
q.put(event)

View File

@ -59,4 +59,4 @@ class ListOperatorNodeData(BaseNodeData):
filter_by: FilterBy
order_by: OrderBy
limit: Limit
extract_by: ExtractConfig
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)

View File

@ -1,6 +1,8 @@
from collections.abc import Generator, Mapping, Sequence
from os import path
from typing import Any, cast
from collections.abc import Mapping, Sequence
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session