mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
merge
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
import os
|
||||
|
||||
from flask import Flask
|
||||
from packaging.version import Version
|
||||
from yarl import URL
|
||||
|
||||
from configs.app_config import DifyConfig
|
||||
@ -40,6 +41,9 @@ def test_dify_config(monkeypatch):
|
||||
|
||||
assert config.WORKFLOW_PARALLEL_DEPTH_LIMIT == 3
|
||||
|
||||
# values from pyproject.toml
|
||||
assert Version(config.project.version) >= Version("1.0.0")
|
||||
|
||||
|
||||
# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected.
|
||||
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.
|
||||
|
||||
1
api/tests/unit_tests/core/ops/__init__.py
Normal file
1
api/tests/unit_tests/core/ops/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Unit tests for core ops module
|
||||
385
api/tests/unit_tests/core/ops/test_config_entity.py
Normal file
385
api/tests/unit_tests/core/ops/test_config_entity.py
Normal file
@ -0,0 +1,385 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.ops.entities.config_entity import (
|
||||
AliyunConfig,
|
||||
ArizeConfig,
|
||||
LangfuseConfig,
|
||||
LangSmithConfig,
|
||||
OpikConfig,
|
||||
PhoenixConfig,
|
||||
TracingProviderEnum,
|
||||
WeaveConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestTracingProviderEnum:
|
||||
"""Test cases for TracingProviderEnum"""
|
||||
|
||||
def test_enum_values(self):
|
||||
"""Test that all expected enum values are present"""
|
||||
assert TracingProviderEnum.ARIZE == "arize"
|
||||
assert TracingProviderEnum.PHOENIX == "phoenix"
|
||||
assert TracingProviderEnum.LANGFUSE == "langfuse"
|
||||
assert TracingProviderEnum.LANGSMITH == "langsmith"
|
||||
assert TracingProviderEnum.OPIK == "opik"
|
||||
assert TracingProviderEnum.WEAVE == "weave"
|
||||
assert TracingProviderEnum.ALIYUN == "aliyun"
|
||||
|
||||
|
||||
class TestArizeConfig:
|
||||
"""Test cases for ArizeConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Arize configuration"""
|
||||
config = ArizeConfig(
|
||||
api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com"
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.space_id == "test_space"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.arize.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = ArizeConfig()
|
||||
assert config.api_key is None
|
||||
assert config.space_id is None
|
||||
assert config.project is None
|
||||
assert config.endpoint == "https://otlp.arize.com"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = ArizeConfig(project="")
|
||||
assert config.project == "default"
|
||||
|
||||
def test_project_validation_none(self):
|
||||
"""Test project validation with None value"""
|
||||
config = ArizeConfig(project=None)
|
||||
assert config.project == "default"
|
||||
|
||||
def test_endpoint_validation_empty(self):
|
||||
"""Test endpoint validation with empty value"""
|
||||
config = ArizeConfig(endpoint="")
|
||||
assert config.endpoint == "https://otlp.arize.com"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation normalizes URL by removing path"""
|
||||
config = ArizeConfig(endpoint="https://custom.arize.com/api/v1")
|
||||
assert config.endpoint == "https://custom.arize.com"
|
||||
|
||||
def test_endpoint_validation_invalid_scheme(self):
|
||||
"""Test endpoint validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
ArizeConfig(endpoint="ftp://invalid.com")
|
||||
|
||||
def test_endpoint_validation_no_scheme(self):
|
||||
"""Test endpoint validation rejects URLs without scheme"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
ArizeConfig(endpoint="invalid.com")
|
||||
|
||||
|
||||
class TestPhoenixConfig:
|
||||
"""Test cases for PhoenixConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Phoenix configuration"""
|
||||
config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com")
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.phoenix.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = PhoenixConfig()
|
||||
assert config.api_key is None
|
||||
assert config.project is None
|
||||
assert config.endpoint == "https://app.phoenix.arize.com"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = PhoenixConfig(project="")
|
||||
assert config.project == "default"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation normalizes URL by removing path"""
|
||||
config = PhoenixConfig(endpoint="https://custom.phoenix.com/api/v1")
|
||||
assert config.endpoint == "https://custom.phoenix.com"
|
||||
|
||||
|
||||
class TestLangfuseConfig:
|
||||
"""Test cases for LangfuseConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Langfuse configuration"""
|
||||
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com")
|
||||
assert config.public_key == "public_key"
|
||||
assert config.secret_key == "secret_key"
|
||||
assert config.host == "https://custom.langfuse.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = LangfuseConfig(public_key="public", secret_key="secret")
|
||||
assert config.host == "https://api.langfuse.com"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(public_key="public")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(secret_key="secret")
|
||||
|
||||
def test_host_validation_empty(self):
|
||||
"""Test host validation with empty value"""
|
||||
config = LangfuseConfig(public_key="public", secret_key="secret", host="")
|
||||
assert config.host == "https://api.langfuse.com"
|
||||
|
||||
|
||||
class TestLangSmithConfig:
|
||||
"""Test cases for LangSmithConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid LangSmith configuration"""
|
||||
config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com")
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.smith.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = LangSmithConfig(api_key="key", project="project")
|
||||
assert config.endpoint == "https://api.smith.langchain.com"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(api_key="key")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(project="project")
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com")
|
||||
|
||||
|
||||
class TestOpikConfig:
|
||||
"""Test cases for OpikConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Opik configuration"""
|
||||
config = OpikConfig(
|
||||
api_key="test_key",
|
||||
project="test_project",
|
||||
workspace="test_workspace",
|
||||
url="https://custom.comet.com/opik/api/",
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.workspace == "test_workspace"
|
||||
assert config.url == "https://custom.comet.com/opik/api/"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = OpikConfig()
|
||||
assert config.api_key is None
|
||||
assert config.project is None
|
||||
assert config.workspace is None
|
||||
assert config.url == "https://www.comet.com/opik/api/"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = OpikConfig(project="")
|
||||
assert config.project == "Default Project"
|
||||
|
||||
def test_url_validation_empty(self):
|
||||
"""Test URL validation with empty value"""
|
||||
config = OpikConfig(url="")
|
||||
assert config.url == "https://www.comet.com/opik/api/"
|
||||
|
||||
def test_url_validation_missing_suffix(self):
|
||||
"""Test URL validation requires /api/ suffix"""
|
||||
with pytest.raises(ValidationError, match="URL should end with /api/"):
|
||||
OpikConfig(url="https://custom.comet.com/opik/")
|
||||
|
||||
def test_url_validation_invalid_scheme(self):
|
||||
"""Test URL validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
|
||||
OpikConfig(url="ftp://custom.comet.com/opik/api/")
|
||||
|
||||
|
||||
class TestWeaveConfig:
|
||||
"""Test cases for WeaveConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Weave configuration"""
|
||||
config = WeaveConfig(
|
||||
api_key="test_key",
|
||||
entity="test_entity",
|
||||
project="test_project",
|
||||
endpoint="https://custom.wandb.ai",
|
||||
host="https://custom.host.com",
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.entity == "test_entity"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.wandb.ai"
|
||||
assert config.host == "https://custom.host.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = WeaveConfig(api_key="key", project="project")
|
||||
assert config.entity is None
|
||||
assert config.endpoint == "https://trace.wandb.ai"
|
||||
assert config.host is None
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(api_key="key")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(project="project")
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai")
|
||||
|
||||
def test_host_validation_optional(self):
|
||||
"""Test host validation is optional but validates when provided"""
|
||||
config = WeaveConfig(api_key="key", project="project", host=None)
|
||||
assert config.host is None
|
||||
|
||||
config = WeaveConfig(api_key="key", project="project", host="")
|
||||
assert config.host == ""
|
||||
|
||||
config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com")
|
||||
assert config.host == "https://valid.host.com"
|
||||
|
||||
def test_host_validation_invalid_scheme(self):
|
||||
"""Test host validation rejects invalid schemes when provided"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com")
|
||||
|
||||
|
||||
class TestAliyunConfig:
|
||||
"""Test cases for AliyunConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Aliyun configuration"""
|
||||
config = AliyunConfig(
|
||||
app_name="test_app",
|
||||
license_key="test_license_key",
|
||||
endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com",
|
||||
)
|
||||
assert config.app_name == "test_app"
|
||||
assert config.license_key == "test_license_key"
|
||||
assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
assert config.app_name == "dify_app"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(license_key="test_license")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_app_name_validation_empty(self):
|
||||
"""Test app_name validation with empty value"""
|
||||
config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
|
||||
)
|
||||
assert config.app_name == "dify_app"
|
||||
|
||||
def test_endpoint_validation_empty(self):
|
||||
"""Test endpoint validation with empty value"""
|
||||
config = AliyunConfig(license_key="test_license", endpoint="")
|
||||
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation normalizes URL by removing path"""
|
||||
config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
|
||||
)
|
||||
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_endpoint_validation_invalid_scheme(self):
|
||||
"""Test endpoint validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_endpoint_validation_no_scheme(self):
|
||||
"""Test endpoint validation rejects URLs without scheme"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_license_key_required(self):
|
||||
"""Test that license_key is required and cannot be empty"""
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
|
||||
class TestConfigIntegration:
|
||||
"""Integration tests for configuration classes"""
|
||||
|
||||
def test_all_configs_can_be_instantiated(self):
|
||||
"""Test that all config classes can be instantiated with valid data"""
|
||||
configs = [
|
||||
ArizeConfig(api_key="key"),
|
||||
PhoenixConfig(api_key="key"),
|
||||
LangfuseConfig(public_key="public", secret_key="secret"),
|
||||
LangSmithConfig(api_key="key", project="project"),
|
||||
OpikConfig(api_key="key"),
|
||||
WeaveConfig(api_key="key", project="project"),
|
||||
AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com"),
|
||||
]
|
||||
|
||||
for config in configs:
|
||||
assert config is not None
|
||||
|
||||
def test_url_normalization_consistency(self):
|
||||
"""Test that URL normalization works consistently across configs"""
|
||||
# Test that paths are removed from endpoints
|
||||
arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test")
|
||||
phoenix_config = PhoenixConfig(endpoint="https://phoenix.com/api/v2/")
|
||||
aliyun_config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
|
||||
)
|
||||
|
||||
assert arize_config.endpoint == "https://arize.com"
|
||||
assert phoenix_config.endpoint == "https://phoenix.com"
|
||||
assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_project_default_values(self):
|
||||
"""Test that project default values are set correctly"""
|
||||
arize_config = ArizeConfig(project="")
|
||||
phoenix_config = PhoenixConfig(project="")
|
||||
opik_config = OpikConfig(project="")
|
||||
aliyun_config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
|
||||
)
|
||||
|
||||
assert arize_config.project == "default"
|
||||
assert phoenix_config.project == "default"
|
||||
assert opik_config.project == "Default Project"
|
||||
assert aliyun_config.app_name == "dify_app"
|
||||
138
api/tests/unit_tests/core/ops/test_utils.py
Normal file
138
api/tests/unit_tests/core/ops/test_utils.py
Normal file
@ -0,0 +1,138 @@
|
||||
import pytest
|
||||
|
||||
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
|
||||
|
||||
|
||||
class TestValidateUrl:
|
||||
"""Test cases for validate_url function"""
|
||||
|
||||
def test_valid_https_url(self):
|
||||
"""Test valid HTTPS URL"""
|
||||
result = validate_url("https://example.com", "https://default.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_valid_http_url(self):
|
||||
"""Test valid HTTP URL"""
|
||||
result = validate_url("http://example.com", "https://default.com")
|
||||
assert result == "http://example.com"
|
||||
|
||||
def test_url_with_path_removed(self):
|
||||
"""Test that URL path is removed during normalization"""
|
||||
result = validate_url("https://example.com/api/v1/test", "https://default.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_url_with_query_removed(self):
|
||||
"""Test that URL query parameters are removed"""
|
||||
result = validate_url("https://example.com?param=value", "https://default.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_url_with_fragment_removed(self):
|
||||
"""Test that URL fragments are removed"""
|
||||
result = validate_url("https://example.com#section", "https://default.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_empty_url_returns_default(self):
|
||||
"""Test empty URL returns default"""
|
||||
result = validate_url("", "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_none_url_returns_default(self):
|
||||
"""Test None URL returns default"""
|
||||
result = validate_url(None, "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_whitespace_url_returns_default(self):
|
||||
"""Test whitespace URL returns default"""
|
||||
result = validate_url(" ", "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_invalid_scheme_raises_error(self):
|
||||
"""Test invalid scheme raises ValueError"""
|
||||
with pytest.raises(ValueError, match="URL scheme must be one of"):
|
||||
validate_url("ftp://example.com", "https://default.com")
|
||||
|
||||
def test_no_scheme_raises_error(self):
|
||||
"""Test URL without scheme raises ValueError"""
|
||||
with pytest.raises(ValueError, match="URL scheme must be one of"):
|
||||
validate_url("example.com", "https://default.com")
|
||||
|
||||
def test_custom_allowed_schemes(self):
|
||||
"""Test custom allowed schemes"""
|
||||
result = validate_url("https://example.com", "https://default.com", allowed_schemes=("https",))
|
||||
assert result == "https://example.com"
|
||||
|
||||
with pytest.raises(ValueError, match="URL scheme must be one of"):
|
||||
validate_url("http://example.com", "https://default.com", allowed_schemes=("https",))
|
||||
|
||||
|
||||
class TestValidateUrlWithPath:
|
||||
"""Test cases for validate_url_with_path function"""
|
||||
|
||||
def test_valid_url_with_path(self):
|
||||
"""Test valid URL with path"""
|
||||
result = validate_url_with_path("https://example.com/api/v1", "https://default.com")
|
||||
assert result == "https://example.com/api/v1"
|
||||
|
||||
def test_valid_url_with_required_suffix(self):
|
||||
"""Test valid URL with required suffix"""
|
||||
result = validate_url_with_path("https://example.com/api/", "https://default.com", required_suffix="/api/")
|
||||
assert result == "https://example.com/api/"
|
||||
|
||||
def test_url_without_required_suffix_raises_error(self):
|
||||
"""Test URL without required suffix raises error"""
|
||||
with pytest.raises(ValueError, match="URL should end with /api/"):
|
||||
validate_url_with_path("https://example.com/api", "https://default.com", required_suffix="/api/")
|
||||
|
||||
def test_empty_url_returns_default(self):
|
||||
"""Test empty URL returns default"""
|
||||
result = validate_url_with_path("", "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_none_url_returns_default(self):
|
||||
"""Test None URL returns default"""
|
||||
result = validate_url_with_path(None, "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_invalid_scheme_raises_error(self):
|
||||
"""Test invalid scheme raises ValueError"""
|
||||
with pytest.raises(ValueError, match="URL must start with https:// or http://"):
|
||||
validate_url_with_path("ftp://example.com", "https://default.com")
|
||||
|
||||
def test_no_scheme_raises_error(self):
|
||||
"""Test URL without scheme raises ValueError"""
|
||||
with pytest.raises(ValueError, match="URL must start with https:// or http://"):
|
||||
validate_url_with_path("example.com", "https://default.com")
|
||||
|
||||
|
||||
class TestValidateProjectName:
|
||||
"""Test cases for validate_project_name function"""
|
||||
|
||||
def test_valid_project_name(self):
|
||||
"""Test valid project name"""
|
||||
result = validate_project_name("my-project", "default")
|
||||
assert result == "my-project"
|
||||
|
||||
def test_empty_project_name_returns_default(self):
|
||||
"""Test empty project name returns default"""
|
||||
result = validate_project_name("", "default")
|
||||
assert result == "default"
|
||||
|
||||
def test_none_project_name_returns_default(self):
|
||||
"""Test None project name returns default"""
|
||||
result = validate_project_name(None, "default")
|
||||
assert result == "default"
|
||||
|
||||
def test_whitespace_project_name_returns_default(self):
|
||||
"""Test whitespace project name returns default"""
|
||||
result = validate_project_name(" ", "default")
|
||||
assert result == "default"
|
||||
|
||||
def test_project_name_with_whitespace_trimmed(self):
|
||||
"""Test project name with whitespace is trimmed"""
|
||||
result = validate_project_name(" my-project ", "default")
|
||||
assert result == "my-project"
|
||||
|
||||
def test_custom_default_name(self):
|
||||
"""Test custom default name"""
|
||||
result = validate_project_name("", "Custom Default")
|
||||
assert result == "Custom Default"
|
||||
@ -1,3 +1,4 @@
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -19,6 +20,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
@ -172,6 +174,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
@ -183,7 +186,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
@ -299,6 +302,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
@ -310,7 +314,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
@ -479,6 +483,7 @@ def test_run_branch(mock_close, mock_remove):
|
||||
user_inputs={"uid": "takato"},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
@ -490,7 +495,7 @@ def test_run_branch(mock_close, mock_remove):
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
@ -813,6 +818,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
@ -824,7 +830,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
@ -11,6 +13,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunStreamChunkEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
@ -163,15 +166,16 @@ class ContinueOnErrorTestHelper:
|
||||
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
|
||||
"""Helper method to create a graph engine instance for testing"""
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
variable_pool = {
|
||||
"system_variables": {
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "clear",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
"user_inputs": user_inputs or {"uid": "takato"},
|
||||
}
|
||||
user_inputs=user_inputs or {"uid": "takato"},
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
return GraphEngine(
|
||||
tenant_id="111",
|
||||
@ -184,7 +188,7 @@ class ContinueOnErrorTestHelper:
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
|
||||
74
api/tests/unit_tests/libs/test_password.py
Normal file
74
api/tests/unit_tests/libs/test_password.py
Normal file
@ -0,0 +1,74 @@
|
||||
import base64
|
||||
import binascii
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.password import compare_password, hash_password, valid_password
|
||||
|
||||
|
||||
class TestValidPassword:
|
||||
"""Test password format validation"""
|
||||
|
||||
def test_should_accept_valid_passwords(self):
|
||||
"""Test accepting valid password formats"""
|
||||
assert valid_password("password123") == "password123"
|
||||
assert valid_password("test1234") == "test1234"
|
||||
assert valid_password("Test123456") == "Test123456"
|
||||
|
||||
def test_should_reject_invalid_passwords(self):
|
||||
"""Test rejecting invalid password formats"""
|
||||
# Too short
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
valid_password("abc123")
|
||||
assert "Password must contain letters and numbers" in str(exc_info.value)
|
||||
|
||||
# No numbers
|
||||
with pytest.raises(ValueError):
|
||||
valid_password("abcdefgh")
|
||||
|
||||
# No letters
|
||||
with pytest.raises(ValueError):
|
||||
valid_password("12345678")
|
||||
|
||||
# Empty
|
||||
with pytest.raises(ValueError):
|
||||
valid_password("")
|
||||
|
||||
|
||||
class TestPasswordHashing:
|
||||
"""Test password hashing and comparison"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test data"""
|
||||
self.password = "test123password"
|
||||
self.salt = os.urandom(16)
|
||||
self.salt_base64 = base64.b64encode(self.salt).decode()
|
||||
|
||||
password_hash = hash_password(self.password, self.salt)
|
||||
self.password_hash_base64 = base64.b64encode(password_hash).decode()
|
||||
|
||||
def test_should_verify_correct_password(self):
|
||||
"""Test correct password verification"""
|
||||
result = compare_password(self.password, self.password_hash_base64, self.salt_base64)
|
||||
assert result is True
|
||||
|
||||
def test_should_reject_wrong_password(self):
|
||||
"""Test rejection of incorrect passwords"""
|
||||
result = compare_password("wrongpassword", self.password_hash_base64, self.salt_base64)
|
||||
assert result is False
|
||||
|
||||
def test_should_handle_invalid_base64(self):
|
||||
"""Test handling of invalid base64 data"""
|
||||
# Invalid base64 hash
|
||||
with pytest.raises(binascii.Error):
|
||||
compare_password(self.password, "invalid_base64!", self.salt_base64)
|
||||
|
||||
# Invalid base64 salt
|
||||
with pytest.raises(binascii.Error):
|
||||
compare_password(self.password, self.password_hash_base64, "invalid_base64!")
|
||||
|
||||
def test_should_be_case_sensitive(self):
|
||||
"""Test password case sensitivity"""
|
||||
result = compare_password(self.password.upper(), self.password_hash_base64, self.salt_base64)
|
||||
assert result is False
|
||||
@ -0,0 +1,465 @@
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
|
||||
|
||||
def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage:
|
||||
"""Create a mock LLMUsage with all required fields"""
|
||||
return LLMUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("1"),
|
||||
prompt_price=Decimal(str(prompt_tokens)) * Decimal("0.001"),
|
||||
completion_tokens=completion_tokens,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("1"),
|
||||
completion_price=Decimal(str(completion_tokens)) * Decimal("0.002"),
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_price=Decimal(str(prompt_tokens)) * Decimal("0.001") + Decimal(str(completion_tokens)) * Decimal("0.002"),
|
||||
currency="USD",
|
||||
latency=1.5,
|
||||
)
|
||||
|
||||
|
||||
def get_model_entity(provider: str, model_name: str, support_structure_output: bool = False) -> AIModelEntity:
|
||||
"""Create a mock AIModelEntity for testing"""
|
||||
model_schema = MagicMock()
|
||||
model_schema.model = model_name
|
||||
model_schema.provider = provider
|
||||
model_schema.model_type = ModelType.LLM
|
||||
model_schema.model_provider = provider
|
||||
model_schema.model_name = model_name
|
||||
model_schema.support_structure_output = support_structure_output
|
||||
model_schema.parameter_rules = []
|
||||
|
||||
return model_schema
|
||||
|
||||
|
||||
def get_model_instance() -> MagicMock:
|
||||
"""Create a mock ModelInstance for testing"""
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.provider = "openai"
|
||||
mock_instance.credentials = {}
|
||||
return mock_instance
|
||||
|
||||
|
||||
def test_structured_output_parser():
|
||||
"""Test cases for invoke_llm_with_structured_output function"""
|
||||
|
||||
testcases = [
|
||||
# Test case 1: Model with native structured output support, non-streaming
|
||||
{
|
||||
"name": "native_structured_output_non_streaming",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content='{"name": "test"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 2: Model with native structured output support, streaming
|
||||
{
|
||||
"name": "native_structured_output_streaming",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"name":'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' "test"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 3: Model without native structured output support, non-streaming
|
||||
{
|
||||
"name": "prompt_based_structured_output_non_streaming",
|
||||
"provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"support_structure_output": False,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="claude-3-sonnet",
|
||||
message=AssistantPromptMessage(content='{"answer": "test response"}'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=8),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 4: Model without native structured output support, streaming
|
||||
{
|
||||
"name": "prompt_based_structured_output_streaming",
|
||||
"provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"support_structure_output": False,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="claude-3-sonnet",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"answer": "test'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="claude-3-sonnet",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' response"}'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=5),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 5: Streaming with list content
|
||||
{
|
||||
"name": "streaming_with_list_content",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"data": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data='{"data":'),
|
||||
]
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data=' "value"}'),
|
||||
]
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 6: Error case - non-string LLM response content (non-streaming)
|
||||
{
|
||||
"name": "error_non_string_content_non_streaming",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content=None), # Non-string content
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": None,
|
||||
"should_raise": True,
|
||||
"expected_error": OutputParserError,
|
||||
},
|
||||
# Test case 7: JSON repair scenario
|
||||
{
|
||||
"name": "json_repair_scenario",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content='{"name": "test"'), # Invalid JSON - missing closing brace
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 8: Model with parameter rules for response format
|
||||
{
|
||||
"name": "model_with_parameter_rules",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"result": {"type": "string"}}},
|
||||
"parameter_rules": [
|
||||
MagicMock(name="response_format", options=["json_schema"], required=False),
|
||||
],
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content='{"result": "success"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 9: Model without native support but with JSON response format rules
|
||||
{
|
||||
"name": "non_native_with_json_rules",
|
||||
"provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"support_structure_output": False,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"output": {"type": "string"}}},
|
||||
"parameter_rules": [
|
||||
MagicMock(name="response_format", options=["JSON"], required=False),
|
||||
],
|
||||
"expected_llm_response": LLMResult(
|
||||
model="claude-3-sonnet",
|
||||
message=AssistantPromptMessage(content='{"output": "result"}'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=8),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
]
|
||||
|
||||
for case in testcases:
|
||||
print(f"Running test case: {case['name']}")
|
||||
|
||||
# Setup model entity
|
||||
model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"])
|
||||
|
||||
# Add parameter rules if specified
|
||||
if "parameter_rules" in case:
|
||||
model_schema.parameter_rules = case["parameter_rules"]
|
||||
|
||||
# Setup model instance
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = case["expected_llm_response"]
|
||||
|
||||
# Setup prompt messages
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content="You are a helpful assistant."),
|
||||
UserPromptMessage(content="Generate a response according to the schema."),
|
||||
]
|
||||
|
||||
if case["should_raise"]:
|
||||
# Test error cases
|
||||
with pytest.raises(case["expected_error"]): # noqa: PT012
|
||||
if case["stream"]:
|
||||
result_generator = invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
stream=case["stream"],
|
||||
)
|
||||
# Consume the generator to trigger the error
|
||||
list(result_generator)
|
||||
else:
|
||||
invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
stream=case["stream"],
|
||||
)
|
||||
else:
|
||||
# Test successful cases
|
||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||
# Configure json_repair mock for cases that need it
|
||||
if case["name"] == "json_repair_scenario":
|
||||
mock_json_repair.return_value = {"name": "test"}
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
stream=case["stream"],
|
||||
model_parameters={"temperature": 0.7, "max_tokens": 100},
|
||||
user="test_user",
|
||||
)
|
||||
|
||||
if case["expected_result_type"] == "generator":
|
||||
# Test streaming results
|
||||
assert hasattr(result, "__iter__")
|
||||
chunks = list(result)
|
||||
assert len(chunks) > 0
|
||||
|
||||
# Verify all chunks are LLMResultChunkWithStructuredOutput
|
||||
for chunk in chunks[:-1]: # All except last
|
||||
assert isinstance(chunk, LLMResultChunkWithStructuredOutput)
|
||||
assert chunk.model == case["model_name"]
|
||||
|
||||
# Last chunk should have structured output
|
||||
last_chunk = chunks[-1]
|
||||
assert isinstance(last_chunk, LLMResultChunkWithStructuredOutput)
|
||||
assert last_chunk.structured_output is not None
|
||||
assert isinstance(last_chunk.structured_output, dict)
|
||||
else:
|
||||
# Test non-streaming results
|
||||
assert isinstance(result, case["expected_result_type"])
|
||||
assert result.model == case["model_name"]
|
||||
assert result.structured_output is not None
|
||||
assert isinstance(result.structured_output, dict)
|
||||
|
||||
# Verify model_instance.invoke_llm was called with correct parameters
|
||||
model_instance.invoke_llm.assert_called_once()
|
||||
call_args = model_instance.invoke_llm.call_args
|
||||
|
||||
assert call_args.kwargs["stream"] == case["stream"]
|
||||
assert call_args.kwargs["user"] == "test_user"
|
||||
assert "temperature" in call_args.kwargs["model_parameters"]
|
||||
assert "max_tokens" in call_args.kwargs["model_parameters"]
|
||||
|
||||
|
||||
def test_parse_structured_output_edge_cases():
|
||||
"""Test edge cases for structured output parsing"""
|
||||
|
||||
# Test case with list that contains dict (reasoning model scenario)
|
||||
testcase_list_with_dict = {
|
||||
"name": "list_with_dict_parsing",
|
||||
"provider": "deepseek",
|
||||
"model_name": "deepseek-r1",
|
||||
"support_structure_output": False,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"thought": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="deepseek-r1",
|
||||
message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
}
|
||||
|
||||
# Setup for list parsing test
|
||||
model_schema = get_model_entity(
|
||||
testcase_list_with_dict["provider"],
|
||||
testcase_list_with_dict["model_name"],
|
||||
testcase_list_with_dict["support_structure_output"],
|
||||
)
|
||||
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = testcase_list_with_dict["expected_llm_response"]
|
||||
|
||||
prompt_messages = [UserPromptMessage(content="Test reasoning")]
|
||||
|
||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||
# Mock json_repair to return a list with dict
|
||||
mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"]
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=testcase_list_with_dict["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=testcase_list_with_dict["json_schema"],
|
||||
stream=testcase_list_with_dict["stream"],
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
assert result.structured_output == {"thought": "reasoning process"}
|
||||
|
||||
|
||||
def test_model_specific_schema_preparation():
|
||||
"""Test schema preparation for different model types"""
|
||||
|
||||
# Test Gemini model
|
||||
gemini_case = {
|
||||
"provider": "google",
|
||||
"model_name": "gemini-pro",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"result": {"type": "boolean"}}, "additionalProperties": False},
|
||||
}
|
||||
|
||||
model_schema = get_model_entity(
|
||||
gemini_case["provider"], gemini_case["model_name"], gemini_case["support_structure_output"]
|
||||
)
|
||||
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = LLMResult(
|
||||
model="gemini-pro",
|
||||
message=AssistantPromptMessage(content='{"result": "true"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content="Test")]
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=gemini_case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=gemini_case["json_schema"],
|
||||
stream=gemini_case["stream"],
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
|
||||
# Verify model_instance.invoke_llm was called and check the schema preparation
|
||||
model_instance.invoke_llm.assert_called_once()
|
||||
call_args = model_instance.invoke_llm.call_args
|
||||
|
||||
# For Gemini, the schema should not have additionalProperties and boolean should be converted to string
|
||||
assert "json_schema" in call_args.kwargs["model_parameters"]
|
||||
Reference in New Issue
Block a user