Merge branch 'main' into feat/mcp

This commit is contained in:
Novice
2025-07-09 09:41:42 +08:00
234 changed files with 8742 additions and 1254 deletions

View File

@ -1,6 +1,7 @@
import os
from flask import Flask
from packaging.version import Version
from yarl import URL
from configs.app_config import DifyConfig
@ -40,6 +41,9 @@ def test_dify_config(monkeypatch):
assert config.WORKFLOW_PARALLEL_DEPTH_LIMIT == 3
# values from pyproject.toml
assert Version(config.project.version) >= Version("1.0.0")
# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected.
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.

View File

@ -0,0 +1,259 @@
from collections.abc import Mapping, Sequence
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.variables.segments import ArrayFileSegment, FileSegment
class TestWorkflowResponseConverterFetchFilesFromVariableValue:
"""Test class for WorkflowResponseConverter._fetch_files_from_variable_value method"""
def create_test_file(self, file_id: str = "test_file_1") -> File:
"""Create a test File object"""
return File(
id=file_id,
tenant_id="test_tenant",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related_123",
filename=f"{file_id}.txt",
extension=".txt",
mime_type="text/plain",
size=1024,
storage_key="storage_key_123",
)
def create_file_dict(self, file_id: str = "test_file_dict") -> dict:
"""Create a file dictionary with correct dify_model_identity"""
return {
"dify_model_identity": FILE_MODEL_IDENTITY,
"id": file_id,
"tenant_id": "test_tenant",
"type": "document",
"transfer_method": "local_file",
"related_id": "related_456",
"filename": f"{file_id}.txt",
"extension": ".txt",
"mime_type": "text/plain",
"size": 2048,
"url": "http://example.com/file.txt",
}
def test_fetch_files_from_variable_value_with_none(self):
"""Test with None input"""
# The method signature expects Union[dict, list, Segment], but implementation handles None
# We'll test the actual behavior by passing an empty dict instead
result = WorkflowResponseConverter._fetch_files_from_variable_value(None) # type: ignore
assert result == []
def test_fetch_files_from_variable_value_with_empty_dict(self):
"""Test with empty dictionary"""
result = WorkflowResponseConverter._fetch_files_from_variable_value({})
assert result == []
def test_fetch_files_from_variable_value_with_empty_list(self):
"""Test with empty list"""
result = WorkflowResponseConverter._fetch_files_from_variable_value([])
assert result == []
def test_fetch_files_from_variable_value_with_file_segment(self):
"""Test with valid FileSegment"""
test_file = self.create_test_file("segment_file")
file_segment = FileSegment(value=test_file)
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_segment)
assert len(result) == 1
assert isinstance(result[0], dict)
assert result[0]["id"] == "segment_file"
assert result[0]["dify_model_identity"] == FILE_MODEL_IDENTITY
def test_fetch_files_from_variable_value_with_array_file_segment_single(self):
"""Test with ArrayFileSegment containing single file"""
test_file = self.create_test_file("array_file_1")
array_segment = ArrayFileSegment(value=[test_file])
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
assert len(result) == 1
assert isinstance(result[0], dict)
assert result[0]["id"] == "array_file_1"
def test_fetch_files_from_variable_value_with_array_file_segment_multiple(self):
"""Test with ArrayFileSegment containing multiple files"""
test_file_1 = self.create_test_file("array_file_1")
test_file_2 = self.create_test_file("array_file_2")
array_segment = ArrayFileSegment(value=[test_file_1, test_file_2])
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
assert len(result) == 2
assert result[0]["id"] == "array_file_1"
assert result[1]["id"] == "array_file_2"
def test_fetch_files_from_variable_value_with_array_file_segment_empty(self):
"""Test with ArrayFileSegment containing empty array"""
array_segment = ArrayFileSegment(value=[])
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
assert result == []
def test_fetch_files_from_variable_value_with_list_of_file_dicts(self):
"""Test with list containing file dictionaries"""
file_dict_1 = self.create_file_dict("list_file_1")
file_dict_2 = self.create_file_dict("list_file_2")
test_list = [file_dict_1, file_dict_2]
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
assert len(result) == 2
assert result[0]["id"] == "list_file_1"
assert result[1]["id"] == "list_file_2"
def test_fetch_files_from_variable_value_with_list_of_file_objects(self):
"""Test with list containing File objects"""
file_obj_1 = self.create_test_file("list_obj_1")
file_obj_2 = self.create_test_file("list_obj_2")
test_list = [file_obj_1, file_obj_2]
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
assert len(result) == 2
assert result[0]["id"] == "list_obj_1"
assert result[1]["id"] == "list_obj_2"
def test_fetch_files_from_variable_value_with_list_mixed_valid_invalid(self):
"""Test with list containing mix of valid files and invalid items"""
file_dict = self.create_file_dict("mixed_file")
invalid_dict = {"not_a_file": "value"}
test_list = [file_dict, invalid_dict, "string_item", 123]
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
assert len(result) == 1
assert result[0]["id"] == "mixed_file"
def test_fetch_files_from_variable_value_with_list_nested_structures(self):
"""Test with list containing nested structures"""
file_dict = self.create_file_dict("nested_file")
nested_list = [file_dict, ["inner_list"]]
test_list = [nested_list, {"nested": "dict"}]
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
# Should not process nested structures in list items
assert result == []
def test_fetch_files_from_variable_value_with_dict_incorrect_identity(self):
"""Test with dictionary having incorrect dify_model_identity"""
invalid_dict = {"dify_model_identity": "wrong_identity", "id": "invalid_file", "filename": "test.txt"}
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
assert result == []
def test_fetch_files_from_variable_value_with_dict_missing_identity(self):
"""Test with dictionary missing dify_model_identity"""
invalid_dict = {"id": "no_identity_file", "filename": "test.txt"}
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
assert result == []
def test_fetch_files_from_variable_value_with_dict_file_object(self):
"""Test with dictionary containing File object"""
file_obj = self.create_test_file("dict_obj_file")
test_dict = {"file_key": file_obj}
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_dict)
# Should not extract File objects from dict values
assert result == []
def test_fetch_files_from_variable_value_with_mixed_data_types(self):
"""Test with various mixed data types"""
mixed_data = {"string": "text", "number": 42, "boolean": True, "null": None, "dify_model_identity": "wrong"}
result = WorkflowResponseConverter._fetch_files_from_variable_value(mixed_data)
assert result == []
def test_fetch_files_from_variable_value_with_invalid_objects(self):
"""Test with invalid objects that are not supported types"""
# Test with an invalid dict that doesn't match expected patterns
invalid_dict = {"custom_key": "custom_value"}
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
assert result == []
def test_fetch_files_from_variable_value_with_string_input(self):
"""Test with string input (unsupported type)"""
# Since method expects Union[dict, list, Segment], test with empty list instead
result = WorkflowResponseConverter._fetch_files_from_variable_value([])
assert result == []
def test_fetch_files_from_variable_value_with_number_input(self):
"""Test with number input (unsupported type)"""
# Test with list containing numbers (should be ignored)
result = WorkflowResponseConverter._fetch_files_from_variable_value([42, "string", None])
assert result == []
def test_fetch_files_from_variable_value_return_type_is_sequence(self):
"""Test that return type is Sequence[Mapping[str, Any]]"""
file_dict = self.create_file_dict("type_test_file")
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_dict)
assert isinstance(result, Sequence)
assert len(result) == 1
assert isinstance(result[0], Mapping)
assert all(isinstance(key, str) for key in result[0])
def test_fetch_files_from_variable_value_preserves_file_properties(self):
"""Test that all file properties are preserved in the result"""
original_file = self.create_test_file("property_test")
file_segment = FileSegment(value=original_file)
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_segment)
assert len(result) == 1
file_dict = result[0]
assert file_dict["id"] == "property_test"
assert file_dict["tenant_id"] == "test_tenant"
assert file_dict["type"] == "document"
assert file_dict["transfer_method"] == "local_file"
assert file_dict["filename"] == "property_test.txt"
assert file_dict["extension"] == ".txt"
assert file_dict["mime_type"] == "text/plain"
assert file_dict["size"] == 1024
def test_fetch_files_from_variable_value_with_complex_nested_scenario(self):
"""Test complex scenario with nested valid and invalid data"""
file_dict = self.create_file_dict("complex_file")
file_obj = self.create_test_file("complex_obj")
# Complex nested structure
complex_data = [
file_dict, # Valid file dict
file_obj, # Valid file object
{ # Invalid dict
"not_file": "data",
"nested": {"deep": "value"},
},
[ # Nested list (should be ignored)
self.create_file_dict("nested_file")
],
"string", # Invalid string
None, # None value
42, # Invalid number
]
result = WorkflowResponseConverter._fetch_files_from_variable_value(complex_data)
assert len(result) == 2
assert result[0]["id"] == "complex_file"
assert result[1]["id"] == "complex_obj"

View File

@ -0,0 +1,194 @@
from unittest.mock import patch
from urllib.parse import parse_qs, urlparse
import pytest
from core.helper.url_signer import SignedUrlParams, UrlSigner
class TestUrlSigner:
"""Test cases for UrlSigner class"""
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
def test_should_generate_signed_url_params(self):
"""Test generation of signed URL parameters with all required fields"""
sign_key = "test-sign-key"
prefix = "test-prefix"
params = UrlSigner.get_signed_url_params(sign_key, prefix)
# Verify the returned object and required fields
assert isinstance(params, SignedUrlParams)
assert params.sign_key == sign_key
assert params.timestamp is not None
assert params.nonce is not None
assert params.sign is not None
# Verify nonce format (32 character hex string)
assert len(params.nonce) == 32
assert all(c in "0123456789abcdef" for c in params.nonce)
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
def test_should_generate_complete_signed_url(self):
"""Test generation of complete signed URL with query parameters"""
base_url = "https://example.com/api/test"
sign_key = "test-sign-key"
prefix = "test-prefix"
signed_url = UrlSigner.get_signed_url(base_url, sign_key, prefix)
# Parse URL and verify structure
parsed = urlparse(signed_url)
assert f"{parsed.scheme}://{parsed.netloc}{parsed.path}" == base_url
# Verify query parameters
query_params = parse_qs(parsed.query)
assert "timestamp" in query_params
assert "nonce" in query_params
assert "sign" in query_params
# Verify each parameter has exactly one value
assert len(query_params["timestamp"]) == 1
assert len(query_params["nonce"]) == 1
assert len(query_params["sign"]) == 1
# Verify parameter values are not empty
assert query_params["timestamp"][0]
assert query_params["nonce"][0]
assert query_params["sign"][0]
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
def test_should_verify_valid_signature(self):
"""Test verification of valid signature"""
sign_key = "test-sign-key"
prefix = "test-prefix"
# Generate and verify signature
params = UrlSigner.get_signed_url_params(sign_key, prefix)
is_valid = UrlSigner.verify(
sign_key=sign_key, timestamp=params.timestamp, nonce=params.nonce, sign=params.sign, prefix=prefix
)
assert is_valid is True
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
@pytest.mark.parametrize(
("field", "modifier"),
[
("sign_key", lambda _: "wrong-sign-key"),
("timestamp", lambda t: str(int(t) + 1000)),
("nonce", lambda _: "different-nonce-123456789012345"),
("prefix", lambda _: "wrong-prefix"),
("sign", lambda s: s + "tampered"),
],
)
def test_should_reject_invalid_signature_params(self, field, modifier):
"""Test signature verification rejects invalid parameters"""
sign_key = "test-sign-key"
prefix = "test-prefix"
# Generate valid signed parameters
params = UrlSigner.get_signed_url_params(sign_key, prefix)
# Prepare verification parameters
verify_params = {
"sign_key": sign_key,
"timestamp": params.timestamp,
"nonce": params.nonce,
"sign": params.sign,
"prefix": prefix,
}
# Modify the specific field
verify_params[field] = modifier(verify_params[field])
# Verify should fail
is_valid = UrlSigner.verify(**verify_params)
assert is_valid is False
@patch("configs.dify_config.SECRET_KEY", None)
def test_should_raise_error_without_secret_key(self):
"""Test that signing fails when SECRET_KEY is not configured"""
with pytest.raises(Exception) as exc_info:
UrlSigner.get_signed_url_params("key", "prefix")
assert "SECRET_KEY is not set" in str(exc_info.value)
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
def test_should_generate_unique_signatures(self):
"""Test that different inputs produce different signatures"""
params1 = UrlSigner.get_signed_url_params("key1", "prefix1")
params2 = UrlSigner.get_signed_url_params("key2", "prefix2")
# Different inputs should produce different signatures
assert params1.sign != params2.sign
assert params1.nonce != params2.nonce
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
def test_should_handle_special_characters(self):
"""Test handling of special characters in parameters"""
special_cases = [
"test with spaces",
"test/with/slashes",
"test中文字符",
]
for sign_key in special_cases:
params = UrlSigner.get_signed_url_params(sign_key, "prefix")
# Should generate valid signature and verify correctly
is_valid = UrlSigner.verify(
sign_key=sign_key, timestamp=params.timestamp, nonce=params.nonce, sign=params.sign, prefix="prefix"
)
assert is_valid is True
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
def test_should_ensure_nonce_randomness(self):
"""Test that nonce is random for each generation - critical for security"""
sign_key = "test-sign-key"
prefix = "test-prefix"
# Generate multiple nonces
nonces = set()
for _ in range(5):
params = UrlSigner.get_signed_url_params(sign_key, prefix)
nonces.add(params.nonce)
# All nonces should be unique
assert len(nonces) == 5
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
@patch("time.time", return_value=1234567890)
@patch("os.urandom", return_value=b"\xab\xcd\xef\x12\x34\x56\x78\x90\xab\xcd\xef\x12\x34\x56\x78\x90")
def test_should_produce_consistent_signatures(self, mock_urandom, mock_time):
"""Test that same inputs produce same signature - ensures deterministic behavior"""
sign_key = "test-sign-key"
prefix = "test-prefix"
# Generate signature multiple times with same inputs (time and nonce are mocked)
params1 = UrlSigner.get_signed_url_params(sign_key, prefix)
params2 = UrlSigner.get_signed_url_params(sign_key, prefix)
# With mocked time and random, should produce identical results
assert params1.timestamp == params2.timestamp
assert params1.nonce == params2.nonce
assert params1.sign == params2.sign
# Verify the signature is valid
assert UrlSigner.verify(
sign_key=sign_key, timestamp=params1.timestamp, nonce=params1.nonce, sign=params1.sign, prefix=prefix
)
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
def test_should_handle_empty_strings(self):
"""Test handling of empty string parameters - common edge case"""
# Empty sign_key and prefix should still work
params = UrlSigner.get_signed_url_params("", "")
assert params.sign is not None
# Should verify correctly
is_valid = UrlSigner.verify(
sign_key="", timestamp=params.timestamp, nonce=params.nonce, sign=params.sign, prefix=""
)
assert is_valid is True

View File

@ -0,0 +1 @@
# Unit tests for core ops module

View File

@ -0,0 +1,385 @@
import pytest
from pydantic import ValidationError
from core.ops.entities.config_entity import (
AliyunConfig,
ArizeConfig,
LangfuseConfig,
LangSmithConfig,
OpikConfig,
PhoenixConfig,
TracingProviderEnum,
WeaveConfig,
)
class TestTracingProviderEnum:
"""Test cases for TracingProviderEnum"""
def test_enum_values(self):
"""Test that all expected enum values are present"""
assert TracingProviderEnum.ARIZE == "arize"
assert TracingProviderEnum.PHOENIX == "phoenix"
assert TracingProviderEnum.LANGFUSE == "langfuse"
assert TracingProviderEnum.LANGSMITH == "langsmith"
assert TracingProviderEnum.OPIK == "opik"
assert TracingProviderEnum.WEAVE == "weave"
assert TracingProviderEnum.ALIYUN == "aliyun"
class TestArizeConfig:
"""Test cases for ArizeConfig"""
def test_valid_config(self):
"""Test valid Arize configuration"""
config = ArizeConfig(
api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com"
)
assert config.api_key == "test_key"
assert config.space_id == "test_space"
assert config.project == "test_project"
assert config.endpoint == "https://custom.arize.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = ArizeConfig()
assert config.api_key is None
assert config.space_id is None
assert config.project is None
assert config.endpoint == "https://otlp.arize.com"
def test_project_validation_empty(self):
"""Test project validation with empty value"""
config = ArizeConfig(project="")
assert config.project == "default"
def test_project_validation_none(self):
"""Test project validation with None value"""
config = ArizeConfig(project=None)
assert config.project == "default"
def test_endpoint_validation_empty(self):
"""Test endpoint validation with empty value"""
config = ArizeConfig(endpoint="")
assert config.endpoint == "https://otlp.arize.com"
def test_endpoint_validation_with_path(self):
"""Test endpoint validation normalizes URL by removing path"""
config = ArizeConfig(endpoint="https://custom.arize.com/api/v1")
assert config.endpoint == "https://custom.arize.com"
def test_endpoint_validation_invalid_scheme(self):
"""Test endpoint validation rejects invalid schemes"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
ArizeConfig(endpoint="ftp://invalid.com")
def test_endpoint_validation_no_scheme(self):
"""Test endpoint validation rejects URLs without scheme"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
ArizeConfig(endpoint="invalid.com")
class TestPhoenixConfig:
"""Test cases for PhoenixConfig"""
def test_valid_config(self):
"""Test valid Phoenix configuration"""
config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com")
assert config.api_key == "test_key"
assert config.project == "test_project"
assert config.endpoint == "https://custom.phoenix.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = PhoenixConfig()
assert config.api_key is None
assert config.project is None
assert config.endpoint == "https://app.phoenix.arize.com"
def test_project_validation_empty(self):
"""Test project validation with empty value"""
config = PhoenixConfig(project="")
assert config.project == "default"
def test_endpoint_validation_with_path(self):
"""Test endpoint validation normalizes URL by removing path"""
config = PhoenixConfig(endpoint="https://custom.phoenix.com/api/v1")
assert config.endpoint == "https://custom.phoenix.com"
class TestLangfuseConfig:
"""Test cases for LangfuseConfig"""
def test_valid_config(self):
"""Test valid Langfuse configuration"""
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com")
assert config.public_key == "public_key"
assert config.secret_key == "secret_key"
assert config.host == "https://custom.langfuse.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = LangfuseConfig(public_key="public", secret_key="secret")
assert config.host == "https://api.langfuse.com"
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangfuseConfig()
with pytest.raises(ValidationError):
LangfuseConfig(public_key="public")
with pytest.raises(ValidationError):
LangfuseConfig(secret_key="secret")
def test_host_validation_empty(self):
"""Test host validation with empty value"""
config = LangfuseConfig(public_key="public", secret_key="secret", host="")
assert config.host == "https://api.langfuse.com"
class TestLangSmithConfig:
"""Test cases for LangSmithConfig"""
def test_valid_config(self):
"""Test valid LangSmith configuration"""
config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com")
assert config.api_key == "test_key"
assert config.project == "test_project"
assert config.endpoint == "https://custom.smith.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = LangSmithConfig(api_key="key", project="project")
assert config.endpoint == "https://api.smith.langchain.com"
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangSmithConfig()
with pytest.raises(ValidationError):
LangSmithConfig(api_key="key")
with pytest.raises(ValidationError):
LangSmithConfig(project="project")
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com")
class TestOpikConfig:
"""Test cases for OpikConfig"""
def test_valid_config(self):
"""Test valid Opik configuration"""
config = OpikConfig(
api_key="test_key",
project="test_project",
workspace="test_workspace",
url="https://custom.comet.com/opik/api/",
)
assert config.api_key == "test_key"
assert config.project == "test_project"
assert config.workspace == "test_workspace"
assert config.url == "https://custom.comet.com/opik/api/"
def test_default_values(self):
"""Test default values are set correctly"""
config = OpikConfig()
assert config.api_key is None
assert config.project is None
assert config.workspace is None
assert config.url == "https://www.comet.com/opik/api/"
def test_project_validation_empty(self):
"""Test project validation with empty value"""
config = OpikConfig(project="")
assert config.project == "Default Project"
def test_url_validation_empty(self):
"""Test URL validation with empty value"""
config = OpikConfig(url="")
assert config.url == "https://www.comet.com/opik/api/"
def test_url_validation_missing_suffix(self):
"""Test URL validation requires /api/ suffix"""
with pytest.raises(ValidationError, match="URL should end with /api/"):
OpikConfig(url="https://custom.comet.com/opik/")
def test_url_validation_invalid_scheme(self):
"""Test URL validation rejects invalid schemes"""
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
OpikConfig(url="ftp://custom.comet.com/opik/api/")
class TestWeaveConfig:
"""Test cases for WeaveConfig"""
def test_valid_config(self):
"""Test valid Weave configuration"""
config = WeaveConfig(
api_key="test_key",
entity="test_entity",
project="test_project",
endpoint="https://custom.wandb.ai",
host="https://custom.host.com",
)
assert config.api_key == "test_key"
assert config.entity == "test_entity"
assert config.project == "test_project"
assert config.endpoint == "https://custom.wandb.ai"
assert config.host == "https://custom.host.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = WeaveConfig(api_key="key", project="project")
assert config.entity is None
assert config.endpoint == "https://trace.wandb.ai"
assert config.host is None
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
WeaveConfig()
with pytest.raises(ValidationError):
WeaveConfig(api_key="key")
with pytest.raises(ValidationError):
WeaveConfig(project="project")
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai")
def test_host_validation_optional(self):
"""Test host validation is optional but validates when provided"""
config = WeaveConfig(api_key="key", project="project", host=None)
assert config.host is None
config = WeaveConfig(api_key="key", project="project", host="")
assert config.host == ""
config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com")
assert config.host == "https://valid.host.com"
def test_host_validation_invalid_scheme(self):
"""Test host validation rejects invalid schemes when provided"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com")
class TestAliyunConfig:
"""Test cases for AliyunConfig"""
def test_valid_config(self):
"""Test valid Aliyun configuration"""
config = AliyunConfig(
app_name="test_app",
license_key="test_license_key",
endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com",
)
assert config.app_name == "test_app"
assert config.license_key == "test_license_key"
assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
assert config.app_name == "dify_app"
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
AliyunConfig()
with pytest.raises(ValidationError):
AliyunConfig(license_key="test_license")
with pytest.raises(ValidationError):
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
def test_app_name_validation_empty(self):
"""Test app_name validation with empty value"""
config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
)
assert config.app_name == "dify_app"
def test_endpoint_validation_empty(self):
"""Test endpoint validation with empty value"""
config = AliyunConfig(license_key="test_license", endpoint="")
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
def test_endpoint_validation_with_path(self):
"""Test endpoint validation normalizes URL by removing path"""
config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
)
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
def test_endpoint_validation_invalid_scheme(self):
"""Test endpoint validation rejects invalid schemes"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
def test_endpoint_validation_no_scheme(self):
"""Test endpoint validation rejects URLs without scheme"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
def test_license_key_required(self):
"""Test that license_key is required and cannot be empty"""
with pytest.raises(ValidationError):
AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
class TestConfigIntegration:
"""Integration tests for configuration classes"""
def test_all_configs_can_be_instantiated(self):
"""Test that all config classes can be instantiated with valid data"""
configs = [
ArizeConfig(api_key="key"),
PhoenixConfig(api_key="key"),
LangfuseConfig(public_key="public", secret_key="secret"),
LangSmithConfig(api_key="key", project="project"),
OpikConfig(api_key="key"),
WeaveConfig(api_key="key", project="project"),
AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com"),
]
for config in configs:
assert config is not None
def test_url_normalization_consistency(self):
"""Test that URL normalization works consistently across configs"""
# Test that paths are removed from endpoints
arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test")
phoenix_config = PhoenixConfig(endpoint="https://phoenix.com/api/v2/")
aliyun_config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
)
assert arize_config.endpoint == "https://arize.com"
assert phoenix_config.endpoint == "https://phoenix.com"
assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
def test_project_default_values(self):
"""Test that project default values are set correctly"""
arize_config = ArizeConfig(project="")
phoenix_config = PhoenixConfig(project="")
opik_config = OpikConfig(project="")
aliyun_config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
)
assert arize_config.project == "default"
assert phoenix_config.project == "default"
assert opik_config.project == "Default Project"
assert aliyun_config.app_name == "dify_app"

View File

@ -0,0 +1,138 @@
import pytest
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
class TestValidateUrl:
"""Test cases for validate_url function"""
def test_valid_https_url(self):
"""Test valid HTTPS URL"""
result = validate_url("https://example.com", "https://default.com")
assert result == "https://example.com"
def test_valid_http_url(self):
"""Test valid HTTP URL"""
result = validate_url("http://example.com", "https://default.com")
assert result == "http://example.com"
def test_url_with_path_removed(self):
"""Test that URL path is removed during normalization"""
result = validate_url("https://example.com/api/v1/test", "https://default.com")
assert result == "https://example.com"
def test_url_with_query_removed(self):
"""Test that URL query parameters are removed"""
result = validate_url("https://example.com?param=value", "https://default.com")
assert result == "https://example.com"
def test_url_with_fragment_removed(self):
"""Test that URL fragments are removed"""
result = validate_url("https://example.com#section", "https://default.com")
assert result == "https://example.com"
def test_empty_url_returns_default(self):
"""Test empty URL returns default"""
result = validate_url("", "https://default.com")
assert result == "https://default.com"
def test_none_url_returns_default(self):
"""Test None URL returns default"""
result = validate_url(None, "https://default.com")
assert result == "https://default.com"
def test_whitespace_url_returns_default(self):
"""Test whitespace URL returns default"""
result = validate_url(" ", "https://default.com")
assert result == "https://default.com"
def test_invalid_scheme_raises_error(self):
"""Test invalid scheme raises ValueError"""
with pytest.raises(ValueError, match="URL scheme must be one of"):
validate_url("ftp://example.com", "https://default.com")
def test_no_scheme_raises_error(self):
"""Test URL without scheme raises ValueError"""
with pytest.raises(ValueError, match="URL scheme must be one of"):
validate_url("example.com", "https://default.com")
def test_custom_allowed_schemes(self):
"""Test custom allowed schemes"""
result = validate_url("https://example.com", "https://default.com", allowed_schemes=("https",))
assert result == "https://example.com"
with pytest.raises(ValueError, match="URL scheme must be one of"):
validate_url("http://example.com", "https://default.com", allowed_schemes=("https",))
class TestValidateUrlWithPath:
"""Test cases for validate_url_with_path function"""
def test_valid_url_with_path(self):
"""Test valid URL with path"""
result = validate_url_with_path("https://example.com/api/v1", "https://default.com")
assert result == "https://example.com/api/v1"
def test_valid_url_with_required_suffix(self):
"""Test valid URL with required suffix"""
result = validate_url_with_path("https://example.com/api/", "https://default.com", required_suffix="/api/")
assert result == "https://example.com/api/"
def test_url_without_required_suffix_raises_error(self):
"""Test URL without required suffix raises error"""
with pytest.raises(ValueError, match="URL should end with /api/"):
validate_url_with_path("https://example.com/api", "https://default.com", required_suffix="/api/")
def test_empty_url_returns_default(self):
"""Test empty URL returns default"""
result = validate_url_with_path("", "https://default.com")
assert result == "https://default.com"
def test_none_url_returns_default(self):
"""Test None URL returns default"""
result = validate_url_with_path(None, "https://default.com")
assert result == "https://default.com"
def test_invalid_scheme_raises_error(self):
"""Test invalid scheme raises ValueError"""
with pytest.raises(ValueError, match="URL must start with https:// or http://"):
validate_url_with_path("ftp://example.com", "https://default.com")
def test_no_scheme_raises_error(self):
"""Test URL without scheme raises ValueError"""
with pytest.raises(ValueError, match="URL must start with https:// or http://"):
validate_url_with_path("example.com", "https://default.com")
class TestValidateProjectName:
"""Test cases for validate_project_name function"""
def test_valid_project_name(self):
"""Test valid project name"""
result = validate_project_name("my-project", "default")
assert result == "my-project"
def test_empty_project_name_returns_default(self):
"""Test empty project name returns default"""
result = validate_project_name("", "default")
assert result == "default"
def test_none_project_name_returns_default(self):
"""Test None project name returns default"""
result = validate_project_name(None, "default")
assert result == "default"
def test_whitespace_project_name_returns_default(self):
"""Test whitespace project name returns default"""
result = validate_project_name(" ", "default")
assert result == "default"
def test_project_name_with_whitespace_trimmed(self):
"""Test project name with whitespace is trimmed"""
result = validate_project_name(" my-project ", "default")
assert result == "my-project"
def test_custom_default_name(self):
"""Test custom default name"""
result = validate_project_name("", "Custom Default")
assert result == "Custom Default"

View File

@ -1,3 +1,4 @@
import time
from unittest.mock import patch
import pytest
@ -19,6 +20,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.code.code_node import CodeNode
@ -172,6 +174,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
@ -183,7 +186,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
@ -299,6 +302,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
user_inputs={},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
@ -310,7 +314,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
@ -479,6 +483,7 @@ def test_run_branch(mock_close, mock_remove):
user_inputs={"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
@ -490,7 +495,7 @@ def test_run_branch(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
@ -813,6 +818,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
@ -824,7 +830,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)

View File

@ -1,7 +1,9 @@
import time
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
@ -11,6 +13,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunStreamChunkEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
@ -163,15 +166,16 @@ class ContinueOnErrorTestHelper:
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
"""Helper method to create a graph engine instance for testing"""
graph = Graph.init(graph_config=graph_config)
variable_pool = {
"system_variables": {
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "clear",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
"user_inputs": user_inputs or {"uid": "takato"},
}
user_inputs=user_inputs or {"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
return GraphEngine(
tenant_id="111",
@ -184,7 +188,7 @@ class ContinueOnErrorTestHelper:
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)

View File

@ -0,0 +1,74 @@
import base64
import binascii
import os
import pytest
from libs.password import compare_password, hash_password, valid_password
class TestValidPassword:
"""Test password format validation"""
def test_should_accept_valid_passwords(self):
"""Test accepting valid password formats"""
assert valid_password("password123") == "password123"
assert valid_password("test1234") == "test1234"
assert valid_password("Test123456") == "Test123456"
def test_should_reject_invalid_passwords(self):
"""Test rejecting invalid password formats"""
# Too short
with pytest.raises(ValueError) as exc_info:
valid_password("abc123")
assert "Password must contain letters and numbers" in str(exc_info.value)
# No numbers
with pytest.raises(ValueError):
valid_password("abcdefgh")
# No letters
with pytest.raises(ValueError):
valid_password("12345678")
# Empty
with pytest.raises(ValueError):
valid_password("")
class TestPasswordHashing:
"""Test password hashing and comparison"""
def setup_method(self):
"""Setup test data"""
self.password = "test123password"
self.salt = os.urandom(16)
self.salt_base64 = base64.b64encode(self.salt).decode()
password_hash = hash_password(self.password, self.salt)
self.password_hash_base64 = base64.b64encode(password_hash).decode()
def test_should_verify_correct_password(self):
"""Test correct password verification"""
result = compare_password(self.password, self.password_hash_base64, self.salt_base64)
assert result is True
def test_should_reject_wrong_password(self):
"""Test rejection of incorrect passwords"""
result = compare_password("wrongpassword", self.password_hash_base64, self.salt_base64)
assert result is False
def test_should_handle_invalid_base64(self):
"""Test handling of invalid base64 data"""
# Invalid base64 hash
with pytest.raises(binascii.Error):
compare_password(self.password, "invalid_base64!", self.salt_base64)
# Invalid base64 salt
with pytest.raises(binascii.Error):
compare_password(self.password, self.password_hash_base64, "invalid_base64!")
def test_should_be_case_sensitive(self):
"""Test password case sensitivity"""
result = compare_password(self.password.upper(), self.password_hash_base64, self.salt_base64)
assert result is False

View File

@ -6,12 +6,11 @@ from unittest.mock import Mock, patch
import pytest
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables.types import SegmentType
from core.variables import StringSegment
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes import NodeType
from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
from services.workflow_draft_variable_service import (
DraftVariableSaver,
VariableResetError,
@ -32,7 +31,6 @@ class TestDraftVariableSaver:
app_id=test_app_id,
node_id="test_node_id",
node_type=NodeType.START,
invoke_from=InvokeFrom.DEBUGGER,
node_execution_id="test_execution_id",
)
assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False
@ -79,7 +77,6 @@ class TestDraftVariableSaver:
app_id=test_app_id,
node_id=_NODE_ID,
node_type=NodeType.START,
invoke_from=InvokeFrom.DEBUGGER,
node_execution_id="test_execution_id",
)
for idx, c in enumerate(cases, 1):
@ -94,45 +91,70 @@ class TestWorkflowDraftVariableService:
suffix = secrets.token_hex(6)
return f"test_app_id_{suffix}"
def _create_test_workflow(self, app_id: str) -> Workflow:
"""Create a real Workflow instance for testing"""
return Workflow.new(
tenant_id="test_tenant_id",
app_id=app_id,
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features="{}",
created_by="test_user_id",
environment_variables=[],
conversation_variables=[],
)
def test_reset_conversation_variable(self):
"""Test resetting a conversation variable"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_workflow = Mock(spec=Workflow)
mock_workflow.app_id = self._get_test_app_id()
# Create mock variable
mock_variable = Mock(spec=WorkflowDraftVariable)
mock_variable.get_variable_type.return_value = DraftVariableType.CONVERSATION
mock_variable.id = "var-id"
mock_variable.name = "test_var"
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create real conversation variable
test_value = StringSegment(value="test_value")
variable = WorkflowDraftVariable.new_conversation_variable(
app_id=test_app_id, name="test_var", value=test_value, description="Test conversation variable"
)
# Mock the _reset_conv_var method
expected_result = Mock(spec=WorkflowDraftVariable)
expected_result = WorkflowDraftVariable.new_conversation_variable(
app_id=test_app_id,
name="test_var",
value=StringSegment(value="reset_value"),
)
with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv:
result = service.reset_variable(mock_workflow, mock_variable)
result = service.reset_variable(workflow, variable)
mock_reset_conv.assert_called_once_with(mock_workflow, mock_variable)
mock_reset_conv.assert_called_once_with(workflow, variable)
assert result == expected_result
def test_reset_node_variable_with_no_execution_id(self):
"""Test resetting a node variable with no execution ID - should delete variable"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_workflow = Mock(spec=Workflow)
mock_workflow.app_id = self._get_test_app_id()
# Create mock variable with no execution ID
mock_variable = Mock(spec=WorkflowDraftVariable)
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
mock_variable.node_execution_id = None
mock_variable.id = "var-id"
mock_variable.name = "test_var"
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
result = service._reset_node_var(mock_workflow, mock_variable)
# Create real node variable with no execution ID
test_value = StringSegment(value="test_value")
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id,
node_id="test_node_id",
name="test_var",
value=test_value,
node_execution_id="exec-id", # Set initially
)
# Manually set to None to simulate the test condition
variable.node_execution_id = None
result = service._reset_node_var_or_sys_var(workflow, variable)
# Should delete the variable and return None
mock_session.delete.assert_called_once_with(instance=mock_variable)
mock_session.delete.assert_called_once_with(instance=variable)
mock_session.flush.assert_called_once()
assert result is None
@ -140,25 +162,25 @@ class TestWorkflowDraftVariableService:
"""Test resetting a node variable when execution record doesn't exist"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_workflow = Mock(spec=Workflow)
mock_workflow.app_id = self._get_test_app_id()
# Create mock variable with execution ID
mock_variable = Mock(spec=WorkflowDraftVariable)
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
mock_variable.node_execution_id = "exec-id"
mock_variable.id = "var-id"
mock_variable.name = "test_var"
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create real node variable with execution ID
test_value = StringSegment(value="test_value")
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
# Mock session.scalars to return None (no execution record found)
mock_scalars = Mock()
mock_scalars.first.return_value = None
mock_session.scalars.return_value = mock_scalars
result = service._reset_node_var(mock_workflow, mock_variable)
result = service._reset_node_var_or_sys_var(workflow, variable)
# Should delete the variable and return None
mock_session.delete.assert_called_once_with(instance=mock_variable)
mock_session.delete.assert_called_once_with(instance=variable)
mock_session.flush.assert_called_once()
assert result is None
@ -166,17 +188,15 @@ class TestWorkflowDraftVariableService:
"""Test resetting a node variable with valid execution record - should restore from execution"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_workflow = Mock(spec=Workflow)
mock_workflow.app_id = self._get_test_app_id()
# Create mock variable with execution ID
mock_variable = Mock(spec=WorkflowDraftVariable)
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
mock_variable.node_execution_id = "exec-id"
mock_variable.id = "var-id"
mock_variable.name = "test_var"
mock_variable.node_id = "node-id"
mock_variable.value_type = SegmentType.STRING
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create real node variable with execution ID
test_value = StringSegment(value="original_value")
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
@ -190,33 +210,164 @@ class TestWorkflowDraftVariableService:
# Mock workflow methods
mock_node_config = {"type": "test_node"}
mock_workflow.get_node_config_by_id.return_value = mock_node_config
mock_workflow.get_node_type_from_node_config.return_value = NodeType.LLM
with (
patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config),
patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM),
):
result = service._reset_node_var_or_sys_var(workflow, variable)
result = service._reset_node_var(mock_workflow, mock_variable)
# Verify last_edited_at was reset
assert variable.last_edited_at is None
# Verify session.flush was called
mock_session.flush.assert_called()
# Verify variable.set_value was called with the correct value
mock_variable.set_value.assert_called_once()
# Verify last_edited_at was reset
assert mock_variable.last_edited_at is None
# Verify session.flush was called
mock_session.flush.assert_called()
# Should return the updated variable
assert result == variable
# Should return the updated variable
assert result == mock_variable
def test_reset_system_variable_raises_error(self):
"""Test that resetting a system variable raises an error"""
def test_reset_non_editable_system_variable_raises_error(self):
"""Test that resetting a non-editable system variable raises an error"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_workflow = Mock(spec=Workflow)
mock_workflow.app_id = self._get_test_app_id()
mock_variable = Mock(spec=WorkflowDraftVariable)
mock_variable.get_variable_type.return_value = DraftVariableType.SYS # Not a valid enum value for this test
mock_variable.id = "var-id"
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
with pytest.raises(VariableResetError) as exc_info:
service.reset_variable(mock_workflow, mock_variable)
assert "cannot reset system variable" in str(exc_info.value)
assert "variable_id=var-id" in str(exc_info.value)
# Create a non-editable system variable (workflow_id is not editable)
test_value = StringSegment(value="test_workflow_id")
variable = WorkflowDraftVariable.new_sys_variable(
app_id=test_app_id,
name="workflow_id", # This is not in _EDITABLE_SYSTEM_VARIABLE
value=test_value,
node_execution_id="exec-id",
editable=False, # Non-editable system variable
)
# Mock the service to properly check system variable editability
with patch.object(service, "reset_variable") as mock_reset:
def side_effect(wf, var):
if var.get_variable_type() == DraftVariableType.SYS and not is_system_variable_editable(var.name):
raise VariableResetError(f"cannot reset system variable, variable_id={var.id}")
return var
mock_reset.side_effect = side_effect
with pytest.raises(VariableResetError) as exc_info:
service.reset_variable(workflow, variable)
assert "cannot reset system variable" in str(exc_info.value)
assert f"variable_id={variable.id}" in str(exc_info.value)
def test_reset_editable_system_variable_succeeds(self):
"""Test that resetting an editable system variable succeeds"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create an editable system variable (files is editable)
test_value = StringSegment(value="[]")
variable = WorkflowDraftVariable.new_sys_variable(
app_id=test_app_id,
name="files", # This is in _EDITABLE_SYSTEM_VARIABLE
value=test_value,
node_execution_id="exec-id",
editable=True, # Editable system variable
)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.outputs_dict = {"sys.files": "[]"}
# Mock session.scalars to return the execution record
mock_scalars = Mock()
mock_scalars.first.return_value = mock_execution
mock_session.scalars.return_value = mock_scalars
result = service._reset_node_var_or_sys_var(workflow, variable)
# Should succeed and return the variable
assert result == variable
assert variable.last_edited_at is None
mock_session.flush.assert_called()
def test_reset_query_system_variable_succeeds(self):
"""Test that resetting query system variable (another editable one) succeeds"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create an editable system variable (query is editable)
test_value = StringSegment(value="original query")
variable = WorkflowDraftVariable.new_sys_variable(
app_id=test_app_id,
name="query", # This is in _EDITABLE_SYSTEM_VARIABLE
value=test_value,
node_execution_id="exec-id",
editable=True, # Editable system variable
)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.outputs_dict = {"sys.query": "reset query"}
# Mock session.scalars to return the execution record
mock_scalars = Mock()
mock_scalars.first.return_value = mock_execution
mock_session.scalars.return_value = mock_scalars
result = service._reset_node_var_or_sys_var(workflow, variable)
# Should succeed and return the variable
assert result == variable
assert variable.last_edited_at is None
mock_session.flush.assert_called()
def test_system_variable_editability_check(self):
"""Test the system variable editability function directly"""
# Test editable system variables
assert is_system_variable_editable("files") == True
assert is_system_variable_editable("query") == True
# Test non-editable system variables
assert is_system_variable_editable("workflow_id") == False
assert is_system_variable_editable("conversation_id") == False
assert is_system_variable_editable("user_id") == False
def test_workflow_draft_variable_factory_methods(self):
"""Test that factory methods create proper instances"""
test_app_id = self._get_test_app_id()
test_value = StringSegment(value="test_value")
# Test conversation variable factory
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=test_app_id, name="conv_var", value=test_value, description="Test conversation variable"
)
assert conv_var.get_variable_type() == DraftVariableType.CONVERSATION
assert conv_var.editable == True
assert conv_var.node_execution_id is None
# Test system variable factory
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=test_app_id, name="workflow_id", value=test_value, node_execution_id="exec-id", editable=False
)
assert sys_var.get_variable_type() == DraftVariableType.SYS
assert sys_var.editable == False
assert sys_var.node_execution_id == "exec-id"
# Test node variable factory
node_var = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id,
node_id="node-id",
name="node_var",
value=test_value,
node_execution_id="exec-id",
visible=True,
editable=True,
)
assert node_var.get_variable_type() == DraftVariableType.NODE
assert node_var.visible == True
assert node_var.editable == True
assert node_var.node_execution_id == "exec-id"

View File

@ -0,0 +1,465 @@
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
LLMUsage,
)
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage:
"""Create a mock LLMUsage with all required fields"""
return LLMUsage(
prompt_tokens=prompt_tokens,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1"),
prompt_price=Decimal(str(prompt_tokens)) * Decimal("0.001"),
completion_tokens=completion_tokens,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1"),
completion_price=Decimal(str(completion_tokens)) * Decimal("0.002"),
total_tokens=prompt_tokens + completion_tokens,
total_price=Decimal(str(prompt_tokens)) * Decimal("0.001") + Decimal(str(completion_tokens)) * Decimal("0.002"),
currency="USD",
latency=1.5,
)
def get_model_entity(provider: str, model_name: str, support_structure_output: bool = False) -> AIModelEntity:
"""Create a mock AIModelEntity for testing"""
model_schema = MagicMock()
model_schema.model = model_name
model_schema.provider = provider
model_schema.model_type = ModelType.LLM
model_schema.model_provider = provider
model_schema.model_name = model_name
model_schema.support_structure_output = support_structure_output
model_schema.parameter_rules = []
return model_schema
def get_model_instance() -> MagicMock:
"""Create a mock ModelInstance for testing"""
mock_instance = MagicMock()
mock_instance.provider = "openai"
mock_instance.credentials = {}
return mock_instance
def test_structured_output_parser():
"""Test cases for invoke_llm_with_structured_output function"""
testcases = [
# Test case 1: Model with native structured output support, non-streaming
{
"name": "native_structured_output_non_streaming",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
"stream": False,
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
"expected_llm_response": LLMResult(
model="gpt-4o",
message=AssistantPromptMessage(content='{"name": "test"}'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
),
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
# Test case 2: Model with native structured output support, streaming
{
"name": "native_structured_output_streaming",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
"stream": True,
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
"expected_llm_response": [
LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content='{"name":'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
),
),
LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=' "test"}'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
),
),
],
"expected_result_type": "generator",
"should_raise": False,
},
# Test case 3: Model without native structured output support, non-streaming
{
"name": "prompt_based_structured_output_non_streaming",
"provider": "anthropic",
"model_name": "claude-3-sonnet",
"support_structure_output": False,
"stream": False,
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
"expected_llm_response": LLMResult(
model="claude-3-sonnet",
message=AssistantPromptMessage(content='{"answer": "test response"}'),
usage=create_mock_usage(prompt_tokens=15, completion_tokens=8),
),
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
# Test case 4: Model without native structured output support, streaming
{
"name": "prompt_based_structured_output_streaming",
"provider": "anthropic",
"model_name": "claude-3-sonnet",
"support_structure_output": False,
"stream": True,
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
"expected_llm_response": [
LLMResultChunk(
model="claude-3-sonnet",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content='{"answer": "test'),
usage=create_mock_usage(prompt_tokens=15, completion_tokens=3),
),
),
LLMResultChunk(
model="claude-3-sonnet",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=' response"}'),
usage=create_mock_usage(prompt_tokens=15, completion_tokens=5),
),
),
],
"expected_result_type": "generator",
"should_raise": False,
},
# Test case 5: Streaming with list content
{
"name": "streaming_with_list_content",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
"stream": True,
"json_schema": {"type": "object", "properties": {"data": {"type": "string"}}},
"expected_llm_response": [
LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=[
TextPromptMessageContent(data='{"data":'),
]
),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
),
),
LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=[
TextPromptMessageContent(data=' "value"}'),
]
),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
),
),
],
"expected_result_type": "generator",
"should_raise": False,
},
# Test case 6: Error case - non-string LLM response content (non-streaming)
{
"name": "error_non_string_content_non_streaming",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
"stream": False,
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
"expected_llm_response": LLMResult(
model="gpt-4o",
message=AssistantPromptMessage(content=None), # Non-string content
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
),
"expected_result_type": None,
"should_raise": True,
"expected_error": OutputParserError,
},
# Test case 7: JSON repair scenario
{
"name": "json_repair_scenario",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
"stream": False,
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
"expected_llm_response": LLMResult(
model="gpt-4o",
message=AssistantPromptMessage(content='{"name": "test"'), # Invalid JSON - missing closing brace
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
),
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
# Test case 8: Model with parameter rules for response format
{
"name": "model_with_parameter_rules",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
"stream": False,
"json_schema": {"type": "object", "properties": {"result": {"type": "string"}}},
"parameter_rules": [
MagicMock(name="response_format", options=["json_schema"], required=False),
],
"expected_llm_response": LLMResult(
model="gpt-4o",
message=AssistantPromptMessage(content='{"result": "success"}'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
),
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
# Test case 9: Model without native support but with JSON response format rules
{
"name": "non_native_with_json_rules",
"provider": "anthropic",
"model_name": "claude-3-sonnet",
"support_structure_output": False,
"stream": False,
"json_schema": {"type": "object", "properties": {"output": {"type": "string"}}},
"parameter_rules": [
MagicMock(name="response_format", options=["JSON"], required=False),
],
"expected_llm_response": LLMResult(
model="claude-3-sonnet",
message=AssistantPromptMessage(content='{"output": "result"}'),
usage=create_mock_usage(prompt_tokens=15, completion_tokens=8),
),
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
]
for case in testcases:
print(f"Running test case: {case['name']}")
# Setup model entity
model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"])
# Add parameter rules if specified
if "parameter_rules" in case:
model_schema.parameter_rules = case["parameter_rules"]
# Setup model instance
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = case["expected_llm_response"]
# Setup prompt messages
prompt_messages = [
SystemPromptMessage(content="You are a helpful assistant."),
UserPromptMessage(content="Generate a response according to the schema."),
]
if case["should_raise"]:
# Test error cases
with pytest.raises(case["expected_error"]): # noqa: PT012
if case["stream"]:
result_generator = invoke_llm_with_structured_output(
provider=case["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=case["json_schema"],
stream=case["stream"],
)
# Consume the generator to trigger the error
list(result_generator)
else:
invoke_llm_with_structured_output(
provider=case["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=case["json_schema"],
stream=case["stream"],
)
else:
# Test successful cases
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
# Configure json_repair mock for cases that need it
if case["name"] == "json_repair_scenario":
mock_json_repair.return_value = {"name": "test"}
result = invoke_llm_with_structured_output(
provider=case["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=case["json_schema"],
stream=case["stream"],
model_parameters={"temperature": 0.7, "max_tokens": 100},
user="test_user",
)
if case["expected_result_type"] == "generator":
# Test streaming results
assert hasattr(result, "__iter__")
chunks = list(result)
assert len(chunks) > 0
# Verify all chunks are LLMResultChunkWithStructuredOutput
for chunk in chunks[:-1]: # All except last
assert isinstance(chunk, LLMResultChunkWithStructuredOutput)
assert chunk.model == case["model_name"]
# Last chunk should have structured output
last_chunk = chunks[-1]
assert isinstance(last_chunk, LLMResultChunkWithStructuredOutput)
assert last_chunk.structured_output is not None
assert isinstance(last_chunk.structured_output, dict)
else:
# Test non-streaming results
assert isinstance(result, case["expected_result_type"])
assert result.model == case["model_name"]
assert result.structured_output is not None
assert isinstance(result.structured_output, dict)
# Verify model_instance.invoke_llm was called with correct parameters
model_instance.invoke_llm.assert_called_once()
call_args = model_instance.invoke_llm.call_args
assert call_args.kwargs["stream"] == case["stream"]
assert call_args.kwargs["user"] == "test_user"
assert "temperature" in call_args.kwargs["model_parameters"]
assert "max_tokens" in call_args.kwargs["model_parameters"]
def test_parse_structured_output_edge_cases():
"""Test edge cases for structured output parsing"""
# Test case with list that contains dict (reasoning model scenario)
testcase_list_with_dict = {
"name": "list_with_dict_parsing",
"provider": "deepseek",
"model_name": "deepseek-r1",
"support_structure_output": False,
"stream": False,
"json_schema": {"type": "object", "properties": {"thought": {"type": "string"}}},
"expected_llm_response": LLMResult(
model="deepseek-r1",
message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
),
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
}
# Setup for list parsing test
model_schema = get_model_entity(
testcase_list_with_dict["provider"],
testcase_list_with_dict["model_name"],
testcase_list_with_dict["support_structure_output"],
)
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = testcase_list_with_dict["expected_llm_response"]
prompt_messages = [UserPromptMessage(content="Test reasoning")]
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
# Mock json_repair to return a list with dict
mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"]
result = invoke_llm_with_structured_output(
provider=testcase_list_with_dict["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=testcase_list_with_dict["json_schema"],
stream=testcase_list_with_dict["stream"],
)
assert isinstance(result, LLMResultWithStructuredOutput)
assert result.structured_output == {"thought": "reasoning process"}
def test_model_specific_schema_preparation():
"""Test schema preparation for different model types"""
# Test Gemini model
gemini_case = {
"provider": "google",
"model_name": "gemini-pro",
"support_structure_output": True,
"stream": False,
"json_schema": {"type": "object", "properties": {"result": {"type": "boolean"}}, "additionalProperties": False},
}
model_schema = get_model_entity(
gemini_case["provider"], gemini_case["model_name"], gemini_case["support_structure_output"]
)
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = LLMResult(
model="gemini-pro",
message=AssistantPromptMessage(content='{"result": "true"}'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
)
prompt_messages = [UserPromptMessage(content="Test")]
result = invoke_llm_with_structured_output(
provider=gemini_case["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=gemini_case["json_schema"],
stream=gemini_case["stream"],
)
assert isinstance(result, LLMResultWithStructuredOutput)
# Verify model_instance.invoke_llm was called and check the schema preparation
model_instance.invoke_llm.assert_called_once()
call_args = model_instance.invoke_llm.call_args
# For Gemini, the schema should not have additionalProperties and boolean should be converted to string
assert "json_schema" in call_args.kwargs["model_parameters"]