mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 07:58:02 +08:00
merge main
This commit is contained in:
@ -45,3 +45,6 @@ class NodeRunResult(BaseModel):
|
||||
|
||||
error: Optional[str] = None # error message if status is failed
|
||||
error_type: Optional[str] = None # error type if status is failed
|
||||
|
||||
# single step node run retry
|
||||
retry_index: int = 0
|
||||
|
||||
@ -33,7 +33,7 @@ class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
exceptions_count: Optional[int] = Field(description="exception count", default=0)
|
||||
exceptions_count: int = Field(description="exception count", default=0)
|
||||
|
||||
|
||||
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
||||
@ -97,6 +97,12 @@ class NodeInIterationFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
start_at: datetime = Field(..., description="retry start time")
|
||||
|
||||
|
||||
###########################################
|
||||
# Parallel Branch Events
|
||||
###########################################
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
@ -170,7 +171,9 @@ class Graph(BaseModel):
|
||||
for parallel in parallel_mapping.values():
|
||||
if parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
parent_parallel_id=parallel.parent_parallel_id,
|
||||
)
|
||||
|
||||
# init answer stream generate routes
|
||||
|
||||
@ -5,6 +5,7 @@ import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from copy import copy, deepcopy
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
@ -25,6 +26,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
@ -581,7 +583,7 @@ class GraphEngine:
|
||||
|
||||
def _run_node(
|
||||
self,
|
||||
node_instance: BaseNode,
|
||||
node_instance: BaseNode[BaseNodeData],
|
||||
route_node_state: RouteNodeState,
|
||||
parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
@ -607,36 +609,120 @@ class GraphEngine:
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
max_retries = node_instance.node_data.retry_config.max_retries
|
||||
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
|
||||
retries = 0
|
||||
shoudl_continue_retry = True
|
||||
while shoudl_continue_retry and retries <= max_retries:
|
||||
try:
|
||||
# run node
|
||||
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
generator = node_instance.run()
|
||||
for item in generator:
|
||||
if isinstance(item, GraphEngineEvent):
|
||||
if isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
|
||||
try:
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
for item in generator:
|
||||
if isinstance(item, GraphEngineEvent):
|
||||
if isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
yield item
|
||||
else:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if (
|
||||
retries == max_retries
|
||||
and node_instance.node_type == NodeType.HTTP_REQUEST
|
||||
and run_result.outputs
|
||||
and not node_instance.should_continue_on_error
|
||||
):
|
||||
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if node_instance.should_retry and retries < max_retries:
|
||||
retries += 1
|
||||
route_node_state.node_run_result = run_result
|
||||
yield NodeRunRetryEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
predecessor_node_id=node_instance.previous_node_id,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
error=run_result.error,
|
||||
retry_index=retries,
|
||||
start_at=retry_start_at,
|
||||
)
|
||||
time.sleep(retry_interval)
|
||||
continue
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
|
||||
yield item
|
||||
else:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if node_instance.should_continue_on_error:
|
||||
# if run failed, handle error
|
||||
run_result = self._handle_continue_on_error(
|
||||
node_instance,
|
||||
item.run_result,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
route_node_state.node_run_result = run_result
|
||||
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
yield NodeRunExceptionEvent(
|
||||
error=run_result.error or "System Error",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
shoudl_continue_retry = False
|
||||
else:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or "Unknown error.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
shoudl_continue_retry = False
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
||||
node_instance.node_id
|
||||
):
|
||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if node_instance.should_continue_on_error:
|
||||
# if run failed, handle error
|
||||
run_result = self._handle_continue_on_error(
|
||||
node_instance,
|
||||
item.run_result,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
route_node_state.node_run_result = run_result
|
||||
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
||||
if run_result.llm_usage:
|
||||
# use the latest usage
|
||||
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
||||
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
@ -645,21 +731,23 @@ class GraphEngine:
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
yield NodeRunExceptionEvent(
|
||||
error=run_result.error or "System Error",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
else:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or "Unknown error.",
|
||||
|
||||
# add parallel info to run result metadata
|
||||
if parallel_id and parallel_start_node_id:
|
||||
if not run_result.metadata:
|
||||
run_result.metadata = {}
|
||||
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = (
|
||||
parallel_start_node_id
|
||||
)
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
||||
parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
@ -670,108 +758,59 @@ class GraphEngine:
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
shoudl_continue_retry = False
|
||||
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
||||
node_instance.node_id
|
||||
):
|
||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if run_result.llm_usage:
|
||||
# use the latest usage
|
||||
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
||||
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
|
||||
# add parallel info to run result metadata
|
||||
if parallel_id and parallel_start_node_id:
|
||||
if not run_result.metadata:
|
||||
run_result.metadata = {}
|
||||
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
||||
parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
"""
|
||||
|
||||
@ -147,6 +147,8 @@ class AnswerStreamGeneratorRouter:
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
if source_node_id not in node_id_config_mapping:
|
||||
continue
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
|
||||
if (
|
||||
|
||||
@ -106,12 +106,25 @@ class DefaultValue(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
error_strategy: Optional[ErrorStrategy] = None
|
||||
default_value: Optional[list[DefaultValue]] = None
|
||||
version: str = "1"
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
def default_value_dict(self):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
class BaseNodeError(Exception):
|
||||
class BaseNodeError(ValueError):
|
||||
"""Base class for node errors."""
|
||||
|
||||
pass
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
@ -72,7 +72,11 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
result = self._run()
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {self.node_id} failed to run")
|
||||
result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError")
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type="WorkflowNodeError",
|
||||
)
|
||||
|
||||
if isinstance(result, NodeRunResult):
|
||||
yield RunCompletedEvent(run_result=result)
|
||||
@ -143,3 +147,12 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
bool: if should continue on error
|
||||
"""
|
||||
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
|
||||
|
||||
@property
|
||||
def should_retry(self) -> bool:
|
||||
"""judge if should retry
|
||||
|
||||
Returns:
|
||||
bool: if should retry
|
||||
"""
|
||||
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
@ -59,7 +59,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
)
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result, self.node_data.outputs)
|
||||
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
@ -67,18 +67,17 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
||||
def _check_string(self, value: str, variable: str) -> str:
|
||||
def _check_string(self, value: str | None, variable: str) -> str | None:
|
||||
"""
|
||||
Check string
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:return:
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, str):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a string")
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a string")
|
||||
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
raise OutputValidationError(
|
||||
@ -88,18 +87,17 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
|
||||
return value.replace("\x00", "")
|
||||
|
||||
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
|
||||
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
|
||||
"""
|
||||
Check number
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:return:
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, int | float):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a number")
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a number")
|
||||
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
raise OutputValidationError(
|
||||
@ -118,14 +116,12 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
return value
|
||||
|
||||
def _transform_result(
|
||||
self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1
|
||||
) -> dict:
|
||||
"""
|
||||
Transform result
|
||||
:param result: result
|
||||
:param output_schema: output schema
|
||||
:return:
|
||||
"""
|
||||
self,
|
||||
result: Mapping[str, Any],
|
||||
output_schema: Optional[dict[str, CodeNodeData.Output]],
|
||||
prefix: str = "",
|
||||
depth: int = 1,
|
||||
):
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
@ -8,12 +9,6 @@ import docx
|
||||
import pandas as pd
|
||||
import pypdfium2 # type: ignore
|
||||
import yaml # type: ignore
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.email import partition_email
|
||||
from unstructured.partition.epub import partition_epub
|
||||
from unstructured.partition.msg import partition_msg
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
@ -28,6 +23,8 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
||||
"""
|
||||
@ -183,10 +180,43 @@ def _extract_text_from_pdf(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
"""
|
||||
Extract text from a DOC/DOCX file.
|
||||
For now support only paragraph and table add more if needed
|
||||
"""
|
||||
try:
|
||||
doc_file = io.BytesIO(file_content)
|
||||
doc = docx.Document(doc_file)
|
||||
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
||||
text = []
|
||||
# Process paragraphs
|
||||
for paragraph in doc.paragraphs:
|
||||
if paragraph.text.strip():
|
||||
text.append(paragraph.text)
|
||||
|
||||
# Process tables
|
||||
for table in doc.tables:
|
||||
# Table header
|
||||
try:
|
||||
# table maybe cause errors so ignore it.
|
||||
if len(table.rows) > 0 and table.rows[0].cells is not None:
|
||||
# Check if any cell in the table has text
|
||||
has_content = False
|
||||
for row in table.rows:
|
||||
if any(cell.text.strip() for cell in row.cells):
|
||||
has_content = True
|
||||
break
|
||||
|
||||
if has_content:
|
||||
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
|
||||
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
|
||||
for row in table.rows[1:]:
|
||||
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
|
||||
text.append(markdown_table)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
||||
continue
|
||||
|
||||
return "\n".join(text)
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
|
||||
|
||||
@ -256,6 +286,8 @@ def _extract_text_from_excel(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_ppt(file=file)
|
||||
@ -265,6 +297,9 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
try:
|
||||
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
|
||||
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
|
||||
@ -287,6 +322,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_epub(file_content: bytes) -> str:
|
||||
from unstructured.partition.epub import partition_epub
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_epub(file=file)
|
||||
@ -296,6 +333,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_eml(file_content: bytes) -> str:
|
||||
from unstructured.partition.email import partition_email
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_email(file=file)
|
||||
@ -305,6 +344,8 @@ def _extract_text_from_eml(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_msg(file_content: bytes) -> str:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_msg(file=file)
|
||||
|
||||
@ -135,6 +135,8 @@ class EndStreamGeneratorRouter:
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
if source_node_id not in node_id_config_mapping:
|
||||
continue
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
if source_node_type in {
|
||||
NodeType.IF_ELSE.value,
|
||||
|
||||
@ -35,3 +35,4 @@ class FailBranchSourceHandle(StrEnum):
|
||||
|
||||
|
||||
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
||||
RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
||||
|
||||
@ -1,4 +1,10 @@
|
||||
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from .event import (
|
||||
ModelInvokeCompletedEvent,
|
||||
RunCompletedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunRetryEvent,
|
||||
RunStreamChunkEvent,
|
||||
)
|
||||
from .types import NodeEvent
|
||||
|
||||
__all__ = [
|
||||
@ -6,5 +12,6 @@ __all__ = [
|
||||
"NodeEvent",
|
||||
"RunCompletedEvent",
|
||||
"RunRetrieverResourceEvent",
|
||||
"RunRetryEvent",
|
||||
"RunStreamChunkEvent",
|
||||
]
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class RunCompletedEvent(BaseModel):
|
||||
@ -26,3 +29,19 @@ class ModelInvokeCompletedEvent(BaseModel):
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
class RunRetryEvent(BaseModel):
|
||||
"""Node Run Retry event"""
|
||||
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="Retry attempt number")
|
||||
start_at: datetime = Field(..., description="Retry start time")
|
||||
|
||||
|
||||
class SingleStepRetryEvent(NodeRunResult):
|
||||
"""Single step retry event"""
|
||||
|
||||
status: str = WorkflowNodeExecutionStatus.RETRY.value
|
||||
|
||||
elapsed_time: float = Field(..., description="elapsed time")
|
||||
|
||||
@ -45,6 +45,7 @@ class Executor:
|
||||
headers: dict[str, str]
|
||||
auth: HttpRequestNodeAuthorization
|
||||
timeout: HttpRequestNodeTimeout
|
||||
max_retries: int
|
||||
|
||||
boundary: str
|
||||
|
||||
@ -54,6 +55,7 @@ class Executor:
|
||||
node_data: HttpRequestNodeData,
|
||||
timeout: HttpRequestNodeTimeout,
|
||||
variable_pool: VariablePool,
|
||||
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
):
|
||||
# If authorization API key is present, convert the API key using the variable pool
|
||||
if node_data.authorization.type == "api-key":
|
||||
@ -73,6 +75,7 @@ class Executor:
|
||||
self.files = None
|
||||
self.data = None
|
||||
self.json = None
|
||||
self.max_retries = max_retries
|
||||
|
||||
# init template
|
||||
self.variable_pool = variable_pool
|
||||
@ -241,11 +244,12 @@ class Executor:
|
||||
"params": self.params,
|
||||
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
|
||||
"follow_redirects": True,
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
response = getattr(ssrf_proxy, self.method)(**request_args)
|
||||
except ssrf_proxy.MaxRetriesExceededError as e:
|
||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||
raise HttpRequestNodeError(str(e))
|
||||
return response
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
@ -51,6 +52,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
||||
},
|
||||
},
|
||||
"retry_config": {
|
||||
"max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
"retry_interval": 0.5 * (2**2),
|
||||
"retry_enabled": True,
|
||||
},
|
||||
}
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
@ -60,12 +66,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
node_data=self.node_data,
|
||||
timeout=self._get_request_timeout(self.node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
max_retries=0,
|
||||
)
|
||||
process_data["request"] = http_executor.to_log()
|
||||
|
||||
response = http_executor.invoke()
|
||||
files = self.extract_files(url=http_executor.url, response=response)
|
||||
if not response.response.is_success and self.should_continue_on_error:
|
||||
if not response.response.is_success and (self.should_continue_on_error or self.should_retry):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
outputs={
|
||||
@ -156,20 +163,24 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
|
||||
def extract_files(self, url: str, response: Response) -> list[File]:
|
||||
"""
|
||||
Extract files from response
|
||||
Extract files from response by checking both Content-Type header and URL
|
||||
"""
|
||||
files = []
|
||||
is_file = response.is_file
|
||||
content_type = response.content_type
|
||||
content = response.content
|
||||
|
||||
if is_file and content_type:
|
||||
if is_file:
|
||||
# Guess file extension from URL or Content-Type header
|
||||
filename = url.split("?")[0].split("/")[-1] or ""
|
||||
mime_type = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
|
||||
tool_file = ToolFileManager.create_file_by_raw(
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=content,
|
||||
mimetype=content_type,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
|
||||
mapping = {
|
||||
|
||||
@ -70,7 +70,20 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
|
||||
except KnowledgeRetrievalNodeError as e:
|
||||
logger.warning("Error when running knowledge retrieval node")
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
|
||||
available_datasets = []
|
||||
@ -160,18 +173,18 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
reranking_model = None
|
||||
weights = None
|
||||
all_documents = dataset_retrieval.multiple_retrieve(
|
||||
self.app_id,
|
||||
self.tenant_id,
|
||||
self.user_id,
|
||||
self.user_from.value,
|
||||
available_datasets,
|
||||
query,
|
||||
node_data.multiple_retrieval_config.top_k,
|
||||
node_data.multiple_retrieval_config.score_threshold,
|
||||
node_data.multiple_retrieval_config.reranking_mode,
|
||||
reranking_model,
|
||||
weights,
|
||||
node_data.multiple_retrieval_config.reranking_enable,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
available_datasets=available_datasets,
|
||||
query=query,
|
||||
top_k=node_data.multiple_retrieval_config.top_k,
|
||||
score_threshold=node_data.multiple_retrieval_config.score_threshold,
|
||||
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
|
||||
reranking_model=reranking_model,
|
||||
weights=weights,
|
||||
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
||||
)
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
|
||||
@ -50,6 +50,7 @@ class PromptConfig(BaseModel):
|
||||
|
||||
|
||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||
text: str = ""
|
||||
jinja2_text: Optional[str] = None
|
||||
|
||||
|
||||
|
||||
@ -145,8 +145,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
query = query_variable.text
|
||||
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
user_query=query,
|
||||
user_files=files,
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
@ -545,8 +545,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
def _fetch_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
user_query: str | None = None,
|
||||
user_files: Sequence["File"],
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence["File"],
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
@ -562,7 +562,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
if isinstance(prompt_template, list):
|
||||
# For chat model
|
||||
prompt_messages.extend(
|
||||
_handle_list_messages(
|
||||
self._handle_list_messages(
|
||||
messages=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
@ -581,14 +581,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
prompt_messages.extend(memory_messages)
|
||||
|
||||
# Add current query to the prompt messages
|
||||
if user_query:
|
||||
if sys_query:
|
||||
message = LLMNodeChatModelMessage(
|
||||
text=user_query,
|
||||
text=sys_query,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
prompt_messages.extend(
|
||||
_handle_list_messages(
|
||||
self._handle_list_messages(
|
||||
messages=[message],
|
||||
context="",
|
||||
jinja2_variables=[],
|
||||
@ -635,24 +635,27 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
raise ValueError("Invalid prompt content type")
|
||||
|
||||
# Add current query to the prompt message
|
||||
if user_query:
|
||||
if sys_query:
|
||||
if prompt_content_type == str:
|
||||
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
|
||||
prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query)
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif prompt_content_type == list:
|
||||
for content_item in prompt_content:
|
||||
if content_item.type == PromptMessageContentType.TEXT:
|
||||
content_item.data = user_query + "\n" + content_item.data
|
||||
content_item.data = sys_query + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
else:
|
||||
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
|
||||
|
||||
if vision_enabled and user_files:
|
||||
# The sys_files will be deprecated later
|
||||
if vision_enabled and sys_files:
|
||||
file_prompts = []
|
||||
for file in user_files:
|
||||
for file in sys_files:
|
||||
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||
file_prompts.append(file_prompt)
|
||||
# If last prompt is a user prompt, add files into its contents,
|
||||
# otherwise append a new user prompt
|
||||
if (
|
||||
len(prompt_messages) > 0
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
@ -662,7 +665,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# Filter prompt messages
|
||||
# Remove empty messages and filter unsupported content
|
||||
filtered_prompt_messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message.content, list):
|
||||
@ -846,6 +849,68 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
},
|
||||
}
|
||||
|
||||
def _handle_list_messages(
|
||||
self,
|
||||
*,
|
||||
messages: Sequence[LLMNodeChatModelMessage],
|
||||
context: Optional[str],
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||
) -> Sequence[PromptMessage]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
if message.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=message.jinja2_text or "",
|
||||
jinjia2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=result_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
else:
|
||||
# Get segment group from basic message
|
||||
if context:
|
||||
template = message.text.replace("{#context#}", context)
|
||||
else:
|
||||
template = message.text
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
|
||||
# Process segments for images
|
||||
file_contents = []
|
||||
for segment in segment_group.value:
|
||||
if isinstance(segment, ArrayFileSegment):
|
||||
for file in segment.value:
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
elif isinstance(segment, FileSegment):
|
||||
file = segment.value
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
|
||||
# Create message with text from all segments
|
||||
plain_text = segment_group.text
|
||||
if plain_text:
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
if file_contents:
|
||||
# Create message with image contents
|
||||
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
||||
match role:
|
||||
@ -880,68 +945,6 @@ def _render_jinja2_message(
|
||||
return result_text
|
||||
|
||||
|
||||
def _handle_list_messages(
|
||||
*,
|
||||
messages: Sequence[LLMNodeChatModelMessage],
|
||||
context: Optional[str],
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||
) -> Sequence[PromptMessage]:
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if message.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=message.jinja2_text or "",
|
||||
jinjia2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=result_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
else:
|
||||
# Get segment group from basic message
|
||||
if context:
|
||||
template = message.text.replace("{#context#}", context)
|
||||
else:
|
||||
template = message.text
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
|
||||
# Process segments for images
|
||||
file_contents = []
|
||||
for segment in segment_group.value:
|
||||
if isinstance(segment, ArrayFileSegment):
|
||||
for file in segment.value:
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
if isinstance(segment, FileSegment):
|
||||
file = segment.value
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
|
||||
# Create message with text from all segments
|
||||
plain_text = segment_group.text
|
||||
if plain_text:
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
if file_contents:
|
||||
# Create message with image contents
|
||||
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _calculate_rest_token(
|
||||
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
) -> int:
|
||||
|
||||
@ -179,6 +179,15 @@ class ParameterExtractorNode(LLMNode):
|
||||
error=str(e),
|
||||
metadata={},
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)},
|
||||
error=str(e),
|
||||
metadata={},
|
||||
)
|
||||
|
||||
error = None
|
||||
|
||||
|
||||
@ -89,10 +89,10 @@ class QuestionClassifierNode(LLMNode):
|
||||
)
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
user_query=query,
|
||||
sys_query=query,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
user_files=files,
|
||||
sys_files=files,
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
@ -157,8 +157,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
|
||||
@ -92,6 +92,16 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
error=f"Failed to invoke tool: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||
},
|
||||
error=f"Failed to invoke tool: {str(e)}",
|
||||
error_type="UnknownError",
|
||||
)
|
||||
|
||||
# convert tool messages
|
||||
plain_text, files, json = self._convert_tool_messages(messages)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
class VariableOperatorNodeError(Exception):
|
||||
class VariableOperatorNodeError(ValueError):
|
||||
"""Base error type, don't use directly."""
|
||||
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user