Merge branch 'fix/chore-fix' into dev/plugin-deploy

This commit is contained in:
Yeuoly
2024-11-15 15:56:45 +08:00
112 changed files with 4206 additions and 219 deletions

View File

@ -24,6 +24,7 @@ class NodeRunMetadataKey(str, Enum):
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
class NodeRunResult(BaseModel):

View File

@ -95,13 +95,16 @@ class VariablePool(BaseModel):
if len(selector) < 2:
raise ValueError("Invalid selector")
if isinstance(value, Variable):
variable = value
if isinstance(value, Segment):
v = value
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
else:
v = variable_factory.build_segment(value)
segment = variable_factory.build_segment(value)
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]][hash_key] = v
self.variable_dictionary[selector[0]][hash_key] = variable
def get(self, selector: Sequence[str], /) -> Segment | None:
"""

View File

@ -148,6 +148,7 @@ class IterationRunStartedEvent(BaseIterationEvent):
class IterationRunNextEvent(BaseIterationEvent):
index: int = Field(..., description="index")
pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output")
duration: Optional[float] = Field(None, description="duration")
class IterationRunSucceededEvent(BaseIterationEvent):
@ -156,6 +157,7 @@ class IterationRunSucceededEvent(BaseIterationEvent):
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
steps: int = 0
iteration_duration_map: Optional[dict[str, float]] = None
class IterationRunFailedEvent(BaseIterationEvent):

View File

@ -143,14 +143,14 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
def _extract_text_from_plain_text(file_content: bytes) -> str:
try:
return file_content.decode("utf-8")
return file_content.decode("utf-8", "ignore")
except UnicodeDecodeError as e:
raise TextExtractionError("Failed to decode plain text file") from e
def _extract_text_from_json(file_content: bytes) -> str:
try:
json_data = json.loads(file_content.decode("utf-8"))
json_data = json.loads(file_content.decode("utf-8", "ignore"))
return json.dumps(json_data, indent=2, ensure_ascii=False)
except (UnicodeDecodeError, json.JSONDecodeError) as e:
raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e
@ -159,7 +159,7 @@ def _extract_text_from_json(file_content: bytes) -> str:
def _extract_text_from_yaml(file_content: bytes) -> str:
"""Extract the content from yaml file"""
try:
yaml_data = yaml.safe_load_all(file_content.decode("utf-8"))
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore"))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, yaml.YAMLError) as e:
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
@ -217,7 +217,7 @@ def _extract_text_from_file(file: File):
def _extract_text_from_csv(file_content: bytes) -> str:
try:
csv_file = io.StringIO(file_content.decode("utf-8"))
csv_file = io.StringIO(file_content.decode("utf-8", "ignore"))
csv_reader = csv.reader(csv_file)
rows = list(csv_reader)

View File

@ -156,6 +156,7 @@ class IterationNode(BaseNode[IterationNodeData]):
index=0,
pre_iteration_output=None,
)
iter_run_map: dict[str, float] = {}
outputs: list[Any] = [None] * len(iterator_list_value)
try:
if self.node_data.is_parallel:
@ -175,6 +176,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_graph,
index,
item,
iter_run_map,
)
future.add_done_callback(thread_pool.task_done_callback)
futures.append(future)
@ -213,6 +215,7 @@ class IterationNode(BaseNode[IterationNodeData]):
start_at,
graph_engine,
iteration_graph,
iter_run_map,
)
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None]
@ -230,7 +233,9 @@ class IterationNode(BaseNode[IterationNodeData]):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)}
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": jsonable_encoder(outputs)},
metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map},
)
)
except IterationNodeError as e:
@ -356,15 +361,19 @@ class IterationNode(BaseNode[IterationNodeData]):
start_at: datetime,
graph_engine: "GraphEngine",
iteration_graph: Graph,
iter_run_map: dict[str, float],
parallel_mode_run_id: Optional[str] = None,
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration
"""
iter_start_at = datetime.now(timezone.utc).replace(tzinfo=None)
try:
rst = graph_engine.run()
# get current iteration index
current_index = variable_pool.get([self.node_id, "index"]).value
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
next_index = int(current_index) + 1
if current_index is None:
@ -431,6 +440,8 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
@ -439,6 +450,7 @@ class IterationNode(BaseNode[IterationNodeData]):
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
@ -449,6 +461,8 @@ class IterationNode(BaseNode[IterationNodeData]):
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
@ -457,6 +471,7 @@ class IterationNode(BaseNode[IterationNodeData]):
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
@ -474,7 +489,10 @@ class IterationNode(BaseNode[IterationNodeData]):
)
yield metadata_event
current_iteration_output = variable_pool.get(self.node_data.output_selector).value
current_output_segment = variable_pool.get(self.node_data.output_selector)
if current_output_segment is None:
raise IterationNodeError("iteration output selector not found")
current_iteration_output = current_output_segment.value
outputs[current_index] = current_iteration_output
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
@ -485,6 +503,8 @@ class IterationNode(BaseNode[IterationNodeData]):
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
@ -493,6 +513,7 @@ class IterationNode(BaseNode[IterationNodeData]):
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
duration=duration,
)
except IterationNodeError as e:
@ -528,6 +549,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_graph: Graph,
index: int,
item: Any,
iter_run_map: dict[str, float],
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration in parallel mode
@ -546,6 +568,7 @@ class IterationNode(BaseNode[IterationNodeData]):
start_at=start_at,
graph_engine=graph_engine_copy,
iteration_graph=iteration_graph,
iter_run_map=iter_run_map,
parallel_mode_run_id=parallel_mode_run_id,
):
q.put(event)

View File

@ -59,4 +59,4 @@ class ListOperatorNodeData(BaseNodeData):
filter_by: FilterBy
order_by: OrderBy
limit: Limit
extract_by: ExtractConfig
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)