mirror of
https://github.com/langgenius/dify.git
synced 2026-04-26 05:35:58 +08:00
feat: add tool call based structured output
This commit is contained in:
@ -6,6 +6,8 @@ from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import (
|
||||
_get_default_value_for_type,
|
||||
fill_defaults_from_schema,
|
||||
invoke_llm_with_pydantic_model,
|
||||
invoke_llm_with_structured_output,
|
||||
)
|
||||
@ -530,3 +532,304 @@ def test_structured_output_with_pydantic_model_validation_error():
|
||||
output_model=ExampleOutput,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
class TestGetDefaultValueForType:
|
||||
"""Test cases for _get_default_value_for_type function"""
|
||||
|
||||
def test_string_type(self):
|
||||
assert _get_default_value_for_type("string") == ""
|
||||
|
||||
def test_object_type(self):
|
||||
assert _get_default_value_for_type("object") == {}
|
||||
|
||||
def test_array_type(self):
|
||||
assert _get_default_value_for_type("array") == []
|
||||
|
||||
def test_number_type(self):
|
||||
assert _get_default_value_for_type("number") == 0
|
||||
|
||||
def test_integer_type(self):
|
||||
assert _get_default_value_for_type("integer") == 0
|
||||
|
||||
def test_boolean_type(self):
|
||||
assert _get_default_value_for_type("boolean") is False
|
||||
|
||||
def test_null_type(self):
|
||||
assert _get_default_value_for_type("null") is None
|
||||
|
||||
def test_none_type(self):
|
||||
assert _get_default_value_for_type(None) is None
|
||||
|
||||
def test_unknown_type(self):
|
||||
assert _get_default_value_for_type("unknown") is None
|
||||
|
||||
def test_union_type_string_null(self):
|
||||
# ["string", "null"] should return "" (first non-null type)
|
||||
assert _get_default_value_for_type(["string", "null"]) == ""
|
||||
|
||||
def test_union_type_null_first(self):
|
||||
# ["null", "integer"] should return 0 (first non-null type)
|
||||
assert _get_default_value_for_type(["null", "integer"]) == 0
|
||||
|
||||
def test_union_type_only_null(self):
|
||||
# ["null"] should return None
|
||||
assert _get_default_value_for_type(["null"]) is None
|
||||
|
||||
|
||||
class TestFillDefaultsFromSchema:
|
||||
"""Test cases for fill_defaults_from_schema function"""
|
||||
|
||||
def test_simple_required_fields(self):
|
||||
"""Test filling simple required fields"""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"email": {"type": "string"},
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
}
|
||||
output = {"name": "Alice"}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
assert result == {"name": "Alice", "age": 0}
|
||||
# email is not required, so it should not be added
|
||||
assert "email" not in result
|
||||
|
||||
def test_non_required_fields_not_filled(self):
|
||||
"""Test that non-required fields are not filled"""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"required_field": {"type": "string"},
|
||||
"optional_field": {"type": "string"},
|
||||
},
|
||||
"required": ["required_field"],
|
||||
}
|
||||
output = {}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
assert result == {"required_field": ""}
|
||||
assert "optional_field" not in result
|
||||
|
||||
def test_nested_object_required_fields(self):
|
||||
"""Test filling nested object required fields"""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"email": {"type": "string"},
|
||||
"address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"street": {"type": "string"},
|
||||
"zipcode": {"type": "string"},
|
||||
},
|
||||
"required": ["city", "street"],
|
||||
},
|
||||
},
|
||||
"required": ["name", "email", "address"],
|
||||
},
|
||||
},
|
||||
"required": ["user"],
|
||||
}
|
||||
output = {
|
||||
"user": {
|
||||
"name": "Alice",
|
||||
"address": {
|
||||
"city": "Beijing",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
assert result == {
|
||||
"user": {
|
||||
"name": "Alice",
|
||||
"email": "", # filled because required
|
||||
"address": {
|
||||
"city": "Beijing",
|
||||
"street": "", # filled because required
|
||||
# zipcode not filled because not required
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def test_missing_nested_object_created(self):
|
||||
"""Test that missing required nested objects are created"""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {"type": "string"},
|
||||
"updated_at": {"type": "string"},
|
||||
},
|
||||
"required": ["created_at"],
|
||||
},
|
||||
},
|
||||
"required": ["metadata"],
|
||||
}
|
||||
output = {}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
assert result == {
|
||||
"metadata": {
|
||||
"created_at": "",
|
||||
}
|
||||
}
|
||||
|
||||
def test_all_types_default_values(self):
|
||||
"""Test default values for all types"""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"str_field": {"type": "string"},
|
||||
"int_field": {"type": "integer"},
|
||||
"num_field": {"type": "number"},
|
||||
"bool_field": {"type": "boolean"},
|
||||
"arr_field": {"type": "array"},
|
||||
"obj_field": {"type": "object"},
|
||||
},
|
||||
"required": ["str_field", "int_field", "num_field", "bool_field", "arr_field", "obj_field"],
|
||||
}
|
||||
output = {}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
assert result == {
|
||||
"str_field": "",
|
||||
"int_field": 0,
|
||||
"num_field": 0,
|
||||
"bool_field": False,
|
||||
"arr_field": [],
|
||||
"obj_field": {},
|
||||
}
|
||||
|
||||
def test_existing_values_preserved(self):
|
||||
"""Test that existing values are not overwritten"""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "count"],
|
||||
}
|
||||
output = {"name": "Bob", "count": 42}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
assert result == {"name": "Bob", "count": 42}
|
||||
|
||||
def test_complex_nested_structure(self):
|
||||
"""Test complex nested structure with multiple levels"""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"email": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"street": {"type": "string"},
|
||||
"zipcode": {"type": "string"},
|
||||
},
|
||||
"required": ["city", "street"],
|
||||
},
|
||||
},
|
||||
"required": ["name", "email", "address"],
|
||||
},
|
||||
"tags": {"type": "array"},
|
||||
"orders": {"type": "array"},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {"type": "string"},
|
||||
"updated_at": {"type": "string"},
|
||||
},
|
||||
"required": ["created_at"],
|
||||
},
|
||||
"is_active": {"type": "boolean"},
|
||||
"notes": {"type": "string"},
|
||||
},
|
||||
"required": ["user", "tags", "metadata", "is_active"],
|
||||
}
|
||||
output = {
|
||||
"user": {
|
||||
"name": "Alice",
|
||||
"age": 25,
|
||||
"address": {
|
||||
"city": "Beijing",
|
||||
},
|
||||
},
|
||||
"orders": [{"id": 1}],
|
||||
"metadata": {
|
||||
"updated_at": "2024-01-01",
|
||||
},
|
||||
}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
expected = {
|
||||
"user": {
|
||||
"name": "Alice",
|
||||
"email": "", # required, filled
|
||||
"age": 25, # not required but exists
|
||||
"address": {
|
||||
"city": "Beijing",
|
||||
"street": "", # required, filled
|
||||
# zipcode not required
|
||||
},
|
||||
},
|
||||
"tags": [], # required, filled
|
||||
"orders": [{"id": 1}], # not required but exists
|
||||
"metadata": {
|
||||
"created_at": "", # required, filled
|
||||
"updated_at": "2024-01-01", # exists
|
||||
},
|
||||
"is_active": False, # required, filled
|
||||
# notes not required
|
||||
}
|
||||
assert result == expected
|
||||
|
||||
def test_empty_schema(self):
|
||||
"""Test with empty schema"""
|
||||
schema = {}
|
||||
output = {"any": "value"}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
assert result == {"any": "value"}
|
||||
|
||||
def test_schema_without_required(self):
|
||||
"""Test schema without required field"""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"optional1": {"type": "string"},
|
||||
"optional2": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
output = {}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
# No required fields, so nothing should be added
|
||||
assert result == {}
|
||||
|
||||
Reference in New Issue
Block a user