Merge branch 'fix/chore-fix' into dev/plugin-deploy

This commit is contained in:
Yeuoly
2024-12-16 14:29:05 +08:00
371 changed files with 10899 additions and 6959 deletions

View File

@ -7,9 +7,10 @@ env = environs.Env()
class Config:
SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-*************-proxy-search-pub.lindorm.aliyuncs.com:30070")
SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070")
SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN")
SEARCH_PWD = env.str("SEARCH_PWD", "PWD")
SEARCH_PWD = env.str("SEARCH_PWD", "ADMIN")
USING_UGC = env.bool("USING_UGC", True)
class TestLindormVectorStore(AbstractVectorTest):
@ -31,5 +32,27 @@ class TestLindormVectorStore(AbstractVectorTest):
assert ids[0] == self.example_doc_id
def test_lindorm_vector(setup_mock_redis):
class TestLindormVectorStoreUGC(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = LindormVectorStore(
collection_name="ugc_index_test",
config=LindormVectorStoreConfig(
hosts=Config.SEARCH_ENDPOINT,
username=Config.SEARCH_USERNAME,
password=Config.SEARCH_PWD,
using_ugc=Config.USING_UGC,
),
routing_value=self.collection_name,
)
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
assert ids is not None
assert len(ids) == 1
assert ids[0] == self.example_doc_id
def test_lindorm_vector_ugc(setup_mock_redis):
TestLindormVectorStore().run_all_tests()
TestLindormVectorStoreUGC().run_all_tests()

View File

@ -12,11 +12,11 @@ def tidb_vector():
return TiDBVector(
collection_name="test_collection",
config=TiDBVectorConfig(
host="xxx.eu-central-1.xxx.aws.tidbcloud.com",
port="4000",
user="xxx.root",
password="xxxxxx",
database="dify",
host="localhost",
port=4000,
user="root",
password="",
database="test",
program_name="langgenius/dify",
),
)
@ -27,35 +27,14 @@ class TiDBVectorTest(AbstractVectorTest):
super().__init__()
self.vector = vector
def text_exists(self):
exist = self.vector.text_exists(self.example_doc_id)
assert exist == False
def search_by_vector(self):
hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 0
def search_by_full_text(self):
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 0
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
assert len(ids) == 1
def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_session):
def test_tidb_vector(setup_mock_redis, tidb_vector):
TiDBVectorTest(vector=tidb_vector).run_all_tests()
@pytest.fixture
def mock_session():
with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.Session", new_callable=MagicMock) as mock_session:
yield mock_session
@pytest.fixture
def setup_tidbvector_mock(tidb_vector, mock_session):
with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine"):
with patch.object(tidb_vector._engine, "connect"):
yield tidb_vector

View File

@ -0,0 +1,20 @@
import pytest
from extensions.storage.opendal_storage import is_r2_endpoint
@pytest.mark.parametrize(
("endpoint", "expected"),
[
("https://bucket.r2.cloudflarestorage.com", True),
("https://custom-domain.r2.cloudflarestorage.com/", True),
("https://bucket.r2.cloudflarestorage.com/path", True),
("https://s3.amazonaws.com", False),
("https://storage.googleapis.com", False),
("http://localhost:9000", False),
("invalid-url", False),
("", False),
],
)
def test_is_r2_endpoint(endpoint: str, expected: bool):
assert is_r2_endpoint(endpoint) == expected

View File

@ -1,271 +0,0 @@
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.http_request import (
BodyData,
HttpRequestNodeAuthorization,
HttpRequestNodeBody,
HttpRequestNodeData,
)
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
from core.workflow.nodes.http_request.executor import Executor
def test_executor_with_json_body_and_number_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "number"], 42)
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test JSON Body with Number Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="",
body=HttpRequestNodeBody(
type="json",
data=[
BodyData(
key="",
type="text",
value='{"number": {{#pre_node_id.number#}}}',
)
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == {}
assert executor.json == {"number": 42}
assert executor.data is None
assert executor.files is None
assert executor.content is None
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /data HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: application/json" in raw_request
assert '{"number": 42}' in raw_request
def test_executor_with_json_body_and_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {
"name": "John Doe", "age": 30, "email": "john@example.com"})
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test JSON Body with Object Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="",
body=HttpRequestNodeBody(
type="json",
data=[
BodyData(
key="",
type="text",
value="{{#pre_node_id.object#}}",
)
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == {}
assert executor.json == {"name": "John Doe",
"age": 30, "email": "john@example.com"}
assert executor.data is None
assert executor.files is None
assert executor.content is None
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /data HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: application/json" in raw_request
assert '"name": "John Doe"' in raw_request
assert '"age": 30' in raw_request
assert '"email": "john@example.com"' in raw_request
def test_executor_with_json_body_and_nested_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {
"name": "John Doe", "age": 30, "email": "john@example.com"})
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test JSON Body with Nested Object Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="",
body=HttpRequestNodeBody(
type="json",
data=[
BodyData(
key="",
type="text",
value='{"object": {{#pre_node_id.object#}}}',
)
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == {}
assert executor.json == {"object": {
"name": "John Doe", "age": 30, "email": "john@example.com"}}
assert executor.data is None
assert executor.files is None
assert executor.content is None
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /data HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: application/json" in raw_request
assert '"object": {' in raw_request
assert '"name": "John Doe"' in raw_request
assert '"age": 30' in raw_request
assert '"email": "john@example.com"' in raw_request
def test_extract_selectors_from_template_with_newline():
variable_pool = VariablePool()
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
node_data = HttpRequestNodeData(
title="Test JSON Body with Nested Object Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="test: {{#node_id.custom_query#}}",
body=HttpRequestNodeBody(
type="none",
data=[],
),
)
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
assert executor.params == {"test": "line1\nline2"}
def test_executor_with_form_data():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "text_field"], "Hello, World!")
variable_pool.add(["pre_node_id", "number_field"], 42)
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test Form Data",
method="post",
url="https://api.example.com/upload",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: multipart/form-data",
params="",
body=HttpRequestNodeBody(
type="form-data",
data=[
BodyData(
key="text_field",
type="text",
value="{{#pre_node_id.text_field#}}",
),
BodyData(
key="number_field",
type="text",
value="{{#pre_node_id.number_field#}}",
),
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/upload"
assert "Content-Type" in executor.headers
assert "multipart/form-data" in executor.headers["Content-Type"]
assert executor.params == {}
assert executor.json is None
assert executor.files is None
assert executor.content is None
# Check that the form data is correctly loaded in executor.data
assert isinstance(executor.data, dict)
assert "text_field" in executor.data
assert executor.data["text_field"] == "Hello, World!"
assert "number_field" in executor.data
assert executor.data["number_field"] == "42"
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /upload HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: multipart/form-data" in raw_request
assert "text_field" in raw_request
assert "Hello, World!" in raw_request
assert "number_field" in raw_request
assert "42" in raw_request

View File

@ -1,204 +0,0 @@
import httpx
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileTransferMethod, FileType
from core.variables import FileVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.http_request import (
BodyData,
HttpRequestNode,
HttpRequestNodeAuthorization,
HttpRequestNodeBody,
HttpRequestNodeData,
)
from core.workflow.nodes.http_request.executor import _plain_text_to_dict
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
def test_plain_text_to_dict():
assert _plain_text_to_dict("aa\n cc:") == {"aa": "", "cc": ""}
assert _plain_text_to_dict("aa:bb\n cc:dd") == {"aa": "bb", "cc": "dd"}
assert _plain_text_to_dict("aa:bb\n cc:dd\n") == {"aa": "bb", "cc": "dd"}
assert _plain_text_to_dict("aa:bb\n\n cc : dd\n\n") == {
"aa": "bb", "cc": "dd"}
def test_http_request_node_binary_file(monkeypatch):
data = HttpRequestNodeData(
title="test",
method="post",
url="http://example.org/post",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="",
params="",
body=HttpRequestNodeBody(
type="binary",
data=[
BodyData(
key="file",
type="file",
value="",
file=["1111", "file"],
)
],
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(
["1111", "file"],
FileVariable(
name="file",
value=File(
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1111",
),
),
)
node = HttpRequestNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
)
monkeypatch.setattr(
"core.helper.ssrf_proxy.post",
lambda *args, **kwargs: httpx.Response(200, content=kwargs["content"]),
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["body"] == "test"
def test_http_request_node_form_with_file(monkeypatch):
data = HttpRequestNodeData(
title="test",
method="post",
url="http://example.org/post",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="",
params="",
body=HttpRequestNodeBody(
type="form-data",
data=[
BodyData(
key="file",
type="file",
file=["1111", "file"],
),
BodyData(
key="name",
type="text",
value="test",
),
],
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(
["1111", "file"],
FileVariable(
name="file",
value=File(
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1111",
),
),
)
node = HttpRequestNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
)
def attr_checker(*args, **kwargs):
assert kwargs["data"] == {"name": "test"}
assert kwargs["files"] == {
"file": (None, b"test", "application/octet-stream")}
return httpx.Response(200, content=b"")
monkeypatch.setattr(
"core.helper.ssrf_proxy.post",
attr_checker,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["body"] == ""

View File

@ -0,0 +1,502 @@
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunStreamChunkEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine
from models.enums import UserFrom
from models.workflow import WorkflowType
class ContinueOnErrorTestHelper:
@staticmethod
def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a code node configuration"""
node = {
"id": "node",
"data": {
"outputs": {"result": {"type": "number"}},
"error_strategy": error_strategy,
"title": "code",
"variables": [],
"code_language": "python3",
"code": "\n".join([line[4:] for line in code.split("\n")]),
"type": "code",
},
}
if default_value:
node["data"]["default_value"] = default_value
return node
@staticmethod
def get_http_node(
error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False
):
"""Helper method to create a http node configuration"""
authorization = (
{
"type": "api-key",
"config": {
"type": "basic",
"api_key": "ak-xxx",
"header": "api-key",
},
}
if authorization_success
else {
"type": "api-key",
# missing config field
}
)
node = {
"id": "node",
"data": {
"title": "http",
"desc": "",
"method": "get",
"url": "http://example.com",
"authorization": authorization,
"headers": "X-Header:123",
"params": "A:b",
"body": None,
"type": "http-request",
"error_strategy": error_strategy,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node
@staticmethod
def get_error_status_code_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a http node configuration"""
node = {
"id": "node",
"data": {
"type": "http-request",
"title": "HTTP Request",
"desc": "",
"variables": [],
"method": "get",
"url": "https://api.github.com/issues",
"authorization": {"type": "no-auth", "config": None},
"headers": "",
"params": "",
"body": {"type": "none", "data": []},
"timeout": {"max_connect_timeout": 0, "max_read_timeout": 0, "max_write_timeout": 0},
"error_strategy": error_strategy,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node
@staticmethod
def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a tool node configuration"""
node = {
"id": "node",
"data": {
"title": "a",
"desc": "a",
"provider_id": "maths",
"provider_type": "builtin",
"provider_name": "maths",
"tool_name": "eval_expression",
"tool_label": "eval_expression",
"tool_configurations": {},
"tool_parameters": {
"expression": {
"type": "variable",
"value": ["1", "123", "args1"],
}
},
"type": "tool",
"error_strategy": error_strategy,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node
@staticmethod
def get_llm_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a llm node configuration"""
node = {
"id": "node",
"data": {
"title": "123",
"type": "llm",
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
"prompt_template": [
{"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
"error_strategy": error_strategy,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node
@staticmethod
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
"""Helper method to create a graph engine instance for testing"""
graph = Graph.init(graph_config=graph_config)
variable_pool = {
"system_variables": {
SystemVariableKey.QUERY: "clear",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
"user_inputs": user_inputs or {"uid": "takato"},
}
return GraphEngine(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
max_execution_steps=500,
max_execution_time=1200,
)
DEFAULT_VALUE_EDGE = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-source-answer-target",
"source": "node",
"target": "answer",
"sourceHandle": "source",
},
]
FAIL_BRANCH_EDGES = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-true-success-target",
"source": "node",
"target": "success",
"sourceHandle": "source",
},
{
"id": "node-false-error-target",
"source": "node",
"target": "error",
"sourceHandle": "fail-branch",
},
]
def test_code_default_value_continue_on_error():
error_code = """
def main() -> dict:
return {
"result": 1 / 0,
}
"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_code_node(
error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}]
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "132123"} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_code_fail_branch_continue_on_error():
error_code = """
def main() -> dict:
return {
"result": 1 / 0,
}
"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "node node run successfully"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "node node run failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_code_node(error_code),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events
)
def test_http_node_default_value_continue_on_error():
"""Test HTTP node with default value error strategy"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
"default-value", [{"key": "response", "type": "string", "value": "http node got error response"}]
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http node got error response"}
for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_http_node_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "HTTP request failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_http_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_tool_node_default_value_continue_on_error():
"""Test tool node with default value error strategy"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_tool_node(
"default-value", [{"key": "result", "type": "string", "value": "default tool result"}]
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_tool_node_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "tool execute successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "tool execute failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_tool_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_llm_node_default_value_continue_on_error():
"""Test LLM node with default value error strategy"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_llm_node(
"default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}]
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_llm_node_fail_branch_continue_on_error():
"""Test LLM node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "LLM request failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_llm_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_status_code_error_http_node_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_variable_pool_error_type_variable():
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
list(graph_engine.run())
error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"])
error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"])
assert error_message != None
assert error_type.value == "HTTPResponseCodeError"
def test_no_node_in_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES[:-1],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
"id": "success",
},
ContinueOnErrorTestHelper.get_http_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0

View File

@ -6,13 +6,17 @@ from extensions.storage.base_storage import BaseStorage
def get_example_folder() -> str:
return "/dify"
return "~/dify"
def get_example_bucket() -> str:
return "dify"
def get_opendal_bucket() -> str:
return "./dify"
def get_example_filename() -> str:
return "test.txt"
@ -22,14 +26,14 @@ def get_example_data() -> bytes:
def get_example_filepath() -> str:
return "/test"
return "~/test"
class BaseStorageTest:
@pytest.fixture(autouse=True)
def setup_method(self):
def setup_method(self, *args, **kwargs):
"""Should be implemented in child classes to setup specific storage."""
self.storage = BaseStorage()
self.storage: BaseStorage
def test_save(self):
"""Test saving data."""

View File

@ -1,18 +0,0 @@
from collections.abc import Generator
import pytest
from extensions.storage.local_fs_storage import LocalFsStorage
from tests.unit_tests.oss.__mock.base import (
BaseStorageTest,
get_example_folder,
)
from tests.unit_tests.oss.__mock.local import setup_local_fs_mock
class TestLocalFS(BaseStorageTest):
@pytest.fixture(autouse=True)
def setup_method(self, setup_local_fs_mock):
"""Executed before each test method."""
self.storage = LocalFsStorage()
self.storage.folder = get_example_folder()

View File

@ -0,0 +1,88 @@
import os
from collections.abc import Generator
from pathlib import Path
import pytest
from configs.middleware.storage.opendal_storage_config import OpenDALScheme
from extensions.storage.opendal_storage import OpenDALStorage
from tests.unit_tests.oss.__mock.base import (
get_example_data,
get_example_filename,
get_example_filepath,
get_opendal_bucket,
)
class TestOpenDAL:
@pytest.fixture(autouse=True)
def setup_method(self, *args, **kwargs):
"""Executed before each test method."""
self.storage = OpenDALStorage(
scheme=OpenDALScheme.FS,
root=get_opendal_bucket(),
)
@pytest.fixture(scope="class", autouse=True)
def teardown_class(self, request):
"""Clean up after all tests in the class."""
def cleanup():
folder = Path(get_opendal_bucket())
if folder.exists() and folder.is_dir():
for item in folder.iterdir():
if item.is_file():
item.unlink()
elif item.is_dir():
item.rmdir()
folder.rmdir()
return cleanup()
def test_save_and_exists(self):
"""Test saving data and checking existence."""
filename = get_example_filename()
data = get_example_data()
assert not self.storage.exists(filename)
self.storage.save(filename, data)
assert self.storage.exists(filename)
def test_load_once(self):
"""Test loading data once."""
filename = get_example_filename()
data = get_example_data()
self.storage.save(filename, data)
loaded_data = self.storage.load_once(filename)
assert loaded_data == data
def test_load_stream(self):
"""Test loading data as a stream."""
filename = get_example_filename()
data = get_example_data()
self.storage.save(filename, data)
generator = self.storage.load_stream(filename)
assert isinstance(generator, Generator)
assert next(generator) == data
def test_download(self):
"""Test downloading data to a file."""
filename = get_example_filename()
filepath = str(Path(get_opendal_bucket()) / filename)
data = get_example_data()
self.storage.save(filename, data)
self.storage.download(filename, filepath)
def test_delete(self):
"""Test deleting a file."""
filename = get_example_filename()
data = get_example_data()
self.storage.save(filename, data)
assert self.storage.exists(filename)
self.storage.delete(filename)
assert not self.storage.exists(filename)