mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
feat: implement file structured output
This commit is contained in:
@ -1,269 +1,120 @@
|
||||
"""
|
||||
Unit tests for file reference detection and conversion.
|
||||
Unit tests for sandbox file path detection and conversion.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.llm_generator.output_parser.file_ref import (
|
||||
FILE_REF_FORMAT,
|
||||
convert_file_refs_in_output,
|
||||
detect_file_ref_fields,
|
||||
is_file_ref_property,
|
||||
FILE_PATH_DESCRIPTION_SUFFIX,
|
||||
FILE_PATH_FORMAT,
|
||||
adapt_schema_for_sandbox_file_paths,
|
||||
convert_sandbox_file_paths_in_output,
|
||||
detect_file_path_fields,
|
||||
is_file_path_property,
|
||||
)
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
|
||||
class TestIsFileRefProperty:
|
||||
"""Tests for is_file_ref_property function."""
|
||||
def _build_file(file_id: str) -> File:
|
||||
return File(
|
||||
id=file_id,
|
||||
tenant_id="tenant_123",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
filename="test.png",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=128,
|
||||
related_id=file_id,
|
||||
storage_key="sandbox/path",
|
||||
)
|
||||
|
||||
def test_valid_file_ref(self):
|
||||
schema = {"type": "string", "format": FILE_REF_FORMAT}
|
||||
assert is_file_ref_property(schema) is True
|
||||
|
||||
class TestIsFilePathProperty:
|
||||
def test_valid_file_path_format(self):
|
||||
schema = {"type": "string", "format": FILE_PATH_FORMAT}
|
||||
assert is_file_path_property(schema) is True
|
||||
|
||||
def test_accepts_snake_case_format(self):
|
||||
schema = {"type": "string", "format": "file_path"}
|
||||
assert is_file_path_property(schema) is True
|
||||
|
||||
def test_invalid_type(self):
|
||||
schema = {"type": "number", "format": FILE_REF_FORMAT}
|
||||
assert is_file_ref_property(schema) is False
|
||||
schema = {"type": "number", "format": FILE_PATH_FORMAT}
|
||||
assert is_file_path_property(schema) is False
|
||||
|
||||
def test_missing_format(self):
|
||||
schema = {"type": "string"}
|
||||
assert is_file_ref_property(schema) is False
|
||||
assert is_file_path_property(schema) is False
|
||||
|
||||
def test_wrong_format(self):
|
||||
schema = {"type": "string", "format": "uuid"}
|
||||
assert is_file_ref_property(schema) is False
|
||||
assert is_file_path_property(schema) is False
|
||||
|
||||
|
||||
class TestDetectFileRefFields:
|
||||
"""Tests for detect_file_ref_fields function."""
|
||||
|
||||
def test_simple_file_ref(self):
|
||||
class TestDetectFilePathFields:
|
||||
def test_detects_nested_file_paths(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"image": {"type": "string", "format": FILE_PATH_FORMAT},
|
||||
"files": {"type": "array", "items": {"type": "string", "format": FILE_PATH_FORMAT}},
|
||||
"meta": {"type": "object", "properties": {"doc": {"type": "string", "format": FILE_PATH_FORMAT}}},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["image"]
|
||||
|
||||
def test_multiple_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"document": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert set(paths) == {"image", "document"}
|
||||
|
||||
def test_array_of_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["files[*]"]
|
||||
|
||||
def test_nested_file_ref(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["data.image"]
|
||||
|
||||
def test_no_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == []
|
||||
assert set(detect_file_path_fields(schema)) == {"image", "files[*]", "meta.doc"}
|
||||
|
||||
def test_empty_schema(self):
|
||||
schema = {}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == []
|
||||
assert detect_file_path_fields({}) == []
|
||||
|
||||
def test_mixed_schema(self):
|
||||
|
||||
class TestAdaptSchemaForSandboxFilePaths:
|
||||
def test_appends_description(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"documents": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert set(paths) == {"image", "documents[*]"}
|
||||
|
||||
|
||||
class TestConvertFileRefsInOutput:
|
||||
"""Tests for convert_file_refs_in_output function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file(self):
|
||||
"""Create a mock File object with all required attributes."""
|
||||
file = MagicMock(spec=File)
|
||||
file.type = FileType.IMAGE
|
||||
file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
file.related_id = "test-related-id"
|
||||
file.remote_url = None
|
||||
file.tenant_id = "tenant_123"
|
||||
file.id = None
|
||||
file.filename = "test.png"
|
||||
file.extension = ".png"
|
||||
file.mime_type = "image/png"
|
||||
file.size = 1024
|
||||
file.dify_model_identity = "__dify__file__"
|
||||
return file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_build_from_mapping(self, mock_file):
|
||||
"""Mock the build_from_mapping function."""
|
||||
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
|
||||
mock.return_value = mock_file
|
||||
yield mock
|
||||
|
||||
def test_convert_simple_file_ref(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"image": file_id}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"image": {"type": "string", "format": FILE_PATH_FORMAT, "description": "Pick a file"},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
adapted, fields = adapt_schema_for_sandbox_file_paths(schema)
|
||||
|
||||
# Result should be wrapped in FileSegment
|
||||
assert isinstance(result["image"], FileSegment)
|
||||
assert result["image"].value == mock_file
|
||||
mock_build_from_mapping.assert_called_once_with(
|
||||
mapping={"transfer_method": "tool_file", "tool_file_id": file_id},
|
||||
tenant_id="tenant_123",
|
||||
)
|
||||
assert set(fields) == {"image"}
|
||||
adapted_image = adapted["properties"]["image"]
|
||||
assert adapted_image["type"] == "string"
|
||||
assert adapted_image["format"] == FILE_PATH_FORMAT
|
||||
assert FILE_PATH_DESCRIPTION_SUFFIX in adapted_image["description"]
|
||||
|
||||
def test_convert_array_of_file_refs(self, mock_build_from_mapping, mock_file):
|
||||
file_id1 = str(uuid.uuid4())
|
||||
file_id2 = str(uuid.uuid4())
|
||||
output = {"files": [file_id1, file_id2]}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
|
||||
class TestConvertSandboxFilePaths:
|
||||
def test_convert_sandbox_file_paths(self):
|
||||
output = {
|
||||
"image": "a.png",
|
||||
"files": ["b.png", "c.png"],
|
||||
"meta": {"doc": "d.pdf"},
|
||||
"name": "demo",
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
def resolver(path: str) -> File:
|
||||
return _build_file(path)
|
||||
|
||||
# Result should be wrapped in ArrayFileSegment
|
||||
assert isinstance(result["files"], ArrayFileSegment)
|
||||
assert list(result["files"].value) == [mock_file, mock_file]
|
||||
assert mock_build_from_mapping.call_count == 2
|
||||
converted, files = convert_sandbox_file_paths_in_output(output, ["image", "files[*]", "meta.doc"], resolver)
|
||||
|
||||
def test_no_conversion_without_file_refs(self):
|
||||
output = {"name": "test", "count": 5}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
assert isinstance(converted["image"], FileSegment)
|
||||
assert isinstance(converted["files"], ArrayFileSegment)
|
||||
assert isinstance(converted["meta"]["doc"], FileSegment)
|
||||
assert converted["name"] == "demo"
|
||||
assert [file.id for file in files] == ["a.png", "b.png", "c.png", "d.pdf"]
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
def test_invalid_path_value_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
convert_sandbox_file_paths_in_output({"image": 123}, ["image"], _build_file)
|
||||
|
||||
assert result == {"name": "test", "count": 5}
|
||||
def test_no_file_paths_returns_output(self):
|
||||
output = {"name": "demo"}
|
||||
converted, files = convert_sandbox_file_paths_in_output(output, [], _build_file)
|
||||
|
||||
def test_invalid_uuid_returns_none(self):
|
||||
output = {"image": "not-a-valid-uuid"}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["image"] is None
|
||||
|
||||
def test_file_not_found_returns_none(self):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"image": file_id}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
|
||||
mock.side_effect = ValueError("File not found")
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["image"] is None
|
||||
|
||||
def test_preserves_non_file_fields(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"query": "search term", "image": file_id, "count": 10}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["query"] == "search term"
|
||||
assert isinstance(result["image"], FileSegment)
|
||||
assert result["image"].value == mock_file
|
||||
assert result["count"] == 10
|
||||
|
||||
def test_does_not_modify_original_output(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
original = {"image": file_id}
|
||||
output = dict(original)
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Original should still contain the string ID
|
||||
assert original["image"] == file_id
|
||||
assert converted == output
|
||||
assert files == []
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -6,23 +7,18 @@ from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import (
|
||||
_get_default_value_for_type,
|
||||
fill_defaults_from_schema,
|
||||
invoke_llm_with_pydantic_model,
|
||||
invoke_llm_with_structured_output,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
@ -57,7 +53,7 @@ def get_model_entity(provider: str, model_name: str, support_structure_output: b
|
||||
model_schema.support_structure_output = support_structure_output
|
||||
model_schema.parameter_rules = []
|
||||
|
||||
return model_schema
|
||||
return cast(AIModelEntity, model_schema)
|
||||
|
||||
|
||||
def get_model_instance() -> MagicMock:
|
||||
@ -71,7 +67,7 @@ def get_model_instance() -> MagicMock:
|
||||
def test_structured_output_parser():
|
||||
"""Test cases for invoke_llm_with_structured_output function"""
|
||||
|
||||
testcases = [
|
||||
testcases: list[dict[str, Any]] = [
|
||||
# Test case 1: Model with native structured output support, non-streaming
|
||||
{
|
||||
"name": "native_structured_output_non_streaming",
|
||||
@ -88,39 +84,6 @@ def test_structured_output_parser():
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 2: Model with native structured output support, streaming
|
||||
{
|
||||
"name": "native_structured_output_streaming",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"name":'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' "test"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 3: Model without native structured output support, non-streaming
|
||||
{
|
||||
"name": "prompt_based_structured_output_non_streaming",
|
||||
@ -137,80 +100,6 @@ def test_structured_output_parser():
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 4: Model without native structured output support, streaming
|
||||
{
|
||||
"name": "prompt_based_structured_output_streaming",
|
||||
"provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"support_structure_output": False,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="claude-3-sonnet",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"answer": "test'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="claude-3-sonnet",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' response"}'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=5),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 5: Streaming with list content
|
||||
{
|
||||
"name": "streaming_with_list_content",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"data": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data='{"data":'),
|
||||
]
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data=' "value"}'),
|
||||
]
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 6: Error case - non-string LLM response content (non-streaming)
|
||||
{
|
||||
"name": "error_non_string_content_non_streaming",
|
||||
@ -290,7 +179,7 @@ def test_structured_output_parser():
|
||||
|
||||
# Add parameter rules if specified
|
||||
if "parameter_rules" in case:
|
||||
model_schema.parameter_rules = case["parameter_rules"]
|
||||
cast(Any, model_schema).parameter_rules = case["parameter_rules"]
|
||||
|
||||
# Setup model instance
|
||||
model_instance = get_model_instance()
|
||||
@ -304,25 +193,14 @@ def test_structured_output_parser():
|
||||
|
||||
if case["should_raise"]:
|
||||
# Test error cases
|
||||
with pytest.raises(case["expected_error"]): # noqa: PT012
|
||||
if case["stream"]:
|
||||
result_generator = invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
)
|
||||
# Consume the generator to trigger the error
|
||||
list(result_generator)
|
||||
else:
|
||||
invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
)
|
||||
with pytest.raises(case["expected_error"]):
|
||||
invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
)
|
||||
else:
|
||||
# Test successful cases
|
||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||
@ -340,34 +218,18 @@ def test_structured_output_parser():
|
||||
user="test_user",
|
||||
)
|
||||
|
||||
if case["expected_result_type"] == "generator":
|
||||
# Test streaming results
|
||||
assert hasattr(result, "__iter__")
|
||||
chunks = list(result)
|
||||
assert len(chunks) > 0
|
||||
|
||||
# Verify all chunks are LLMResultChunkWithStructuredOutput
|
||||
for chunk in chunks[:-1]: # All except last
|
||||
assert isinstance(chunk, LLMResultChunkWithStructuredOutput)
|
||||
assert chunk.model == case["model_name"]
|
||||
|
||||
# Last chunk should have structured output
|
||||
last_chunk = chunks[-1]
|
||||
assert isinstance(last_chunk, LLMResultChunkWithStructuredOutput)
|
||||
assert last_chunk.structured_output is not None
|
||||
assert isinstance(last_chunk.structured_output, dict)
|
||||
else:
|
||||
# Test non-streaming results
|
||||
assert isinstance(result, case["expected_result_type"])
|
||||
assert result.model == case["model_name"]
|
||||
assert result.structured_output is not None
|
||||
assert isinstance(result.structured_output, dict)
|
||||
# Test non-streaming results
|
||||
expected_type = cast(type, case["expected_result_type"])
|
||||
assert isinstance(result, expected_type)
|
||||
assert result.model == case["model_name"]
|
||||
assert result.structured_output is not None
|
||||
assert isinstance(result.structured_output, dict)
|
||||
|
||||
# Verify model_instance.invoke_llm was called with correct parameters
|
||||
model_instance.invoke_llm.assert_called_once()
|
||||
call_args = model_instance.invoke_llm.call_args
|
||||
|
||||
assert call_args.kwargs["stream"] == case["stream"]
|
||||
assert call_args.kwargs["stream"] is False
|
||||
assert call_args.kwargs["user"] == "test_user"
|
||||
assert "temperature" in call_args.kwargs["model_parameters"]
|
||||
assert "max_tokens" in call_args.kwargs["model_parameters"]
|
||||
@ -377,7 +239,7 @@ def test_parse_structured_output_edge_cases():
|
||||
"""Test edge cases for structured output parsing"""
|
||||
|
||||
# Test case with list that contains dict (reasoning model scenario)
|
||||
testcase_list_with_dict = {
|
||||
testcase_list_with_dict: dict[str, Any] = {
|
||||
"name": "list_with_dict_parsing",
|
||||
"provider": "deepseek",
|
||||
"model_name": "deepseek-r1",
|
||||
@ -425,7 +287,7 @@ def test_model_specific_schema_preparation():
|
||||
"""Test schema preparation for different model types"""
|
||||
|
||||
# Test Gemini model
|
||||
gemini_case = {
|
||||
gemini_case: dict[str, Any] = {
|
||||
"provider": "google",
|
||||
"model_name": "gemini-pro",
|
||||
"support_structure_output": True,
|
||||
@ -493,46 +355,6 @@ def test_structured_output_with_pydantic_model_non_streaming():
|
||||
assert result.name == "test"
|
||||
|
||||
|
||||
def test_structured_output_with_pydantic_model_streaming():
|
||||
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
|
||||
model_instance = get_model_instance()
|
||||
|
||||
def mock_streaming_response():
|
||||
yield LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"name":'),
|
||||
usage=create_mock_usage(prompt_tokens=8, completion_tokens=2),
|
||||
),
|
||||
)
|
||||
yield LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' "test"}'),
|
||||
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
|
||||
),
|
||||
)
|
||||
|
||||
model_instance.invoke_llm.return_value = mock_streaming_response()
|
||||
|
||||
result = invoke_llm_with_pydantic_model(
|
||||
provider="openai",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[UserPromptMessage(content="Return a JSON object with name.")],
|
||||
output_model=ExampleOutput
|
||||
)
|
||||
|
||||
assert isinstance(result, ExampleOutput)
|
||||
assert result.name == "test"
|
||||
|
||||
|
||||
def test_structured_output_with_pydantic_model_validation_error():
|
||||
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
|
||||
model_instance = get_model_instance()
|
||||
@ -552,55 +374,12 @@ def test_structured_output_with_pydantic_model_validation_error():
|
||||
)
|
||||
|
||||
|
||||
class TestGetDefaultValueForType:
|
||||
"""Test cases for _get_default_value_for_type function"""
|
||||
|
||||
def test_string_type(self):
|
||||
assert _get_default_value_for_type("string") == ""
|
||||
|
||||
def test_object_type(self):
|
||||
assert _get_default_value_for_type("object") == {}
|
||||
|
||||
def test_array_type(self):
|
||||
assert _get_default_value_for_type("array") == []
|
||||
|
||||
def test_number_type(self):
|
||||
assert _get_default_value_for_type("number") == 0
|
||||
|
||||
def test_integer_type(self):
|
||||
assert _get_default_value_for_type("integer") == 0
|
||||
|
||||
def test_boolean_type(self):
|
||||
assert _get_default_value_for_type("boolean") is False
|
||||
|
||||
def test_null_type(self):
|
||||
assert _get_default_value_for_type("null") is None
|
||||
|
||||
def test_none_type(self):
|
||||
assert _get_default_value_for_type(None) is None
|
||||
|
||||
def test_unknown_type(self):
|
||||
assert _get_default_value_for_type("unknown") is None
|
||||
|
||||
def test_union_type_string_null(self):
|
||||
# ["string", "null"] should return "" (first non-null type)
|
||||
assert _get_default_value_for_type(["string", "null"]) == ""
|
||||
|
||||
def test_union_type_null_first(self):
|
||||
# ["null", "integer"] should return 0 (first non-null type)
|
||||
assert _get_default_value_for_type(["null", "integer"]) == 0
|
||||
|
||||
def test_union_type_only_null(self):
|
||||
# ["null"] should return None
|
||||
assert _get_default_value_for_type(["null"]) is None
|
||||
|
||||
|
||||
class TestFillDefaultsFromSchema:
|
||||
"""Test cases for fill_defaults_from_schema function"""
|
||||
|
||||
def test_simple_required_fields(self):
|
||||
"""Test filling simple required fields"""
|
||||
schema = {
|
||||
schema: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
@ -609,7 +388,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
}
|
||||
output = {"name": "Alice"}
|
||||
output: dict[str, Any] = {"name": "Alice"}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@ -619,7 +398,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_non_required_fields_not_filled(self):
|
||||
"""Test that non-required fields are not filled"""
|
||||
schema = {
|
||||
schema: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"required_field": {"type": "string"},
|
||||
@ -627,7 +406,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["required_field"],
|
||||
}
|
||||
output = {}
|
||||
output: dict[str, Any] = {}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@ -636,7 +415,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_nested_object_required_fields(self):
|
||||
"""Test filling nested object required fields"""
|
||||
schema = {
|
||||
schema: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
@ -659,7 +438,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["user"],
|
||||
}
|
||||
output = {
|
||||
output: dict[str, Any] = {
|
||||
"user": {
|
||||
"name": "Alice",
|
||||
"address": {
|
||||
@ -684,7 +463,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_missing_nested_object_created(self):
|
||||
"""Test that missing required nested objects are created"""
|
||||
schema = {
|
||||
schema: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metadata": {
|
||||
@ -698,7 +477,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["metadata"],
|
||||
}
|
||||
output = {}
|
||||
output: dict[str, Any] = {}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@ -710,7 +489,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_all_types_default_values(self):
|
||||
"""Test default values for all types"""
|
||||
schema = {
|
||||
schema: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"str_field": {"type": "string"},
|
||||
@ -722,7 +501,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["str_field", "int_field", "num_field", "bool_field", "arr_field", "obj_field"],
|
||||
}
|
||||
output = {}
|
||||
output: dict[str, Any] = {}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@ -737,7 +516,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_existing_values_preserved(self):
|
||||
"""Test that existing values are not overwritten"""
|
||||
schema = {
|
||||
schema: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
@ -745,7 +524,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["name", "count"],
|
||||
}
|
||||
output = {"name": "Bob", "count": 42}
|
||||
output: dict[str, Any] = {"name": "Bob", "count": 42}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@ -753,7 +532,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_complex_nested_structure(self):
|
||||
"""Test complex nested structure with multiple levels"""
|
||||
schema = {
|
||||
schema: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
@ -789,7 +568,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["user", "tags", "metadata", "is_active"],
|
||||
}
|
||||
output = {
|
||||
output: dict[str, Any] = {
|
||||
"user": {
|
||||
"name": "Alice",
|
||||
"age": 25,
|
||||
@ -829,8 +608,8 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_empty_schema(self):
|
||||
"""Test with empty schema"""
|
||||
schema = {}
|
||||
output = {"any": "value"}
|
||||
schema: dict[str, Any] = {}
|
||||
output: dict[str, Any] = {"any": "value"}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@ -838,14 +617,14 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_schema_without_required(self):
|
||||
"""Test schema without required field"""
|
||||
schema = {
|
||||
schema: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"optional1": {"type": "string"},
|
||||
"optional2": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
output = {}
|
||||
output: dict[str, Any] = {}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user