mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
test: unit test cases for core.variables, core.plugin, core.prompt module (#32637)
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -13,6 +14,8 @@ from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from models.model import Conversation
|
||||
@ -188,3 +191,328 @@ def get_chat_model_args():
|
||||
context = "I am superman."
|
||||
|
||||
return model_config_mock, memory_config, prompt_messages, inputs, context
|
||||
|
||||
|
||||
def test_get_prompt_dispatches_completion_and_chat_and_invalid():
|
||||
transform = AdvancedPromptTransform()
|
||||
model_config = MagicMock(spec=ModelConfigEntity)
|
||||
completion_template = CompletionModelPromptTemplate(text="Hello {{name}}", edition_type="basic")
|
||||
chat_template = [ChatModelMessage(text="Hello {{name}}", role=PromptMessageRole.USER, edition_type="basic")]
|
||||
|
||||
transform._get_completion_model_prompt_messages = MagicMock(return_value=[UserPromptMessage(content="c")])
|
||||
transform._get_chat_model_prompt_messages = MagicMock(return_value=[UserPromptMessage(content="h")])
|
||||
|
||||
completion_result = transform.get_prompt(
|
||||
prompt_template=completion_template,
|
||||
inputs={"name": "john"},
|
||||
query="q",
|
||||
files=[],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
assert completion_result[0].content == "c"
|
||||
|
||||
chat_result = transform.get_prompt(
|
||||
prompt_template=chat_template,
|
||||
inputs={"name": "john"},
|
||||
query="q",
|
||||
files=[],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
assert chat_result[0].content == "h"
|
||||
|
||||
invalid_result = transform.get_prompt(
|
||||
prompt_template=cast(list, ["not-chat-model-message"]),
|
||||
inputs={"name": "john"},
|
||||
query="q",
|
||||
files=[],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
assert invalid_result == []
|
||||
|
||||
|
||||
def test_completion_prompt_jinja2_with_files():
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
transform = AdvancedPromptTransform()
|
||||
completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2")
|
||||
|
||||
file = File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.prompt.advanced_prompt_transform.Jinja2Formatter.format", return_value="Hi John"),
|
||||
patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content,
|
||||
):
|
||||
to_content.return_value = ImagePromptMessageContent(
|
||||
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
|
||||
)
|
||||
messages = transform._get_completion_model_prompt_messages(
|
||||
prompt_template=completion_template,
|
||||
inputs={"name": "John"},
|
||||
query="",
|
||||
files=[file],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert isinstance(messages[0].content, list)
|
||||
assert messages[0].content[0].data == "https://example.com/image.jpg"
|
||||
assert isinstance(messages[0].content[1], TextPromptMessageContent)
|
||||
assert messages[0].content[1].data == "Hi John"
|
||||
|
||||
|
||||
def test_completion_prompt_basic_sets_query_variable():
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
transform = AdvancedPromptTransform()
|
||||
template = CompletionModelPromptTemplate(text="Q={{#query#}}", edition_type="basic")
|
||||
|
||||
messages = transform._get_completion_model_prompt_messages(
|
||||
prompt_template=template,
|
||||
inputs={},
|
||||
query="what?",
|
||||
files=[],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert messages[0].content == "Q=what?"
|
||||
|
||||
|
||||
def test_chat_prompt_with_variable_template_and_context():
|
||||
transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
prompt_template = [ChatModelMessage(text="sys={{#node.name#}} ctx={{#context#}}", role=PromptMessageRole.SYSTEM)]
|
||||
|
||||
messages = transform._get_chat_model_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs={"#node.name#": "john"},
|
||||
query=None,
|
||||
files=[],
|
||||
context="context-text",
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert isinstance(messages[0], SystemPromptMessage)
|
||||
assert messages[0].content == "sys=john ctx=context-text"
|
||||
|
||||
|
||||
def test_chat_prompt_jinja2_branch_and_invalid_edition():
|
||||
transform = AdvancedPromptTransform()
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
prompt_template = [ChatModelMessage(text="Hello {{name}}", role=PromptMessageRole.USER, edition_type="jinja2")]
|
||||
|
||||
with patch("core.prompt.advanced_prompt_transform.Jinja2Formatter.format", return_value="Hello John"):
|
||||
messages = transform._get_chat_model_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs={"name": "John"},
|
||||
query=None,
|
||||
files=[],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
assert messages[0].content == "Hello John"
|
||||
|
||||
bad_prompt_template = [ChatModelMessage.model_construct(text="bad", role=PromptMessageRole.USER, edition_type="x")]
|
||||
with pytest.raises(ValueError, match="Invalid edition type"):
|
||||
transform._get_chat_model_prompt_messages(
|
||||
prompt_template=bad_prompt_template,
|
||||
inputs={},
|
||||
query=None,
|
||||
files=[],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
|
||||
def test_chat_prompt_query_template_and_query_only_branch():
|
||||
transform = AdvancedPromptTransform()
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
memory_config = MemoryConfig(
|
||||
window=MemoryConfig.WindowConfig(enabled=False),
|
||||
query_prompt_template="query={{#sys.query#}} ctx={{#context#}}",
|
||||
)
|
||||
prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)]
|
||||
|
||||
messages = transform._get_chat_model_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="what",
|
||||
files=[],
|
||||
context="ctx",
|
||||
memory_config=memory_config,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
assert messages[-1].content == "query={{#sys.query#}} ctx=ctx"
|
||||
|
||||
|
||||
def test_chat_prompt_memory_with_files_and_query():
|
||||
transform = AdvancedPromptTransform()
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
memory = MagicMock(spec=TokenBufferMemory)
|
||||
prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)]
|
||||
file = File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
transform._append_chat_histories = MagicMock(
|
||||
side_effect=lambda memory, memory_config, prompt_messages, **kwargs: prompt_messages
|
||||
)
|
||||
with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content:
|
||||
to_content.return_value = ImagePromptMessageContent(
|
||||
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
|
||||
)
|
||||
messages = transform._get_chat_model_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[file],
|
||||
context=None,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert isinstance(messages[-1].content, list)
|
||||
assert messages[-1].content[1].data == "q"
|
||||
|
||||
|
||||
def test_chat_prompt_files_without_query_updates_last_user_or_appends_new():
|
||||
transform = AdvancedPromptTransform()
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
file = File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
prompt_with_last_user = [ChatModelMessage(text="u", role=PromptMessageRole.USER)]
|
||||
with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content:
|
||||
to_content.return_value = ImagePromptMessageContent(
|
||||
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
|
||||
)
|
||||
messages = transform._get_chat_model_prompt_messages(
|
||||
prompt_template=prompt_with_last_user,
|
||||
inputs={},
|
||||
query=None,
|
||||
files=[file],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
assert isinstance(messages[-1].content, list)
|
||||
assert messages[-1].content[1].data == "u"
|
||||
|
||||
prompt_without_last_user = [ChatModelMessage(text="s", role=PromptMessageRole.SYSTEM)]
|
||||
with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content:
|
||||
to_content.return_value = ImagePromptMessageContent(
|
||||
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
|
||||
)
|
||||
messages = transform._get_chat_model_prompt_messages(
|
||||
prompt_template=prompt_without_last_user,
|
||||
inputs={},
|
||||
query=None,
|
||||
files=[file],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
assert isinstance(messages[-1], UserPromptMessage)
|
||||
assert isinstance(messages[-1].content, list)
|
||||
assert messages[-1].content[1].data == ""
|
||||
|
||||
|
||||
def test_chat_prompt_files_with_query_branch():
|
||||
transform = AdvancedPromptTransform()
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
file = File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content:
|
||||
to_content.return_value = ImagePromptMessageContent(
|
||||
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
|
||||
)
|
||||
messages = transform._get_chat_model_prompt_messages(
|
||||
prompt_template=[],
|
||||
inputs={},
|
||||
query="query-text",
|
||||
files=[file],
|
||||
context=None,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert isinstance(messages[-1].content, list)
|
||||
assert messages[-1].content[1].data == "query-text"
|
||||
|
||||
|
||||
def test_set_context_query_histories_variable_helpers():
|
||||
transform = AdvancedPromptTransform()
|
||||
parser_context = PromptTemplateParser(template="{{#context#}}")
|
||||
parser_query = PromptTemplateParser(template="{{#query#}}")
|
||||
parser_hist = PromptTemplateParser(template="{{#histories#}}")
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
memory_config = MemoryConfig(
|
||||
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
|
||||
window=MemoryConfig.WindowConfig(enabled=False),
|
||||
)
|
||||
|
||||
assert transform._set_context_variable(None, parser_context, {})["#context#"] == ""
|
||||
assert transform._set_query_variable("", parser_query, {})["#query#"] == ""
|
||||
assert transform._set_query_variable("x", parser_query, {})["#query#"] == "x"
|
||||
assert (
|
||||
transform._set_histories_variable(
|
||||
memory=None, # type: ignore[arg-type]
|
||||
memory_config=memory_config,
|
||||
raw_prompt="{{#histories#}}",
|
||||
role_prefix=memory_config.role_prefix, # type: ignore[arg-type]
|
||||
parser=parser_hist,
|
||||
prompt_inputs={},
|
||||
model_config=model_config_mock,
|
||||
)["#histories#"]
|
||||
== ""
|
||||
)
|
||||
|
||||
@ -2,12 +2,14 @@ from uuid import uuid4
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
|
||||
|
||||
class MockMessage:
|
||||
def __init__(self, id, parent_message_id):
|
||||
def __init__(self, id, parent_message_id, answer="answer"):
|
||||
self.id = id
|
||||
self.parent_message_id = parent_message_id
|
||||
self.answer = answer
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
@ -89,3 +91,44 @@ def test_extract_thread_messages_mixed_with_legacy_messages():
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert [msg["id"] for msg in result] == [id5, id4, id2, id1]
|
||||
|
||||
|
||||
def test_extract_thread_messages_breaks_when_parent_is_none():
|
||||
id1, id2 = str(uuid4()), str(uuid4())
|
||||
messages = [MockMessage(id2, None), MockMessage(id1, UUID_NIL)]
|
||||
|
||||
result = extract_thread_messages(messages)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == id2
|
||||
|
||||
|
||||
def test_get_thread_messages_length_excludes_newly_created_empty_answer(mocker):
|
||||
id1, id2 = str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
MockMessage(id2, id1, answer=""), # newest generated message should be excluded
|
||||
MockMessage(id1, UUID_NIL, answer="ok"),
|
||||
]
|
||||
|
||||
mock_scalars = mocker.patch("core.prompt.utils.get_thread_messages_length.db.session.scalars")
|
||||
mock_scalars.return_value.all.return_value = messages
|
||||
|
||||
length = get_thread_messages_length("conversation-1")
|
||||
|
||||
assert length == 1
|
||||
mock_scalars.assert_called_once()
|
||||
|
||||
|
||||
def test_get_thread_messages_length_keeps_non_empty_latest_answer(mocker):
|
||||
id1, id2 = str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
MockMessage(id2, id1, answer="latest-answer"),
|
||||
MockMessage(id1, UUID_NIL, answer="older-answer"),
|
||||
]
|
||||
|
||||
mock_scalars = mocker.patch("core.prompt.utils.get_thread_messages_length.db.session.scalars")
|
||||
mock_scalars.return_value.all.return_value = messages
|
||||
|
||||
length = get_thread_messages_length("conversation-2")
|
||||
|
||||
assert length == 2
|
||||
|
||||
@ -1,6 +1,11 @@
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
@ -25,3 +30,82 @@ def test_dump_prompt_message():
|
||||
)
|
||||
data = prompt.model_dump()
|
||||
assert data["content"][0].get("url") == example_url
|
||||
|
||||
|
||||
def test_prompt_messages_to_prompt_for_saving_chat_mode():
|
||||
chat_messages = [
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="hello "),
|
||||
ImagePromptMessageContent(
|
||||
url="https://example.com/image1.jpg",
|
||||
format="jpg",
|
||||
mime_type="image/jpeg",
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
),
|
||||
AudioPromptMessageContent(
|
||||
url="https://example.com/audio1.mp3",
|
||||
format="mp3",
|
||||
mime_type="audio/mpeg",
|
||||
),
|
||||
TextPromptMessageContent(data="world"),
|
||||
]
|
||||
),
|
||||
AssistantPromptMessage(
|
||||
content="assistant-text",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "tool-1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": '{"q":"python"}'},
|
||||
}
|
||||
],
|
||||
),
|
||||
ToolPromptMessage(content="tool-output", name="search", tool_call_id="tool-1"),
|
||||
UserPromptMessage.model_construct(role="unknown", content="skip"), # type: ignore[arg-type]
|
||||
]
|
||||
|
||||
prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving(ModelMode.CHAT, chat_messages)
|
||||
|
||||
assert len(prompts) == 3
|
||||
assert prompts[0]["role"] == "user"
|
||||
assert prompts[0]["text"] == "hello world"
|
||||
assert prompts[0]["files"][0]["type"] == "image"
|
||||
assert prompts[0]["files"][1]["type"] == "audio"
|
||||
|
||||
assert prompts[1]["role"] == "assistant"
|
||||
assert prompts[1]["text"] == "assistant-text"
|
||||
assert prompts[1]["tool_calls"][0]["function"]["name"] == "search"
|
||||
assert prompts[2]["role"] == "tool"
|
||||
|
||||
|
||||
def test_prompt_messages_to_prompt_for_saving_completion_mode_with_and_without_files():
|
||||
completion_message_with_files = UserPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="first "),
|
||||
TextPromptMessageContent(data="second"),
|
||||
ImagePromptMessageContent(
|
||||
url="https://example.com/image2.jpg",
|
||||
format="jpg",
|
||||
mime_type="image/jpeg",
|
||||
detail=ImagePromptMessageContent.DETAIL.LOW,
|
||||
),
|
||||
]
|
||||
)
|
||||
prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
ModelMode.COMPLETION, [completion_message_with_files]
|
||||
)
|
||||
assert prompts == [
|
||||
{
|
||||
"role": "user",
|
||||
"text": "first second",
|
||||
"files": prompts[0]["files"],
|
||||
}
|
||||
]
|
||||
assert prompts[0]["files"][0]["type"] == "image"
|
||||
|
||||
completion_message_text_only = UserPromptMessage(content="plain text")
|
||||
prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
ModelMode.COMPLETION, [completion_message_text_only]
|
||||
)
|
||||
assert prompts == [{"role": "user", "text": "plain text"}]
|
||||
|
||||
@ -1,4 +1,10 @@
|
||||
# from unittest.mock import MagicMock
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
|
||||
# from core.app.app_config.entities import ModelConfigEntity
|
||||
# from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
@ -9,44 +15,217 @@
|
||||
# from core.prompt.prompt_transform import PromptTransform
|
||||
|
||||
|
||||
# def test__calculate_rest_token():
|
||||
# model_schema_mock = MagicMock(spec=AIModelEntity)
|
||||
# parameter_rule_mock = MagicMock(spec=ParameterRule)
|
||||
# parameter_rule_mock.name = "max_tokens"
|
||||
# model_schema_mock.parameter_rules = [parameter_rule_mock]
|
||||
# model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62}
|
||||
class TestPromptTransform:
|
||||
def test_resolve_model_runtime_requires_model_config_or_instance(self):
|
||||
transform = PromptTransform()
|
||||
|
||||
# large_language_model_mock = MagicMock(spec=LargeLanguageModel)
|
||||
# large_language_model_mock.get_num_tokens.return_value = 6
|
||||
with pytest.raises(ValueError, match="Either model_config or model_instance must be provided."):
|
||||
transform._resolve_model_runtime()
|
||||
|
||||
# provider_mock = MagicMock(spec=ProviderEntity)
|
||||
# provider_mock.provider = "openai"
|
||||
def test_resolve_model_runtime_builds_model_instance_from_model_config(self):
|
||||
transform = PromptTransform()
|
||||
fake_model_schema = SimpleNamespace(model_properties={}, parameter_rules=[])
|
||||
fake_model_type_instance = MagicMock()
|
||||
fake_model_type_instance.get_model_schema.return_value = fake_model_schema
|
||||
fake_model_instance = SimpleNamespace(
|
||||
model_type_instance=fake_model_type_instance,
|
||||
model_name="resolved-model",
|
||||
credentials=None,
|
||||
parameters=None,
|
||||
stop=None,
|
||||
)
|
||||
model_config = SimpleNamespace(
|
||||
provider_model_bundle=object(),
|
||||
model="config-model",
|
||||
credentials={"api_key": "secret"},
|
||||
parameters={"temperature": 0.1},
|
||||
stop=["END"],
|
||||
model_schema=SimpleNamespace(model_properties={}, parameter_rules=[]),
|
||||
)
|
||||
|
||||
# provider_configuration_mock = MagicMock(spec=ProviderConfiguration)
|
||||
# provider_configuration_mock.provider = provider_mock
|
||||
# provider_configuration_mock.model_settings = None
|
||||
with patch(
|
||||
"core.prompt.prompt_transform.ModelInstance", return_value=fake_model_instance
|
||||
) as model_instance_cls:
|
||||
model_instance, model_schema = transform._resolve_model_runtime(model_config=model_config)
|
||||
|
||||
# provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
|
||||
# provider_model_bundle_mock.model_type_instance = large_language_model_mock
|
||||
# provider_model_bundle_mock.configuration = provider_configuration_mock
|
||||
model_instance_cls.assert_called_once_with(
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
model=model_config.model,
|
||||
)
|
||||
fake_model_type_instance.get_model_schema.assert_called_once_with(
|
||||
model="resolved-model",
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
assert model_instance is fake_model_instance
|
||||
assert model_instance.credentials == {"api_key": "secret"}
|
||||
assert model_instance.parameters == {"temperature": 0.1}
|
||||
assert model_instance.stop == ["END"]
|
||||
assert model_schema is fake_model_schema
|
||||
|
||||
# model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
# model_config_mock.model = "gpt-4"
|
||||
# model_config_mock.credentials = {}
|
||||
# model_config_mock.parameters = {"max_tokens": 50}
|
||||
# model_config_mock.model_schema = model_schema_mock
|
||||
# model_config_mock.provider_model_bundle = provider_model_bundle_mock
|
||||
def test_resolve_model_runtime_uses_model_config_schema_fallback(self):
|
||||
transform = PromptTransform()
|
||||
fallback_schema = SimpleNamespace(model_properties={}, parameter_rules=[])
|
||||
fake_model_type_instance = MagicMock()
|
||||
fake_model_type_instance.get_model_schema.return_value = None
|
||||
model_instance = SimpleNamespace(
|
||||
model_type_instance=fake_model_type_instance,
|
||||
model_name="resolved-model",
|
||||
credentials={"api_key": "secret"},
|
||||
parameters={},
|
||||
)
|
||||
model_config = SimpleNamespace(model_schema=fallback_schema)
|
||||
|
||||
# prompt_transform = PromptTransform()
|
||||
resolved_model_instance, resolved_schema = transform._resolve_model_runtime(
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
# prompt_messages = [UserPromptMessage(content="Hello, how are you?")]
|
||||
# rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock)
|
||||
assert resolved_model_instance is model_instance
|
||||
assert resolved_schema is fallback_schema
|
||||
|
||||
# # Validate based on the mock configuration and expected logic
|
||||
# expected_rest_tokens = (
|
||||
# model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
||||
# - model_config_mock.parameters["max_tokens"]
|
||||
# - large_language_model_mock.get_num_tokens.return_value
|
||||
# )
|
||||
# assert rest_tokens == expected_rest_tokens
|
||||
# assert rest_tokens == 6
|
||||
def test_resolve_model_runtime_raises_when_schema_missing_without_model_config(self):
|
||||
transform = PromptTransform()
|
||||
fake_model_type_instance = MagicMock()
|
||||
fake_model_type_instance.get_model_schema.return_value = None
|
||||
model_instance = SimpleNamespace(
|
||||
model_type_instance=fake_model_type_instance,
|
||||
model_name="resolved-model",
|
||||
credentials={"api_key": "secret"},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Model schema not found for the provided model instance."):
|
||||
transform._resolve_model_runtime(model_instance=model_instance)
|
||||
|
||||
def test_calculate_rest_token_defaults_when_context_size_missing(self):
|
||||
transform = PromptTransform()
|
||||
fake_model_instance = SimpleNamespace(parameters={}, get_llm_num_tokens=lambda _: 0)
|
||||
fake_model_schema = SimpleNamespace(model_properties={}, parameter_rules=[])
|
||||
transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema))
|
||||
model_config = SimpleNamespace(
|
||||
model_schema=SimpleNamespace(model_properties={}, parameter_rules=[]),
|
||||
provider_model_bundle=object(),
|
||||
model="test-model",
|
||||
parameters={},
|
||||
)
|
||||
|
||||
rest = transform._calculate_rest_token([], model_config=model_config)
|
||||
|
||||
assert rest == 2000
|
||||
|
||||
def test_calculate_rest_token_uses_max_tokens_and_clamps_to_zero(self):
|
||||
transform = PromptTransform()
|
||||
|
||||
parameter_rule = SimpleNamespace(name="max_tokens", use_template=None)
|
||||
fake_model_instance = SimpleNamespace(parameters={"max_tokens": 50}, get_llm_num_tokens=lambda _: 95)
|
||||
fake_model_schema = SimpleNamespace(
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 100},
|
||||
parameter_rules=[parameter_rule],
|
||||
)
|
||||
transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema))
|
||||
model_config = SimpleNamespace(
|
||||
model_schema=SimpleNamespace(
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 100},
|
||||
parameter_rules=[parameter_rule],
|
||||
),
|
||||
provider_model_bundle=object(),
|
||||
model="test-model",
|
||||
parameters={"max_tokens": 50},
|
||||
)
|
||||
|
||||
rest = transform._calculate_rest_token([SimpleNamespace()], model_config=model_config)
|
||||
|
||||
assert rest == 0
|
||||
|
||||
def test_calculate_rest_token_supports_use_template_parameter(self):
|
||||
transform = PromptTransform()
|
||||
|
||||
parameter_rule = SimpleNamespace(name="generation_max", use_template="max_tokens")
|
||||
fake_model_instance = SimpleNamespace(parameters={"max_tokens": 30}, get_llm_num_tokens=lambda _: 20)
|
||||
fake_model_schema = SimpleNamespace(
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 200},
|
||||
parameter_rules=[parameter_rule],
|
||||
)
|
||||
transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema))
|
||||
model_config = SimpleNamespace(
|
||||
model_schema=SimpleNamespace(
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 200},
|
||||
parameter_rules=[parameter_rule],
|
||||
),
|
||||
provider_model_bundle=object(),
|
||||
model="test-model",
|
||||
parameters={"max_tokens": 30},
|
||||
)
|
||||
|
||||
rest = transform._calculate_rest_token([SimpleNamespace()], model_config=model_config)
|
||||
|
||||
assert rest == 150
|
||||
|
||||
def test_get_history_messages_from_memory_with_and_without_window(self):
|
||||
transform = PromptTransform()
|
||||
memory = MagicMock()
|
||||
memory.get_history_prompt_text.return_value = "history"
|
||||
|
||||
memory_config_with_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=3))
|
||||
result = transform._get_history_messages_from_memory(
|
||||
memory=memory,
|
||||
memory_config=memory_config_with_window,
|
||||
max_token_limit=100,
|
||||
human_prefix="Human",
|
||||
ai_prefix="Assistant",
|
||||
)
|
||||
|
||||
assert result == "history"
|
||||
memory.get_history_prompt_text.assert_called_with(
|
||||
max_token_limit=100,
|
||||
human_prefix="Human",
|
||||
ai_prefix="Assistant",
|
||||
message_limit=3,
|
||||
)
|
||||
|
||||
memory.reset_mock()
|
||||
memory_config_no_window = SimpleNamespace(window=SimpleNamespace(enabled=False, size=2))
|
||||
transform._get_history_messages_from_memory(
|
||||
memory=memory,
|
||||
memory_config=memory_config_no_window,
|
||||
max_token_limit=50,
|
||||
)
|
||||
memory.get_history_prompt_text.assert_called_with(max_token_limit=50)
|
||||
|
||||
def test_get_history_messages_list_from_memory_with_and_without_window(self):
|
||||
transform = PromptTransform()
|
||||
memory = MagicMock()
|
||||
memory.get_history_prompt_messages.return_value = ["m1", "m2"]
|
||||
|
||||
memory_config_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=2))
|
||||
result = transform._get_history_messages_list_from_memory(memory, memory_config_window, 120)
|
||||
assert result == ["m1", "m2"]
|
||||
memory.get_history_prompt_messages.assert_called_with(max_token_limit=120, message_limit=2)
|
||||
|
||||
memory.reset_mock()
|
||||
memory.get_history_prompt_messages.return_value = ["only"]
|
||||
memory_config_no_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=0))
|
||||
result = transform._get_history_messages_list_from_memory(memory, memory_config_no_window, 10)
|
||||
assert result == ["only"]
|
||||
memory.get_history_prompt_messages.assert_called_with(max_token_limit=10, message_limit=None)
|
||||
|
||||
def test_append_chat_histories_extends_prompt_messages(self, monkeypatch):
|
||||
transform = PromptTransform()
|
||||
memory = MagicMock()
|
||||
memory_config = SimpleNamespace(window=SimpleNamespace(enabled=False, size=None))
|
||||
|
||||
monkeypatch.setattr(transform, "_calculate_rest_token", lambda prompt_messages, **kwargs: 99)
|
||||
monkeypatch.setattr(
|
||||
transform,
|
||||
"_get_history_messages_list_from_memory",
|
||||
lambda memory, memory_config, max_token_limit: ["h1", "h2"],
|
||||
)
|
||||
|
||||
result = transform._append_chat_histories(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
prompt_messages=["p1"],
|
||||
model_config=SimpleNamespace(),
|
||||
)
|
||||
|
||||
assert result == ["p1", "h1", "h2"]
|
||||
|
||||
@ -1,9 +1,29 @@
|
||||
from unittest.mock import MagicMock
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.prompt.prompt_templates.advanced_prompt_templates import (
|
||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
|
||||
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
|
||||
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
|
||||
BAICHUAN_CONTEXT,
|
||||
CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||
CHAT_APP_COMPLETION_PROMPT_CONFIG,
|
||||
COMPLETION_APP_CHAT_PROMPT_CONFIG,
|
||||
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
|
||||
CONTEXT,
|
||||
)
|
||||
from core.prompt.simple_prompt_transform import SimplePromptTransform
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from models.model import AppMode, Conversation
|
||||
|
||||
|
||||
@ -244,3 +264,178 @@ def test__get_completion_model_prompt_messages():
|
||||
assert len(prompt_messages) == 1
|
||||
assert stops == prompt_rules.get("stops")
|
||||
assert prompt_messages[0].content == real_prompt
|
||||
|
||||
|
||||
def test_get_prompt_dispatches_chat_and_completion():
|
||||
transform = SimplePromptTransform()
|
||||
model_config_chat = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
model_config_chat.mode = "chat"
|
||||
model_config_completion = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
model_config_completion.mode = "completion"
|
||||
prompt_entity = SimpleNamespace(simple_prompt_template="hello")
|
||||
|
||||
transform._get_chat_model_prompt_messages = MagicMock(return_value=(["chat-msg"], None))
|
||||
transform._get_completion_model_prompt_messages = MagicMock(return_value=(["completion-msg"], ["stop"]))
|
||||
|
||||
chat_messages, chat_stops = transform.get_prompt(
|
||||
app_mode=AppMode.CHAT,
|
||||
prompt_template_entity=prompt_entity,
|
||||
inputs={"n": 1},
|
||||
query="q",
|
||||
files=[],
|
||||
context=None,
|
||||
memory=None,
|
||||
model_config=model_config_chat,
|
||||
)
|
||||
assert chat_messages == ["chat-msg"]
|
||||
assert chat_stops is None
|
||||
|
||||
completion_messages, completion_stops = transform.get_prompt(
|
||||
app_mode=AppMode.CHAT,
|
||||
prompt_template_entity=prompt_entity,
|
||||
inputs={"n": 1},
|
||||
query="q",
|
||||
files=[],
|
||||
context=None,
|
||||
memory=None,
|
||||
model_config=model_config_completion,
|
||||
)
|
||||
assert completion_messages == ["completion-msg"]
|
||||
assert completion_stops == ["stop"]
|
||||
|
||||
|
||||
def test_get_prompt_str_and_rules_type_validation_errors():
|
||||
transform = SimplePromptTransform()
|
||||
model_config = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
model_config.provider = "openai"
|
||||
model_config.model = "gpt-4"
|
||||
valid_prompt_template = SimplePromptTransform().get_prompt_template(
|
||||
AppMode.CHAT, "openai", "gpt-4", "", False, False
|
||||
)["prompt_template"]
|
||||
|
||||
bad_custom_keys = {
|
||||
"prompt_template": valid_prompt_template,
|
||||
"custom_variable_keys": "not-list",
|
||||
"special_variable_keys": [],
|
||||
"prompt_rules": {},
|
||||
}
|
||||
transform.get_prompt_template = MagicMock(return_value=bad_custom_keys)
|
||||
with pytest.raises(TypeError, match="custom_variable_keys"):
|
||||
transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None)
|
||||
|
||||
bad_special_keys = {
|
||||
**bad_custom_keys,
|
||||
"custom_variable_keys": [],
|
||||
"special_variable_keys": "not-list",
|
||||
}
|
||||
transform.get_prompt_template = MagicMock(return_value=bad_special_keys)
|
||||
with pytest.raises(TypeError, match="special_variable_keys"):
|
||||
transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None)
|
||||
|
||||
bad_prompt_template = {
|
||||
**bad_custom_keys,
|
||||
"custom_variable_keys": [],
|
||||
"special_variable_keys": [],
|
||||
"prompt_template": 123,
|
||||
}
|
||||
transform.get_prompt_template = MagicMock(return_value=bad_prompt_template)
|
||||
with pytest.raises(TypeError, match="PromptTemplateParser"):
|
||||
transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None)
|
||||
|
||||
bad_prompt_rules = {
|
||||
**bad_custom_keys,
|
||||
"custom_variable_keys": [],
|
||||
"special_variable_keys": [],
|
||||
"prompt_template": valid_prompt_template,
|
||||
"prompt_rules": "not-dict",
|
||||
}
|
||||
transform.get_prompt_template = MagicMock(return_value=bad_prompt_rules)
|
||||
with pytest.raises(TypeError, match="prompt_rules"):
|
||||
transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None)
|
||||
|
||||
|
||||
def test_chat_model_prompt_messages_uses_prompt_when_query_empty():
|
||||
transform = SimplePromptTransform()
|
||||
model_config = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
transform._get_prompt_str_and_rules = MagicMock(return_value=("prompt-text", {}))
|
||||
transform._get_last_user_message = MagicMock(return_value=UserPromptMessage(content="prompt-text"))
|
||||
|
||||
prompt_messages, _ = transform._get_chat_model_prompt_messages(
|
||||
app_mode=AppMode.CHAT,
|
||||
pre_prompt="",
|
||||
inputs={},
|
||||
query="",
|
||||
files=[],
|
||||
context=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
assert prompt_messages[0].content == "prompt-text"
|
||||
transform._get_last_user_message.assert_called_once_with("prompt-text", [], None, None)
|
||||
|
||||
|
||||
def test_completion_model_prompt_messages_empty_stops_becomes_none():
|
||||
transform = SimplePromptTransform()
|
||||
model_config = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
transform._get_prompt_str_and_rules = MagicMock(return_value=("prompt", {"stops": []}))
|
||||
|
||||
prompt_messages, stops = transform._get_completion_model_prompt_messages(
|
||||
app_mode=AppMode.CHAT,
|
||||
pre_prompt="",
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
context=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 1
|
||||
assert stops is None
|
||||
|
||||
|
||||
def test_get_last_user_message_with_files_and_context_files():
|
||||
transform = SimplePromptTransform()
|
||||
file = SimpleNamespace()
|
||||
context_file = SimpleNamespace()
|
||||
|
||||
with patch("core.prompt.simple_prompt_transform.file_manager.to_prompt_message_content") as to_content:
|
||||
to_content.side_effect = [
|
||||
ImagePromptMessageContent(url="https://example.com/a.jpg", format="jpg", mime_type="image/jpg"),
|
||||
ImagePromptMessageContent(url="https://example.com/b.jpg", format="jpg", mime_type="image/jpg"),
|
||||
]
|
||||
message = transform._get_last_user_message(
|
||||
prompt="hello",
|
||||
files=[file],
|
||||
context_files=[context_file],
|
||||
image_detail_config=None,
|
||||
)
|
||||
|
||||
assert isinstance(message.content, list)
|
||||
assert message.content[0].data == "https://example.com/a.jpg"
|
||||
assert message.content[1].data == "https://example.com/b.jpg"
|
||||
assert isinstance(message.content[2], TextPromptMessageContent)
|
||||
assert message.content[2].data == "hello"
|
||||
|
||||
|
||||
def test_prompt_file_name_branches():
|
||||
transform = SimplePromptTransform()
|
||||
|
||||
assert transform._prompt_file_name(AppMode.CHAT, "openai", "gpt-4") == "common_chat"
|
||||
assert transform._prompt_file_name(AppMode.COMPLETION, "openai", "gpt-4") == "common_completion"
|
||||
assert transform._prompt_file_name(AppMode.COMPLETION, "baichuan", "Baichuan2") == "baichuan_completion"
|
||||
assert transform._prompt_file_name(AppMode.CHAT, "huggingface_hub", "baichuan-13b") == "baichuan_chat"
|
||||
|
||||
|
||||
def test_advanced_prompt_templates_constants_are_importable():
|
||||
assert isinstance(CONTEXT, str)
|
||||
assert isinstance(BAICHUAN_CONTEXT, str)
|
||||
assert "completion_prompt_config" in CHAT_APP_COMPLETION_PROMPT_CONFIG
|
||||
assert "chat_prompt_config" in CHAT_APP_CHAT_PROMPT_CONFIG
|
||||
assert "chat_prompt_config" in COMPLETION_APP_CHAT_PROMPT_CONFIG
|
||||
assert "completion_prompt_config" in COMPLETION_APP_COMPLETION_PROMPT_CONFIG
|
||||
assert "completion_prompt_config" in BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG
|
||||
assert "chat_prompt_config" in BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG
|
||||
assert "chat_prompt_config" in BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
|
||||
assert "completion_prompt_config" in BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG
|
||||
|
||||
Reference in New Issue
Block a user