test: unit test cases for core.variables, core.plugin, core.prompt module (#32637)

This commit is contained in:
Rajat Agarwal
2026-03-12 08:59:02 +05:30
committed by GitHub
parent 135b3a15a6
commit 07e19c0748
24 changed files with 3526 additions and 97 deletions

View File

@ -1,3 +1,4 @@
from typing import cast
from unittest.mock import MagicMock, patch
import pytest
@ -13,6 +14,8 @@ from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessageRole,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from models.model import Conversation
@ -188,3 +191,328 @@ def get_chat_model_args():
context = "I am superman."
return model_config_mock, memory_config, prompt_messages, inputs, context
def test_get_prompt_dispatches_completion_and_chat_and_invalid():
transform = AdvancedPromptTransform()
model_config = MagicMock(spec=ModelConfigEntity)
completion_template = CompletionModelPromptTemplate(text="Hello {{name}}", edition_type="basic")
chat_template = [ChatModelMessage(text="Hello {{name}}", role=PromptMessageRole.USER, edition_type="basic")]
transform._get_completion_model_prompt_messages = MagicMock(return_value=[UserPromptMessage(content="c")])
transform._get_chat_model_prompt_messages = MagicMock(return_value=[UserPromptMessage(content="h")])
completion_result = transform.get_prompt(
prompt_template=completion_template,
inputs={"name": "john"},
query="q",
files=[],
context=None,
memory_config=None,
memory=None,
model_config=model_config,
)
assert completion_result[0].content == "c"
chat_result = transform.get_prompt(
prompt_template=chat_template,
inputs={"name": "john"},
query="q",
files=[],
context=None,
memory_config=None,
memory=None,
model_config=model_config,
)
assert chat_result[0].content == "h"
invalid_result = transform.get_prompt(
prompt_template=cast(list, ["not-chat-model-message"]),
inputs={"name": "john"},
query="q",
files=[],
context=None,
memory_config=None,
memory=None,
model_config=model_config,
)
assert invalid_result == []
def test_completion_prompt_jinja2_with_files():
model_config_mock = MagicMock(spec=ModelConfigEntity)
transform = AdvancedPromptTransform()
completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2")
file = File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
)
with (
patch("core.prompt.advanced_prompt_transform.Jinja2Formatter.format", return_value="Hi John"),
patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content,
):
to_content.return_value = ImagePromptMessageContent(
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
)
messages = transform._get_completion_model_prompt_messages(
prompt_template=completion_template,
inputs={"name": "John"},
query="",
files=[file],
context=None,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert len(messages) == 1
assert isinstance(messages[0].content, list)
assert messages[0].content[0].data == "https://example.com/image.jpg"
assert isinstance(messages[0].content[1], TextPromptMessageContent)
assert messages[0].content[1].data == "Hi John"
def test_completion_prompt_basic_sets_query_variable():
model_config_mock = MagicMock(spec=ModelConfigEntity)
transform = AdvancedPromptTransform()
template = CompletionModelPromptTemplate(text="Q={{#query#}}", edition_type="basic")
messages = transform._get_completion_model_prompt_messages(
prompt_template=template,
inputs={},
query="what?",
files=[],
context=None,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert messages[0].content == "Q=what?"
def test_chat_prompt_with_variable_template_and_context():
transform = AdvancedPromptTransform(with_variable_tmpl=True)
model_config_mock = MagicMock(spec=ModelConfigEntity)
prompt_template = [ChatModelMessage(text="sys={{#node.name#}} ctx={{#context#}}", role=PromptMessageRole.SYSTEM)]
messages = transform._get_chat_model_prompt_messages(
prompt_template=prompt_template,
inputs={"#node.name#": "john"},
query=None,
files=[],
context="context-text",
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert len(messages) == 1
assert isinstance(messages[0], SystemPromptMessage)
assert messages[0].content == "sys=john ctx=context-text"
def test_chat_prompt_jinja2_branch_and_invalid_edition():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
prompt_template = [ChatModelMessage(text="Hello {{name}}", role=PromptMessageRole.USER, edition_type="jinja2")]
with patch("core.prompt.advanced_prompt_transform.Jinja2Formatter.format", return_value="Hello John"):
messages = transform._get_chat_model_prompt_messages(
prompt_template=prompt_template,
inputs={"name": "John"},
query=None,
files=[],
context=None,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert messages[0].content == "Hello John"
bad_prompt_template = [ChatModelMessage.model_construct(text="bad", role=PromptMessageRole.USER, edition_type="x")]
with pytest.raises(ValueError, match="Invalid edition type"):
transform._get_chat_model_prompt_messages(
prompt_template=bad_prompt_template,
inputs={},
query=None,
files=[],
context=None,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
def test_chat_prompt_query_template_and_query_only_branch():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
memory_config = MemoryConfig(
window=MemoryConfig.WindowConfig(enabled=False),
query_prompt_template="query={{#sys.query#}} ctx={{#context#}}",
)
prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)]
messages = transform._get_chat_model_prompt_messages(
prompt_template=prompt_template,
inputs={},
query="what",
files=[],
context="ctx",
memory_config=memory_config,
memory=None,
model_config=model_config_mock,
)
assert messages[-1].content == "query={{#sys.query#}} ctx=ctx"
def test_chat_prompt_memory_with_files_and_query():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
memory = MagicMock(spec=TokenBufferMemory)
prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)]
file = File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
)
transform._append_chat_histories = MagicMock(
side_effect=lambda memory, memory_config, prompt_messages, **kwargs: prompt_messages
)
with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content:
to_content.return_value = ImagePromptMessageContent(
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
)
messages = transform._get_chat_model_prompt_messages(
prompt_template=prompt_template,
inputs={},
query="q",
files=[file],
context=None,
memory_config=memory_config,
memory=memory,
model_config=model_config_mock,
)
assert isinstance(messages[-1].content, list)
assert messages[-1].content[1].data == "q"
def test_chat_prompt_files_without_query_updates_last_user_or_appends_new():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
file = File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
)
prompt_with_last_user = [ChatModelMessage(text="u", role=PromptMessageRole.USER)]
with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content:
to_content.return_value = ImagePromptMessageContent(
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
)
messages = transform._get_chat_model_prompt_messages(
prompt_template=prompt_with_last_user,
inputs={},
query=None,
files=[file],
context=None,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert isinstance(messages[-1].content, list)
assert messages[-1].content[1].data == "u"
prompt_without_last_user = [ChatModelMessage(text="s", role=PromptMessageRole.SYSTEM)]
with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content:
to_content.return_value = ImagePromptMessageContent(
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
)
messages = transform._get_chat_model_prompt_messages(
prompt_template=prompt_without_last_user,
inputs={},
query=None,
files=[file],
context=None,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert isinstance(messages[-1], UserPromptMessage)
assert isinstance(messages[-1].content, list)
assert messages[-1].content[1].data == ""
def test_chat_prompt_files_with_query_branch():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
file = File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
)
with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content:
to_content.return_value = ImagePromptMessageContent(
url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg"
)
messages = transform._get_chat_model_prompt_messages(
prompt_template=[],
inputs={},
query="query-text",
files=[file],
context=None,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert isinstance(messages[-1].content, list)
assert messages[-1].content[1].data == "query-text"
def test_set_context_query_histories_variable_helpers():
transform = AdvancedPromptTransform()
parser_context = PromptTemplateParser(template="{{#context#}}")
parser_query = PromptTemplateParser(template="{{#query#}}")
parser_hist = PromptTemplateParser(template="{{#histories#}}")
model_config_mock = MagicMock(spec=ModelConfigEntity)
memory_config = MemoryConfig(
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
window=MemoryConfig.WindowConfig(enabled=False),
)
assert transform._set_context_variable(None, parser_context, {})["#context#"] == ""
assert transform._set_query_variable("", parser_query, {})["#query#"] == ""
assert transform._set_query_variable("x", parser_query, {})["#query#"] == "x"
assert (
transform._set_histories_variable(
memory=None, # type: ignore[arg-type]
memory_config=memory_config,
raw_prompt="{{#histories#}}",
role_prefix=memory_config.role_prefix, # type: ignore[arg-type]
parser=parser_hist,
prompt_inputs={},
model_config=model_config_mock,
)["#histories#"]
== ""
)

View File

@ -2,12 +2,14 @@ from uuid import uuid4
from constants import UUID_NIL
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
class MockMessage:
def __init__(self, id, parent_message_id):
def __init__(self, id, parent_message_id, answer="answer"):
self.id = id
self.parent_message_id = parent_message_id
self.answer = answer
def __getitem__(self, item):
return getattr(self, item)
@ -89,3 +91,44 @@ def test_extract_thread_messages_mixed_with_legacy_messages():
result = extract_thread_messages(messages)
assert len(result) == 4
assert [msg["id"] for msg in result] == [id5, id4, id2, id1]
def test_extract_thread_messages_breaks_when_parent_is_none():
id1, id2 = str(uuid4()), str(uuid4())
messages = [MockMessage(id2, None), MockMessage(id1, UUID_NIL)]
result = extract_thread_messages(messages)
assert len(result) == 1
assert result[0].id == id2
def test_get_thread_messages_length_excludes_newly_created_empty_answer(mocker):
id1, id2 = str(uuid4()), str(uuid4())
messages = [
MockMessage(id2, id1, answer=""), # newest generated message should be excluded
MockMessage(id1, UUID_NIL, answer="ok"),
]
mock_scalars = mocker.patch("core.prompt.utils.get_thread_messages_length.db.session.scalars")
mock_scalars.return_value.all.return_value = messages
length = get_thread_messages_length("conversation-1")
assert length == 1
mock_scalars.assert_called_once()
def test_get_thread_messages_length_keeps_non_empty_latest_answer(mocker):
id1, id2 = str(uuid4()), str(uuid4())
messages = [
MockMessage(id2, id1, answer="latest-answer"),
MockMessage(id1, UUID_NIL, answer="older-answer"),
]
mock_scalars = mocker.patch("core.prompt.utils.get_thread_messages_length.db.session.scalars")
mock_scalars.return_value.all.return_value = messages
length = get_thread_messages_length("conversation-2")
assert length == 2

View File

@ -1,6 +1,11 @@
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
AudioPromptMessageContent,
ImagePromptMessageContent,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
@ -25,3 +30,82 @@ def test_dump_prompt_message():
)
data = prompt.model_dump()
assert data["content"][0].get("url") == example_url
def test_prompt_messages_to_prompt_for_saving_chat_mode():
chat_messages = [
UserPromptMessage(
content=[
TextPromptMessageContent(data="hello "),
ImagePromptMessageContent(
url="https://example.com/image1.jpg",
format="jpg",
mime_type="image/jpeg",
detail=ImagePromptMessageContent.DETAIL.HIGH,
),
AudioPromptMessageContent(
url="https://example.com/audio1.mp3",
format="mp3",
mime_type="audio/mpeg",
),
TextPromptMessageContent(data="world"),
]
),
AssistantPromptMessage(
content="assistant-text",
tool_calls=[
{
"id": "tool-1",
"type": "function",
"function": {"name": "search", "arguments": '{"q":"python"}'},
}
],
),
ToolPromptMessage(content="tool-output", name="search", tool_call_id="tool-1"),
UserPromptMessage.model_construct(role="unknown", content="skip"), # type: ignore[arg-type]
]
prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving(ModelMode.CHAT, chat_messages)
assert len(prompts) == 3
assert prompts[0]["role"] == "user"
assert prompts[0]["text"] == "hello world"
assert prompts[0]["files"][0]["type"] == "image"
assert prompts[0]["files"][1]["type"] == "audio"
assert prompts[1]["role"] == "assistant"
assert prompts[1]["text"] == "assistant-text"
assert prompts[1]["tool_calls"][0]["function"]["name"] == "search"
assert prompts[2]["role"] == "tool"
def test_prompt_messages_to_prompt_for_saving_completion_mode_with_and_without_files():
completion_message_with_files = UserPromptMessage(
content=[
TextPromptMessageContent(data="first "),
TextPromptMessageContent(data="second"),
ImagePromptMessageContent(
url="https://example.com/image2.jpg",
format="jpg",
mime_type="image/jpeg",
detail=ImagePromptMessageContent.DETAIL.LOW,
),
]
)
prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
ModelMode.COMPLETION, [completion_message_with_files]
)
assert prompts == [
{
"role": "user",
"text": "first second",
"files": prompts[0]["files"],
}
]
assert prompts[0]["files"][0]["type"] == "image"
completion_message_text_only = UserPromptMessage(content="plain text")
prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
ModelMode.COMPLETION, [completion_message_text_only]
)
assert prompts == [{"role": "user", "text": "plain text"}]

View File

@ -1,4 +1,10 @@
# from unittest.mock import MagicMock
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.prompt.prompt_transform import PromptTransform
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
# from core.app.app_config.entities import ModelConfigEntity
# from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
@ -9,44 +15,217 @@
# from core.prompt.prompt_transform import PromptTransform
# def test__calculate_rest_token():
# model_schema_mock = MagicMock(spec=AIModelEntity)
# parameter_rule_mock = MagicMock(spec=ParameterRule)
# parameter_rule_mock.name = "max_tokens"
# model_schema_mock.parameter_rules = [parameter_rule_mock]
# model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62}
class TestPromptTransform:
def test_resolve_model_runtime_requires_model_config_or_instance(self):
transform = PromptTransform()
# large_language_model_mock = MagicMock(spec=LargeLanguageModel)
# large_language_model_mock.get_num_tokens.return_value = 6
with pytest.raises(ValueError, match="Either model_config or model_instance must be provided."):
transform._resolve_model_runtime()
# provider_mock = MagicMock(spec=ProviderEntity)
# provider_mock.provider = "openai"
def test_resolve_model_runtime_builds_model_instance_from_model_config(self):
transform = PromptTransform()
fake_model_schema = SimpleNamespace(model_properties={}, parameter_rules=[])
fake_model_type_instance = MagicMock()
fake_model_type_instance.get_model_schema.return_value = fake_model_schema
fake_model_instance = SimpleNamespace(
model_type_instance=fake_model_type_instance,
model_name="resolved-model",
credentials=None,
parameters=None,
stop=None,
)
model_config = SimpleNamespace(
provider_model_bundle=object(),
model="config-model",
credentials={"api_key": "secret"},
parameters={"temperature": 0.1},
stop=["END"],
model_schema=SimpleNamespace(model_properties={}, parameter_rules=[]),
)
# provider_configuration_mock = MagicMock(spec=ProviderConfiguration)
# provider_configuration_mock.provider = provider_mock
# provider_configuration_mock.model_settings = None
with patch(
"core.prompt.prompt_transform.ModelInstance", return_value=fake_model_instance
) as model_instance_cls:
model_instance, model_schema = transform._resolve_model_runtime(model_config=model_config)
# provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
# provider_model_bundle_mock.model_type_instance = large_language_model_mock
# provider_model_bundle_mock.configuration = provider_configuration_mock
model_instance_cls.assert_called_once_with(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model,
)
fake_model_type_instance.get_model_schema.assert_called_once_with(
model="resolved-model",
credentials={"api_key": "secret"},
)
assert model_instance is fake_model_instance
assert model_instance.credentials == {"api_key": "secret"}
assert model_instance.parameters == {"temperature": 0.1}
assert model_instance.stop == ["END"]
assert model_schema is fake_model_schema
# model_config_mock = MagicMock(spec=ModelConfigEntity)
# model_config_mock.model = "gpt-4"
# model_config_mock.credentials = {}
# model_config_mock.parameters = {"max_tokens": 50}
# model_config_mock.model_schema = model_schema_mock
# model_config_mock.provider_model_bundle = provider_model_bundle_mock
def test_resolve_model_runtime_uses_model_config_schema_fallback(self):
transform = PromptTransform()
fallback_schema = SimpleNamespace(model_properties={}, parameter_rules=[])
fake_model_type_instance = MagicMock()
fake_model_type_instance.get_model_schema.return_value = None
model_instance = SimpleNamespace(
model_type_instance=fake_model_type_instance,
model_name="resolved-model",
credentials={"api_key": "secret"},
parameters={},
)
model_config = SimpleNamespace(model_schema=fallback_schema)
# prompt_transform = PromptTransform()
resolved_model_instance, resolved_schema = transform._resolve_model_runtime(
model_config=model_config,
model_instance=model_instance,
)
# prompt_messages = [UserPromptMessage(content="Hello, how are you?")]
# rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock)
assert resolved_model_instance is model_instance
assert resolved_schema is fallback_schema
# # Validate based on the mock configuration and expected logic
# expected_rest_tokens = (
# model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
# - model_config_mock.parameters["max_tokens"]
# - large_language_model_mock.get_num_tokens.return_value
# )
# assert rest_tokens == expected_rest_tokens
# assert rest_tokens == 6
def test_resolve_model_runtime_raises_when_schema_missing_without_model_config(self):
transform = PromptTransform()
fake_model_type_instance = MagicMock()
fake_model_type_instance.get_model_schema.return_value = None
model_instance = SimpleNamespace(
model_type_instance=fake_model_type_instance,
model_name="resolved-model",
credentials={"api_key": "secret"},
parameters={},
)
with pytest.raises(ValueError, match="Model schema not found for the provided model instance."):
transform._resolve_model_runtime(model_instance=model_instance)
def test_calculate_rest_token_defaults_when_context_size_missing(self):
transform = PromptTransform()
fake_model_instance = SimpleNamespace(parameters={}, get_llm_num_tokens=lambda _: 0)
fake_model_schema = SimpleNamespace(model_properties={}, parameter_rules=[])
transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema))
model_config = SimpleNamespace(
model_schema=SimpleNamespace(model_properties={}, parameter_rules=[]),
provider_model_bundle=object(),
model="test-model",
parameters={},
)
rest = transform._calculate_rest_token([], model_config=model_config)
assert rest == 2000
def test_calculate_rest_token_uses_max_tokens_and_clamps_to_zero(self):
transform = PromptTransform()
parameter_rule = SimpleNamespace(name="max_tokens", use_template=None)
fake_model_instance = SimpleNamespace(parameters={"max_tokens": 50}, get_llm_num_tokens=lambda _: 95)
fake_model_schema = SimpleNamespace(
model_properties={ModelPropertyKey.CONTEXT_SIZE: 100},
parameter_rules=[parameter_rule],
)
transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema))
model_config = SimpleNamespace(
model_schema=SimpleNamespace(
model_properties={ModelPropertyKey.CONTEXT_SIZE: 100},
parameter_rules=[parameter_rule],
),
provider_model_bundle=object(),
model="test-model",
parameters={"max_tokens": 50},
)
rest = transform._calculate_rest_token([SimpleNamespace()], model_config=model_config)
assert rest == 0
def test_calculate_rest_token_supports_use_template_parameter(self):
transform = PromptTransform()
parameter_rule = SimpleNamespace(name="generation_max", use_template="max_tokens")
fake_model_instance = SimpleNamespace(parameters={"max_tokens": 30}, get_llm_num_tokens=lambda _: 20)
fake_model_schema = SimpleNamespace(
model_properties={ModelPropertyKey.CONTEXT_SIZE: 200},
parameter_rules=[parameter_rule],
)
transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema))
model_config = SimpleNamespace(
model_schema=SimpleNamespace(
model_properties={ModelPropertyKey.CONTEXT_SIZE: 200},
parameter_rules=[parameter_rule],
),
provider_model_bundle=object(),
model="test-model",
parameters={"max_tokens": 30},
)
rest = transform._calculate_rest_token([SimpleNamespace()], model_config=model_config)
assert rest == 150
def test_get_history_messages_from_memory_with_and_without_window(self):
transform = PromptTransform()
memory = MagicMock()
memory.get_history_prompt_text.return_value = "history"
memory_config_with_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=3))
result = transform._get_history_messages_from_memory(
memory=memory,
memory_config=memory_config_with_window,
max_token_limit=100,
human_prefix="Human",
ai_prefix="Assistant",
)
assert result == "history"
memory.get_history_prompt_text.assert_called_with(
max_token_limit=100,
human_prefix="Human",
ai_prefix="Assistant",
message_limit=3,
)
memory.reset_mock()
memory_config_no_window = SimpleNamespace(window=SimpleNamespace(enabled=False, size=2))
transform._get_history_messages_from_memory(
memory=memory,
memory_config=memory_config_no_window,
max_token_limit=50,
)
memory.get_history_prompt_text.assert_called_with(max_token_limit=50)
def test_get_history_messages_list_from_memory_with_and_without_window(self):
transform = PromptTransform()
memory = MagicMock()
memory.get_history_prompt_messages.return_value = ["m1", "m2"]
memory_config_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=2))
result = transform._get_history_messages_list_from_memory(memory, memory_config_window, 120)
assert result == ["m1", "m2"]
memory.get_history_prompt_messages.assert_called_with(max_token_limit=120, message_limit=2)
memory.reset_mock()
memory.get_history_prompt_messages.return_value = ["only"]
memory_config_no_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=0))
result = transform._get_history_messages_list_from_memory(memory, memory_config_no_window, 10)
assert result == ["only"]
memory.get_history_prompt_messages.assert_called_with(max_token_limit=10, message_limit=None)
def test_append_chat_histories_extends_prompt_messages(self, monkeypatch):
transform = PromptTransform()
memory = MagicMock()
memory_config = SimpleNamespace(window=SimpleNamespace(enabled=False, size=None))
monkeypatch.setattr(transform, "_calculate_rest_token", lambda prompt_messages, **kwargs: 99)
monkeypatch.setattr(
transform,
"_get_history_messages_list_from_memory",
lambda memory, memory_config, max_token_limit: ["h1", "h2"],
)
result = transform._append_chat_histories(
memory=memory,
memory_config=memory_config,
prompt_messages=["p1"],
model_config=SimpleNamespace(),
)
assert result == ["p1", "h1", "h2"]

View File

@ -1,9 +1,29 @@
from unittest.mock import MagicMock
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_CONTEXT,
CHAT_APP_CHAT_PROMPT_CONFIG,
CHAT_APP_COMPLETION_PROMPT_CONFIG,
COMPLETION_APP_CHAT_PROMPT_CONFIG,
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
CONTEXT,
)
from core.prompt.simple_prompt_transform import SimplePromptTransform
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
TextPromptMessageContent,
UserPromptMessage,
)
from models.model import AppMode, Conversation
@ -244,3 +264,178 @@ def test__get_completion_model_prompt_messages():
assert len(prompt_messages) == 1
assert stops == prompt_rules.get("stops")
assert prompt_messages[0].content == real_prompt
def test_get_prompt_dispatches_chat_and_completion():
transform = SimplePromptTransform()
model_config_chat = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_chat.mode = "chat"
model_config_completion = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_completion.mode = "completion"
prompt_entity = SimpleNamespace(simple_prompt_template="hello")
transform._get_chat_model_prompt_messages = MagicMock(return_value=(["chat-msg"], None))
transform._get_completion_model_prompt_messages = MagicMock(return_value=(["completion-msg"], ["stop"]))
chat_messages, chat_stops = transform.get_prompt(
app_mode=AppMode.CHAT,
prompt_template_entity=prompt_entity,
inputs={"n": 1},
query="q",
files=[],
context=None,
memory=None,
model_config=model_config_chat,
)
assert chat_messages == ["chat-msg"]
assert chat_stops is None
completion_messages, completion_stops = transform.get_prompt(
app_mode=AppMode.CHAT,
prompt_template_entity=prompt_entity,
inputs={"n": 1},
query="q",
files=[],
context=None,
memory=None,
model_config=model_config_completion,
)
assert completion_messages == ["completion-msg"]
assert completion_stops == ["stop"]
def test_get_prompt_str_and_rules_type_validation_errors():
transform = SimplePromptTransform()
model_config = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config.provider = "openai"
model_config.model = "gpt-4"
valid_prompt_template = SimplePromptTransform().get_prompt_template(
AppMode.CHAT, "openai", "gpt-4", "", False, False
)["prompt_template"]
bad_custom_keys = {
"prompt_template": valid_prompt_template,
"custom_variable_keys": "not-list",
"special_variable_keys": [],
"prompt_rules": {},
}
transform.get_prompt_template = MagicMock(return_value=bad_custom_keys)
with pytest.raises(TypeError, match="custom_variable_keys"):
transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None)
bad_special_keys = {
**bad_custom_keys,
"custom_variable_keys": [],
"special_variable_keys": "not-list",
}
transform.get_prompt_template = MagicMock(return_value=bad_special_keys)
with pytest.raises(TypeError, match="special_variable_keys"):
transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None)
bad_prompt_template = {
**bad_custom_keys,
"custom_variable_keys": [],
"special_variable_keys": [],
"prompt_template": 123,
}
transform.get_prompt_template = MagicMock(return_value=bad_prompt_template)
with pytest.raises(TypeError, match="PromptTemplateParser"):
transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None)
bad_prompt_rules = {
**bad_custom_keys,
"custom_variable_keys": [],
"special_variable_keys": [],
"prompt_template": valid_prompt_template,
"prompt_rules": "not-dict",
}
transform.get_prompt_template = MagicMock(return_value=bad_prompt_rules)
with pytest.raises(TypeError, match="prompt_rules"):
transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None)
def test_chat_model_prompt_messages_uses_prompt_when_query_empty():
transform = SimplePromptTransform()
model_config = MagicMock(spec=ModelConfigWithCredentialsEntity)
transform._get_prompt_str_and_rules = MagicMock(return_value=("prompt-text", {}))
transform._get_last_user_message = MagicMock(return_value=UserPromptMessage(content="prompt-text"))
prompt_messages, _ = transform._get_chat_model_prompt_messages(
app_mode=AppMode.CHAT,
pre_prompt="",
inputs={},
query="",
files=[],
context=None,
memory=None,
model_config=model_config,
)
assert prompt_messages[0].content == "prompt-text"
transform._get_last_user_message.assert_called_once_with("prompt-text", [], None, None)
def test_completion_model_prompt_messages_empty_stops_becomes_none():
transform = SimplePromptTransform()
model_config = MagicMock(spec=ModelConfigWithCredentialsEntity)
transform._get_prompt_str_and_rules = MagicMock(return_value=("prompt", {"stops": []}))
prompt_messages, stops = transform._get_completion_model_prompt_messages(
app_mode=AppMode.CHAT,
pre_prompt="",
inputs={},
query="q",
files=[],
context=None,
memory=None,
model_config=model_config,
)
assert len(prompt_messages) == 1
assert stops is None
def test_get_last_user_message_with_files_and_context_files():
transform = SimplePromptTransform()
file = SimpleNamespace()
context_file = SimpleNamespace()
with patch("core.prompt.simple_prompt_transform.file_manager.to_prompt_message_content") as to_content:
to_content.side_effect = [
ImagePromptMessageContent(url="https://example.com/a.jpg", format="jpg", mime_type="image/jpg"),
ImagePromptMessageContent(url="https://example.com/b.jpg", format="jpg", mime_type="image/jpg"),
]
message = transform._get_last_user_message(
prompt="hello",
files=[file],
context_files=[context_file],
image_detail_config=None,
)
assert isinstance(message.content, list)
assert message.content[0].data == "https://example.com/a.jpg"
assert message.content[1].data == "https://example.com/b.jpg"
assert isinstance(message.content[2], TextPromptMessageContent)
assert message.content[2].data == "hello"
def test_prompt_file_name_branches():
transform = SimplePromptTransform()
assert transform._prompt_file_name(AppMode.CHAT, "openai", "gpt-4") == "common_chat"
assert transform._prompt_file_name(AppMode.COMPLETION, "openai", "gpt-4") == "common_completion"
assert transform._prompt_file_name(AppMode.COMPLETION, "baichuan", "Baichuan2") == "baichuan_completion"
assert transform._prompt_file_name(AppMode.CHAT, "huggingface_hub", "baichuan-13b") == "baichuan_chat"
def test_advanced_prompt_templates_constants_are_importable():
assert isinstance(CONTEXT, str)
assert isinstance(BAICHUAN_CONTEXT, str)
assert "completion_prompt_config" in CHAT_APP_COMPLETION_PROMPT_CONFIG
assert "chat_prompt_config" in CHAT_APP_CHAT_PROMPT_CONFIG
assert "chat_prompt_config" in COMPLETION_APP_CHAT_PROMPT_CONFIG
assert "completion_prompt_config" in COMPLETION_APP_COMPLETION_PROMPT_CONFIG
assert "completion_prompt_config" in BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG
assert "chat_prompt_config" in BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG
assert "chat_prompt_config" in BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
assert "completion_prompt_config" in BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG