merge main

This commit is contained in:
Joel
2024-12-23 15:33:08 +08:00
396 changed files with 7187 additions and 2056 deletions

View File

@ -45,3 +45,6 @@ class NodeRunResult(BaseModel):
error: Optional[str] = None # error message if status is failed
error_type: Optional[str] = None # error type if status is failed
# single step node run retry
retry_index: int = 0

View File

@ -33,7 +33,7 @@ class GraphRunSucceededEvent(BaseGraphEvent):
class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason")
exceptions_count: Optional[int] = Field(description="exception count", default=0)
exceptions_count: int = Field(description="exception count", default=0)
class GraphRunPartialSucceededEvent(BaseGraphEvent):
@ -97,6 +97,12 @@ class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeRunRetryEvent(NodeRunStartedEvent):
error: str = Field(..., description="error")
retry_index: int = Field(..., description="which retry attempt is about to be performed")
start_at: datetime = Field(..., description="retry start time")
###########################################
# Parallel Branch Events
###########################################

View File

@ -4,6 +4,7 @@ from typing import Any, Optional, cast
from pydantic import BaseModel, Field
from configs import dify_config
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes import NodeType
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
@ -170,7 +171,9 @@ class Graph(BaseModel):
for parallel in parallel_mapping.values():
if parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id
parallel_mapping=parallel_mapping,
level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
parent_parallel_id=parallel.parent_parallel_id,
)
# init answer stream generate routes

View File

@ -5,6 +5,7 @@ import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from datetime import UTC, datetime
from typing import Any, Optional, cast
from flask import Flask, current_app
@ -25,6 +26,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
@ -581,7 +583,7 @@ class GraphEngine:
def _run_node(
self,
node_instance: BaseNode,
node_instance: BaseNode[BaseNodeData],
route_node_state: RouteNodeState,
parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
@ -607,36 +609,120 @@ class GraphEngine:
)
db.session.close()
max_retries = node_instance.node_data.retry_config.max_retries
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
retries = 0
shoudl_continue_retry = True
while shoudl_continue_retry and retries <= max_retries:
try:
# run node
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
generator = node_instance.run()
for item in generator:
if isinstance(item, GraphEngineEvent):
if isinstance(item, BaseIterationEvent):
# add parallel info to iteration event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id
try:
# run node
generator = node_instance.run()
for item in generator:
if isinstance(item, GraphEngineEvent):
if isinstance(item, BaseIterationEvent):
# add parallel info to iteration event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id
yield item
else:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if (
retries == max_retries
and node_instance.node_type == NodeType.HTTP_REQUEST
and run_result.outputs
and not node_instance.should_continue_on_error
):
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
if node_instance.should_retry and retries < max_retries:
retries += 1
route_node_state.node_run_result = run_result
yield NodeRunRetryEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
predecessor_node_id=node_instance.previous_node_id,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
error=run_result.error,
retry_index=retries,
start_at=retry_start_at,
)
time.sleep(retry_interval)
continue
route_node_state.set_finished(run_result=run_result)
yield item
else:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if node_instance.should_continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
item.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
shoudl_continue_retry = False
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
shoudl_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
node_instance.node_id
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if node_instance.should_continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
item.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.llm_usage:
# use the latest usage
self.graph_runtime_state.llm_usage += run_result.llm_usage
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
@ -645,21 +731,23 @@ class GraphEngine:
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
if not run_result.metadata:
run_result.metadata = {}
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = (
parallel_start_node_id
)
if parent_parallel_id and parent_parallel_start_node_id:
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
)
yield NodeRunSucceededEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
@ -670,108 +758,59 @@ class GraphEngine:
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
shoudl_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
node_instance.node_id
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)
if run_result.llm_usage:
# use the latest usage
self.graph_runtime_state.llm_usage += run_result.llm_usage
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
if not run_result.metadata:
run_result.metadata = {}
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
if parent_parallel_id and parent_parallel_start_node_id:
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
)
yield NodeRunSucceededEvent(
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
except GenerateTaskStoppedError:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
return
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e
finally:
db.session.close()
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
except GenerateTaskStoppedError:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
return
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e
finally:
db.session.close()
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
"""

View File

@ -147,6 +147,8 @@ class AnswerStreamGeneratorRouter:
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:
continue
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
if (

View File

@ -106,12 +106,25 @@ class DefaultValue(BaseModel):
return self
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
version: str = "1"
retry_config: RetryConfig = RetryConfig()
@property
def default_value_dict(self):

View File

@ -1,4 +1,4 @@
class BaseNodeError(Exception):
class BaseNodeError(ValueError):
"""Base class for node errors."""
pass

View File

@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from models.workflow import WorkflowNodeExecutionStatus
@ -72,7 +72,11 @@ class BaseNode(Generic[GenericNodeData]):
result = self._run()
except Exception as e:
logger.exception(f"Node {self.node_id} failed to run")
result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError")
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
error_type="WorkflowNodeError",
)
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(run_result=result)
@ -143,3 +147,12 @@ class BaseNode(Generic[GenericNodeData]):
bool: if should continue on error
"""
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
@property
def should_retry(self) -> bool:
"""judge if should retry
Returns:
bool: if should retry
"""
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union
from typing import Any, Optional
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
@ -59,7 +59,7 @@ class CodeNode(BaseNode[CodeNodeData]):
)
# Transform result
result = self._transform_result(result, self.node_data.outputs)
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
@ -67,18 +67,17 @@ class CodeNode(BaseNode[CodeNodeData]):
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
def _check_string(self, value: str, variable: str) -> str:
def _check_string(self, value: str | None, variable: str) -> str | None:
"""
Check string
:param value: value
:param variable: variable
:return:
"""
if value is None:
return None
if not isinstance(value, str):
if value is None:
return None
else:
raise OutputValidationError(f"Output variable `{variable}` must be a string")
raise OutputValidationError(f"Output variable `{variable}` must be a string")
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
raise OutputValidationError(
@ -88,18 +87,17 @@ class CodeNode(BaseNode[CodeNodeData]):
return value.replace("\x00", "")
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
"""
Check number
:param value: value
:param variable: variable
:return:
"""
if value is None:
return None
if not isinstance(value, int | float):
if value is None:
return None
else:
raise OutputValidationError(f"Output variable `{variable}` must be a number")
raise OutputValidationError(f"Output variable `{variable}` must be a number")
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
raise OutputValidationError(
@ -118,14 +116,12 @@ class CodeNode(BaseNode[CodeNodeData]):
return value
def _transform_result(
self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1
) -> dict:
"""
Transform result
:param result: result
:param output_schema: output schema
:return:
"""
self,
result: Mapping[str, Any],
output_schema: Optional[dict[str, CodeNodeData.Output]],
prefix: str = "",
depth: int = 1,
):
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")

View File

@ -1,6 +1,7 @@
import csv
import io
import json
import logging
import os
import tempfile
@ -8,12 +9,6 @@ import docx
import pandas as pd
import pypdfium2 # type: ignore
import yaml # type: ignore
from unstructured.partition.api import partition_via_api
from unstructured.partition.email import partition_email
from unstructured.partition.epub import partition_epub
from unstructured.partition.msg import partition_msg
from unstructured.partition.ppt import partition_ppt
from unstructured.partition.pptx import partition_pptx
from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
@ -28,6 +23,8 @@ from models.workflow import WorkflowNodeExecutionStatus
from .entities import DocumentExtractorNodeData
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
logger = logging.getLogger(__name__)
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
"""
@ -183,10 +180,43 @@ def _extract_text_from_pdf(file_content: bytes) -> str:
def _extract_text_from_doc(file_content: bytes) -> str:
"""
Extract text from a DOC/DOCX file.
For now support only paragraph and table add more if needed
"""
try:
doc_file = io.BytesIO(file_content)
doc = docx.Document(doc_file)
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
text = []
# Process paragraphs
for paragraph in doc.paragraphs:
if paragraph.text.strip():
text.append(paragraph.text)
# Process tables
for table in doc.tables:
# Table header
try:
# table maybe cause errors so ignore it.
if len(table.rows) > 0 and table.rows[0].cells is not None:
# Check if any cell in the table has text
has_content = False
for row in table.rows:
if any(cell.text.strip() for cell in row.cells):
has_content = True
break
if has_content:
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
for row in table.rows[1:]:
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
text.append(markdown_table)
except Exception as e:
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
continue
return "\n".join(text)
except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
@ -256,6 +286,8 @@ def _extract_text_from_excel(file_content: bytes) -> str:
def _extract_text_from_ppt(file_content: bytes) -> str:
from unstructured.partition.ppt import partition_ppt
try:
with io.BytesIO(file_content) as file:
elements = partition_ppt(file=file)
@ -265,6 +297,9 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
def _extract_text_from_pptx(file_content: bytes) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.pptx import partition_pptx
try:
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
@ -287,6 +322,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
def _extract_text_from_epub(file_content: bytes) -> str:
from unstructured.partition.epub import partition_epub
try:
with io.BytesIO(file_content) as file:
elements = partition_epub(file=file)
@ -296,6 +333,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:
def _extract_text_from_eml(file_content: bytes) -> str:
from unstructured.partition.email import partition_email
try:
with io.BytesIO(file_content) as file:
elements = partition_email(file=file)
@ -305,6 +344,8 @@ def _extract_text_from_eml(file_content: bytes) -> str:
def _extract_text_from_msg(file_content: bytes) -> str:
from unstructured.partition.msg import partition_msg
try:
with io.BytesIO(file_content) as file:
elements = partition_msg(file=file)

View File

@ -135,6 +135,8 @@ class EndStreamGeneratorRouter:
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:
continue
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
if source_node_type in {
NodeType.IF_ELSE.value,

View File

@ -35,3 +35,4 @@ class FailBranchSourceHandle(StrEnum):
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.TOOL, NodeType.HTTP_REQUEST]

View File

@ -1,4 +1,10 @@
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from .event import (
ModelInvokeCompletedEvent,
RunCompletedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
RunStreamChunkEvent,
)
from .types import NodeEvent
__all__ = [
@ -6,5 +12,6 @@ __all__ = [
"NodeEvent",
"RunCompletedEvent",
"RunRetrieverResourceEvent",
"RunRetryEvent",
"RunStreamChunkEvent",
]

View File

@ -1,7 +1,10 @@
from datetime import datetime
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.node_entities import NodeRunResult
from models.workflow import WorkflowNodeExecutionStatus
class RunCompletedEvent(BaseModel):
@ -26,3 +29,19 @@ class ModelInvokeCompletedEvent(BaseModel):
text: str
usage: LLMUsage
finish_reason: str | None = None
class RunRetryEvent(BaseModel):
"""Node Run Retry event"""
error: str = Field(..., description="error")
retry_index: int = Field(..., description="Retry attempt number")
start_at: datetime = Field(..., description="Retry start time")
class SingleStepRetryEvent(NodeRunResult):
"""Single step retry event"""
status: str = WorkflowNodeExecutionStatus.RETRY.value
elapsed_time: float = Field(..., description="elapsed time")

View File

@ -45,6 +45,7 @@ class Executor:
headers: dict[str, str]
auth: HttpRequestNodeAuthorization
timeout: HttpRequestNodeTimeout
max_retries: int
boundary: str
@ -54,6 +55,7 @@ class Executor:
node_data: HttpRequestNodeData,
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
@ -73,6 +75,7 @@ class Executor:
self.files = None
self.data = None
self.json = None
self.max_retries = max_retries
# init template
self.variable_pool = variable_pool
@ -241,11 +244,12 @@ class Executor:
"params": self.params,
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
"follow_redirects": True,
"max_retries": self.max_retries,
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response = getattr(ssrf_proxy, self.method)(**request_args)
except ssrf_proxy.MaxRetriesExceededError as e:
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e))
return response

View File

@ -1,4 +1,5 @@
import logging
import mimetypes
from collections.abc import Mapping, Sequence
from typing import Any
@ -51,6 +52,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
},
},
"retry_config": {
"max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES,
"retry_interval": 0.5 * (2**2),
"retry_enabled": True,
},
}
def _run(self) -> NodeRunResult:
@ -60,12 +66,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
node_data=self.node_data,
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
)
process_data["request"] = http_executor.to_log()
response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response)
if not response.response.is_success and self.should_continue_on_error:
if not response.response.is_success and (self.should_continue_on_error or self.should_retry):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
@ -156,20 +163,24 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
def extract_files(self, url: str, response: Response) -> list[File]:
"""
Extract files from response
Extract files from response by checking both Content-Type header and URL
"""
files = []
is_file = response.is_file
content_type = response.content_type
content = response.content
if is_file and content_type:
if is_file:
# Guess file extension from URL or Content-Type header
filename = url.split("?")[0].split("/")[-1] or ""
mime_type = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
tool_file = ToolFileManager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
file_binary=content,
mimetype=content_type,
mimetype=mime_type,
)
mapping = {

View File

@ -70,7 +70,20 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
except KnowledgeRetrievalNodeError as e:
logger.warning("Error when running knowledge retrieval node")
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
available_datasets = []
@ -160,18 +173,18 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
reranking_model = None
weights = None
all_documents = dataset_retrieval.multiple_retrieve(
self.app_id,
self.tenant_id,
self.user_id,
self.user_from.value,
available_datasets,
query,
node_data.multiple_retrieval_config.top_k,
node_data.multiple_retrieval_config.score_threshold,
node_data.multiple_retrieval_config.reranking_mode,
reranking_model,
weights,
node_data.multiple_retrieval_config.reranking_enable,
app_id=self.app_id,
tenant_id=self.tenant_id,
user_id=self.user_id,
user_from=self.user_from.value,
available_datasets=available_datasets,
query=query,
top_k=node_data.multiple_retrieval_config.top_k,
score_threshold=node_data.multiple_retrieval_config.score_threshold,
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
reranking_model=reranking_model,
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
)
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]

View File

@ -50,6 +50,7 @@ class PromptConfig(BaseModel):
class LLMNodeChatModelMessage(ChatModelMessage):
text: str = ""
jinja2_text: Optional[str] = None

View File

@ -145,8 +145,8 @@ class LLMNode(BaseNode[LLMNodeData]):
query = query_variable.text
prompt_messages, stop = self._fetch_prompt_messages(
user_query=query,
user_files=files,
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
@ -545,8 +545,8 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_prompt_messages(
self,
*,
user_query: str | None = None,
user_files: Sequence["File"],
sys_query: str | None = None,
sys_files: Sequence["File"],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
@ -562,7 +562,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
_handle_list_messages(
self._handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
@ -581,14 +581,14 @@ class LLMNode(BaseNode[LLMNodeData]):
prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if user_query:
if sys_query:
message = LLMNodeChatModelMessage(
text=user_query,
text=sys_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
_handle_list_messages(
self._handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
@ -635,24 +635,27 @@ class LLMNode(BaseNode[LLMNodeData]):
raise ValueError("Invalid prompt content type")
# Add current query to the prompt message
if user_query:
if sys_query:
if prompt_content_type == str:
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
content_item.data = user_query + "\n" + content_item.data
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
else:
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
if vision_enabled and user_files:
# The sys_files will be deprecated later
if vision_enabled and sys_files:
file_prompts = []
for file in user_files:
for file in sys_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
@ -662,7 +665,7 @@ class LLMNode(BaseNode[LLMNodeData]):
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Filter prompt messages
# Remove empty messages and filter unsupported content
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
@ -846,6 +849,68 @@ class LLMNode(BaseNode[LLMNodeData]):
},
}
def _handle_list_messages(
self,
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages: list[PromptMessage] = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
elif isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
match role:
@ -880,68 +945,6 @@ def _render_jinja2_message(
return result_text
def _handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
if isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
def _calculate_rest_token(
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:

View File

@ -179,6 +179,15 @@ class ParameterExtractorNode(LLMNode):
error=str(e),
metadata={},
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)},
error=str(e),
metadata={},
)
error = None

View File

@ -89,10 +89,10 @@ class QuestionClassifierNode(LLMNode):
)
prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template,
user_query=query,
sys_query=query,
memory=memory,
model_config=model_config,
user_files=files,
sys_files=files,
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
@ -157,8 +157,7 @@ class QuestionClassifierNode(LLMNode):
},
llm_usage=usage,
)
except ValueError as e:
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,

View File

@ -92,6 +92,16 @@ class ToolNode(BaseNode[ToolNodeData]):
error=f"Failed to invoke tool: {str(e)}",
error_type=type(e).__name__,
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
error=f"Failed to invoke tool: {str(e)}",
error_type="UnknownError",
)
# convert tool messages
plain_text, files, json = self._convert_tool_messages(messages)

View File

@ -1,4 +1,4 @@
class VariableOperatorNodeError(Exception):
class VariableOperatorNodeError(ValueError):
"""Base error type, don't use directly."""
pass