mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 16:08:04 +08:00
Merge branch 'main' into fix/chore-fix
This commit is contained in:
@ -59,6 +59,8 @@ def test_dify_config(example_env_file):
|
||||
# annotated field with configured value
|
||||
assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30
|
||||
|
||||
assert config.WORKFLOW_PARALLEL_DEPTH_LIMIT == 3
|
||||
|
||||
|
||||
# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected.
|
||||
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.
|
||||
|
||||
@ -1,20 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from extensions.storage.opendal_storage import is_r2_endpoint
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("endpoint", "expected"),
|
||||
[
|
||||
("https://bucket.r2.cloudflarestorage.com", True),
|
||||
("https://custom-domain.r2.cloudflarestorage.com/", True),
|
||||
("https://bucket.r2.cloudflarestorage.com/path", True),
|
||||
("https://s3.amazonaws.com", False),
|
||||
("https://storage.googleapis.com", False),
|
||||
("http://localhost:9000", False),
|
||||
("invalid-url", False),
|
||||
("", False),
|
||||
],
|
||||
)
|
||||
def test_is_r2_endpoint(endpoint: str, expected: bool):
|
||||
assert is_r2_endpoint(endpoint) == expected
|
||||
@ -2,6 +2,8 @@ import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.variables import (
|
||||
ArrayFileVariable,
|
||||
ArrayVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
ObjectVariable,
|
||||
@ -81,3 +83,8 @@ def test_variable_to_object():
|
||||
assert var.to_object() == 3.14
|
||||
var = SecretVariable(name="secret", value="secret_value")
|
||||
assert var.to_object() == "secret_value"
|
||||
|
||||
|
||||
def test_array_file_variable_is_array_variable():
|
||||
var = ArrayFileVariable(name="files", value=[])
|
||||
assert isinstance(var, ArrayVariable)
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@ -126,6 +127,7 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
|
||||
|
||||
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
|
||||
model_config_mock, _, messages, inputs, context = get_chat_model_args
|
||||
dify_config.MULTIMODAL_SEND_FORMAT = "url"
|
||||
|
||||
files = [
|
||||
File(
|
||||
@ -134,13 +136,16 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
storage_key="",
|
||||
)
|
||||
]
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
|
||||
mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url))
|
||||
mock_get_encoded_string.return_value = ImagePromptMessageContent(
|
||||
url=str(files[0].remote_url), format="jpg", mime_type="image/jpg"
|
||||
)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template=messages,
|
||||
inputs=inputs,
|
||||
|
||||
@ -1,34 +1,9 @@
|
||||
import json
|
||||
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType, FileUploadConfig
|
||||
from core.file import File, FileTransferMethod, FileType, FileUploadConfig
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
def test_file_loads_and_dumps():
|
||||
file = File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
)
|
||||
|
||||
file_dict = file.model_dump()
|
||||
assert file_dict["dify_model_identity"] == FILE_MODEL_IDENTITY
|
||||
assert file_dict["type"] == file.type.value
|
||||
assert isinstance(file_dict["type"], str)
|
||||
assert file_dict["transfer_method"] == file.transfer_method.value
|
||||
assert isinstance(file_dict["transfer_method"], str)
|
||||
assert "_extra_config" not in file_dict
|
||||
|
||||
file_obj = File.model_validate(file_dict)
|
||||
assert file_obj.id == file.id
|
||||
assert file_obj.tenant_id == file.tenant_id
|
||||
assert file_obj.type == file.type
|
||||
assert file_obj.transfer_method == file.transfer_method
|
||||
assert file_obj.remote_url == file.remote_url
|
||||
|
||||
|
||||
def test_file_to_dict():
|
||||
file = File(
|
||||
id="file1",
|
||||
@ -36,10 +11,11 @@ def test_file_to_dict():
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
storage_key="storage_key",
|
||||
)
|
||||
|
||||
file_dict = file.to_dict()
|
||||
assert "_extra_config" not in file_dict
|
||||
assert "_storage_key" not in file_dict
|
||||
assert "url" in file_dict
|
||||
|
||||
|
||||
|
||||
@ -488,14 +488,12 @@ def test_run_branch(mock_close, mock_remove):
|
||||
items = []
|
||||
generator = graph_engine.run()
|
||||
for item in generator:
|
||||
# print(type(item), item)
|
||||
items.append(item)
|
||||
|
||||
assert len(items) == 10
|
||||
assert items[3].route_node_state.node_id == "if-else-1"
|
||||
assert items[4].route_node_state.node_id == "if-else-1"
|
||||
assert isinstance(items[5], NodeRunStreamChunkEvent)
|
||||
assert items[5].chunk_content == "1 "
|
||||
assert isinstance(items[6], NodeRunStreamChunkEvent)
|
||||
assert items[6].chunk_content == "takato"
|
||||
assert items[7].route_node_state.node_id == "answer-1"
|
||||
|
||||
@ -51,6 +51,7 @@ def test_http_request_node_binary_file(monkeypatch):
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1111",
|
||||
storage_key="",
|
||||
),
|
||||
),
|
||||
)
|
||||
@ -138,6 +139,7 @@ def test_http_request_node_form_with_file(monkeypatch):
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1111",
|
||||
storage_key="",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@ -18,11 +18,11 @@ from core.model_runtime.entities.message_entities import (
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||
@ -158,6 +158,7 @@ def test_fetch_files_with_file_segment(llm_node):
|
||||
filename="test.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1",
|
||||
storage_key="",
|
||||
)
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
||||
|
||||
@ -174,6 +175,7 @@ def test_fetch_files_with_array_file_segment(llm_node):
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1",
|
||||
storage_key="",
|
||||
),
|
||||
File(
|
||||
id="2",
|
||||
@ -182,6 +184,7 @@ def test_fetch_files_with_array_file_segment(llm_node):
|
||||
filename="test2.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="2",
|
||||
storage_key="",
|
||||
),
|
||||
]
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
||||
@ -225,14 +228,15 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
storage_key="",
|
||||
)
|
||||
]
|
||||
|
||||
fake_query = faker.sentence()
|
||||
|
||||
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||
user_query=fake_query,
|
||||
user_files=files,
|
||||
sys_query=fake_query,
|
||||
sys_files=files,
|
||||
context=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
@ -249,8 +253,7 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
|
||||
|
||||
def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
# Setup dify config
|
||||
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
|
||||
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
|
||||
dify_config.MULTIMODAL_SEND_FORMAT = "url"
|
||||
|
||||
# Generate fake values for prompt template
|
||||
fake_assistant_prompt = faker.sentence()
|
||||
@ -285,8 +288,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
test_scenarios = [
|
||||
LLMNodeTestScenario(
|
||||
description="No files",
|
||||
user_query=fake_query,
|
||||
user_files=[],
|
||||
sys_query=fake_query,
|
||||
sys_files=[],
|
||||
features=[],
|
||||
vision_enabled=False,
|
||||
vision_detail=None,
|
||||
@ -320,14 +323,17 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
),
|
||||
LLMNodeTestScenario(
|
||||
description="User files",
|
||||
user_query=fake_query,
|
||||
user_files=[
|
||||
sys_query=fake_query,
|
||||
sys_files=[
|
||||
File(
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
extension=".jpg",
|
||||
mime_type="image/jpg",
|
||||
storage_key="",
|
||||
)
|
||||
],
|
||||
vision_enabled=True,
|
||||
@ -361,15 +367,17 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data=fake_query),
|
||||
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||
ImagePromptMessageContent(
|
||||
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
|
||||
),
|
||||
]
|
||||
),
|
||||
],
|
||||
),
|
||||
LLMNodeTestScenario(
|
||||
description="Prompt template with variable selector of File",
|
||||
user_query=fake_query,
|
||||
user_files=[],
|
||||
sys_query=fake_query,
|
||||
sys_files=[],
|
||||
vision_enabled=False,
|
||||
vision_detail=fake_vision_detail,
|
||||
features=[ModelFeature.VISION],
|
||||
@ -384,7 +392,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
expected_messages=[
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||
ImagePromptMessageContent(
|
||||
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
@ -397,6 +407,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
extension=".jpg",
|
||||
mime_type="image/jpg",
|
||||
storage_key="",
|
||||
)
|
||||
},
|
||||
),
|
||||
@ -411,8 +424,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
|
||||
# Call the method under test
|
||||
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||
user_query=scenario.user_query,
|
||||
user_files=scenario.user_files,
|
||||
sys_query=scenario.sys_query,
|
||||
sys_files=scenario.sys_files,
|
||||
context=fake_context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
@ -429,3 +442,29 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
assert (
|
||||
prompt_messages == scenario.expected_messages
|
||||
), f"Message content mismatch in scenario: {scenario.description}"
|
||||
|
||||
|
||||
def test_handle_list_messages_basic(llm_node):
|
||||
messages = [
|
||||
LLMNodeChatModelMessage(
|
||||
text="Hello, {#context#}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
]
|
||||
context = "world"
|
||||
jinja2_variables = []
|
||||
variable_pool = llm_node.graph_runtime_state.variable_pool
|
||||
vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
|
||||
|
||||
result = llm_node._handle_list_messages(
|
||||
messages=messages,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail_config,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserPromptMessage)
|
||||
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
|
||||
|
||||
@ -12,8 +12,8 @@ class LLMNodeTestScenario(BaseModel):
|
||||
"""Test scenario for LLM node testing."""
|
||||
|
||||
description: str = Field(..., description="Description of the test scenario")
|
||||
user_query: str = Field(..., description="User query input")
|
||||
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
|
||||
sys_query: str = Field(..., description="User query input")
|
||||
sys_files: Sequence[File] = Field(default_factory=list, description="List of user files")
|
||||
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
|
||||
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
|
||||
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
|
||||
|
||||
@ -2,7 +2,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
)
|
||||
@ -14,7 +13,9 @@ from models.workflow import WorkflowType
|
||||
|
||||
class ContinueOnErrorTestHelper:
|
||||
@staticmethod
|
||||
def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None):
|
||||
def get_code_node(
|
||||
code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {}
|
||||
):
|
||||
"""Helper method to create a code node configuration"""
|
||||
node = {
|
||||
"id": "node",
|
||||
@ -26,6 +27,7 @@ class ContinueOnErrorTestHelper:
|
||||
"code_language": "python3",
|
||||
"code": "\n".join([line[4:] for line in code.split("\n")]),
|
||||
"type": "code",
|
||||
**retry_config,
|
||||
},
|
||||
}
|
||||
if default_value:
|
||||
@ -34,7 +36,10 @@ class ContinueOnErrorTestHelper:
|
||||
|
||||
@staticmethod
|
||||
def get_http_node(
|
||||
error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False
|
||||
error_strategy: str = "fail-branch",
|
||||
default_value: dict | None = None,
|
||||
authorization_success: bool = False,
|
||||
retry_config: dict = {},
|
||||
):
|
||||
"""Helper method to create a http node configuration"""
|
||||
authorization = (
|
||||
@ -65,6 +70,7 @@ class ContinueOnErrorTestHelper:
|
||||
"body": None,
|
||||
"type": "http-request",
|
||||
"error_strategy": error_strategy,
|
||||
**retry_config,
|
||||
},
|
||||
}
|
||||
if default_value:
|
||||
|
||||
@ -248,6 +248,7 @@ def test_array_file_contains_file_name():
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1",
|
||||
filename="ab",
|
||||
storage_key="",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@ -57,6 +57,7 @@ def test_filter_files_by_type(list_operator_node):
|
||||
tenant_id="tenant1",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related1",
|
||||
storage_key="",
|
||||
),
|
||||
File(
|
||||
filename="document1.pdf",
|
||||
@ -64,6 +65,7 @@ def test_filter_files_by_type(list_operator_node):
|
||||
tenant_id="tenant1",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related2",
|
||||
storage_key="",
|
||||
),
|
||||
File(
|
||||
filename="image2.png",
|
||||
@ -71,6 +73,7 @@ def test_filter_files_by_type(list_operator_node):
|
||||
tenant_id="tenant1",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related3",
|
||||
storage_key="",
|
||||
),
|
||||
File(
|
||||
filename="audio1.mp3",
|
||||
@ -78,6 +81,7 @@ def test_filter_files_by_type(list_operator_node):
|
||||
tenant_id="tenant1",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related4",
|
||||
storage_key="",
|
||||
),
|
||||
]
|
||||
variable = ArrayFileSegment(value=files)
|
||||
@ -130,6 +134,7 @@ def test_get_file_extract_string_func():
|
||||
mime_type="text/plain",
|
||||
remote_url="https://example.com/test_file.txt",
|
||||
related_id="test_related_id",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
# Test each case
|
||||
@ -150,6 +155,7 @@ def test_get_file_extract_string_func():
|
||||
mime_type=None,
|
||||
remote_url=None,
|
||||
related_id="test_related_id",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
assert _get_file_extract_string_func(key="name")(empty_file) == ""
|
||||
|
||||
73
api/tests/unit_tests/core/workflow/nodes/test_retry.py
Normal file
73
api/tests/unit_tests/core/workflow/nodes/test_retry.py
Normal file
@ -0,0 +1,73 @@
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunRetryEvent,
|
||||
)
|
||||
from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper
|
||||
|
||||
DEFAULT_VALUE_EDGE = [
|
||||
{
|
||||
"id": "start-source-node-target",
|
||||
"source": "start",
|
||||
"target": "node",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
{
|
||||
"id": "node-source-answer-target",
|
||||
"source": "node",
|
||||
"target": "answer",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_retry_default_value_partial_success():
|
||||
"""retry default value node with partial success status"""
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_http_node(
|
||||
"default-value",
|
||||
[{"key": "result", "type": "string", "value": "http node got error response"}],
|
||||
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||
assert events[-1].outputs == {"answer": "http node got error response"}
|
||||
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
|
||||
assert len(events) == 11
|
||||
|
||||
|
||||
def test_retry_failed():
|
||||
"""retry failed with success status"""
|
||||
error_code = """
|
||||
def main() -> dict:
|
||||
return {
|
||||
"result": 1 / 0,
|
||||
}
|
||||
"""
|
||||
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_http_node(
|
||||
None,
|
||||
None,
|
||||
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
||||
),
|
||||
],
|
||||
}
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
|
||||
assert len(events) == 8
|
||||
@ -19,6 +19,7 @@ def file():
|
||||
related_id="test_related_id",
|
||||
remote_url="test_url",
|
||||
filename="test_file.txt",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -4,8 +4,8 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from oss2 import Bucket
|
||||
from oss2.models import GetObjectResult, PutObjectResult
|
||||
from oss2 import Bucket # type: ignore
|
||||
from oss2.models import GetObjectResult, PutObjectResult # type: ignore
|
||||
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
get_example_bucket,
|
||||
|
||||
@ -3,8 +3,8 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from qcloud_cos import CosS3Client
|
||||
from qcloud_cos.streambody import StreamBody
|
||||
from qcloud_cos import CosS3Client # type: ignore
|
||||
from qcloud_cos.streambody import StreamBody # type: ignore
|
||||
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
get_example_bucket,
|
||||
|
||||
@ -4,8 +4,8 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from tos import TosClientV2
|
||||
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput
|
||||
from tos import TosClientV2 # type: ignore
|
||||
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore
|
||||
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
get_example_bucket,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from oss2 import Auth
|
||||
from oss2 import Auth # type: ignore
|
||||
|
||||
from extensions.storage.aliyun_oss_storage import AliyunOssStorage
|
||||
from tests.unit_tests.oss.__mock.aliyun_oss import setup_aliyun_oss_mock
|
||||
|
||||
@ -1,15 +1,12 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from configs.middleware.storage.opendal_storage_config import OpenDALScheme
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
get_example_data,
|
||||
get_example_filename,
|
||||
get_example_filepath,
|
||||
get_opendal_bucket,
|
||||
)
|
||||
|
||||
@ -19,7 +16,7 @@ class TestOpenDAL:
|
||||
def setup_method(self, *args, **kwargs):
|
||||
"""Executed before each test method."""
|
||||
self.storage = OpenDALStorage(
|
||||
scheme=OpenDALScheme.FS,
|
||||
scheme="fs",
|
||||
root=get_opendal_bucket(),
|
||||
)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from qcloud_cos import CosConfig
|
||||
from qcloud_cos import CosConfig # type: ignore
|
||||
|
||||
from extensions.storage.tencent_cos_storage import TencentCosStorage
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from tos import TosClientV2
|
||||
from tos import TosClientV2 # type: ignore
|
||||
|
||||
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
from yaml import YAMLError
|
||||
from yaml import YAMLError # type: ignore
|
||||
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
|
||||
Reference in New Issue
Block a user