mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
Merge branch 'main' into feat/node-execution-retry
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import google.generativeai.types.generation_types as generation_config_types
|
||||
import pytest
|
||||
@ -6,11 +7,10 @@ from _pytest.monkeypatch import MonkeyPatch
|
||||
from google.ai import generativelanguage as glm
|
||||
from google.ai.generativelanguage_v1beta.types import content as gag_content
|
||||
from google.generativeai import GenerativeModel
|
||||
from google.generativeai.client import _ClientManager, configure
|
||||
from google.generativeai.types import GenerateContentResponse, content_types, safety_types
|
||||
from google.generativeai.types.generation_types import BaseGenerateContentResponse
|
||||
|
||||
current_api_key = ""
|
||||
from extensions import ext_redis
|
||||
|
||||
|
||||
class MockGoogleResponseClass:
|
||||
@ -57,11 +57,6 @@ class MockGoogleClass:
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> GenerateContentResponse:
|
||||
global current_api_key
|
||||
|
||||
if len(current_api_key) < 16:
|
||||
raise Exception("Invalid API key")
|
||||
|
||||
if stream:
|
||||
return MockGoogleClass.generate_content_stream()
|
||||
|
||||
@ -75,33 +70,29 @@ class MockGoogleClass:
|
||||
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
|
||||
return [MockGoogleResponseCandidateClass()]
|
||||
|
||||
def make_client(self: _ClientManager, name: str):
|
||||
global current_api_key
|
||||
|
||||
if name.endswith("_async"):
|
||||
name = name.split("_")[0]
|
||||
cls = getattr(glm, name.title() + "ServiceAsyncClient")
|
||||
else:
|
||||
cls = getattr(glm, name.title() + "ServiceClient")
|
||||
def mock_configure(api_key: str):
|
||||
if len(api_key) < 16:
|
||||
raise Exception("Invalid API key")
|
||||
|
||||
# Attempt to configure using defaults.
|
||||
if not self.client_config:
|
||||
configure()
|
||||
|
||||
client_options = self.client_config.get("client_options", None)
|
||||
if client_options:
|
||||
current_api_key = client_options.api_key
|
||||
class MockFileState:
|
||||
def __init__(self):
|
||||
self.name = "FINISHED"
|
||||
|
||||
def nop(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
original_init = cls.__init__
|
||||
cls.__init__ = nop
|
||||
client: glm.GenerativeServiceClient = cls(**self.client_config)
|
||||
cls.__init__ = original_init
|
||||
class MockGoogleFile:
|
||||
def __init__(self, name: str = "mock_file_name"):
|
||||
self.name = name
|
||||
self.state = MockFileState()
|
||||
|
||||
if not self.default_metadata:
|
||||
return client
|
||||
|
||||
def mock_get_file(name: str) -> MockGoogleFile:
|
||||
return MockGoogleFile(name)
|
||||
|
||||
|
||||
def mock_upload_file(path: str, mime_type: str) -> MockGoogleFile:
|
||||
return MockGoogleFile()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -109,8 +100,17 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch):
|
||||
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
|
||||
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
|
||||
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
|
||||
monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)
|
||||
monkeypatch.setattr("google.generativeai.configure", mock_configure)
|
||||
monkeypatch.setattr("google.generativeai.get_file", mock_get_file)
|
||||
monkeypatch.setattr("google.generativeai.upload_file", mock_upload_file)
|
||||
|
||||
yield
|
||||
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mock_redis() -> None:
|
||||
ext_redis.redis_client.get = MagicMock(return_value=None)
|
||||
ext_redis.redis_client.setex = MagicMock(return_value=None)
|
||||
ext_redis.redis_client.exists = MagicMock(return_value=True)
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -12,11 +12,11 @@ def tidb_vector():
|
||||
return TiDBVector(
|
||||
collection_name="test_collection",
|
||||
config=TiDBVectorConfig(
|
||||
host="xxx.eu-central-1.xxx.aws.tidbcloud.com",
|
||||
port="4000",
|
||||
user="xxx.root",
|
||||
password="xxxxxx",
|
||||
database="dify",
|
||||
host="localhost",
|
||||
port=4000,
|
||||
user="root",
|
||||
password="",
|
||||
database="test",
|
||||
program_name="langgenius/dify",
|
||||
),
|
||||
)
|
||||
@ -27,35 +27,14 @@ class TiDBVectorTest(AbstractVectorTest):
|
||||
super().__init__()
|
||||
self.vector = vector
|
||||
|
||||
def text_exists(self):
|
||||
exist = self.vector.text_exists(self.example_doc_id)
|
||||
assert exist == False
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 0
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 0
|
||||
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
|
||||
def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_session):
|
||||
def test_tidb_vector(setup_mock_redis, tidb_vector):
|
||||
TiDBVectorTest(vector=tidb_vector).run_all_tests()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.Session", new_callable=MagicMock) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tidbvector_mock(tidb_vector, mock_session):
|
||||
with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine"):
|
||||
with patch.object(tidb_vector._engine, "connect"):
|
||||
yield tidb_vector
|
||||
|
||||
@ -1,20 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from extensions.storage.opendal_storage import is_r2_endpoint
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("endpoint", "expected"),
|
||||
[
|
||||
("https://bucket.r2.cloudflarestorage.com", True),
|
||||
("https://custom-domain.r2.cloudflarestorage.com/", True),
|
||||
("https://bucket.r2.cloudflarestorage.com/path", True),
|
||||
("https://s3.amazonaws.com", False),
|
||||
("https://storage.googleapis.com", False),
|
||||
("http://localhost:9000", False),
|
||||
("invalid-url", False),
|
||||
("", False),
|
||||
],
|
||||
)
|
||||
def test_is_r2_endpoint(endpoint: str, expected: bool):
|
||||
assert is_r2_endpoint(endpoint) == expected
|
||||
@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@ -126,6 +127,7 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
|
||||
|
||||
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
|
||||
model_config_mock, _, messages, inputs, context = get_chat_model_args
|
||||
dify_config.MULTIMODAL_SEND_FORMAT = "url"
|
||||
|
||||
files = [
|
||||
File(
|
||||
@ -140,7 +142,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
|
||||
mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url))
|
||||
mock_get_encoded_string.return_value = ImagePromptMessageContent(
|
||||
url=str(files[0].remote_url), format="jpg", mime_type="image/jpg"
|
||||
)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template=messages,
|
||||
inputs=inputs,
|
||||
|
||||
@ -48,7 +48,7 @@ def test_executor_with_json_body_and_number_variable():
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.params == []
|
||||
assert executor.json == {"number": 42}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
@ -101,7 +101,7 @@ def test_executor_with_json_body_and_object_variable():
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.params == []
|
||||
assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
@ -156,7 +156,7 @@ def test_executor_with_json_body_and_nested_object_variable():
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.params == []
|
||||
assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
@ -195,7 +195,7 @@ def test_extract_selectors_from_template_with_newline():
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
assert executor.params == {"test": "line1\nline2"}
|
||||
assert executor.params == [("test", "line1\nline2")]
|
||||
|
||||
|
||||
def test_executor_with_form_data():
|
||||
@ -244,7 +244,7 @@ def test_executor_with_form_data():
|
||||
assert executor.url == "https://api.example.com/upload"
|
||||
assert "Content-Type" in executor.headers
|
||||
assert "multipart/form-data" in executor.headers["Content-Type"]
|
||||
assert executor.params == {}
|
||||
assert executor.params == []
|
||||
assert executor.json is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
@ -265,3 +265,72 @@ def test_executor_with_form_data():
|
||||
assert "Hello, World!" in raw_request
|
||||
assert "number_field" in raw_request
|
||||
assert "42" in raw_request
|
||||
|
||||
|
||||
def test_init_headers():
|
||||
def create_executor(headers: str) -> Executor:
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers=headers,
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
|
||||
|
||||
executor = create_executor("aa\n cc:")
|
||||
executor._init_headers()
|
||||
assert executor.headers == {"aa": "", "cc": ""}
|
||||
|
||||
executor = create_executor("aa:bb\n cc:dd")
|
||||
executor._init_headers()
|
||||
assert executor.headers == {"aa": "bb", "cc": "dd"}
|
||||
|
||||
executor = create_executor("aa:bb\n cc:dd\n")
|
||||
executor._init_headers()
|
||||
assert executor.headers == {"aa": "bb", "cc": "dd"}
|
||||
|
||||
executor = create_executor("aa:bb\n\n cc : dd\n\n")
|
||||
executor._init_headers()
|
||||
assert executor.headers == {"aa": "bb", "cc": "dd"}
|
||||
|
||||
|
||||
def test_init_params():
|
||||
def create_executor(params: str) -> Executor:
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params=params,
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
|
||||
|
||||
# Test basic key-value pairs
|
||||
executor = create_executor("key1:value1\nkey2:value2")
|
||||
executor._init_params()
|
||||
assert executor.params == [("key1", "value1"), ("key2", "value2")]
|
||||
|
||||
# Test empty values
|
||||
executor = create_executor("key1:\nkey2:")
|
||||
executor._init_params()
|
||||
assert executor.params == [("key1", ""), ("key2", "")]
|
||||
|
||||
# Test duplicate keys (which is allowed for params)
|
||||
executor = create_executor("key1:value1\nkey1:value2")
|
||||
executor._init_params()
|
||||
assert executor.params == [("key1", "value1"), ("key1", "value2")]
|
||||
|
||||
# Test whitespace handling
|
||||
executor = create_executor(" key1 : value1 \n key2 : value2 ")
|
||||
executor._init_params()
|
||||
assert executor.params == [("key1", "value1"), ("key2", "value2")]
|
||||
|
||||
# Test empty lines and extra whitespace
|
||||
executor = create_executor("key1:value1\n\nkey2:value2\n\n")
|
||||
executor._init_params()
|
||||
assert executor.params == [("key1", "value1"), ("key2", "value2")]
|
||||
|
||||
@ -14,18 +14,10 @@ from core.workflow.nodes.http_request import (
|
||||
HttpRequestNodeBody,
|
||||
HttpRequestNodeData,
|
||||
)
|
||||
from core.workflow.nodes.http_request.executor import _plain_text_to_dict
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
def test_plain_text_to_dict():
|
||||
assert _plain_text_to_dict("aa\n cc:") == {"aa": "", "cc": ""}
|
||||
assert _plain_text_to_dict("aa:bb\n cc:dd") == {"aa": "bb", "cc": "dd"}
|
||||
assert _plain_text_to_dict("aa:bb\n cc:dd\n") == {"aa": "bb", "cc": "dd"}
|
||||
assert _plain_text_to_dict("aa:bb\n\n cc : dd\n\n") == {"aa": "bb", "cc": "dd"}
|
||||
|
||||
|
||||
def test_http_request_node_binary_file(monkeypatch):
|
||||
data = HttpRequestNodeData(
|
||||
title="test",
|
||||
|
||||
@ -18,8 +18,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||
@ -249,8 +248,7 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
|
||||
|
||||
def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
# Setup dify config
|
||||
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
|
||||
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
|
||||
dify_config.MULTIMODAL_SEND_FORMAT = "url"
|
||||
|
||||
# Generate fake values for prompt template
|
||||
fake_assistant_prompt = faker.sentence()
|
||||
@ -328,6 +326,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
extension=".jpg",
|
||||
mime_type="image/jpg",
|
||||
)
|
||||
],
|
||||
vision_enabled=True,
|
||||
@ -361,7 +361,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data=fake_query),
|
||||
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||
ImagePromptMessageContent(
|
||||
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
|
||||
),
|
||||
]
|
||||
),
|
||||
],
|
||||
@ -384,7 +386,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
expected_messages=[
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||
ImagePromptMessageContent(
|
||||
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
@ -397,6 +401,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
extension=".jpg",
|
||||
mime_type="image/jpg",
|
||||
)
|
||||
},
|
||||
),
|
||||
|
||||
@ -1,15 +1,12 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from configs.middleware.storage.opendal_storage_config import OpenDALScheme
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
get_example_data,
|
||||
get_example_filename,
|
||||
get_example_filepath,
|
||||
get_opendal_bucket,
|
||||
)
|
||||
|
||||
@ -19,7 +16,7 @@ class TestOpenDAL:
|
||||
def setup_method(self, *args, **kwargs):
|
||||
"""Executed before each test method."""
|
||||
self.storage = OpenDALStorage(
|
||||
scheme=OpenDALScheme.FS,
|
||||
scheme="fs",
|
||||
root=get_opendal_bucket(),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user