This commit is contained in:
Joel
2025-07-08 17:25:40 +08:00
160 changed files with 7071 additions and 550 deletions

View File

@ -1,6 +1,7 @@
import os
from flask import Flask
from packaging.version import Version
from yarl import URL
from configs.app_config import DifyConfig
@ -40,6 +41,9 @@ def test_dify_config(monkeypatch):
assert config.WORKFLOW_PARALLEL_DEPTH_LIMIT == 3
# values from pyproject.toml
assert Version(config.project.version) >= Version("1.0.0")
# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected.
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.

View File

@ -0,0 +1 @@
# Unit tests for core ops module

View File

@ -0,0 +1,385 @@
import pytest
from pydantic import ValidationError
from core.ops.entities.config_entity import (
AliyunConfig,
ArizeConfig,
LangfuseConfig,
LangSmithConfig,
OpikConfig,
PhoenixConfig,
TracingProviderEnum,
WeaveConfig,
)
class TestTracingProviderEnum:
"""Test cases for TracingProviderEnum"""
def test_enum_values(self):
"""Test that all expected enum values are present"""
assert TracingProviderEnum.ARIZE == "arize"
assert TracingProviderEnum.PHOENIX == "phoenix"
assert TracingProviderEnum.LANGFUSE == "langfuse"
assert TracingProviderEnum.LANGSMITH == "langsmith"
assert TracingProviderEnum.OPIK == "opik"
assert TracingProviderEnum.WEAVE == "weave"
assert TracingProviderEnum.ALIYUN == "aliyun"
class TestArizeConfig:
"""Test cases for ArizeConfig"""
def test_valid_config(self):
"""Test valid Arize configuration"""
config = ArizeConfig(
api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com"
)
assert config.api_key == "test_key"
assert config.space_id == "test_space"
assert config.project == "test_project"
assert config.endpoint == "https://custom.arize.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = ArizeConfig()
assert config.api_key is None
assert config.space_id is None
assert config.project is None
assert config.endpoint == "https://otlp.arize.com"
def test_project_validation_empty(self):
"""Test project validation with empty value"""
config = ArizeConfig(project="")
assert config.project == "default"
def test_project_validation_none(self):
"""Test project validation with None value"""
config = ArizeConfig(project=None)
assert config.project == "default"
def test_endpoint_validation_empty(self):
"""Test endpoint validation with empty value"""
config = ArizeConfig(endpoint="")
assert config.endpoint == "https://otlp.arize.com"
def test_endpoint_validation_with_path(self):
"""Test endpoint validation normalizes URL by removing path"""
config = ArizeConfig(endpoint="https://custom.arize.com/api/v1")
assert config.endpoint == "https://custom.arize.com"
def test_endpoint_validation_invalid_scheme(self):
"""Test endpoint validation rejects invalid schemes"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
ArizeConfig(endpoint="ftp://invalid.com")
def test_endpoint_validation_no_scheme(self):
"""Test endpoint validation rejects URLs without scheme"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
ArizeConfig(endpoint="invalid.com")
class TestPhoenixConfig:
"""Test cases for PhoenixConfig"""
def test_valid_config(self):
"""Test valid Phoenix configuration"""
config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com")
assert config.api_key == "test_key"
assert config.project == "test_project"
assert config.endpoint == "https://custom.phoenix.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = PhoenixConfig()
assert config.api_key is None
assert config.project is None
assert config.endpoint == "https://app.phoenix.arize.com"
def test_project_validation_empty(self):
"""Test project validation with empty value"""
config = PhoenixConfig(project="")
assert config.project == "default"
def test_endpoint_validation_with_path(self):
"""Test endpoint validation normalizes URL by removing path"""
config = PhoenixConfig(endpoint="https://custom.phoenix.com/api/v1")
assert config.endpoint == "https://custom.phoenix.com"
class TestLangfuseConfig:
"""Test cases for LangfuseConfig"""
def test_valid_config(self):
"""Test valid Langfuse configuration"""
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com")
assert config.public_key == "public_key"
assert config.secret_key == "secret_key"
assert config.host == "https://custom.langfuse.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = LangfuseConfig(public_key="public", secret_key="secret")
assert config.host == "https://api.langfuse.com"
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangfuseConfig()
with pytest.raises(ValidationError):
LangfuseConfig(public_key="public")
with pytest.raises(ValidationError):
LangfuseConfig(secret_key="secret")
def test_host_validation_empty(self):
"""Test host validation with empty value"""
config = LangfuseConfig(public_key="public", secret_key="secret", host="")
assert config.host == "https://api.langfuse.com"
class TestLangSmithConfig:
"""Test cases for LangSmithConfig"""
def test_valid_config(self):
"""Test valid LangSmith configuration"""
config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com")
assert config.api_key == "test_key"
assert config.project == "test_project"
assert config.endpoint == "https://custom.smith.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = LangSmithConfig(api_key="key", project="project")
assert config.endpoint == "https://api.smith.langchain.com"
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangSmithConfig()
with pytest.raises(ValidationError):
LangSmithConfig(api_key="key")
with pytest.raises(ValidationError):
LangSmithConfig(project="project")
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com")
class TestOpikConfig:
"""Test cases for OpikConfig"""
def test_valid_config(self):
"""Test valid Opik configuration"""
config = OpikConfig(
api_key="test_key",
project="test_project",
workspace="test_workspace",
url="https://custom.comet.com/opik/api/",
)
assert config.api_key == "test_key"
assert config.project == "test_project"
assert config.workspace == "test_workspace"
assert config.url == "https://custom.comet.com/opik/api/"
def test_default_values(self):
"""Test default values are set correctly"""
config = OpikConfig()
assert config.api_key is None
assert config.project is None
assert config.workspace is None
assert config.url == "https://www.comet.com/opik/api/"
def test_project_validation_empty(self):
"""Test project validation with empty value"""
config = OpikConfig(project="")
assert config.project == "Default Project"
def test_url_validation_empty(self):
"""Test URL validation with empty value"""
config = OpikConfig(url="")
assert config.url == "https://www.comet.com/opik/api/"
def test_url_validation_missing_suffix(self):
"""Test URL validation requires /api/ suffix"""
with pytest.raises(ValidationError, match="URL should end with /api/"):
OpikConfig(url="https://custom.comet.com/opik/")
def test_url_validation_invalid_scheme(self):
"""Test URL validation rejects invalid schemes"""
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
OpikConfig(url="ftp://custom.comet.com/opik/api/")
class TestWeaveConfig:
"""Test cases for WeaveConfig"""
def test_valid_config(self):
"""Test valid Weave configuration"""
config = WeaveConfig(
api_key="test_key",
entity="test_entity",
project="test_project",
endpoint="https://custom.wandb.ai",
host="https://custom.host.com",
)
assert config.api_key == "test_key"
assert config.entity == "test_entity"
assert config.project == "test_project"
assert config.endpoint == "https://custom.wandb.ai"
assert config.host == "https://custom.host.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = WeaveConfig(api_key="key", project="project")
assert config.entity is None
assert config.endpoint == "https://trace.wandb.ai"
assert config.host is None
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
WeaveConfig()
with pytest.raises(ValidationError):
WeaveConfig(api_key="key")
with pytest.raises(ValidationError):
WeaveConfig(project="project")
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai")
def test_host_validation_optional(self):
"""Test host validation is optional but validates when provided"""
config = WeaveConfig(api_key="key", project="project", host=None)
assert config.host is None
config = WeaveConfig(api_key="key", project="project", host="")
assert config.host == ""
config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com")
assert config.host == "https://valid.host.com"
def test_host_validation_invalid_scheme(self):
"""Test host validation rejects invalid schemes when provided"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com")
class TestAliyunConfig:
"""Test cases for AliyunConfig"""
def test_valid_config(self):
"""Test valid Aliyun configuration"""
config = AliyunConfig(
app_name="test_app",
license_key="test_license_key",
endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com",
)
assert config.app_name == "test_app"
assert config.license_key == "test_license_key"
assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
assert config.app_name == "dify_app"
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
AliyunConfig()
with pytest.raises(ValidationError):
AliyunConfig(license_key="test_license")
with pytest.raises(ValidationError):
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
def test_app_name_validation_empty(self):
"""Test app_name validation with empty value"""
config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
)
assert config.app_name == "dify_app"
def test_endpoint_validation_empty(self):
"""Test endpoint validation with empty value"""
config = AliyunConfig(license_key="test_license", endpoint="")
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
def test_endpoint_validation_with_path(self):
"""Test endpoint validation normalizes URL by removing path"""
config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
)
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
def test_endpoint_validation_invalid_scheme(self):
"""Test endpoint validation rejects invalid schemes"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
def test_endpoint_validation_no_scheme(self):
"""Test endpoint validation rejects URLs without scheme"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
def test_license_key_required(self):
"""Test that license_key is required and cannot be empty"""
with pytest.raises(ValidationError):
AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
class TestConfigIntegration:
"""Integration tests for configuration classes"""
def test_all_configs_can_be_instantiated(self):
"""Test that all config classes can be instantiated with valid data"""
configs = [
ArizeConfig(api_key="key"),
PhoenixConfig(api_key="key"),
LangfuseConfig(public_key="public", secret_key="secret"),
LangSmithConfig(api_key="key", project="project"),
OpikConfig(api_key="key"),
WeaveConfig(api_key="key", project="project"),
AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com"),
]
for config in configs:
assert config is not None
def test_url_normalization_consistency(self):
"""Test that URL normalization works consistently across configs"""
# Test that paths are removed from endpoints
arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test")
phoenix_config = PhoenixConfig(endpoint="https://phoenix.com/api/v2/")
aliyun_config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
)
assert arize_config.endpoint == "https://arize.com"
assert phoenix_config.endpoint == "https://phoenix.com"
assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
def test_project_default_values(self):
"""Test that project default values are set correctly"""
arize_config = ArizeConfig(project="")
phoenix_config = PhoenixConfig(project="")
opik_config = OpikConfig(project="")
aliyun_config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
)
assert arize_config.project == "default"
assert phoenix_config.project == "default"
assert opik_config.project == "Default Project"
assert aliyun_config.app_name == "dify_app"

View File

@ -0,0 +1,138 @@
import pytest
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
class TestValidateUrl:
"""Test cases for validate_url function"""
def test_valid_https_url(self):
"""Test valid HTTPS URL"""
result = validate_url("https://example.com", "https://default.com")
assert result == "https://example.com"
def test_valid_http_url(self):
"""Test valid HTTP URL"""
result = validate_url("http://example.com", "https://default.com")
assert result == "http://example.com"
def test_url_with_path_removed(self):
"""Test that URL path is removed during normalization"""
result = validate_url("https://example.com/api/v1/test", "https://default.com")
assert result == "https://example.com"
def test_url_with_query_removed(self):
"""Test that URL query parameters are removed"""
result = validate_url("https://example.com?param=value", "https://default.com")
assert result == "https://example.com"
def test_url_with_fragment_removed(self):
"""Test that URL fragments are removed"""
result = validate_url("https://example.com#section", "https://default.com")
assert result == "https://example.com"
def test_empty_url_returns_default(self):
"""Test empty URL returns default"""
result = validate_url("", "https://default.com")
assert result == "https://default.com"
def test_none_url_returns_default(self):
"""Test None URL returns default"""
result = validate_url(None, "https://default.com")
assert result == "https://default.com"
def test_whitespace_url_returns_default(self):
"""Test whitespace URL returns default"""
result = validate_url(" ", "https://default.com")
assert result == "https://default.com"
def test_invalid_scheme_raises_error(self):
"""Test invalid scheme raises ValueError"""
with pytest.raises(ValueError, match="URL scheme must be one of"):
validate_url("ftp://example.com", "https://default.com")
def test_no_scheme_raises_error(self):
"""Test URL without scheme raises ValueError"""
with pytest.raises(ValueError, match="URL scheme must be one of"):
validate_url("example.com", "https://default.com")
def test_custom_allowed_schemes(self):
"""Test custom allowed schemes"""
result = validate_url("https://example.com", "https://default.com", allowed_schemes=("https",))
assert result == "https://example.com"
with pytest.raises(ValueError, match="URL scheme must be one of"):
validate_url("http://example.com", "https://default.com", allowed_schemes=("https",))
class TestValidateUrlWithPath:
"""Test cases for validate_url_with_path function"""
def test_valid_url_with_path(self):
"""Test valid URL with path"""
result = validate_url_with_path("https://example.com/api/v1", "https://default.com")
assert result == "https://example.com/api/v1"
def test_valid_url_with_required_suffix(self):
"""Test valid URL with required suffix"""
result = validate_url_with_path("https://example.com/api/", "https://default.com", required_suffix="/api/")
assert result == "https://example.com/api/"
def test_url_without_required_suffix_raises_error(self):
"""Test URL without required suffix raises error"""
with pytest.raises(ValueError, match="URL should end with /api/"):
validate_url_with_path("https://example.com/api", "https://default.com", required_suffix="/api/")
def test_empty_url_returns_default(self):
"""Test empty URL returns default"""
result = validate_url_with_path("", "https://default.com")
assert result == "https://default.com"
def test_none_url_returns_default(self):
"""Test None URL returns default"""
result = validate_url_with_path(None, "https://default.com")
assert result == "https://default.com"
def test_invalid_scheme_raises_error(self):
"""Test invalid scheme raises ValueError"""
with pytest.raises(ValueError, match="URL must start with https:// or http://"):
validate_url_with_path("ftp://example.com", "https://default.com")
def test_no_scheme_raises_error(self):
"""Test URL without scheme raises ValueError"""
with pytest.raises(ValueError, match="URL must start with https:// or http://"):
validate_url_with_path("example.com", "https://default.com")
class TestValidateProjectName:
"""Test cases for validate_project_name function"""
def test_valid_project_name(self):
"""Test valid project name"""
result = validate_project_name("my-project", "default")
assert result == "my-project"
def test_empty_project_name_returns_default(self):
"""Test empty project name returns default"""
result = validate_project_name("", "default")
assert result == "default"
def test_none_project_name_returns_default(self):
"""Test None project name returns default"""
result = validate_project_name(None, "default")
assert result == "default"
def test_whitespace_project_name_returns_default(self):
"""Test whitespace project name returns default"""
result = validate_project_name(" ", "default")
assert result == "default"
def test_project_name_with_whitespace_trimmed(self):
"""Test project name with whitespace is trimmed"""
result = validate_project_name(" my-project ", "default")
assert result == "my-project"
def test_custom_default_name(self):
"""Test custom default name"""
result = validate_project_name("", "Custom Default")
assert result == "Custom Default"

View File

@ -1,3 +1,4 @@
import time
from unittest.mock import patch
import pytest
@ -19,6 +20,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.code.code_node import CodeNode
@ -172,6 +174,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
@ -183,7 +186,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
@ -299,6 +302,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
user_inputs={},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
@ -310,7 +314,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
@ -479,6 +483,7 @@ def test_run_branch(mock_close, mock_remove):
user_inputs={"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
@ -490,7 +495,7 @@ def test_run_branch(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
@ -813,6 +818,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
@ -824,7 +830,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)

View File

@ -1,7 +1,9 @@
import time
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
@ -11,6 +13,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunStreamChunkEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
@ -163,15 +166,16 @@ class ContinueOnErrorTestHelper:
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
"""Helper method to create a graph engine instance for testing"""
graph = Graph.init(graph_config=graph_config)
variable_pool = {
"system_variables": {
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "clear",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
"user_inputs": user_inputs or {"uid": "takato"},
}
user_inputs=user_inputs or {"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
return GraphEngine(
tenant_id="111",
@ -184,7 +188,7 @@ class ContinueOnErrorTestHelper:
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)

View File

@ -0,0 +1,74 @@
import base64
import binascii
import os
import pytest
from libs.password import compare_password, hash_password, valid_password
class TestValidPassword:
"""Test password format validation"""
def test_should_accept_valid_passwords(self):
"""Test accepting valid password formats"""
assert valid_password("password123") == "password123"
assert valid_password("test1234") == "test1234"
assert valid_password("Test123456") == "Test123456"
def test_should_reject_invalid_passwords(self):
"""Test rejecting invalid password formats"""
# Too short
with pytest.raises(ValueError) as exc_info:
valid_password("abc123")
assert "Password must contain letters and numbers" in str(exc_info.value)
# No numbers
with pytest.raises(ValueError):
valid_password("abcdefgh")
# No letters
with pytest.raises(ValueError):
valid_password("12345678")
# Empty
with pytest.raises(ValueError):
valid_password("")
class TestPasswordHashing:
"""Test password hashing and comparison"""
def setup_method(self):
"""Setup test data"""
self.password = "test123password"
self.salt = os.urandom(16)
self.salt_base64 = base64.b64encode(self.salt).decode()
password_hash = hash_password(self.password, self.salt)
self.password_hash_base64 = base64.b64encode(password_hash).decode()
def test_should_verify_correct_password(self):
"""Test correct password verification"""
result = compare_password(self.password, self.password_hash_base64, self.salt_base64)
assert result is True
def test_should_reject_wrong_password(self):
"""Test rejection of incorrect passwords"""
result = compare_password("wrongpassword", self.password_hash_base64, self.salt_base64)
assert result is False
def test_should_handle_invalid_base64(self):
"""Test handling of invalid base64 data"""
# Invalid base64 hash
with pytest.raises(binascii.Error):
compare_password(self.password, "invalid_base64!", self.salt_base64)
# Invalid base64 salt
with pytest.raises(binascii.Error):
compare_password(self.password, self.password_hash_base64, "invalid_base64!")
def test_should_be_case_sensitive(self):
"""Test password case sensitivity"""
result = compare_password(self.password.upper(), self.password_hash_base64, self.salt_base64)
assert result is False

View File

@ -0,0 +1,465 @@
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
LLMUsage,
)
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage:
"""Create a mock LLMUsage with all required fields"""
return LLMUsage(
prompt_tokens=prompt_tokens,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1"),
prompt_price=Decimal(str(prompt_tokens)) * Decimal("0.001"),
completion_tokens=completion_tokens,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1"),
completion_price=Decimal(str(completion_tokens)) * Decimal("0.002"),
total_tokens=prompt_tokens + completion_tokens,
total_price=Decimal(str(prompt_tokens)) * Decimal("0.001") + Decimal(str(completion_tokens)) * Decimal("0.002"),
currency="USD",
latency=1.5,
)
def get_model_entity(provider: str, model_name: str, support_structure_output: bool = False) -> AIModelEntity:
"""Create a mock AIModelEntity for testing"""
model_schema = MagicMock()
model_schema.model = model_name
model_schema.provider = provider
model_schema.model_type = ModelType.LLM
model_schema.model_provider = provider
model_schema.model_name = model_name
model_schema.support_structure_output = support_structure_output
model_schema.parameter_rules = []
return model_schema
def get_model_instance() -> MagicMock:
"""Create a mock ModelInstance for testing"""
mock_instance = MagicMock()
mock_instance.provider = "openai"
mock_instance.credentials = {}
return mock_instance
def test_structured_output_parser():
"""Test cases for invoke_llm_with_structured_output function"""
testcases = [
# Test case 1: Model with native structured output support, non-streaming
{
"name": "native_structured_output_non_streaming",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
"stream": False,
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
"expected_llm_response": LLMResult(
model="gpt-4o",
message=AssistantPromptMessage(content='{"name": "test"}'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
),
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
# Test case 2: Model with native structured output support, streaming
{
"name": "native_structured_output_streaming",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
"stream": True,
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
"expected_llm_response": [
LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content='{"name":'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
),
),
LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=' "test"}'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
),
),
],
"expected_result_type": "generator",
"should_raise": False,
},
# Test case 3: Model without native structured output support, non-streaming
{
"name": "prompt_based_structured_output_non_streaming",
"provider": "anthropic",
"model_name": "claude-3-sonnet",
"support_structure_output": False,
"stream": False,
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
"expected_llm_response": LLMResult(
model="claude-3-sonnet",
message=AssistantPromptMessage(content='{"answer": "test response"}'),
usage=create_mock_usage(prompt_tokens=15, completion_tokens=8),
),
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
# Test case 4: Model without native structured output support, streaming
{
"name": "prompt_based_structured_output_streaming",
"provider": "anthropic",
"model_name": "claude-3-sonnet",
"support_structure_output": False,
"stream": True,
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
"expected_llm_response": [
LLMResultChunk(
model="claude-3-sonnet",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content='{"answer": "test'),
usage=create_mock_usage(prompt_tokens=15, completion_tokens=3),
),
),
LLMResultChunk(
model="claude-3-sonnet",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=' response"}'),
usage=create_mock_usage(prompt_tokens=15, completion_tokens=5),
),
),
],
"expected_result_type": "generator",
"should_raise": False,
},
# Test case 5: Streaming with list content
{
"name": "streaming_with_list_content",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
"stream": True,
"json_schema": {"type": "object", "properties": {"data": {"type": "string"}}},
"expected_llm_response": [
LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=[
TextPromptMessageContent(data='{"data":'),
]
),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
),
),
LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=[
TextPromptMessageContent(data=' "value"}'),
]
),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
),
),
],
"expected_result_type": "generator",
"should_raise": False,
},
# Test case 6: Error case - non-string LLM response content (non-streaming)
{
"name": "error_non_string_content_non_streaming",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
"stream": False,
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
"expected_llm_response": LLMResult(
model="gpt-4o",
message=AssistantPromptMessage(content=None), # Non-string content
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
),
"expected_result_type": None,
"should_raise": True,
"expected_error": OutputParserError,
},
# Test case 7: JSON repair scenario
{
"name": "json_repair_scenario",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
"stream": False,
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
"expected_llm_response": LLMResult(
model="gpt-4o",
message=AssistantPromptMessage(content='{"name": "test"'), # Invalid JSON - missing closing brace
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
),
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
# Test case 8: Model with parameter rules for response format
{
"name": "model_with_parameter_rules",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
"stream": False,
"json_schema": {"type": "object", "properties": {"result": {"type": "string"}}},
"parameter_rules": [
MagicMock(name="response_format", options=["json_schema"], required=False),
],
"expected_llm_response": LLMResult(
model="gpt-4o",
message=AssistantPromptMessage(content='{"result": "success"}'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
),
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
# Test case 9: Model without native support but with JSON response format rules
{
"name": "non_native_with_json_rules",
"provider": "anthropic",
"model_name": "claude-3-sonnet",
"support_structure_output": False,
"stream": False,
"json_schema": {"type": "object", "properties": {"output": {"type": "string"}}},
"parameter_rules": [
MagicMock(name="response_format", options=["JSON"], required=False),
],
"expected_llm_response": LLMResult(
model="claude-3-sonnet",
message=AssistantPromptMessage(content='{"output": "result"}'),
usage=create_mock_usage(prompt_tokens=15, completion_tokens=8),
),
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
]
for case in testcases:
print(f"Running test case: {case['name']}")
# Setup model entity
model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"])
# Add parameter rules if specified
if "parameter_rules" in case:
model_schema.parameter_rules = case["parameter_rules"]
# Setup model instance
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = case["expected_llm_response"]
# Setup prompt messages
prompt_messages = [
SystemPromptMessage(content="You are a helpful assistant."),
UserPromptMessage(content="Generate a response according to the schema."),
]
if case["should_raise"]:
# Test error cases
with pytest.raises(case["expected_error"]): # noqa: PT012
if case["stream"]:
result_generator = invoke_llm_with_structured_output(
provider=case["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=case["json_schema"],
stream=case["stream"],
)
# Consume the generator to trigger the error
list(result_generator)
else:
invoke_llm_with_structured_output(
provider=case["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=case["json_schema"],
stream=case["stream"],
)
else:
# Test successful cases
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
# Configure json_repair mock for cases that need it
if case["name"] == "json_repair_scenario":
mock_json_repair.return_value = {"name": "test"}
result = invoke_llm_with_structured_output(
provider=case["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=case["json_schema"],
stream=case["stream"],
model_parameters={"temperature": 0.7, "max_tokens": 100},
user="test_user",
)
if case["expected_result_type"] == "generator":
# Test streaming results
assert hasattr(result, "__iter__")
chunks = list(result)
assert len(chunks) > 0
# Verify all chunks are LLMResultChunkWithStructuredOutput
for chunk in chunks[:-1]: # All except last
assert isinstance(chunk, LLMResultChunkWithStructuredOutput)
assert chunk.model == case["model_name"]
# Last chunk should have structured output
last_chunk = chunks[-1]
assert isinstance(last_chunk, LLMResultChunkWithStructuredOutput)
assert last_chunk.structured_output is not None
assert isinstance(last_chunk.structured_output, dict)
else:
# Test non-streaming results
assert isinstance(result, case["expected_result_type"])
assert result.model == case["model_name"]
assert result.structured_output is not None
assert isinstance(result.structured_output, dict)
# Verify model_instance.invoke_llm was called with correct parameters
model_instance.invoke_llm.assert_called_once()
call_args = model_instance.invoke_llm.call_args
assert call_args.kwargs["stream"] == case["stream"]
assert call_args.kwargs["user"] == "test_user"
assert "temperature" in call_args.kwargs["model_parameters"]
assert "max_tokens" in call_args.kwargs["model_parameters"]
def test_parse_structured_output_edge_cases():
"""Test edge cases for structured output parsing"""
# Test case with list that contains dict (reasoning model scenario)
testcase_list_with_dict = {
"name": "list_with_dict_parsing",
"provider": "deepseek",
"model_name": "deepseek-r1",
"support_structure_output": False,
"stream": False,
"json_schema": {"type": "object", "properties": {"thought": {"type": "string"}}},
"expected_llm_response": LLMResult(
model="deepseek-r1",
message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
),
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
}
# Setup for list parsing test
model_schema = get_model_entity(
testcase_list_with_dict["provider"],
testcase_list_with_dict["model_name"],
testcase_list_with_dict["support_structure_output"],
)
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = testcase_list_with_dict["expected_llm_response"]
prompt_messages = [UserPromptMessage(content="Test reasoning")]
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
# Mock json_repair to return a list with dict
mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"]
result = invoke_llm_with_structured_output(
provider=testcase_list_with_dict["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=testcase_list_with_dict["json_schema"],
stream=testcase_list_with_dict["stream"],
)
assert isinstance(result, LLMResultWithStructuredOutput)
assert result.structured_output == {"thought": "reasoning process"}
def test_model_specific_schema_preparation():
"""Test schema preparation for different model types"""
# Test Gemini model
gemini_case = {
"provider": "google",
"model_name": "gemini-pro",
"support_structure_output": True,
"stream": False,
"json_schema": {"type": "object", "properties": {"result": {"type": "boolean"}}, "additionalProperties": False},
}
model_schema = get_model_entity(
gemini_case["provider"], gemini_case["model_name"], gemini_case["support_structure_output"]
)
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = LLMResult(
model="gemini-pro",
message=AssistantPromptMessage(content='{"result": "true"}'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
)
prompt_messages = [UserPromptMessage(content="Test")]
result = invoke_llm_with_structured_output(
provider=gemini_case["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=gemini_case["json_schema"],
stream=gemini_case["stream"],
)
assert isinstance(result, LLMResultWithStructuredOutput)
# Verify model_instance.invoke_llm was called and check the schema preparation
model_instance.invoke_llm.assert_called_once()
call_args = model_instance.invoke_llm.call_args
# For Gemini, the schema should not have additionalProperties and boolean should be converted to string
assert "json_schema" in call_args.kwargs["model_parameters"]