mirror of
https://github.com/langgenius/dify.git
synced 2026-04-30 15:38:08 +08:00
Merge branch 'fix/chore-fix' into dev/plugin-deploy
This commit is contained in:
@ -24,6 +24,7 @@ class NodeRunMetadataKey(str, Enum):
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
|
||||
@ -95,13 +95,16 @@ class VariablePool(BaseModel):
|
||||
if len(selector) < 2:
|
||||
raise ValueError("Invalid selector")
|
||||
|
||||
if isinstance(value, Variable):
|
||||
variable = value
|
||||
if isinstance(value, Segment):
|
||||
v = value
|
||||
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
|
||||
else:
|
||||
v = variable_factory.build_segment(value)
|
||||
segment = variable_factory.build_segment(value)
|
||||
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
||||
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self.variable_dictionary[selector[0]][hash_key] = v
|
||||
self.variable_dictionary[selector[0]][hash_key] = variable
|
||||
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""
|
||||
|
||||
@ -148,6 +148,7 @@ class IterationRunStartedEvent(BaseIterationEvent):
|
||||
class IterationRunNextEvent(BaseIterationEvent):
|
||||
index: int = Field(..., description="index")
|
||||
pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output")
|
||||
duration: Optional[float] = Field(None, description="duration")
|
||||
|
||||
|
||||
class IterationRunSucceededEvent(BaseIterationEvent):
|
||||
@ -156,6 +157,7 @@ class IterationRunSucceededEvent(BaseIterationEvent):
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
steps: int = 0
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class IterationRunFailedEvent(BaseIterationEvent):
|
||||
|
||||
@ -143,14 +143,14 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
|
||||
|
||||
def _extract_text_from_plain_text(file_content: bytes) -> str:
|
||||
try:
|
||||
return file_content.decode("utf-8")
|
||||
return file_content.decode("utf-8", "ignore")
|
||||
except UnicodeDecodeError as e:
|
||||
raise TextExtractionError("Failed to decode plain text file") from e
|
||||
|
||||
|
||||
def _extract_text_from_json(file_content: bytes) -> str:
|
||||
try:
|
||||
json_data = json.loads(file_content.decode("utf-8"))
|
||||
json_data = json.loads(file_content.decode("utf-8", "ignore"))
|
||||
return json.dumps(json_data, indent=2, ensure_ascii=False)
|
||||
except (UnicodeDecodeError, json.JSONDecodeError) as e:
|
||||
raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e
|
||||
@ -159,7 +159,7 @@ def _extract_text_from_json(file_content: bytes) -> str:
|
||||
def _extract_text_from_yaml(file_content: bytes) -> str:
|
||||
"""Extract the content from yaml file"""
|
||||
try:
|
||||
yaml_data = yaml.safe_load_all(file_content.decode("utf-8"))
|
||||
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore"))
|
||||
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
|
||||
except (UnicodeDecodeError, yaml.YAMLError) as e:
|
||||
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
|
||||
@ -217,7 +217,7 @@ def _extract_text_from_file(file: File):
|
||||
|
||||
def _extract_text_from_csv(file_content: bytes) -> str:
|
||||
try:
|
||||
csv_file = io.StringIO(file_content.decode("utf-8"))
|
||||
csv_file = io.StringIO(file_content.decode("utf-8", "ignore"))
|
||||
csv_reader = csv.reader(csv_file)
|
||||
rows = list(csv_reader)
|
||||
|
||||
|
||||
@ -156,6 +156,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
index=0,
|
||||
pre_iteration_output=None,
|
||||
)
|
||||
iter_run_map: dict[str, float] = {}
|
||||
outputs: list[Any] = [None] * len(iterator_list_value)
|
||||
try:
|
||||
if self.node_data.is_parallel:
|
||||
@ -175,6 +176,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
iteration_graph,
|
||||
index,
|
||||
item,
|
||||
iter_run_map,
|
||||
)
|
||||
future.add_done_callback(thread_pool.task_done_callback)
|
||||
futures.append(future)
|
||||
@ -213,6 +215,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
start_at,
|
||||
graph_engine,
|
||||
iteration_graph,
|
||||
iter_run_map,
|
||||
)
|
||||
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
outputs = [output for output in outputs if output is not None]
|
||||
@ -230,7 +233,9 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)}
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map},
|
||||
)
|
||||
)
|
||||
except IterationNodeError as e:
|
||||
@ -356,15 +361,19 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
start_at: datetime,
|
||||
graph_engine: "GraphEngine",
|
||||
iteration_graph: Graph,
|
||||
iter_run_map: dict[str, float],
|
||||
parallel_mode_run_id: Optional[str] = None,
|
||||
) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
run single iteration
|
||||
"""
|
||||
iter_start_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
try:
|
||||
rst = graph_engine.run()
|
||||
# get current iteration index
|
||||
current_index = variable_pool.get([self.node_id, "index"]).value
|
||||
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
|
||||
next_index = int(current_index) + 1
|
||||
|
||||
if current_index is None:
|
||||
@ -431,6 +440,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
iter_run_map[iteration_run_id] = duration
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
@ -439,6 +450,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=None,
|
||||
duration=duration,
|
||||
)
|
||||
return
|
||||
elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
@ -449,6 +461,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
iter_run_map[iteration_run_id] = duration
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
@ -457,6 +471,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=None,
|
||||
duration=duration,
|
||||
)
|
||||
return
|
||||
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
|
||||
@ -474,7 +489,10 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
)
|
||||
yield metadata_event
|
||||
|
||||
current_iteration_output = variable_pool.get(self.node_data.output_selector).value
|
||||
current_output_segment = variable_pool.get(self.node_data.output_selector)
|
||||
if current_output_segment is None:
|
||||
raise IterationNodeError("iteration output selector not found")
|
||||
current_iteration_output = current_output_segment.value
|
||||
outputs[current_index] = current_iteration_output
|
||||
# remove all nodes outputs from variable pool
|
||||
for node_id in iteration_graph.node_ids:
|
||||
@ -485,6 +503,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
iter_run_map[iteration_run_id] = duration
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
@ -493,6 +513,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
except IterationNodeError as e:
|
||||
@ -528,6 +549,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
iteration_graph: Graph,
|
||||
index: int,
|
||||
item: Any,
|
||||
iter_run_map: dict[str, float],
|
||||
) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
run single iteration in parallel mode
|
||||
@ -546,6 +568,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
start_at=start_at,
|
||||
graph_engine=graph_engine_copy,
|
||||
iteration_graph=iteration_graph,
|
||||
iter_run_map=iter_run_map,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
):
|
||||
q.put(event)
|
||||
|
||||
@ -59,4 +59,4 @@ class ListOperatorNodeData(BaseNodeData):
|
||||
filter_by: FilterBy
|
||||
order_by: OrderBy
|
||||
limit: Limit
|
||||
extract_by: ExtractConfig
|
||||
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)
|
||||
|
||||
Reference in New Issue
Block a user