merge main

This commit is contained in:
Joel
2024-10-25 11:25:04 +08:00
parent ae00211691
commit bdb990eb90
375 changed files with 18637 additions and 7426 deletions

View File

@ -0,0 +1,40 @@
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
def test_file_loads_and_dumps():
file = File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
)
file_dict = file.model_dump()
assert file_dict["dify_model_identity"] == FILE_MODEL_IDENTITY
assert file_dict["type"] == file.type.value
assert isinstance(file_dict["type"], str)
assert file_dict["transfer_method"] == file.transfer_method.value
assert isinstance(file_dict["transfer_method"], str)
assert "_extra_config" not in file_dict
file_obj = File.model_validate(file_dict)
assert file_obj.id == file.id
assert file_obj.tenant_id == file.tenant_id
assert file_obj.type == file.type
assert file_obj.transfer_method == file.transfer_method
assert file_obj.remote_url == file.remote_url
def test_file_to_dict():
file = File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
)
file_dict = file.to_dict()
assert "_extra_config" not in file_dict
assert "url" in file_dict

View File

@ -1,56 +0,0 @@
import pytest
from core.tools.entities.tool_entities import ToolParameter
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
def test_get_parameter_type():
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == "string"
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == "string"
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == "boolean"
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == "number"
with pytest.raises(ValueError):
ToolParameterConverter.get_parameter_type("unsupported_type")
def test_cast_parameter_by_type():
# string
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.STRING) == "test"
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == "1"
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == "1.0"
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == ""
# secret input
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SECRET_INPUT) == "test"
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == "1"
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == "1.0"
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == ""
# select
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SELECT) == "test"
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == "1"
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == "1.0"
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == ""
# boolean
true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"]
for value in true_values:
assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True
false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""]
for value in false_values:
assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False
# number
assert ToolParameterConverter.cast_parameter_by_type("1", ToolParameter.ToolParameterType.NUMBER) == 1
assert ToolParameterConverter.cast_parameter_by_type("1.0", ToolParameter.ToolParameterType.NUMBER) == 1.0
assert ToolParameterConverter.cast_parameter_by_type("-1.0", ToolParameter.ToolParameterType.NUMBER) == -1.0
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0
assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None
# unknown
assert ToolParameterConverter.cast_parameter_by_type("1", "unknown_type") == "1"
assert ToolParameterConverter.cast_parameter_by_type(1, "unknown_type") == "1"
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None

View File

@ -0,0 +1,49 @@
from core.tools.entities.tool_entities import ToolParameter
def test_get_parameter_type():
assert ToolParameter.ToolParameterType.STRING.as_normal_type() == "string"
assert ToolParameter.ToolParameterType.SELECT.as_normal_type() == "string"
assert ToolParameter.ToolParameterType.SECRET_INPUT.as_normal_type() == "string"
assert ToolParameter.ToolParameterType.BOOLEAN.as_normal_type() == "boolean"
assert ToolParameter.ToolParameterType.NUMBER.as_normal_type() == "number"
assert ToolParameter.ToolParameterType.FILE.as_normal_type() == "file"
assert ToolParameter.ToolParameterType.FILES.as_normal_type() == "files"
def test_cast_parameter_by_type():
# string
assert ToolParameter.ToolParameterType.STRING.cast_value("test") == "test"
assert ToolParameter.ToolParameterType.STRING.cast_value(1) == "1"
assert ToolParameter.ToolParameterType.STRING.cast_value(1.0) == "1.0"
assert ToolParameter.ToolParameterType.STRING.cast_value(None) == ""
# secret input
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value("test") == "test"
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1) == "1"
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1.0) == "1.0"
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(None) == ""
# select
assert ToolParameter.ToolParameterType.SELECT.cast_value("test") == "test"
assert ToolParameter.ToolParameterType.SELECT.cast_value(1) == "1"
assert ToolParameter.ToolParameterType.SELECT.cast_value(1.0) == "1.0"
assert ToolParameter.ToolParameterType.SELECT.cast_value(None) == ""
# boolean
true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"]
for value in true_values:
assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is True
false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""]
for value in false_values:
assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is False
# number
assert ToolParameter.ToolParameterType.NUMBER.cast_value("1") == 1
assert ToolParameter.ToolParameterType.NUMBER.cast_value("1.0") == 1.0
assert ToolParameter.ToolParameterType.NUMBER.cast_value("-1.0") == -1.0
assert ToolParameter.ToolParameterType.NUMBER.cast_value(1) == 1
assert ToolParameter.ToolParameterType.NUMBER.cast_value(1.0) == 1.0
assert ToolParameter.ToolParameterType.NUMBER.cast_value(-1.0) == -1.0
assert ToolParameter.ToolParameterType.NUMBER.cast_value(None) is None

View File

@ -0,0 +1,167 @@
from unittest.mock import Mock, patch
import pytest
from core.file import File, FileTransferMethod
from core.variables import ArrayFileSegment
from core.variables.variables import StringVariable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData
from core.workflow.nodes.document_extractor.node import (
_extract_text_from_doc,
_extract_text_from_pdf,
_extract_text_from_plain_text,
)
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
@pytest.fixture
def document_extractor_node():
node_data = DocumentExtractorNodeData(
title="Test Document Extractor",
variable_selector=["node_id", "variable_name"],
)
return DocumentExtractorNode(
id="test_node_id",
config={"id": "test_node_id", "data": node_data.model_dump()},
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=Mock(),
)
@pytest.fixture
def mock_graph_runtime_state():
return Mock()
def test_run_variable_not_found(document_extractor_node, mock_graph_runtime_state):
document_extractor_node.graph_runtime_state = mock_graph_runtime_state
mock_graph_runtime_state.variable_pool.get.return_value = None
result = document_extractor_node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error is not None
assert "File variable not found" in result.error
def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_state):
document_extractor_node.graph_runtime_state = mock_graph_runtime_state
mock_graph_runtime_state.variable_pool.get.return_value = StringVariable(
value="Not an ArrayFileSegment", name="test"
)
result = document_extractor_node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error is not None
assert "is not an ArrayFileSegment" in result.error
@pytest.mark.parametrize(
("mime_type", "file_content", "expected_text", "transfer_method", "extension"),
[
("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"),
(
"application/pdf",
b"%PDF-1.5\n%Test PDF content",
["Mocked PDF content"],
FileTransferMethod.LOCAL_FILE,
".pdf",
),
(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
b"PK\x03\x04",
["Mocked DOCX content"],
FileTransferMethod.REMOTE_URL,
"",
),
("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None),
],
)
def test_run_extract_text(
document_extractor_node,
mock_graph_runtime_state,
mime_type,
file_content,
expected_text,
transfer_method,
extension,
monkeypatch,
):
document_extractor_node.graph_runtime_state = mock_graph_runtime_state
mock_file = Mock(spec=File)
mock_file.mime_type = mime_type
mock_file.transfer_method = transfer_method
mock_file.related_id = "test_file_id" if transfer_method == FileTransferMethod.LOCAL_FILE else None
mock_file.remote_url = "https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None
mock_file.extension = extension
mock_array_file_segment = Mock(spec=ArrayFileSegment)
mock_array_file_segment.value = [mock_file]
mock_graph_runtime_state.variable_pool.get.return_value = mock_array_file_segment
mock_download = Mock(return_value=file_content)
mock_ssrf_proxy_get = Mock()
mock_ssrf_proxy_get.return_value.content = file_content
mock_ssrf_proxy_get.return_value.raise_for_status = Mock()
monkeypatch.setattr("core.file.file_manager.download", mock_download)
monkeypatch.setattr("core.helper.ssrf_proxy.get", mock_ssrf_proxy_get)
if mime_type == "application/pdf":
mock_pdf_extract = Mock(return_value=expected_text[0])
monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract)
elif mime_type.startswith("application/vnd.openxmlformats"):
mock_docx_extract = Mock(return_value=expected_text[0])
monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_doc", mock_docx_extract)
result = document_extractor_node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["text"] == expected_text
if transfer_method == FileTransferMethod.REMOTE_URL:
mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt")
elif transfer_method == FileTransferMethod.LOCAL_FILE:
mock_download.assert_called_once_with(mock_file)
def test_extract_text_from_plain_text():
text = _extract_text_from_plain_text(b"Hello, world!")
assert text == "Hello, world!"
@patch("pypdfium2.PdfDocument")
def test_extract_text_from_pdf(mock_pdf_document):
mock_page = Mock()
mock_text_page = Mock()
mock_text_page.get_text_range.return_value = "PDF content"
mock_page.get_textpage.return_value = mock_text_page
mock_pdf_document.return_value = [mock_page]
text = _extract_text_from_pdf(b"%PDF-1.5\n%Test PDF content")
assert text == "PDF content"
@patch("docx.Document")
def test_extract_text_from_doc(mock_document):
mock_paragraph1 = Mock()
mock_paragraph1.text = "Paragraph 1"
mock_paragraph2 = Mock()
mock_paragraph2.text = "Paragraph 2"
mock_document.return_value.paragraphs = [mock_paragraph1, mock_paragraph2]
text = _extract_text_from_doc(b"PK\x03\x04")
assert text == "Paragraph 1\nParagraph 2"
def test_node_type(document_extractor_node):
assert document_extractor_node._node_type == NodeType.DOCUMENT_EXTRACTOR

View File

@ -0,0 +1,369 @@
import json
import httpx
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileTransferMethod, FileType
from core.variables import FileVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.http_request import (
BodyData,
HttpRequestNode,
HttpRequestNodeAuthorization,
HttpRequestNodeBody,
HttpRequestNodeData,
)
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
from core.workflow.nodes.http_request.executor import Executor, _plain_text_to_dict
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
def test_plain_text_to_dict():
assert _plain_text_to_dict("aa\n cc:") == {"aa": "", "cc": ""}
assert _plain_text_to_dict("aa:bb\n cc:dd") == {"aa": "bb", "cc": "dd"}
assert _plain_text_to_dict("aa:bb\n cc:dd\n") == {"aa": "bb", "cc": "dd"}
assert _plain_text_to_dict("aa:bb\n\n cc : dd\n\n") == {"aa": "bb", "cc": "dd"}
def test_http_request_node_binary_file(monkeypatch):
data = HttpRequestNodeData(
title="test",
method="post",
url="http://example.org/post",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="",
params="",
body=HttpRequestNodeBody(
type="binary",
data=[
BodyData(
key="file",
type="file",
value="",
file=["1111", "file"],
)
],
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(
["1111", "file"],
FileVariable(
name="file",
value=File(
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1111",
),
),
)
node = HttpRequestNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
)
monkeypatch.setattr(
"core.helper.ssrf_proxy.post",
lambda *args, **kwargs: httpx.Response(200, content=kwargs["content"]),
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["body"] == "test"
def test_http_request_node_form_with_file(monkeypatch):
data = HttpRequestNodeData(
title="test",
method="post",
url="http://example.org/post",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="",
params="",
body=HttpRequestNodeBody(
type="form-data",
data=[
BodyData(
key="file",
type="file",
file=["1111", "file"],
),
BodyData(
key="name",
type="text",
value="test",
),
],
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(
["1111", "file"],
FileVariable(
name="file",
value=File(
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1111",
),
),
)
node = HttpRequestNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
)
def attr_checker(*args, **kwargs):
assert kwargs["data"] == {"name": "test"}
assert kwargs["files"] == {"file": b"test"}
return httpx.Response(200, content=b"")
monkeypatch.setattr(
"core.helper.ssrf_proxy.post",
attr_checker,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["body"] == ""
def test_executor_with_json_body_and_number_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "number"], 42)
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test JSON Body with Number Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="",
body=HttpRequestNodeBody(
type="json",
data=[
BodyData(
key="",
type="text",
value='{"number": {{#pre_node_id.number#}}}',
)
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == {}
assert executor.json == {"number": 42}
assert executor.data is None
assert executor.files is None
assert executor.content is None
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /data HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: application/json" in raw_request
assert '{"number": 42}' in raw_request
def test_executor_with_json_body_and_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test JSON Body with Object Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="",
body=HttpRequestNodeBody(
type="json",
data=[
BodyData(
key="",
type="text",
value="{{#pre_node_id.object#}}",
)
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == {}
assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"}
assert executor.data is None
assert executor.files is None
assert executor.content is None
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /data HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: application/json" in raw_request
assert '"name": "John Doe"' in raw_request
assert '"age": 30' in raw_request
assert '"email": "john@example.com"' in raw_request
def test_executor_with_json_body_and_nested_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test JSON Body with Nested Object Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="",
body=HttpRequestNodeBody(
type="json",
data=[
BodyData(
key="",
type="text",
value='{"object": {{#pre_node_id.object#}}}',
)
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == {}
assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}}
assert executor.data is None
assert executor.files is None
assert executor.content is None
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /data HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: application/json" in raw_request
assert '"object": {' in raw_request
assert '"name": "John Doe"' in raw_request
assert '"age": 30' in raw_request
assert '"email": "john@example.com"' in raw_request

View File

@ -0,0 +1,111 @@
from unittest.mock import MagicMock
import pytest
from core.file import File
from core.file.models import FileTransferMethod, FileType
from core.variables import ArrayFileSegment
from core.workflow.nodes.list_operator.entities import FilterBy, FilterCondition, Limit, ListOperatorNodeData, OrderBy
from core.workflow.nodes.list_operator.node import ListOperatorNode
from models.workflow import WorkflowNodeExecutionStatus
@pytest.fixture
def list_operator_node():
config = {
"variable": ["test_variable"],
"filter_by": FilterBy(
enabled=True,
conditions=[
FilterCondition(key="type", comparison_operator="in", value=[FileType.IMAGE, FileType.DOCUMENT])
],
),
"order_by": OrderBy(enabled=False, value="asc"),
"limit": Limit(enabled=False, size=0),
"title": "Test Title",
}
node_data = ListOperatorNodeData(**config)
node = ListOperatorNode(
id="test_node_id",
config={
"id": "test_node_id",
"data": node_data.model_dump(),
},
graph_init_params=MagicMock(),
graph=MagicMock(),
graph_runtime_state=MagicMock(),
)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.variable_pool = MagicMock()
return node
def test_filter_files_by_type(list_operator_node):
# Setup test data
files = [
File(
filename="image1.jpg",
type=FileType.IMAGE,
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related1",
),
File(
filename="document1.pdf",
type=FileType.DOCUMENT,
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related2",
),
File(
filename="image2.png",
type=FileType.IMAGE,
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related3",
),
File(
filename="audio1.mp3",
type=FileType.AUDIO,
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related4",
),
]
variable = ArrayFileSegment(value=files)
list_operator_node.graph_runtime_state.variable_pool.get.return_value = variable
# Run the node
result = list_operator_node._run()
# Verify the result
expected_files = [
{
"filename": "image1.jpg",
"type": FileType.IMAGE,
"tenant_id": "tenant1",
"transfer_method": FileTransferMethod.LOCAL_FILE,
"related_id": "related1",
},
{
"filename": "document1.pdf",
"type": FileType.DOCUMENT,
"tenant_id": "tenant1",
"transfer_method": FileTransferMethod.LOCAL_FILE,
"related_id": "related2",
},
{
"filename": "image2.png",
"type": FileType.IMAGE,
"tenant_id": "tenant1",
"transfer_method": FileTransferMethod.LOCAL_FILE,
"related_id": "related3",
},
]
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
for expected_file, result_file in zip(expected_files, result.outputs["result"]):
assert expected_file["filename"] == result_file.filename
assert expected_file["type"] == result_file.type
assert expected_file["tenant_id"] == result_file.tenant_id
assert expected_file["transfer_method"] == result_file.transfer_method
assert expected_file["related_id"] == result_file.related_id

View File

@ -0,0 +1,67 @@
from core.model_runtime.entities import ImagePromptMessageContent
from core.workflow.nodes.question_classifier import QuestionClassifierNodeData
def test_init_question_classifier_node_data():
data = {
"title": "test classifier node",
"query_variable_selector": ["id", "name"],
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
"classes": [{"id": "1", "name": "class 1"}],
"instruction": "This is a test instruction",
"memory": {
"role_prefix": {"user": "Human:", "assistant": "AI:"},
"window": {"enabled": True, "size": 5},
"query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:",
},
"vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}},
}
node_data = QuestionClassifierNodeData(**data)
assert node_data.query_variable_selector == ["id", "name"]
assert node_data.model.provider == "openai"
assert node_data.classes[0].id == "1"
assert node_data.instruction == "This is a test instruction"
assert node_data.memory is not None
assert node_data.memory.role_prefix is not None
assert node_data.memory.role_prefix.user == "Human:"
assert node_data.memory.role_prefix.assistant == "AI:"
assert node_data.memory.window.enabled == True
assert node_data.memory.window.size == 5
assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:"
assert node_data.vision.enabled == True
assert node_data.vision.configs.variable_selector == ["image"]
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.LOW
def test_init_question_classifier_node_data_without_vision_config():
data = {
"title": "test classifier node",
"query_variable_selector": ["id", "name"],
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
"classes": [{"id": "1", "name": "class 1"}],
"instruction": "This is a test instruction",
"memory": {
"role_prefix": {"user": "Human:", "assistant": "AI:"},
"window": {"enabled": True, "size": 5},
"query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:",
},
}
node_data = QuestionClassifierNodeData(**data)
assert node_data.query_variable_selector == ["id", "name"]
assert node_data.model.provider == "openai"
assert node_data.classes[0].id == "1"
assert node_data.instruction == "This is a test instruction"
assert node_data.memory is not None
assert node_data.memory.role_prefix is not None
assert node_data.memory.role_prefix.user == "Human:"
assert node_data.memory.role_prefix.assistant == "AI:"
assert node_data.memory.window.enabled == True
assert node_data.memory.window.size == 5
assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:"
assert node_data.vision.enabled == False
assert node_data.vision.configs.variable_selector == ["sys", "files"]
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH

View File

@ -0,0 +1,45 @@
import pytest
from core.file import File, FileTransferMethod, FileType
from core.variables import FileSegment, StringSegment
from core.workflow.entities.variable_pool import VariablePool
@pytest.fixture
def pool():
return VariablePool(system_variables={}, user_inputs={})
@pytest.fixture
def file():
return File(
tenant_id="test_tenant_id",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_related_id",
remote_url="test_url",
filename="test_file.txt",
)
def test_get_file_attribute(pool, file):
# Add a FileSegment to the pool
pool.add(("node_1", "file_var"), FileSegment(value=file))
# Test getting the 'name' attribute of the file
result = pool.get(("node_1", "file_var", "name"))
assert result is not None
assert result.value == file.filename
# Test getting a non-existent attribute
result = pool.get(("node_1", "file_var", "non_existent_attr"))
assert result is None
def test_use_long_selector(pool):
pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value"))
result = pool.get(("node_1", "part_1", "part_2"))
assert result is not None
assert result.value == "test_value"

View File

@ -0,0 +1,28 @@
from core.variables import SecretVariable
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.utils import variable_template_parser
def test_extract_selectors_from_template():
variable_pool = VariablePool(
system_variables={
SystemVariableKey("user_id"): "fake-user-id",
},
user_inputs={},
environment_variables=[
SecretVariable(name="secret_key", value="fake-secret-key"),
],
conversation_variables=[],
)
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
template = (
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
)
selectors = variable_template_parser.extract_selectors_from_template(template)
assert selectors == [
VariableSelector(variable="#sys.user_id#", value_selector=["sys", "user_id"]),
VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]),
VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]),
]

View File

@ -0,0 +1,100 @@
import os
from typing import Union
from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from tos import TosClientV2
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput
class AttrDict(dict):
def __getattr__(self, item):
return self.get(item)
def get_example_bucket() -> str:
return "dify"
def get_example_filename() -> str:
return "test.txt"
def get_example_data() -> bytes:
return b"test"
def get_example_filepath() -> str:
return "/test"
class MockVolcengineTosClass:
def __init__(self, ak="", sk="", endpoint="", region=""):
self.bucket_name = get_example_bucket()
self.key = get_example_filename()
self.content = get_example_data()
self.filepath = get_example_filepath()
self.resp = AttrDict(
{
"x-tos-server-side-encryption": "kms",
"x-tos-server-side-encryption-kms-key-id": "trn:kms:cn-beijing:****:keyrings/ring-test/keys/key-test",
"x-tos-server-side-encryption-customer-algorithm": "AES256",
"x-tos-version-id": "test",
"x-tos-hash-crc64ecma": 123456,
"request_id": "test",
"headers": {
"x-tos-id-2": "test",
"ETag": "123456",
},
"status": 200,
}
)
def put_object(self, bucket: str, key: str, content=None) -> PutObjectOutput:
assert bucket == self.bucket_name
assert key == self.key
assert content == self.content
return PutObjectOutput(self.resp)
def get_object(self, bucket: str, key: str) -> GetObjectOutput:
assert bucket == self.bucket_name
assert key == self.key
get_object_output = MagicMock(GetObjectOutput)
get_object_output.read.return_value = self.content
return get_object_output
def get_object_to_file(self, bucket: str, key: str, file_path: str):
assert bucket == self.bucket_name
assert key == self.key
assert file_path == self.filepath
def head_object(self, bucket: str, key: str) -> HeadObjectOutput:
assert bucket == self.bucket_name
assert key == self.key
return HeadObjectOutput(self.resp)
def delete_object(self, bucket: str, key: str):
assert bucket == self.bucket_name
assert key == self.key
return DeleteObjectOutput(self.resp)
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_volcengine_tos_mock(monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(TosClientV2, "__init__", MockVolcengineTosClass.__init__)
monkeypatch.setattr(TosClientV2, "put_object", MockVolcengineTosClass.put_object)
monkeypatch.setattr(TosClientV2, "get_object", MockVolcengineTosClass.get_object)
monkeypatch.setattr(TosClientV2, "get_object_to_file", MockVolcengineTosClass.get_object_to_file)
monkeypatch.setattr(TosClientV2, "head_object", MockVolcengineTosClass.head_object)
monkeypatch.setattr(TosClientV2, "delete_object", MockVolcengineTosClass.delete_object)
yield
if MOCK:
monkeypatch.undo()

View File

@ -0,0 +1,67 @@
from collections.abc import Generator
from flask import Flask
from tos import TosClientV2
from tos.clientv2 import GetObjectOutput, HeadObjectOutput, PutObjectOutput
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
from tests.unit_tests.oss.__mock.volcengine_tos import (
get_example_bucket,
get_example_data,
get_example_filename,
get_example_filepath,
setup_volcengine_tos_mock,
)
class VolcengineTosTest:
_instance = None
def __new__(cls):
if cls._instance == None:
cls._instance = object.__new__(cls)
return cls._instance
else:
return cls._instance
def __init__(self):
self.storage = VolcengineTosStorage()
self.storage.bucket_name = get_example_bucket()
self.storage.client = TosClientV2(
ak="dify",
sk="dify",
endpoint="https://xxx.volces.com",
region="cn-beijing",
)
def test_save(setup_volcengine_tos_mock):
volc_tos = VolcengineTosTest()
volc_tos.storage.save(get_example_filename(), get_example_data())
def test_load_once(setup_volcengine_tos_mock):
volc_tos = VolcengineTosTest()
assert volc_tos.storage.load_once(get_example_filename()) == get_example_data()
def test_load_stream(setup_volcengine_tos_mock):
volc_tos = VolcengineTosTest()
generator = volc_tos.storage.load_stream(get_example_filename())
assert isinstance(generator, Generator)
assert next(generator) == get_example_data()
def test_download(setup_volcengine_tos_mock):
volc_tos = VolcengineTosTest()
volc_tos.storage.download(get_example_filename(), get_example_filepath())
def test_exists(setup_volcengine_tos_mock):
volc_tos = VolcengineTosTest()
assert volc_tos.storage.exists(get_example_filename())
def test_delete(setup_volcengine_tos_mock):
volc_tos = VolcengineTosTest()
volc_tos.storage.delete(get_example_filename())