Merge branch 'main' into fix/chore-fix

This commit is contained in:
Yeuoly
2024-12-24 21:28:56 +08:00
734 changed files with 7911 additions and 5007 deletions

View File

@ -59,6 +59,8 @@ def test_dify_config(example_env_file):
# annotated field with configured value
assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30
assert config.WORKFLOW_PARALLEL_DEPTH_LIMIT == 3
# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected.
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.

View File

@ -1,20 +0,0 @@
import pytest
from extensions.storage.opendal_storage import is_r2_endpoint
@pytest.mark.parametrize(
("endpoint", "expected"),
[
("https://bucket.r2.cloudflarestorage.com", True),
("https://custom-domain.r2.cloudflarestorage.com/", True),
("https://bucket.r2.cloudflarestorage.com/path", True),
("https://s3.amazonaws.com", False),
("https://storage.googleapis.com", False),
("http://localhost:9000", False),
("invalid-url", False),
("", False),
],
)
def test_is_r2_endpoint(endpoint: str, expected: bool):
assert is_r2_endpoint(endpoint) == expected

View File

@ -2,6 +2,8 @@ import pytest
from pydantic import ValidationError
from core.variables import (
ArrayFileVariable,
ArrayVariable,
FloatVariable,
IntegerVariable,
ObjectVariable,
@ -81,3 +83,8 @@ def test_variable_to_object():
assert var.to_object() == 3.14
var = SecretVariable(name="secret", value="secret_value")
assert var.to_object() == "secret_value"
def test_array_file_variable_is_array_variable():
var = ArrayFileVariable(name="files", value=[])
assert isinstance(var, ArrayVariable)

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
from configs import dify_config
from core.app.app_config.entities import ModelConfigEntity
from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig
from core.memory.token_buffer_memory import TokenBufferMemory
@ -126,6 +127,7 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
model_config_mock, _, messages, inputs, context = get_chat_model_args
dify_config.MULTIMODAL_SEND_FORMAT = "url"
files = [
File(
@ -134,13 +136,16 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="",
)
]
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url))
mock_get_encoded_string.return_value = ImagePromptMessageContent(
url=str(files[0].remote_url), format="jpg", mime_type="image/jpg"
)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template=messages,
inputs=inputs,

View File

@ -1,34 +1,9 @@
import json
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType, FileUploadConfig
from core.file import File, FileTransferMethod, FileType, FileUploadConfig
from models.workflow import Workflow
def test_file_loads_and_dumps():
file = File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
)
file_dict = file.model_dump()
assert file_dict["dify_model_identity"] == FILE_MODEL_IDENTITY
assert file_dict["type"] == file.type.value
assert isinstance(file_dict["type"], str)
assert file_dict["transfer_method"] == file.transfer_method.value
assert isinstance(file_dict["transfer_method"], str)
assert "_extra_config" not in file_dict
file_obj = File.model_validate(file_dict)
assert file_obj.id == file.id
assert file_obj.tenant_id == file.tenant_id
assert file_obj.type == file.type
assert file_obj.transfer_method == file.transfer_method
assert file_obj.remote_url == file.remote_url
def test_file_to_dict():
file = File(
id="file1",
@ -36,10 +11,11 @@ def test_file_to_dict():
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="storage_key",
)
file_dict = file.to_dict()
assert "_extra_config" not in file_dict
assert "_storage_key" not in file_dict
assert "url" in file_dict

View File

@ -488,14 +488,12 @@ def test_run_branch(mock_close, mock_remove):
items = []
generator = graph_engine.run()
for item in generator:
# print(type(item), item)
items.append(item)
assert len(items) == 10
assert items[3].route_node_state.node_id == "if-else-1"
assert items[4].route_node_state.node_id == "if-else-1"
assert isinstance(items[5], NodeRunStreamChunkEvent)
assert items[5].chunk_content == "1 "
assert isinstance(items[6], NodeRunStreamChunkEvent)
assert items[6].chunk_content == "takato"
assert items[7].route_node_state.node_id == "answer-1"

View File

@ -51,6 +51,7 @@ def test_http_request_node_binary_file(monkeypatch):
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1111",
storage_key="",
),
),
)
@ -138,6 +139,7 @@ def test_http_request_node_form_with_file(monkeypatch):
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1111",
storage_key="",
),
),
)

View File

@ -18,11 +18,11 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment, StringSegment
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
@ -158,6 +158,7 @@ def test_fetch_files_with_file_segment(llm_node):
filename="test.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
storage_key="",
)
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
@ -174,6 +175,7 @@ def test_fetch_files_with_array_file_segment(llm_node):
filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
storage_key="",
),
File(
id="2",
@ -182,6 +184,7 @@ def test_fetch_files_with_array_file_segment(llm_node):
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
storage_key="",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
@ -225,14 +228,15 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
storage_key="",
)
]
fake_query = faker.sentence()
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=fake_query,
user_files=files,
sys_query=fake_query,
sys_files=files,
context=None,
memory=None,
model_config=model_config,
@ -249,8 +253,7 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
# Setup dify config
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
dify_config.MULTIMODAL_SEND_FORMAT = "url"
# Generate fake values for prompt template
fake_assistant_prompt = faker.sentence()
@ -285,8 +288,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
test_scenarios = [
LLMNodeTestScenario(
description="No files",
user_query=fake_query,
user_files=[],
sys_query=fake_query,
sys_files=[],
features=[],
vision_enabled=False,
vision_detail=None,
@ -320,14 +323,17 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
),
LLMNodeTestScenario(
description="User files",
user_query=fake_query,
user_files=[
sys_query=fake_query,
sys_files=[
File(
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
extension=".jpg",
mime_type="image/jpg",
storage_key="",
)
],
vision_enabled=True,
@ -361,15 +367,17 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
UserPromptMessage(
content=[
TextPromptMessageContent(data=fake_query),
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
ImagePromptMessageContent(
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
),
]
),
],
),
LLMNodeTestScenario(
description="Prompt template with variable selector of File",
user_query=fake_query,
user_files=[],
sys_query=fake_query,
sys_files=[],
vision_enabled=False,
vision_detail=fake_vision_detail,
features=[ModelFeature.VISION],
@ -384,7 +392,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
expected_messages=[
UserPromptMessage(
content=[
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
ImagePromptMessageContent(
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
),
]
),
]
@ -397,6 +407,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
extension=".jpg",
mime_type="image/jpg",
storage_key="",
)
},
),
@ -411,8 +424,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
# Call the method under test
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=scenario.user_query,
user_files=scenario.user_files,
sys_query=scenario.sys_query,
sys_files=scenario.sys_files,
context=fake_context,
memory=memory,
model_config=model_config,
@ -429,3 +442,29 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
assert (
prompt_messages == scenario.expected_messages
), f"Message content mismatch in scenario: {scenario.description}"
def test_handle_list_messages_basic(llm_node):
messages = [
LLMNodeChatModelMessage(
text="Hello, {#context#}",
role=PromptMessageRole.USER,
edition_type="basic",
)
]
context = "world"
jinja2_variables = []
variable_pool = llm_node.graph_runtime_state.variable_pool
vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
result = llm_node._handle_list_messages(
messages=messages,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail_config,
)
assert len(result) == 1
assert isinstance(result[0], UserPromptMessage)
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]

View File

@ -12,8 +12,8 @@ class LLMNodeTestScenario(BaseModel):
"""Test scenario for LLM node testing."""
description: str = Field(..., description="Description of the test scenario")
user_query: str = Field(..., description="User query input")
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
sys_query: str = Field(..., description="User query input")
sys_files: Sequence[File] = Field(default_factory=list, description="List of user files")
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")

View File

@ -2,7 +2,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunStreamChunkEvent,
)
@ -14,7 +13,9 @@ from models.workflow import WorkflowType
class ContinueOnErrorTestHelper:
@staticmethod
def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None):
def get_code_node(
code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {}
):
"""Helper method to create a code node configuration"""
node = {
"id": "node",
@ -26,6 +27,7 @@ class ContinueOnErrorTestHelper:
"code_language": "python3",
"code": "\n".join([line[4:] for line in code.split("\n")]),
"type": "code",
**retry_config,
},
}
if default_value:
@ -34,7 +36,10 @@ class ContinueOnErrorTestHelper:
@staticmethod
def get_http_node(
error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False
error_strategy: str = "fail-branch",
default_value: dict | None = None,
authorization_success: bool = False,
retry_config: dict = {},
):
"""Helper method to create a http node configuration"""
authorization = (
@ -65,6 +70,7 @@ class ContinueOnErrorTestHelper:
"body": None,
"type": "http-request",
"error_strategy": error_strategy,
**retry_config,
},
}
if default_value:

View File

@ -248,6 +248,7 @@ def test_array_file_contains_file_name():
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
filename="ab",
storage_key="",
),
],
)

View File

@ -57,6 +57,7 @@ def test_filter_files_by_type(list_operator_node):
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related1",
storage_key="",
),
File(
filename="document1.pdf",
@ -64,6 +65,7 @@ def test_filter_files_by_type(list_operator_node):
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related2",
storage_key="",
),
File(
filename="image2.png",
@ -71,6 +73,7 @@ def test_filter_files_by_type(list_operator_node):
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related3",
storage_key="",
),
File(
filename="audio1.mp3",
@ -78,6 +81,7 @@ def test_filter_files_by_type(list_operator_node):
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related4",
storage_key="",
),
]
variable = ArrayFileSegment(value=files)
@ -130,6 +134,7 @@ def test_get_file_extract_string_func():
mime_type="text/plain",
remote_url="https://example.com/test_file.txt",
related_id="test_related_id",
storage_key="",
)
# Test each case
@ -150,6 +155,7 @@ def test_get_file_extract_string_func():
mime_type=None,
remote_url=None,
related_id="test_related_id",
storage_key="",
)
assert _get_file_extract_string_func(key="name")(empty_file) == ""

View File

@ -0,0 +1,73 @@
from core.workflow.graph_engine.entities.event import (
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
NodeRunRetryEvent,
)
from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper
DEFAULT_VALUE_EDGE = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-source-answer-target",
"source": "node",
"target": "answer",
"sourceHandle": "source",
},
]
def test_retry_default_value_partial_success():
"""retry default value node with partial success status"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
"default-value",
[{"key": "result", "type": "string", "value": "http node got error response"}],
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert events[-1].outputs == {"answer": "http node got error response"}
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
assert len(events) == 11
def test_retry_failed():
"""retry failed with success status"""
error_code = """
def main() -> dict:
return {
"result": 1 / 0,
}
"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
None,
None,
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
assert len(events) == 8

View File

@ -19,6 +19,7 @@ def file():
related_id="test_related_id",
remote_url="test_url",
filename="test_file.txt",
storage_key="",
)

View File

@ -4,8 +4,8 @@ from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from oss2 import Bucket
from oss2.models import GetObjectResult, PutObjectResult
from oss2 import Bucket # type: ignore
from oss2.models import GetObjectResult, PutObjectResult # type: ignore
from tests.unit_tests.oss.__mock.base import (
get_example_bucket,

View File

@ -3,8 +3,8 @@ from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from qcloud_cos import CosS3Client
from qcloud_cos.streambody import StreamBody
from qcloud_cos import CosS3Client # type: ignore
from qcloud_cos.streambody import StreamBody # type: ignore
from tests.unit_tests.oss.__mock.base import (
get_example_bucket,

View File

@ -4,8 +4,8 @@ from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from tos import TosClientV2
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput
from tos import TosClientV2 # type: ignore
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore
from tests.unit_tests.oss.__mock.base import (
get_example_bucket,

View File

@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from oss2 import Auth
from oss2 import Auth # type: ignore
from extensions.storage.aliyun_oss_storage import AliyunOssStorage
from tests.unit_tests.oss.__mock.aliyun_oss import setup_aliyun_oss_mock

View File

@ -1,15 +1,12 @@
import os
from collections.abc import Generator
from pathlib import Path
import pytest
from configs.middleware.storage.opendal_storage_config import OpenDALScheme
from extensions.storage.opendal_storage import OpenDALStorage
from tests.unit_tests.oss.__mock.base import (
get_example_data,
get_example_filename,
get_example_filepath,
get_opendal_bucket,
)
@ -19,7 +16,7 @@ class TestOpenDAL:
def setup_method(self, *args, **kwargs):
"""Executed before each test method."""
self.storage = OpenDALStorage(
scheme=OpenDALScheme.FS,
scheme="fs",
root=get_opendal_bucket(),
)

View File

@ -1,7 +1,7 @@
from unittest.mock import patch
import pytest
from qcloud_cos import CosConfig
from qcloud_cos import CosConfig # type: ignore
from extensions.storage.tencent_cos_storage import TencentCosStorage
from tests.unit_tests.oss.__mock.base import (

View File

@ -1,5 +1,5 @@
import pytest
from tos import TosClientV2
from tos import TosClientV2 # type: ignore
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
from tests.unit_tests.oss.__mock.base import (

View File

@ -1,7 +1,7 @@
from textwrap import dedent
import pytest
from yaml import YAMLError
from yaml import YAMLError # type: ignore
from core.tools.utils.yaml_utils import load_yaml_file