mirror of
https://github.com/langgenius/dify.git
synced 2026-04-22 03:37:44 +08:00
refactor: unify structured output with pydantic model
Signed-off-by: Stream <Stream_2@qq.com>
This commit is contained in:
@ -2,9 +2,13 @@ from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.llm_generator.output_parser.structured_output import (
|
||||
invoke_llm_with_pydantic_model,
|
||||
invoke_llm_with_structured_output,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
@ -461,3 +465,68 @@ def test_model_specific_schema_preparation():
|
||||
|
||||
# For Gemini, the schema should not have additionalProperties and boolean should be converted to string
|
||||
assert "json_schema" in call_args.kwargs["model_parameters"]
|
||||
|
||||
|
||||
class ExampleOutput(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
def test_structured_output_with_pydantic_model():
|
||||
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content='{"name": "test"}'),
|
||||
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content="Return a JSON object with name.")]
|
||||
|
||||
result = invoke_llm_with_pydantic_model(
|
||||
provider="openai",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
output_model=ExampleOutput,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
assert result.structured_output == {"name": "test"}
|
||||
|
||||
|
||||
def test_structured_output_with_pydantic_model_streaming_rejected():
|
||||
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
|
||||
model_instance = get_model_instance()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
invoke_llm_with_pydantic_model(
|
||||
provider="openai",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
output_model=ExampleOutput,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
def test_structured_output_with_pydantic_model_validation_error():
|
||||
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content='{"name": 123}'),
|
||||
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
|
||||
)
|
||||
|
||||
with pytest.raises(OutputParserError):
|
||||
invoke_llm_with_pydantic_model(
|
||||
provider="openai",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
output_model=ExampleOutput,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user