Merge branch 'main' into feat/node-execution-retry

This commit is contained in:
Novice Lee
2024-12-18 09:38:18 +08:00
179 changed files with 3286 additions and 1295 deletions

View File

@ -1,4 +1,5 @@
from collections.abc import Generator
from unittest.mock import MagicMock
import google.generativeai.types.generation_types as generation_config_types
import pytest
@ -6,11 +7,10 @@ from _pytest.monkeypatch import MonkeyPatch
from google.ai import generativelanguage as glm
from google.ai.generativelanguage_v1beta.types import content as gag_content
from google.generativeai import GenerativeModel
from google.generativeai.client import _ClientManager, configure
from google.generativeai.types import GenerateContentResponse, content_types, safety_types
from google.generativeai.types.generation_types import BaseGenerateContentResponse
current_api_key = ""
from extensions import ext_redis
class MockGoogleResponseClass:
@ -57,11 +57,6 @@ class MockGoogleClass:
stream: bool = False,
**kwargs,
) -> GenerateContentResponse:
global current_api_key
if len(current_api_key) < 16:
raise Exception("Invalid API key")
if stream:
return MockGoogleClass.generate_content_stream()
@ -75,33 +70,29 @@ class MockGoogleClass:
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
return [MockGoogleResponseCandidateClass()]
def make_client(self: _ClientManager, name: str):
global current_api_key
if name.endswith("_async"):
name = name.split("_")[0]
cls = getattr(glm, name.title() + "ServiceAsyncClient")
else:
cls = getattr(glm, name.title() + "ServiceClient")
def mock_configure(api_key: str):
if len(api_key) < 16:
raise Exception("Invalid API key")
# Attempt to configure using defaults.
if not self.client_config:
configure()
client_options = self.client_config.get("client_options", None)
if client_options:
current_api_key = client_options.api_key
class MockFileState:
def __init__(self):
self.name = "FINISHED"
def nop(self, *args, **kwargs):
pass
original_init = cls.__init__
cls.__init__ = nop
client: glm.GenerativeServiceClient = cls(**self.client_config)
cls.__init__ = original_init
class MockGoogleFile:
def __init__(self, name: str = "mock_file_name"):
self.name = name
self.state = MockFileState()
if not self.default_metadata:
return client
def mock_get_file(name: str) -> MockGoogleFile:
return MockGoogleFile(name)
def mock_upload_file(path: str, mime_type: str) -> MockGoogleFile:
return MockGoogleFile()
@pytest.fixture
@ -109,8 +100,17 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch):
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)
monkeypatch.setattr("google.generativeai.configure", mock_configure)
monkeypatch.setattr("google.generativeai.get_file", mock_get_file)
monkeypatch.setattr("google.generativeai.upload_file", mock_upload_file)
yield
monkeypatch.undo()
@pytest.fixture
def setup_mock_redis() -> None:
ext_redis.redis_client.get = MagicMock(return_value=None)
ext_redis.redis_client.setex = MagicMock(return_value=None)
ext_redis.redis_client.exists = MagicMock(return_value=True)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -12,11 +12,11 @@ def tidb_vector():
return TiDBVector(
collection_name="test_collection",
config=TiDBVectorConfig(
host="xxx.eu-central-1.xxx.aws.tidbcloud.com",
port="4000",
user="xxx.root",
password="xxxxxx",
database="dify",
host="localhost",
port=4000,
user="root",
password="",
database="test",
program_name="langgenius/dify",
),
)
@ -27,35 +27,14 @@ class TiDBVectorTest(AbstractVectorTest):
super().__init__()
self.vector = vector
def text_exists(self):
exist = self.vector.text_exists(self.example_doc_id)
assert exist == False
def search_by_vector(self):
hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 0
def search_by_full_text(self):
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 0
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
assert len(ids) == 1
def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_session):
def test_tidb_vector(setup_mock_redis, tidb_vector):
TiDBVectorTest(vector=tidb_vector).run_all_tests()
@pytest.fixture
def mock_session():
with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.Session", new_callable=MagicMock) as mock_session:
yield mock_session
@pytest.fixture
def setup_tidbvector_mock(tidb_vector, mock_session):
with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine"):
with patch.object(tidb_vector._engine, "connect"):
yield tidb_vector

View File

@ -1,20 +0,0 @@
import pytest
from extensions.storage.opendal_storage import is_r2_endpoint
@pytest.mark.parametrize(
("endpoint", "expected"),
[
("https://bucket.r2.cloudflarestorage.com", True),
("https://custom-domain.r2.cloudflarestorage.com/", True),
("https://bucket.r2.cloudflarestorage.com/path", True),
("https://s3.amazonaws.com", False),
("https://storage.googleapis.com", False),
("http://localhost:9000", False),
("invalid-url", False),
("", False),
],
)
def test_is_r2_endpoint(endpoint: str, expected: bool):
assert is_r2_endpoint(endpoint) == expected

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
from configs import dify_config
from core.app.app_config.entities import ModelConfigEntity
from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig
from core.memory.token_buffer_memory import TokenBufferMemory
@ -126,6 +127,7 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
model_config_mock, _, messages, inputs, context = get_chat_model_args
dify_config.MULTIMODAL_SEND_FORMAT = "url"
files = [
File(
@ -140,7 +142,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url))
mock_get_encoded_string.return_value = ImagePromptMessageContent(
url=str(files[0].remote_url), format="jpg", mime_type="image/jpg"
)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template=messages,
inputs=inputs,

View File

@ -48,7 +48,7 @@ def test_executor_with_json_body_and_number_variable():
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == {}
assert executor.params == []
assert executor.json == {"number": 42}
assert executor.data is None
assert executor.files is None
@ -101,7 +101,7 @@ def test_executor_with_json_body_and_object_variable():
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == {}
assert executor.params == []
assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"}
assert executor.data is None
assert executor.files is None
@ -156,7 +156,7 @@ def test_executor_with_json_body_and_nested_object_variable():
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == {}
assert executor.params == []
assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}}
assert executor.data is None
assert executor.files is None
@ -195,7 +195,7 @@ def test_extract_selectors_from_template_with_newline():
variable_pool=variable_pool,
)
assert executor.params == {"test": "line1\nline2"}
assert executor.params == [("test", "line1\nline2")]
def test_executor_with_form_data():
@ -244,7 +244,7 @@ def test_executor_with_form_data():
assert executor.url == "https://api.example.com/upload"
assert "Content-Type" in executor.headers
assert "multipart/form-data" in executor.headers["Content-Type"]
assert executor.params == {}
assert executor.params == []
assert executor.json is None
assert executor.files is None
assert executor.content is None
@ -265,3 +265,72 @@ def test_executor_with_form_data():
assert "Hello, World!" in raw_request
assert "number_field" in raw_request
assert "42" in raw_request
def test_init_headers():
def create_executor(headers: str) -> Executor:
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers=headers,
params="",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
executor = create_executor("aa\n cc:")
executor._init_headers()
assert executor.headers == {"aa": "", "cc": ""}
executor = create_executor("aa:bb\n cc:dd")
executor._init_headers()
assert executor.headers == {"aa": "bb", "cc": "dd"}
executor = create_executor("aa:bb\n cc:dd\n")
executor._init_headers()
assert executor.headers == {"aa": "bb", "cc": "dd"}
executor = create_executor("aa:bb\n\n cc : dd\n\n")
executor._init_headers()
assert executor.headers == {"aa": "bb", "cc": "dd"}
def test_init_params():
def create_executor(params: str) -> Executor:
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params=params,
authorization=HttpRequestNodeAuthorization(type="no-auth"),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
# Test basic key-value pairs
executor = create_executor("key1:value1\nkey2:value2")
executor._init_params()
assert executor.params == [("key1", "value1"), ("key2", "value2")]
# Test empty values
executor = create_executor("key1:\nkey2:")
executor._init_params()
assert executor.params == [("key1", ""), ("key2", "")]
# Test duplicate keys (which is allowed for params)
executor = create_executor("key1:value1\nkey1:value2")
executor._init_params()
assert executor.params == [("key1", "value1"), ("key1", "value2")]
# Test whitespace handling
executor = create_executor(" key1 : value1 \n key2 : value2 ")
executor._init_params()
assert executor.params == [("key1", "value1"), ("key2", "value2")]
# Test empty lines and extra whitespace
executor = create_executor("key1:value1\n\nkey2:value2\n\n")
executor._init_params()
assert executor.params == [("key1", "value1"), ("key2", "value2")]

View File

@ -14,18 +14,10 @@ from core.workflow.nodes.http_request import (
HttpRequestNodeBody,
HttpRequestNodeData,
)
from core.workflow.nodes.http_request.executor import _plain_text_to_dict
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
def test_plain_text_to_dict():
assert _plain_text_to_dict("aa\n cc:") == {"aa": "", "cc": ""}
assert _plain_text_to_dict("aa:bb\n cc:dd") == {"aa": "bb", "cc": "dd"}
assert _plain_text_to_dict("aa:bb\n cc:dd\n") == {"aa": "bb", "cc": "dd"}
assert _plain_text_to_dict("aa:bb\n\n cc : dd\n\n") == {"aa": "bb", "cc": "dd"}
def test_http_request_node_binary_file(monkeypatch):
data = HttpRequestNodeData(
title="test",

View File

@ -18,8 +18,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
@ -249,8 +248,7 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
# Setup dify config
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
dify_config.MULTIMODAL_SEND_FORMAT = "url"
# Generate fake values for prompt template
fake_assistant_prompt = faker.sentence()
@ -328,6 +326,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
extension=".jpg",
mime_type="image/jpg",
)
],
vision_enabled=True,
@ -361,7 +361,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
UserPromptMessage(
content=[
TextPromptMessageContent(data=fake_query),
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
ImagePromptMessageContent(
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
),
]
),
],
@ -384,7 +386,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
expected_messages=[
UserPromptMessage(
content=[
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
ImagePromptMessageContent(
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
),
]
),
]
@ -397,6 +401,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
extension=".jpg",
mime_type="image/jpg",
)
},
),

View File

@ -1,15 +1,12 @@
import os
from collections.abc import Generator
from pathlib import Path
import pytest
from configs.middleware.storage.opendal_storage_config import OpenDALScheme
from extensions.storage.opendal_storage import OpenDALStorage
from tests.unit_tests.oss.__mock.base import (
get_example_data,
get_example_filename,
get_example_filepath,
get_opendal_bucket,
)
@ -19,7 +16,7 @@ class TestOpenDAL:
def setup_method(self, *args, **kwargs):
"""Executed before each test method."""
self.storage = OpenDALStorage(
scheme=OpenDALScheme.FS,
scheme="fs",
root=get_opendal_bucket(),
)