Merge fix/chore-fix into dev/plugin-deploy

This commit is contained in:
Yeuoly
2024-12-25 15:12:05 +08:00
733 changed files with 7774 additions and 5226 deletions

View File

@ -147,6 +147,8 @@ class AnswerStreamGeneratorRouter:
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:
continue
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
if (

View File

@ -60,11 +60,10 @@ class AnswerStreamProcessor(StreamProcessor):
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
# remove unreachable nodes
self._remove_unreachable_nodes(event)
# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(event)
yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event))
else:
yield event
@ -131,7 +130,7 @@ class AnswerStreamProcessor(StreamProcessor):
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
from_variable_selector=value_selector,
from_variable_selector=list(value_selector),
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,

View File

@ -1,10 +1,13 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
logger = logging.getLogger(__name__)
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
@ -16,7 +19,7 @@ class StreamProcessor(ABC):
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
@ -29,15 +32,24 @@ class StreamProcessor(ABC):
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
reachable_node_ids: list[str] = []
unreachable_first_node_ids: list[str] = []
if finished_node_id not in self.graph.edge_mapping:
logger.warning(f"node {finished_node_id} has no edge mapping")
return
for edge in self.graph.edge_mapping[finished_node_id]:
if (
edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify
):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
# remove unreachable nodes
# FIXME: because of the code branch can combine directly, so for answer node
# we remove the node maybe shortcut the answer node, so comment this code for now
# there is not effect on the answer node and the workflow, when we have a better solution
# we can open this code. Issues: #11542 #9560 #10638 #10564
# reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)

View File

@ -38,7 +38,8 @@ class DefaultValue(BaseModel):
@staticmethod
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
@staticmethod
def _convert_number(value: str) -> float:
@ -84,7 +85,7 @@ class DefaultValue(BaseModel):
},
}
validator = type_validators.get(self.type)
validator: dict[str, Any] = type_validators.get(self.type, {})
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
@ -106,12 +107,25 @@ class DefaultValue(BaseModel):
return self
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
version: str = "1"
retry_config: RetryConfig = RetryConfig()
@property
def default_value_dict(self):

View File

@ -1,4 +1,4 @@
class BaseNodeError(Exception):
class BaseNodeError(ValueError):
"""Base class for node errors."""
pass

View File

@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from models.workflow import WorkflowNodeExecutionStatus
@ -72,7 +72,11 @@ class BaseNode(Generic[GenericNodeData]):
result = self._run()
except Exception as e:
logger.exception(f"Node {self.node_id} failed to run")
result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError")
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
error_type="WorkflowNodeError",
)
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(run_result=result)
@ -143,3 +147,12 @@ class BaseNode(Generic[GenericNodeData]):
bool: if should continue on error
"""
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
@property
def should_retry(self) -> bool:
"""judge if should retry
Returns:
bool: if should retry
"""
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union
from typing import Any, Optional
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
@ -59,7 +59,7 @@ class CodeNode(BaseNode[CodeNodeData]):
)
# Transform result
result = self._transform_result(result, self.node_data.outputs)
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
@ -67,18 +67,17 @@ class CodeNode(BaseNode[CodeNodeData]):
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
def _check_string(self, value: str, variable: str) -> str:
def _check_string(self, value: str | None, variable: str) -> str | None:
"""
Check string
:param value: value
:param variable: variable
:return:
"""
if value is None:
return None
if not isinstance(value, str):
if value is None:
return None
else:
raise OutputValidationError(f"Output variable `{variable}` must be a string")
raise OutputValidationError(f"Output variable `{variable}` must be a string")
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
raise OutputValidationError(
@ -88,18 +87,17 @@ class CodeNode(BaseNode[CodeNodeData]):
return value.replace("\x00", "")
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
"""
Check number
:param value: value
:param variable: variable
:return:
"""
if value is None:
return None
if not isinstance(value, int | float):
if value is None:
return None
else:
raise OutputValidationError(f"Output variable `{variable}` must be a number")
raise OutputValidationError(f"Output variable `{variable}` must be a number")
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
raise OutputValidationError(
@ -118,18 +116,16 @@ class CodeNode(BaseNode[CodeNodeData]):
return value
def _transform_result(
self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1
) -> dict:
"""
Transform result
:param result: result
:param output_schema: output schema
:return:
"""
self,
result: Mapping[str, Any],
output_schema: Optional[dict[str, CodeNodeData.Output]],
prefix: str = "",
depth: int = 1,
):
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
transformed_result = {}
transformed_result: dict[str, Any] = {}
if output_schema is None:
# validate output thought instance type
for output_name, output_value in result.items():

View File

@ -14,7 +14,7 @@ class CodeNodeData(BaseNodeData):
class Output(BaseModel):
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
children: Optional[dict[str, "Output"]] = None
children: Optional[dict[str, "CodeNodeData.Output"]] = None
class Dependency(BaseModel):
name: str

View File

@ -1,19 +1,15 @@
import csv
import io
import json
import logging
import os
import tempfile
from typing import cast
import docx
import pandas as pd
import pypdfium2 # type: ignore
import yaml # type: ignore
from unstructured.partition.api import partition_via_api
from unstructured.partition.email import partition_email
from unstructured.partition.epub import partition_epub
from unstructured.partition.msg import partition_msg
from unstructured.partition.ppt import partition_ppt
from unstructured.partition.pptx import partition_pptx
from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
@ -28,6 +24,8 @@ from models.workflow import WorkflowNodeExecutionStatus
from .entities import DocumentExtractorNodeData
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
logger = logging.getLogger(__name__)
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
"""
@ -162,7 +160,7 @@ def _extract_text_from_yaml(file_content: bytes) -> str:
"""Extract the content from yaml file"""
try:
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore"))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
except (UnicodeDecodeError, yaml.YAMLError) as e:
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
@ -183,10 +181,43 @@ def _extract_text_from_pdf(file_content: bytes) -> str:
def _extract_text_from_doc(file_content: bytes) -> str:
"""
Extract text from a DOC/DOCX file.
For now support only paragraph and table add more if needed
"""
try:
doc_file = io.BytesIO(file_content)
doc = docx.Document(doc_file)
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
text = []
# Process paragraphs
for paragraph in doc.paragraphs:
if paragraph.text.strip():
text.append(paragraph.text)
# Process tables
for table in doc.tables:
# Table header
try:
# table maybe cause errors so ignore it.
if len(table.rows) > 0 and table.rows[0].cells is not None:
# Check if any cell in the table has text
has_content = False
for row in table.rows:
if any(cell.text.strip() for cell in row.cells):
has_content = True
break
if has_content:
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
for row in table.rows[1:]:
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
text.append(markdown_table)
except Exception as e:
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
continue
return "\n".join(text)
except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
@ -199,9 +230,9 @@ def _download_file_content(file: File) -> bytes:
raise FileDownloadError("Missing URL for remote file")
response = ssrf_proxy.get(file.remote_url)
response.raise_for_status()
return response.content
return cast(bytes, response.content)
else:
return file_manager.download(file)
return cast(bytes, file_manager.download(file))
except Exception as e:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
@ -256,6 +287,8 @@ def _extract_text_from_excel(file_content: bytes) -> str:
def _extract_text_from_ppt(file_content: bytes) -> str:
from unstructured.partition.ppt import partition_ppt
try:
with io.BytesIO(file_content) as file:
elements = partition_ppt(file=file)
@ -265,6 +298,9 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
def _extract_text_from_pptx(file_content: bytes) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.pptx import partition_pptx
try:
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
@ -287,6 +323,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
def _extract_text_from_epub(file_content: bytes) -> str:
from unstructured.partition.epub import partition_epub
try:
with io.BytesIO(file_content) as file:
elements = partition_epub(file=file)
@ -296,6 +334,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:
def _extract_text_from_eml(file_content: bytes) -> str:
from unstructured.partition.email import partition_email
try:
with io.BytesIO(file_content) as file:
elements = partition_email(file=file)
@ -305,6 +345,8 @@ def _extract_text_from_eml(file_content: bytes) -> str:
def _extract_text_from_msg(file_content: bytes) -> str:
from unstructured.partition.msg import partition_msg
try:
with io.BytesIO(file_content) as file:
elements = partition_msg(file=file)

View File

@ -67,7 +67,7 @@ class EndStreamGeneratorRouter:
and node_type == NodeType.LLM.value
and variable_selector.value_selector[1] == "text"
):
value_selectors.append(variable_selector.value_selector)
value_selectors.append(list(variable_selector.value_selector))
return value_selectors
@ -119,8 +119,7 @@ class EndStreamGeneratorRouter:
current_node_id: str,
end_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]],
# type: ignore[name-defined]
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
end_dependencies: dict[str, list[str]],
) -> None:
"""
@ -135,6 +134,8 @@ class EndStreamGeneratorRouter:
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:
continue
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
if source_node_type in {
NodeType.IF_ELSE.value,

View File

@ -23,7 +23,7 @@ class EndStreamProcessor(StreamProcessor):
self.route_position[end_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
self.has_output = False
self.output_node_ids = set()
self.output_node_ids: set[str] = set()
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
for event in generator:

View File

@ -36,3 +36,4 @@ class FailBranchSourceHandle(StrEnum):
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE

View File

@ -1,4 +1,10 @@
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from .event import (
ModelInvokeCompletedEvent,
RunCompletedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
RunStreamChunkEvent,
)
from .types import NodeEvent
__all__ = [
@ -6,5 +12,6 @@ __all__ = [
"NodeEvent",
"RunCompletedEvent",
"RunRetrieverResourceEvent",
"RunRetryEvent",
"RunStreamChunkEvent",
]

View File

@ -1,7 +1,10 @@
from datetime import datetime
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.node_entities import NodeRunResult
from models.workflow import WorkflowNodeExecutionStatus
class RunCompletedEvent(BaseModel):
@ -26,3 +29,19 @@ class ModelInvokeCompletedEvent(BaseModel):
text: str
usage: LLMUsage
finish_reason: str | None = None
class RunRetryEvent(BaseModel):
"""Node Run Retry event"""
error: str = Field(..., description="error")
retry_index: int = Field(..., description="Retry attempt number")
start_at: datetime = Field(..., description="Retry start time")
class SingleStepRetryEvent(NodeRunResult):
"""Single step retry event"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RETRY
elapsed_time: float = Field(..., description="elapsed time")

View File

@ -16,3 +16,7 @@ class InvalidHttpMethodError(HttpRequestNodeError):
class ResponseSizeError(HttpRequestNodeError):
"""Raised when the response size exceeds the allowed threshold."""
class RequestBodyError(HttpRequestNodeError):
"""Raised when the request body is invalid."""

View File

@ -23,6 +23,7 @@ from .exc import (
FileFetchError,
HttpRequestNodeError,
InvalidHttpMethodError,
RequestBodyError,
ResponseSizeError,
)
@ -45,6 +46,7 @@ class Executor:
headers: dict[str, str]
auth: HttpRequestNodeAuthorization
timeout: HttpRequestNodeTimeout
max_retries: int
boundary: str
@ -54,6 +56,7 @@ class Executor:
node_data: HttpRequestNodeData,
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
@ -73,6 +76,7 @@ class Executor:
self.files = None
self.data = None
self.json = None
self.max_retries = max_retries
# init template
self.variable_pool = variable_pool
@ -103,9 +107,9 @@ class Executor:
if not (key := key.strip()):
continue
value = value[0].strip() if value else ""
value_str = value[0].strip() if value else ""
result.append(
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value).text)
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text)
)
self.params = result
@ -140,13 +144,19 @@ class Executor:
case "none":
self.content = ""
case "raw-text":
if len(data) != 1:
raise RequestBodyError("raw-text body type should have exactly one item")
self.content = self.variable_pool.convert_template(data[0].value).text
case "json":
if len(data) != 1:
raise RequestBodyError("json body type should have exactly one item")
json_string = self.variable_pool.convert_template(data[0].value).text
json_object = json.loads(json_string, strict=False)
self.json = json_object
# self.json = self._parse_object_contains_variables(json_object)
case "binary":
if len(data) != 1:
raise RequestBodyError("binary body type should have exactly one item")
file_selector = data[0].file
file_variable = self.variable_pool.get_file(file_selector)
if file_variable is None:
@ -172,9 +182,10 @@ class Executor:
self.variable_pool.convert_template(item.key).text: item.file
for item in filter(lambda item: item.type == "file", data)
}
files: dict[str, Any] = {}
files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()}
files = {k: v for k, v in files.items() if v is not None}
files = {k: variable.value for k, variable in files.items()}
files = {k: variable.value for k, variable in files.items() if variable is not None}
files = {
k: (v.filename, file_manager.download(v), v.mime_type or "application/octet-stream")
for k, v in files.items()
@ -241,13 +252,15 @@ class Executor:
"params": self.params,
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
"follow_redirects": True,
"max_retries": self.max_retries,
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response = getattr(ssrf_proxy, self.method)(**request_args)
except ssrf_proxy.MaxRetriesExceededError as e:
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e))
return response
# FIXME: fix type ignore, this maybe httpx type issue
return response # type: ignore
def invoke(self) -> Response:
# assemble headers
@ -289,35 +302,37 @@ class Executor:
continue
raw += f"{k}: {v}\r\n"
body = ""
body_string = ""
if self.files:
for k, v in self.files.items():
body += f"--{boundary}\r\n"
body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
body += f"{v[1]}\r\n"
body += f"--{boundary}--\r\n"
body_string += f"--{boundary}\r\n"
body_string += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
body_string += f"{v[1]}\r\n"
body_string += f"--{boundary}--\r\n"
elif self.node_data.body:
if self.content:
if isinstance(self.content, str):
body = self.content
body_string = self.content
elif isinstance(self.content, bytes):
body = self.content.decode("utf-8", errors="replace")
body_string = self.content.decode("utf-8", errors="replace")
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
body = urlencode(self.data)
body_string = urlencode(self.data)
elif self.data and self.node_data.body.type == "form-data":
for key, value in self.data.items():
body += f"--{boundary}\r\n"
body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
body += f"{value}\r\n"
body += f"--{boundary}--\r\n"
body_string += f"--{boundary}\r\n"
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
body_string += f"{value}\r\n"
body_string += f"--{boundary}--\r\n"
elif self.json:
body = json.dumps(self.json)
body_string = json.dumps(self.json)
elif self.node_data.body.type == "raw-text":
body = self.node_data.body.data[0].value
if body:
raw += f"Content-Length: {len(body)}\r\n"
if len(self.node_data.body.data) != 1:
raise RequestBodyError("raw-text body type should have exactly one item")
body_string = self.node_data.body.data[0].value
if body_string:
raw += f"Content-Length: {len(body_string)}\r\n"
raw += "\r\n" # Empty line between headers and body
raw += body
raw += body_string
return raw

View File

@ -1,6 +1,7 @@
import logging
import mimetypes
from collections.abc import Mapping, Sequence
from typing import Any
from typing import Any, Optional
from configs import dify_config
from core.file import File, FileTransferMethod
@ -19,7 +20,7 @@ from .entities import (
HttpRequestNodeTimeout,
Response,
)
from .exc import HttpRequestNodeError
from .exc import HttpRequestNodeError, RequestBodyError
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
@ -35,7 +36,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
_node_type = NodeType.HTTP_REQUEST
@classmethod
def get_default_config(cls, filters: dict | None = None) -> dict:
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
return {
"type": "http-request",
"config": {
@ -51,6 +52,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
},
},
"retry_config": {
"max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES,
"retry_interval": 0.5 * (2**2),
"retry_enabled": True,
},
}
def _run(self) -> NodeRunResult:
@ -60,12 +66,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
node_data=self.node_data,
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
)
process_data["request"] = http_executor.to_log()
response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response)
if not response.response.is_success and self.should_continue_on_error:
if not response.response.is_success and (self.should_continue_on_error or self.should_retry):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
@ -129,9 +136,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
data = node_data.body.data
match body_type:
case "binary":
if len(data) != 1:
raise RequestBodyError("invalid body data, should have only one item")
selector = data[0].file
selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector))
case "json" | "raw-text":
if len(data) != 1:
raise RequestBodyError("invalid body data, should have only one item")
selectors += variable_template_parser.extract_selectors_from_template(data[0].key)
selectors += variable_template_parser.extract_selectors_from_template(data[0].value)
case "x-www-form-urlencoded":
@ -149,27 +160,31 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
)
mapping = {}
for selector in selectors:
mapping[node_id + "." + selector.variable] = selector.value_selector
for selector_iter in selectors:
mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector
return mapping
def extract_files(self, url: str, response: Response) -> list[File]:
"""
Extract files from response
Extract files from response by checking both Content-Type header and URL
"""
files = []
is_file = response.is_file
content_type = response.content_type
content = response.content
if is_file and content_type:
if is_file:
# Guess file extension from URL or Content-Type header
filename = url.split("?")[0].split("/")[-1] or ""
mime_type = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
tool_file = ToolFileManager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
file_binary=content,
mimetype=content_type,
mimetype=mime_type,
)
mapping = {

View File

@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
from flask import Flask, current_app
from configs import dify_config
from core.variables import IntegerVariable
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeRunResult,
@ -76,12 +76,15 @@ class IterationNode(BaseNode[IterationNodeData]):
"""
Run the node.
"""
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
if not iterator_list_segment:
raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found")
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
if len(iterator_list_segment.value) == 0:
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -90,7 +93,7 @@ class IterationNode(BaseNode[IterationNodeData]):
)
return
iterator_list_value = iterator_list_segment.to_object()
iterator_list_value = variable.to_object()
if not isinstance(iterator_list_value, list):
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
@ -360,13 +363,16 @@ class IterationNode(BaseNode[IterationNodeData]):
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
if self.node_data.is_parallel:
metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
else:
metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index
metadata = {
**metadata,
NodeRunMetadataKey.ITERATION_ID: self.node_id,
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID
if self.node_data.is_parallel
else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id
if self.node_data.is_parallel
else iter_run_index,
}
event.route_node_state.node_run_result.metadata = metadata
return event

View File

@ -70,7 +70,20 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
except KnowledgeRetrievalNodeError as e:
logger.warning("Error when running knowledge retrieval node")
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
available_datasets = []
@ -134,6 +147,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
planning_strategy=planning_strategy,
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
@ -144,6 +159,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
reranking_model = None
weights = None
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
if node_data.multiple_retrieval_config.weights is None:
raise ValueError("weights is required")
reranking_model = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
@ -160,18 +177,20 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
reranking_model = None
weights = None
all_documents = dataset_retrieval.multiple_retrieve(
self.app_id,
self.tenant_id,
self.user_id,
self.user_from.value,
available_datasets,
query,
node_data.multiple_retrieval_config.top_k,
node_data.multiple_retrieval_config.score_threshold,
node_data.multiple_retrieval_config.reranking_mode,
reranking_model,
weights,
node_data.multiple_retrieval_config.reranking_enable,
app_id=self.app_id,
tenant_id=self.tenant_id,
user_id=self.user_id,
user_from=self.user_from.value,
available_datasets=available_datasets,
query=query,
top_k=node_data.multiple_retrieval_config.top_k,
score_threshold=node_data.multiple_retrieval_config.score_threshold
if node_data.multiple_retrieval_config.score_threshold is not None
else 0.0,
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
reranking_model=reranking_model,
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
)
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
@ -192,7 +211,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
"content": item.page_content,
}
retrieval_resource_list.append(source)
document_score_list = {}
document_score_list: dict[str, float] = {}
# deal with dify documents
if dify_documents:
document_score_list = {}
@ -247,7 +266,9 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
retrieval_resource_list.append(source)
if retrieval_resource_list:
retrieval_resource_list = sorted(
retrieval_resource_list, key=lambda x: x.get("metadata").get("score") or 0.0, reverse=True
retrieval_resource_list,
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
reverse=True,
)
position = 1
for item in retrieval_resource_list:
@ -282,6 +303,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
:param node_data: node data
:return:
"""
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
model_name = node_data.single_retrieval_config.model.name
provider_name = node_data.single_retrieval_config.model.provider

View File

@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence
from typing import Literal, Union
from typing import Any, Literal, Union
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
@ -17,9 +17,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
_node_type = NodeType.LIST_OPERATOR
def _run(self):
inputs = {}
process_data = {}
outputs = {}
inputs: dict[str, list] = {}
process_data: dict[str, list] = {}
outputs: dict[str, Any] = {}
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
if variable is None:
@ -93,6 +93,8 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
def _apply_filter(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
filter_func: Callable[[Any], bool]
result: list[Any] = []
for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
@ -236,6 +238,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
extract_func: Callable[[File], Any]
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
@ -249,47 +252,47 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
raise InvalidKeyError(f"Invalid key: {key}")
def _contains(value: str):
def _contains(value: str) -> Callable[[str], bool]:
return lambda x: value in x
def _startswith(value: str):
def _startswith(value: str) -> Callable[[str], bool]:
return lambda x: x.startswith(value)
def _endswith(value: str):
def _endswith(value: str) -> Callable[[str], bool]:
return lambda x: x.endswith(value)
def _is(value: str):
def _is(value: str) -> Callable[[str], bool]:
return lambda x: x is value
def _in(value: str | Sequence[str]):
def _in(value: str | Sequence[str]) -> Callable[[str], bool]:
return lambda x: x in value
def _eq(value: int | float):
def _eq(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x == value
def _ne(value: int | float):
def _ne(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x != value
def _lt(value: int | float):
def _lt(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x < value
def _le(value: int | float):
def _le(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x <= value
def _gt(value: int | float):
def _gt(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x > value
def _ge(value: int | float):
def _ge(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x >= value
@ -302,6 +305,7 @@ def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
extract_func: Callable[[File], Any]
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
extract_func = _get_file_extract_string_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")

View File

@ -50,6 +50,7 @@ class PromptConfig(BaseModel):
class LLMNodeChatModelMessage(ChatModelMessage):
text: str = ""
jinja2_text: Optional[str] = None

View File

@ -88,8 +88,8 @@ class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]:
node_inputs = None
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
node_inputs: Optional[dict[str, Any]] = None
process_data = None
try:
@ -145,8 +145,8 @@ class LLMNode(BaseNode[LLMNodeData]):
query = query_variable.text
prompt_messages, stop = self._fetch_prompt_messages(
user_query=query,
user_files=files,
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
@ -196,7 +196,6 @@ class LLMNode(BaseNode[LLMNodeData]):
error_type=type(e).__name__,
)
)
return
except Exception as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
@ -206,7 +205,6 @@ class LLMNode(BaseNode[LLMNodeData]):
process_data=process_data,
)
)
return
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
@ -302,7 +300,7 @@ class LLMNode(BaseNode[LLMNodeData]):
return messages
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variables = {}
variables: dict[str, Any] = {}
if not node_data.prompt_config:
return variables
@ -319,7 +317,7 @@ class LLMNode(BaseNode[LLMNodeData]):
"""
# check if it's a context structure
if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
return input_dict["content"]
return str(input_dict["content"])
# else, parse the dict
try:
@ -545,8 +543,8 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_prompt_messages(
self,
*,
user_query: str | None = None,
user_files: Sequence["File"],
sys_query: str | None = None,
sys_files: Sequence["File"],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
@ -557,12 +555,13 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages = []
# FIXME: fix the type error cause prompt_messages is type quick a few times
prompt_messages: list[Any] = []
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
_handle_list_messages(
self._handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
@ -581,14 +580,14 @@ class LLMNode(BaseNode[LLMNodeData]):
prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if user_query:
if sys_query:
message = LLMNodeChatModelMessage(
text=user_query,
text=sys_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
_handle_list_messages(
self._handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
@ -635,24 +634,27 @@ class LLMNode(BaseNode[LLMNodeData]):
raise ValueError("Invalid prompt content type")
# Add current query to the prompt message
if user_query:
if sys_query:
if prompt_content_type == str:
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
content_item.data = user_query + "\n" + content_item.data
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
else:
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
if vision_enabled and user_files:
# The sys_files will be deprecated later
if vision_enabled and sys_files:
file_prompts = []
for file in user_files:
for file in sys_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
@ -662,7 +664,7 @@ class LLMNode(BaseNode[LLMNodeData]):
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Filter prompt messages
# Remove empty messages and filter unsupported content
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
@ -780,7 +782,7 @@ class LLMNode(BaseNode[LLMNodeData]):
else:
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
variable_mapping = {}
variable_mapping: dict[str, Any] = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
@ -846,6 +848,68 @@ class LLMNode(BaseNode[LLMNodeData]):
},
}
def _handle_list_messages(
self,
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages: list[PromptMessage] = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
elif isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
match role:
@ -880,68 +944,6 @@ def _render_jinja2_message(
return result_text
def _handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
if isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
def _calculate_rest_token(
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:
@ -978,7 +980,7 @@ def _handle_memory_chat_mode(
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
memory_messages = []
memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)

View File

@ -14,8 +14,8 @@ class LoopNode(BaseNode[LoopNodeData]):
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
def _run(self) -> LoopState:
return super()._run()
def _run(self) -> LoopState: # type: ignore
return super()._run() # type: ignore
@classmethod
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
@ -28,7 +28,7 @@ class LoopNode(BaseNode[LoopNodeData]):
# TODO waiting for implementation
return [
Condition(
Condition( # type: ignore
variable_selector=[node_id, "index"],
comparison_operator="",
value_type="value_selector",

View File

@ -25,7 +25,7 @@ class ParameterConfig(BaseModel):
raise ValueError("Parameter name is required")
if value in {"__reason", "__is_success"}:
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
return value
return str(value)
class ParameterExtractorNodeData(BaseNodeData):
@ -52,7 +52,7 @@ class ParameterExtractorNodeData(BaseNodeData):
:return: parameter json schema
"""
parameters = {"type": "object", "properties": {}, "required": []}
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
for parameter in self.parameters:
parameter_schema: dict[str, Any] = {"description": parameter.description}

View File

@ -63,7 +63,8 @@ class ParameterExtractorNode(LLMNode):
Parameter Extractor Node.
"""
_node_data_cls = ParameterExtractorNodeData
# FIXME: figure out why here is different from super class
_node_data_cls = ParameterExtractorNodeData # type: ignore
_node_type = NodeType.PARAMETER_EXTRACTOR
_model_instance: Optional[ModelInstance] = None
@ -179,6 +180,15 @@ class ParameterExtractorNode(LLMNode):
error=str(e),
metadata={},
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)},
error=str(e),
metadata={},
)
error = None
@ -244,6 +254,9 @@ class ParameterExtractorNode(LLMNode):
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
if text is None:
text = ""
return text, usage, tool_call
def _generate_function_call_prompt(
@ -596,9 +609,10 @@ class ParameterExtractorNode(LLMNode):
json_str = extract_json(result[idx:])
if json_str:
try:
return json.loads(json_str)
return cast(dict, json.loads(json_str))
except Exception:
pass
return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
"""
@ -607,13 +621,13 @@ class ParameterExtractorNode(LLMNode):
if not tool_call or not tool_call.function.arguments:
return None
return json.loads(tool_call.function.arguments)
return cast(dict, json.loads(tool_call.function.arguments))
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
"""
Generate default result.
"""
result = {}
result: dict[str, Any] = {}
for parameter in data.parameters:
if parameter.type == "number":
result[parameter.name] = 0
@ -763,7 +777,7 @@ class ParameterExtractorNode(LLMNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ParameterExtractorNodeData,
node_data: ParameterExtractorNodeData, # type: ignore
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -772,6 +786,7 @@ class ParameterExtractorNode(LLMNode):
:param node_data: node data
:return:
"""
# FIXME: fix the type error later
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
if node_data.instruction:

View File

@ -1,3 +1,5 @@
from typing import Any
FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters"
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy.
@ -35,7 +37,7 @@ FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information fr
</structure>
""" # noqa: E501
FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [
FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [
{
"user": {
"query": "What is the weather today in SF?",

View File

@ -1,10 +1,8 @@
import json
import logging
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.llm_generator.output_parser.errors import OutputParserError
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
@ -36,12 +34,9 @@ from .template_prompts import (
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
if TYPE_CHECKING:
from core.file import File
class QuestionClassifierNode(LLMNode):
_node_data_cls = QuestionClassifierNodeData
_node_data_cls = QuestionClassifierNodeData # type: ignore
_node_type = NodeType.QUESTION_CLASSIFIER
def _run(self):
@ -66,7 +61,7 @@ class QuestionClassifierNode(LLMNode):
node_data.instruction = variable_pool.convert_template(
node_data.instruction).text
files: Sequence[File] = (
files = (
self._fetch_files(
selector=node_data.vision.configs.variable_selector,
)
@ -89,37 +84,38 @@ class QuestionClassifierNode(LLMNode):
)
prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template,
user_query=query,
sys_query=query,
memory=memory,
model_config=model_config,
user_files=files,
sys_files=files,
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
)
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
)
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
category_name = node_data.classes[0].name
category_id = node_data.classes[0].id
try:
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
)
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
category_name = node_data.classes[0].name
category_id = node_data.classes[0].id
result_text_json = parse_and_check_json_markdown(result_text, [])
# result_text_json = json.loads(result_text.strip('```JSON\n'))
if "category_name" in result_text_json and "category_id" in result_text_json:
@ -130,10 +126,6 @@ class QuestionClassifierNode(LLMNode):
if category_id_result in category_ids:
category_name = classes_map[category_id_result]
category_id = category_id_result
except OutputParserError:
logging.exception(f"Failed to parse result text: {result_text}")
try:
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
@ -157,7 +149,6 @@ class QuestionClassifierNode(LLMNode):
},
llm_usage=usage,
)
except ValueError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -177,7 +168,7 @@ class QuestionClassifierNode(LLMNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: QuestionClassifierNodeData,
node_data: Any,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -186,6 +177,7 @@ class QuestionClassifierNode(LLMNode):
:param node_data: node data
:return:
"""
node_data = cast(QuestionClassifierNodeData, node_data)
variable_mapping = {"query": node_data.query_variable_selector}
variable_selectors = []
if node_data.instruction:

View File

@ -9,7 +9,6 @@ from core.file import File, FileTransferMethod, FileType
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
@ -58,6 +57,8 @@ class ToolNode(BaseNode[ToolNodeData]):
# get tool runtime
try:
from core.tools.tool_manager import ToolManager
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
)
@ -145,7 +146,7 @@ class ToolNode(BaseNode[ToolNodeData]):
"""
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
result = {}
result: dict[str, Any] = {}
for parameter_name in node_data.tool_parameters:
parameter = tool_parameters_dictionary.get(parameter_name)
if not parameter:

View File

@ -1,4 +1,4 @@
class VariableOperatorNodeError(Exception):
class VariableOperatorNodeError(ValueError):
"""Base error type, don't use directly."""
pass

View File

@ -36,6 +36,8 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
if income_value is None:
raise VariableOperatorNodeError("income value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:

View File

@ -1,5 +1,5 @@
import json
from typing import Any
from typing import Any, cast
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
@ -29,7 +29,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump()
process_data = {}
process_data: dict[str, Any] = {}
# NOTE: This node has no outputs
updated_variables: list[Variable] = []
@ -119,7 +119,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
else:
conversation_id = conversation_id.value
common_helpers.update_conversation_variable(
conversation_id=conversation_id,
conversation_id=cast(str, conversation_id),
variable=variable,
)