mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
merge main
This commit is contained in:
40
api/tests/unit_tests/core/test_file.py
Normal file
40
api/tests/unit_tests/core/test_file.py
Normal file
@ -0,0 +1,40 @@
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
|
||||
|
||||
def test_file_loads_and_dumps():
|
||||
file = File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
)
|
||||
|
||||
file_dict = file.model_dump()
|
||||
assert file_dict["dify_model_identity"] == FILE_MODEL_IDENTITY
|
||||
assert file_dict["type"] == file.type.value
|
||||
assert isinstance(file_dict["type"], str)
|
||||
assert file_dict["transfer_method"] == file.transfer_method.value
|
||||
assert isinstance(file_dict["transfer_method"], str)
|
||||
assert "_extra_config" not in file_dict
|
||||
|
||||
file_obj = File.model_validate(file_dict)
|
||||
assert file_obj.id == file.id
|
||||
assert file_obj.tenant_id == file.tenant_id
|
||||
assert file_obj.type == file.type
|
||||
assert file_obj.transfer_method == file.transfer_method
|
||||
assert file_obj.remote_url == file.remote_url
|
||||
|
||||
|
||||
def test_file_to_dict():
|
||||
file = File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
)
|
||||
|
||||
file_dict = file.to_dict()
|
||||
assert "_extra_config" not in file_dict
|
||||
assert "url" in file_dict
|
||||
@ -1,56 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
|
||||
|
||||
def test_get_parameter_type():
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == "string"
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == "string"
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == "boolean"
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == "number"
|
||||
with pytest.raises(ValueError):
|
||||
ToolParameterConverter.get_parameter_type("unsupported_type")
|
||||
|
||||
|
||||
def test_cast_parameter_by_type():
|
||||
# string
|
||||
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.STRING) == "test"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == "1.0"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == ""
|
||||
|
||||
# secret input
|
||||
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SECRET_INPUT) == "test"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == "1.0"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == ""
|
||||
|
||||
# select
|
||||
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SELECT) == "test"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == "1.0"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == ""
|
||||
|
||||
# boolean
|
||||
true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"]
|
||||
for value in true_values:
|
||||
assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True
|
||||
|
||||
false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""]
|
||||
for value in false_values:
|
||||
assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False
|
||||
|
||||
# number
|
||||
assert ToolParameterConverter.cast_parameter_by_type("1", ToolParameter.ToolParameterType.NUMBER) == 1
|
||||
assert ToolParameterConverter.cast_parameter_by_type("1.0", ToolParameter.ToolParameterType.NUMBER) == 1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type("-1.0", ToolParameter.ToolParameterType.NUMBER) == -1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None
|
||||
|
||||
# unknown
|
||||
assert ToolParameterConverter.cast_parameter_by_type("1", "unknown_type") == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, "unknown_type") == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None
|
||||
49
api/tests/unit_tests/core/tools/test_tool_parameter_type.py
Normal file
49
api/tests/unit_tests/core/tools/test_tool_parameter_type.py
Normal file
@ -0,0 +1,49 @@
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
|
||||
|
||||
def test_get_parameter_type():
|
||||
assert ToolParameter.ToolParameterType.STRING.as_normal_type() == "string"
|
||||
assert ToolParameter.ToolParameterType.SELECT.as_normal_type() == "string"
|
||||
assert ToolParameter.ToolParameterType.SECRET_INPUT.as_normal_type() == "string"
|
||||
assert ToolParameter.ToolParameterType.BOOLEAN.as_normal_type() == "boolean"
|
||||
assert ToolParameter.ToolParameterType.NUMBER.as_normal_type() == "number"
|
||||
assert ToolParameter.ToolParameterType.FILE.as_normal_type() == "file"
|
||||
assert ToolParameter.ToolParameterType.FILES.as_normal_type() == "files"
|
||||
|
||||
|
||||
def test_cast_parameter_by_type():
|
||||
# string
|
||||
assert ToolParameter.ToolParameterType.STRING.cast_value("test") == "test"
|
||||
assert ToolParameter.ToolParameterType.STRING.cast_value(1) == "1"
|
||||
assert ToolParameter.ToolParameterType.STRING.cast_value(1.0) == "1.0"
|
||||
assert ToolParameter.ToolParameterType.STRING.cast_value(None) == ""
|
||||
|
||||
# secret input
|
||||
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value("test") == "test"
|
||||
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1) == "1"
|
||||
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1.0) == "1.0"
|
||||
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(None) == ""
|
||||
|
||||
# select
|
||||
assert ToolParameter.ToolParameterType.SELECT.cast_value("test") == "test"
|
||||
assert ToolParameter.ToolParameterType.SELECT.cast_value(1) == "1"
|
||||
assert ToolParameter.ToolParameterType.SELECT.cast_value(1.0) == "1.0"
|
||||
assert ToolParameter.ToolParameterType.SELECT.cast_value(None) == ""
|
||||
|
||||
# boolean
|
||||
true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"]
|
||||
for value in true_values:
|
||||
assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is True
|
||||
|
||||
false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""]
|
||||
for value in false_values:
|
||||
assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is False
|
||||
|
||||
# number
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value("1") == 1
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value("1.0") == 1.0
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value("-1.0") == -1.0
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value(1) == 1
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value(1.0) == 1.0
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value(-1.0) == -1.0
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value(None) is None
|
||||
@ -0,0 +1,167 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.variables import StringVariable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData
|
||||
from core.workflow.nodes.document_extractor.node import (
|
||||
_extract_text_from_doc,
|
||||
_extract_text_from_pdf,
|
||||
_extract_text_from_plain_text,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_extractor_node():
|
||||
node_data = DocumentExtractorNodeData(
|
||||
title="Test Document Extractor",
|
||||
variable_selector=["node_id", "variable_name"],
|
||||
)
|
||||
return DocumentExtractorNode(
|
||||
id="test_node_id",
|
||||
config={"id": "test_node_id", "data": node_data.model_dump()},
|
||||
graph_init_params=Mock(),
|
||||
graph=Mock(),
|
||||
graph_runtime_state=Mock(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_runtime_state():
|
||||
return Mock()
|
||||
|
||||
|
||||
def test_run_variable_not_found(document_extractor_node, mock_graph_runtime_state):
|
||||
document_extractor_node.graph_runtime_state = mock_graph_runtime_state
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = None
|
||||
|
||||
result = document_extractor_node._run()
|
||||
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error is not None
|
||||
assert "File variable not found" in result.error
|
||||
|
||||
|
||||
def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_state):
|
||||
document_extractor_node.graph_runtime_state = mock_graph_runtime_state
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = StringVariable(
|
||||
value="Not an ArrayFileSegment", name="test"
|
||||
)
|
||||
|
||||
result = document_extractor_node._run()
|
||||
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error is not None
|
||||
assert "is not an ArrayFileSegment" in result.error
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mime_type", "file_content", "expected_text", "transfer_method", "extension"),
|
||||
[
|
||||
("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"),
|
||||
(
|
||||
"application/pdf",
|
||||
b"%PDF-1.5\n%Test PDF content",
|
||||
["Mocked PDF content"],
|
||||
FileTransferMethod.LOCAL_FILE,
|
||||
".pdf",
|
||||
),
|
||||
(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
b"PK\x03\x04",
|
||||
["Mocked DOCX content"],
|
||||
FileTransferMethod.REMOTE_URL,
|
||||
"",
|
||||
),
|
||||
("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None),
|
||||
],
|
||||
)
|
||||
def test_run_extract_text(
|
||||
document_extractor_node,
|
||||
mock_graph_runtime_state,
|
||||
mime_type,
|
||||
file_content,
|
||||
expected_text,
|
||||
transfer_method,
|
||||
extension,
|
||||
monkeypatch,
|
||||
):
|
||||
document_extractor_node.graph_runtime_state = mock_graph_runtime_state
|
||||
|
||||
mock_file = Mock(spec=File)
|
||||
mock_file.mime_type = mime_type
|
||||
mock_file.transfer_method = transfer_method
|
||||
mock_file.related_id = "test_file_id" if transfer_method == FileTransferMethod.LOCAL_FILE else None
|
||||
mock_file.remote_url = "https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None
|
||||
mock_file.extension = extension
|
||||
|
||||
mock_array_file_segment = Mock(spec=ArrayFileSegment)
|
||||
mock_array_file_segment.value = [mock_file]
|
||||
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_array_file_segment
|
||||
|
||||
mock_download = Mock(return_value=file_content)
|
||||
mock_ssrf_proxy_get = Mock()
|
||||
mock_ssrf_proxy_get.return_value.content = file_content
|
||||
mock_ssrf_proxy_get.return_value.raise_for_status = Mock()
|
||||
|
||||
monkeypatch.setattr("core.file.file_manager.download", mock_download)
|
||||
monkeypatch.setattr("core.helper.ssrf_proxy.get", mock_ssrf_proxy_get)
|
||||
|
||||
if mime_type == "application/pdf":
|
||||
mock_pdf_extract = Mock(return_value=expected_text[0])
|
||||
monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract)
|
||||
elif mime_type.startswith("application/vnd.openxmlformats"):
|
||||
mock_docx_extract = Mock(return_value=expected_text[0])
|
||||
monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_doc", mock_docx_extract)
|
||||
|
||||
result = document_extractor_node._run()
|
||||
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["text"] == expected_text
|
||||
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt")
|
||||
elif transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
mock_download.assert_called_once_with(mock_file)
|
||||
|
||||
|
||||
def test_extract_text_from_plain_text():
|
||||
text = _extract_text_from_plain_text(b"Hello, world!")
|
||||
assert text == "Hello, world!"
|
||||
|
||||
|
||||
@patch("pypdfium2.PdfDocument")
|
||||
def test_extract_text_from_pdf(mock_pdf_document):
|
||||
mock_page = Mock()
|
||||
mock_text_page = Mock()
|
||||
mock_text_page.get_text_range.return_value = "PDF content"
|
||||
mock_page.get_textpage.return_value = mock_text_page
|
||||
mock_pdf_document.return_value = [mock_page]
|
||||
text = _extract_text_from_pdf(b"%PDF-1.5\n%Test PDF content")
|
||||
assert text == "PDF content"
|
||||
|
||||
|
||||
@patch("docx.Document")
|
||||
def test_extract_text_from_doc(mock_document):
|
||||
mock_paragraph1 = Mock()
|
||||
mock_paragraph1.text = "Paragraph 1"
|
||||
mock_paragraph2 = Mock()
|
||||
mock_paragraph2.text = "Paragraph 2"
|
||||
mock_document.return_value.paragraphs = [mock_paragraph1, mock_paragraph2]
|
||||
|
||||
text = _extract_text_from_doc(b"PK\x03\x04")
|
||||
assert text == "Paragraph 1\nParagraph 2"
|
||||
|
||||
|
||||
def test_node_type(document_extractor_node):
|
||||
assert document_extractor_node._node_type == NodeType.DOCUMENT_EXTRACTOR
|
||||
@ -0,0 +1,369 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import FileVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end import EndStreamParam
|
||||
from core.workflow.nodes.http_request import (
|
||||
BodyData,
|
||||
HttpRequestNode,
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeBody,
|
||||
HttpRequestNodeData,
|
||||
)
|
||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
||||
from core.workflow.nodes.http_request.executor import Executor, _plain_text_to_dict
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
def test_plain_text_to_dict():
|
||||
assert _plain_text_to_dict("aa\n cc:") == {"aa": "", "cc": ""}
|
||||
assert _plain_text_to_dict("aa:bb\n cc:dd") == {"aa": "bb", "cc": "dd"}
|
||||
assert _plain_text_to_dict("aa:bb\n cc:dd\n") == {"aa": "bb", "cc": "dd"}
|
||||
assert _plain_text_to_dict("aa:bb\n\n cc : dd\n\n") == {"aa": "bb", "cc": "dd"}
|
||||
|
||||
|
||||
def test_http_request_node_binary_file(monkeypatch):
|
||||
data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="post",
|
||||
url="http://example.org/post",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="binary",
|
||||
data=[
|
||||
BodyData(
|
||||
key="file",
|
||||
type="file",
|
||||
value="",
|
||||
file=["1111", "file"],
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(
|
||||
["1111", "file"],
|
||||
FileVariable(
|
||||
name="file",
|
||||
value=File(
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1111",
|
||||
),
|
||||
),
|
||||
)
|
||||
node = HttpRequestNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.http_request.executor.file_manager.download",
|
||||
lambda *args, **kwargs: b"test",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.helper.ssrf_proxy.post",
|
||||
lambda *args, **kwargs: httpx.Response(200, content=kwargs["content"]),
|
||||
)
|
||||
result = node._run()
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["body"] == "test"
|
||||
|
||||
|
||||
def test_http_request_node_form_with_file(monkeypatch):
|
||||
data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="post",
|
||||
url="http://example.org/post",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="form-data",
|
||||
data=[
|
||||
BodyData(
|
||||
key="file",
|
||||
type="file",
|
||||
file=["1111", "file"],
|
||||
),
|
||||
BodyData(
|
||||
key="name",
|
||||
type="text",
|
||||
value="test",
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(
|
||||
["1111", "file"],
|
||||
FileVariable(
|
||||
name="file",
|
||||
value=File(
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1111",
|
||||
),
|
||||
),
|
||||
)
|
||||
node = HttpRequestNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.http_request.executor.file_manager.download",
|
||||
lambda *args, **kwargs: b"test",
|
||||
)
|
||||
|
||||
def attr_checker(*args, **kwargs):
|
||||
assert kwargs["data"] == {"name": "test"}
|
||||
assert kwargs["files"] == {"file": b"test"}
|
||||
return httpx.Response(200, content=b"")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.helper.ssrf_proxy.post",
|
||||
attr_checker,
|
||||
)
|
||||
result = node._run()
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["body"] == ""
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_number_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "number"], 42)
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Number Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value='{"number": {{#pre_node_id.number#}}}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"number": 42}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '{"number": 42}' in raw_request
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Object Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value="{{#pre_node_id.object#}}",
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '"name": "John Doe"' in raw_request
|
||||
assert '"age": 30' in raw_request
|
||||
assert '"email": "john@example.com"' in raw_request
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_nested_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Nested Object Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value='{"object": {{#pre_node_id.object#}}}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '"object": {' in raw_request
|
||||
assert '"name": "John Doe"' in raw_request
|
||||
assert '"age": 30' in raw_request
|
||||
assert '"email": "john@example.com"' in raw_request
|
||||
111
api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
Normal file
111
api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
Normal file
@ -0,0 +1,111 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file import File
|
||||
from core.file.models import FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.workflow.nodes.list_operator.entities import FilterBy, FilterCondition, Limit, ListOperatorNodeData, OrderBy
|
||||
from core.workflow.nodes.list_operator.node import ListOperatorNode
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def list_operator_node():
|
||||
config = {
|
||||
"variable": ["test_variable"],
|
||||
"filter_by": FilterBy(
|
||||
enabled=True,
|
||||
conditions=[
|
||||
FilterCondition(key="type", comparison_operator="in", value=[FileType.IMAGE, FileType.DOCUMENT])
|
||||
],
|
||||
),
|
||||
"order_by": OrderBy(enabled=False, value="asc"),
|
||||
"limit": Limit(enabled=False, size=0),
|
||||
"title": "Test Title",
|
||||
}
|
||||
node_data = ListOperatorNodeData(**config)
|
||||
node = ListOperatorNode(
|
||||
id="test_node_id",
|
||||
config={
|
||||
"id": "test_node_id",
|
||||
"data": node_data.model_dump(),
|
||||
},
|
||||
graph_init_params=MagicMock(),
|
||||
graph=MagicMock(),
|
||||
graph_runtime_state=MagicMock(),
|
||||
)
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.variable_pool = MagicMock()
|
||||
return node
|
||||
|
||||
|
||||
def test_filter_files_by_type(list_operator_node):
|
||||
# Setup test data
|
||||
files = [
|
||||
File(
|
||||
filename="image1.jpg",
|
||||
type=FileType.IMAGE,
|
||||
tenant_id="tenant1",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related1",
|
||||
),
|
||||
File(
|
||||
filename="document1.pdf",
|
||||
type=FileType.DOCUMENT,
|
||||
tenant_id="tenant1",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related2",
|
||||
),
|
||||
File(
|
||||
filename="image2.png",
|
||||
type=FileType.IMAGE,
|
||||
tenant_id="tenant1",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related3",
|
||||
),
|
||||
File(
|
||||
filename="audio1.mp3",
|
||||
type=FileType.AUDIO,
|
||||
tenant_id="tenant1",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related4",
|
||||
),
|
||||
]
|
||||
variable = ArrayFileSegment(value=files)
|
||||
list_operator_node.graph_runtime_state.variable_pool.get.return_value = variable
|
||||
|
||||
# Run the node
|
||||
result = list_operator_node._run()
|
||||
|
||||
# Verify the result
|
||||
expected_files = [
|
||||
{
|
||||
"filename": "image1.jpg",
|
||||
"type": FileType.IMAGE,
|
||||
"tenant_id": "tenant1",
|
||||
"transfer_method": FileTransferMethod.LOCAL_FILE,
|
||||
"related_id": "related1",
|
||||
},
|
||||
{
|
||||
"filename": "document1.pdf",
|
||||
"type": FileType.DOCUMENT,
|
||||
"tenant_id": "tenant1",
|
||||
"transfer_method": FileTransferMethod.LOCAL_FILE,
|
||||
"related_id": "related2",
|
||||
},
|
||||
{
|
||||
"filename": "image2.png",
|
||||
"type": FileType.IMAGE,
|
||||
"tenant_id": "tenant1",
|
||||
"transfer_method": FileTransferMethod.LOCAL_FILE,
|
||||
"related_id": "related3",
|
||||
},
|
||||
]
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
for expected_file, result_file in zip(expected_files, result.outputs["result"]):
|
||||
assert expected_file["filename"] == result_file.filename
|
||||
assert expected_file["type"] == result_file.type
|
||||
assert expected_file["tenant_id"] == result_file.tenant_id
|
||||
assert expected_file["transfer_method"] == result_file.transfer_method
|
||||
assert expected_file["related_id"] == result_file.related_id
|
||||
@ -0,0 +1,67 @@
|
||||
from core.model_runtime.entities import ImagePromptMessageContent
|
||||
from core.workflow.nodes.question_classifier import QuestionClassifierNodeData
|
||||
|
||||
|
||||
def test_init_question_classifier_node_data():
|
||||
data = {
|
||||
"title": "test classifier node",
|
||||
"query_variable_selector": ["id", "name"],
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
|
||||
"classes": [{"id": "1", "name": "class 1"}],
|
||||
"instruction": "This is a test instruction",
|
||||
"memory": {
|
||||
"role_prefix": {"user": "Human:", "assistant": "AI:"},
|
||||
"window": {"enabled": True, "size": 5},
|
||||
"query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:",
|
||||
},
|
||||
"vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}},
|
||||
}
|
||||
|
||||
node_data = QuestionClassifierNodeData(**data)
|
||||
|
||||
assert node_data.query_variable_selector == ["id", "name"]
|
||||
assert node_data.model.provider == "openai"
|
||||
assert node_data.classes[0].id == "1"
|
||||
assert node_data.instruction == "This is a test instruction"
|
||||
assert node_data.memory is not None
|
||||
assert node_data.memory.role_prefix is not None
|
||||
assert node_data.memory.role_prefix.user == "Human:"
|
||||
assert node_data.memory.role_prefix.assistant == "AI:"
|
||||
assert node_data.memory.window.enabled == True
|
||||
assert node_data.memory.window.size == 5
|
||||
assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:"
|
||||
assert node_data.vision.enabled == True
|
||||
assert node_data.vision.configs.variable_selector == ["image"]
|
||||
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
|
||||
def test_init_question_classifier_node_data_without_vision_config():
|
||||
data = {
|
||||
"title": "test classifier node",
|
||||
"query_variable_selector": ["id", "name"],
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
|
||||
"classes": [{"id": "1", "name": "class 1"}],
|
||||
"instruction": "This is a test instruction",
|
||||
"memory": {
|
||||
"role_prefix": {"user": "Human:", "assistant": "AI:"},
|
||||
"window": {"enabled": True, "size": 5},
|
||||
"query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:",
|
||||
},
|
||||
}
|
||||
|
||||
node_data = QuestionClassifierNodeData(**data)
|
||||
|
||||
assert node_data.query_variable_selector == ["id", "name"]
|
||||
assert node_data.model.provider == "openai"
|
||||
assert node_data.classes[0].id == "1"
|
||||
assert node_data.instruction == "This is a test instruction"
|
||||
assert node_data.memory is not None
|
||||
assert node_data.memory.role_prefix is not None
|
||||
assert node_data.memory.role_prefix.user == "Human:"
|
||||
assert node_data.memory.role_prefix.assistant == "AI:"
|
||||
assert node_data.memory.window.enabled == True
|
||||
assert node_data.memory.window.size == 5
|
||||
assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:"
|
||||
assert node_data.vision.enabled == False
|
||||
assert node_data.vision.configs.variable_selector == ["sys", "files"]
|
||||
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH
|
||||
45
api/tests/unit_tests/core/workflow/test_variable_pool.py
Normal file
45
api/tests/unit_tests/core/workflow/test_variable_pool.py
Normal file
@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import FileSegment, StringSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pool():
|
||||
return VariablePool(system_variables={}, user_inputs={})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file():
|
||||
return File(
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="test_related_id",
|
||||
remote_url="test_url",
|
||||
filename="test_file.txt",
|
||||
)
|
||||
|
||||
|
||||
def test_get_file_attribute(pool, file):
|
||||
# Add a FileSegment to the pool
|
||||
pool.add(("node_1", "file_var"), FileSegment(value=file))
|
||||
|
||||
# Test getting the 'name' attribute of the file
|
||||
result = pool.get(("node_1", "file_var", "name"))
|
||||
|
||||
assert result is not None
|
||||
assert result.value == file.filename
|
||||
|
||||
# Test getting a non-existent attribute
|
||||
result = pool.get(("node_1", "file_var", "non_existent_attr"))
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_use_long_selector(pool):
|
||||
pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value"))
|
||||
|
||||
result = pool.get(("node_1", "part_1", "part_2"))
|
||||
assert result is not None
|
||||
assert result.value == "test_value"
|
||||
@ -0,0 +1,28 @@
|
||||
from core.variables import SecretVariable
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.utils import variable_template_parser
|
||||
|
||||
|
||||
def test_extract_selectors_from_template():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey("user_id"): "fake-user-id",
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[
|
||||
SecretVariable(name="secret_key", value="fake-secret-key"),
|
||||
],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
|
||||
template = (
|
||||
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
|
||||
)
|
||||
selectors = variable_template_parser.extract_selectors_from_template(template)
|
||||
assert selectors == [
|
||||
VariableSelector(variable="#sys.user_id#", value_selector=["sys", "user_id"]),
|
||||
VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]),
|
||||
VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]),
|
||||
]
|
||||
Reference in New Issue
Block a user