mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
test: added test for core token buffer memory and model runtime (#32512)
Co-authored-by: rajatagarwal-oss <rajat.agarwal@infocusp.com>
This commit is contained in:
969
api/tests/unit_tests/core/memory/test_token_buffer_memory.py
Normal file
969
api/tests/unit_tests/core/memory/test_token_buffer_memory.py
Normal file
@ -0,0 +1,969 @@
|
||||
"""Comprehensive unit tests for core/memory/token_buffer_memory.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from dify_graph.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from models.model import AppMode
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers / shared fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_conversation(mode: AppMode = AppMode.CHAT) -> MagicMock:
|
||||
"""Return a minimal Conversation mock."""
|
||||
conv = MagicMock()
|
||||
conv.id = str(uuid4())
|
||||
conv.mode = mode
|
||||
conv.model_config = {}
|
||||
return conv
|
||||
|
||||
|
||||
def _make_model_instance() -> MagicMock:
|
||||
"""Return a ModelInstance mock whose token counter returns a constant."""
|
||||
mi = MagicMock()
|
||||
mi.get_llm_num_tokens.return_value = 100
|
||||
return mi
|
||||
|
||||
|
||||
def _make_message(answer: str = "hello", answer_tokens: int = 5) -> MagicMock:
|
||||
msg = MagicMock()
|
||||
msg.id = str(uuid4())
|
||||
msg.query = "user query"
|
||||
msg.answer = answer
|
||||
msg.answer_tokens = answer_tokens
|
||||
msg.workflow_run_id = str(uuid4())
|
||||
msg.created_at = MagicMock()
|
||||
return msg
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for __init__ and workflow_run_repo property
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestInit:
|
||||
def test_init_stores_conversation_and_model_instance(self):
|
||||
conv = _make_conversation()
|
||||
mi = _make_model_instance()
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
|
||||
assert mem.conversation is conv
|
||||
assert mem.model_instance is mi
|
||||
assert mem._workflow_run_repo is None
|
||||
|
||||
def test_workflow_run_repo_is_created_lazily(self):
|
||||
conv = _make_conversation()
|
||||
mi = _make_model_instance()
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
|
||||
|
||||
mock_repo = MagicMock()
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.sessionmaker") as mock_sm,
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=mock_repo,
|
||||
),
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
repo = mem.workflow_run_repo
|
||||
assert repo is mock_repo
|
||||
assert mem._workflow_run_repo is mock_repo
|
||||
|
||||
def test_workflow_run_repo_cached_after_first_access(self):
|
||||
conv = _make_conversation()
|
||||
mi = _make_model_instance()
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
|
||||
|
||||
existing_repo = MagicMock()
|
||||
mem._workflow_run_repo = existing_repo
|
||||
|
||||
with patch(
|
||||
"core.memory.token_buffer_memory.DifyAPIRepositoryFactory.create_api_workflow_run_repository"
|
||||
) as mock_factory:
|
||||
repo = mem.workflow_run_repo
|
||||
mock_factory.assert_not_called()
|
||||
assert repo is existing_repo
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for _build_prompt_message_with_files
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestBuildPromptMessageWithFiles:
|
||||
"""Tests for the private _build_prompt_message_with_files method."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mode: CHAT / AGENT_CHAT / COMPLETION (simple branch)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
|
||||
def test_chat_mode_no_files_user_message(self, mode):
|
||||
"""When file_extra_config is falsy or app_record is None → plain UserPromptMessage."""
|
||||
conv = _make_conversation(mode)
|
||||
mi = _make_model_instance()
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
|
||||
|
||||
with patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None, # falsy → file_objs = []
|
||||
):
|
||||
result = mem._build_prompt_message_with_files(
|
||||
message_files=[],
|
||||
text_content="hello",
|
||||
message=_make_message(),
|
||||
app_record=MagicMock(),
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, UserPromptMessage)
|
||||
assert result.content == "hello"
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
|
||||
def test_chat_mode_no_files_assistant_message(self, mode):
|
||||
"""Plain AssistantPromptMessage when no files and is_user_message=False."""
|
||||
conv = _make_conversation(mode)
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
with patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
):
|
||||
result = mem._build_prompt_message_with_files(
|
||||
message_files=[],
|
||||
text_content="ai reply",
|
||||
message=_make_message(),
|
||||
app_record=None,
|
||||
is_user_message=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, AssistantPromptMessage)
|
||||
assert result.content == "ai reply"
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
|
||||
def test_chat_mode_with_files_user_message(self, mode):
|
||||
"""When files are present, returns UserPromptMessage with list content."""
|
||||
conv = _make_conversation(mode)
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
mock_file_extra_config = MagicMock()
|
||||
mock_file_extra_config.image_config = None # no detail override
|
||||
|
||||
mock_file_obj = MagicMock()
|
||||
# Must be a real entity so Pydantic's tagged union discriminator can validate it
|
||||
real_image_content = ImagePromptMessageContent(
|
||||
url="http://example.com/img.png", format="png", mime_type="image/png"
|
||||
)
|
||||
|
||||
mock_message_file = MagicMock()
|
||||
mock_app_record = MagicMock()
|
||||
mock_app_record.tenant_id = "tenant-1"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=mock_file_extra_config,
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.file_factory.build_from_message_file",
|
||||
return_value=mock_file_obj,
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.file_manager.to_prompt_message_content",
|
||||
return_value=real_image_content,
|
||||
),
|
||||
):
|
||||
result = mem._build_prompt_message_with_files(
|
||||
message_files=[mock_message_file],
|
||||
text_content="user text",
|
||||
message=_make_message(),
|
||||
app_record=mock_app_record,
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, UserPromptMessage)
|
||||
assert isinstance(result.content, list)
|
||||
# Last element should be TextPromptMessageContent
|
||||
assert isinstance(result.content[-1], TextPromptMessageContent)
|
||||
assert result.content[-1].data == "user text"
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
|
||||
def test_chat_mode_with_files_assistant_message(self, mode):
|
||||
"""When files are present, returns AssistantPromptMessage with list content."""
|
||||
conv = _make_conversation(mode)
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
mock_file_extra_config = MagicMock()
|
||||
mock_file_extra_config.image_config = None
|
||||
|
||||
mock_file_obj = MagicMock()
|
||||
real_image_content = ImagePromptMessageContent(
|
||||
url="http://example.com/img.png", format="png", mime_type="image/png"
|
||||
)
|
||||
mock_app_record = MagicMock()
|
||||
mock_app_record.tenant_id = "tenant-1"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=mock_file_extra_config,
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.file_factory.build_from_message_file",
|
||||
return_value=mock_file_obj,
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.file_manager.to_prompt_message_content",
|
||||
return_value=real_image_content,
|
||||
),
|
||||
):
|
||||
result = mem._build_prompt_message_with_files(
|
||||
message_files=[MagicMock()],
|
||||
text_content="ai text",
|
||||
message=_make_message(),
|
||||
app_record=mock_app_record,
|
||||
is_user_message=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, AssistantPromptMessage)
|
||||
assert isinstance(result.content, list)
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
|
||||
def test_chat_mode_with_files_image_detail_overridden(self, mode):
|
||||
"""When image_config.detail is set, detail is taken from config."""
|
||||
conv = _make_conversation(mode)
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
mock_image_config = MagicMock()
|
||||
mock_image_config.detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
mock_file_extra_config = MagicMock()
|
||||
mock_file_extra_config.image_config = mock_image_config
|
||||
|
||||
mock_app_record = MagicMock()
|
||||
mock_app_record.tenant_id = "tenant-1"
|
||||
|
||||
real_image_content = ImagePromptMessageContent(
|
||||
url="http://example.com/img.png", format="png", mime_type="image/png"
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=mock_file_extra_config,
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.file_factory.build_from_message_file",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.file_manager.to_prompt_message_content",
|
||||
return_value=real_image_content,
|
||||
) as mock_to_prompt,
|
||||
):
|
||||
mem._build_prompt_message_with_files(
|
||||
message_files=[MagicMock()],
|
||||
text_content="user text",
|
||||
message=_make_message(),
|
||||
app_record=mock_app_record,
|
||||
is_user_message=True,
|
||||
)
|
||||
# Ensure the LOW detail was passed through
|
||||
mock_to_prompt.assert_called_once_with(
|
||||
mock_to_prompt.call_args[0][0], image_detail_config=ImagePromptMessageContent.DETAIL.LOW
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION])
|
||||
def test_chat_mode_app_record_none_returns_empty_file_objs(self, mode):
|
||||
"""app_record=None path → file_objs stays empty → plain messages."""
|
||||
conv = _make_conversation(mode)
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
mock_file_extra_config = MagicMock()
|
||||
|
||||
with patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=mock_file_extra_config,
|
||||
):
|
||||
result = mem._build_prompt_message_with_files(
|
||||
message_files=[MagicMock()],
|
||||
text_content="hello",
|
||||
message=_make_message(),
|
||||
app_record=None, # <-- forces the else branch → file_objs = []
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, UserPromptMessage)
|
||||
assert result.content == "hello"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mode: ADVANCED_CHAT / WORKFLOW
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def test_workflow_mode_no_app_raises(self, mode):
|
||||
"""Raises ValueError when conversation.app is falsy."""
|
||||
conv = _make_conversation(mode)
|
||||
conv.app = None
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
with pytest.raises(ValueError, match="App not found for conversation"):
|
||||
mem._build_prompt_message_with_files(
|
||||
message_files=[],
|
||||
text_content="text",
|
||||
message=_make_message(),
|
||||
app_record=MagicMock(),
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def test_workflow_mode_no_workflow_run_id_raises(self, mode):
|
||||
"""Raises ValueError when message.workflow_run_id is falsy."""
|
||||
conv = _make_conversation(mode)
|
||||
conv.app = MagicMock()
|
||||
|
||||
message = _make_message()
|
||||
message.workflow_run_id = None # force missing
|
||||
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow run ID not found"):
|
||||
mem._build_prompt_message_with_files(
|
||||
message_files=[],
|
||||
text_content="text",
|
||||
message=message,
|
||||
app_record=MagicMock(),
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def test_workflow_mode_workflow_run_not_found_raises(self, mode):
|
||||
"""Raises ValueError when workflow_run_repo returns None."""
|
||||
conv = _make_conversation(mode)
|
||||
mock_app = MagicMock()
|
||||
conv.app = mock_app
|
||||
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
mem._workflow_run_repo = MagicMock()
|
||||
mem._workflow_run_repo.get_workflow_run_by_id.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow run not found"):
|
||||
mem._build_prompt_message_with_files(
|
||||
message_files=[],
|
||||
text_content="text",
|
||||
message=_make_message(),
|
||||
app_record=MagicMock(),
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def test_workflow_mode_workflow_not_found_raises(self, mode):
|
||||
"""Raises ValueError when Workflow lookup returns None."""
|
||||
conv = _make_conversation(mode)
|
||||
conv.app = MagicMock()
|
||||
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.workflow_id = str(uuid4())
|
||||
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
mem._workflow_run_repo = MagicMock()
|
||||
mem._workflow_run_repo.get_workflow_run_by_id.return_value = mock_workflow_run
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
):
|
||||
mock_db.session.scalar.return_value = None # workflow not found
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
mem._build_prompt_message_with_files(
|
||||
message_files=[],
|
||||
text_content="text",
|
||||
message=_make_message(),
|
||||
app_record=MagicMock(),
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def test_workflow_mode_success_no_files_user(self, mode):
|
||||
"""Happy path: workflow mode, no message files → plain UserPromptMessage."""
|
||||
conv = _make_conversation(mode)
|
||||
conv.app = MagicMock()
|
||||
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.workflow_id = str(uuid4())
|
||||
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.features_dict = {}
|
||||
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
mem._workflow_run_repo = MagicMock()
|
||||
mem._workflow_run_repo.get_workflow_run_by_id.return_value = mock_workflow_run
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_db.session.scalar.return_value = mock_workflow
|
||||
|
||||
result = mem._build_prompt_message_with_files(
|
||||
message_files=[],
|
||||
text_content="wf text",
|
||||
message=_make_message(),
|
||||
app_record=MagicMock(),
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, UserPromptMessage)
|
||||
assert result.content == "wf text"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Invalid mode
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_invalid_mode_raises_assertion(self):
|
||||
"""Any unknown AppMode raises AssertionError."""
|
||||
conv = _make_conversation()
|
||||
conv.mode = "unknown_mode" # not in any set
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
with pytest.raises(AssertionError, match="Invalid app mode"):
|
||||
mem._build_prompt_message_with_files(
|
||||
message_files=[],
|
||||
text_content="text",
|
||||
message=_make_message(),
|
||||
app_record=MagicMock(),
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for get_history_prompt_messages
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestGetHistoryPromptMessages:
|
||||
"""Tests for get_history_prompt_messages."""
|
||||
|
||||
def _make_memory(self, mode: AppMode = AppMode.CHAT) -> TokenBufferMemory:
|
||||
conv = _make_conversation(mode)
|
||||
conv.app = MagicMock()
|
||||
return TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
def test_returns_empty_when_no_messages(self):
|
||||
mem = self._make_memory()
|
||||
with patch("core.memory.token_buffer_memory.db") as mock_db:
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
result = mem.get_history_prompt_messages()
|
||||
assert result == []
|
||||
|
||||
def test_skips_first_message_without_answer(self):
|
||||
"""The newest message (index 0 after extraction) without answer and tokens==0 is skipped."""
|
||||
mem = self._make_memory()
|
||||
|
||||
msg_no_answer = _make_message(answer="", answer_tokens=0)
|
||||
msg_no_answer.parent_message_id = None # ensures extract_thread_messages returns it
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.extract_thread_messages",
|
||||
return_value=[msg_no_answer],
|
||||
),
|
||||
):
|
||||
mock_db.session.scalars.return_value.all.side_effect = [
|
||||
[msg_no_answer], # first call: messages query
|
||||
[], # second call: user files query (never hit, but safe)
|
||||
]
|
||||
result = mem.get_history_prompt_messages()
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_message_with_answer_not_skipped(self):
|
||||
"""A message with a non-empty answer is NOT popped."""
|
||||
mem = self._make_memory()
|
||||
|
||||
msg = _make_message(answer="some answer", answer_tokens=10)
|
||||
msg.parent_message_id = None
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.extract_thread_messages",
|
||||
return_value=[msg],
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
# user files query → empty; assistant files query → empty
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
result = mem.get_history_prompt_messages()
|
||||
|
||||
assert len(result) == 2 # one user + one assistant
|
||||
|
||||
def test_message_limit_default_is_500(self):
|
||||
"""When message_limit is None the stmt is limited to 500."""
|
||||
mem = self._make_memory()
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch("core.memory.token_buffer_memory.select") as mock_select,
|
||||
patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]),
|
||||
):
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value.where.return_value.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
mem.get_history_prompt_messages(message_limit=None)
|
||||
mock_stmt.limit.assert_called_with(500)
|
||||
|
||||
def test_message_limit_clipped_to_500(self):
|
||||
"""A message_limit > 500 is clamped to 500."""
|
||||
mem = self._make_memory()
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch("core.memory.token_buffer_memory.select") as mock_select,
|
||||
patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]),
|
||||
):
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value.where.return_value.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
mem.get_history_prompt_messages(message_limit=9999)
|
||||
mock_stmt.limit.assert_called_with(500)
|
||||
|
||||
def test_message_limit_positive_used(self):
|
||||
"""A positive message_limit < 500 is used as-is."""
|
||||
mem = self._make_memory()
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch("core.memory.token_buffer_memory.select") as mock_select,
|
||||
patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]),
|
||||
):
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value.where.return_value.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
mem.get_history_prompt_messages(message_limit=10)
|
||||
mock_stmt.limit.assert_called_with(10)
|
||||
|
||||
def test_message_limit_zero_uses_default(self):
|
||||
"""message_limit=0 triggers the else branch → default 500."""
|
||||
mem = self._make_memory()
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch("core.memory.token_buffer_memory.select") as mock_select,
|
||||
patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]),
|
||||
):
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value.where.return_value.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
mem.get_history_prompt_messages(message_limit=0)
|
||||
mock_stmt.limit.assert_called_with(500)
|
||||
|
||||
def test_user_files_cause_build_with_files_call(self):
|
||||
"""When user_files is non-empty _build_prompt_message_with_files is invoked."""
|
||||
mem = self._make_memory()
|
||||
msg = _make_message()
|
||||
msg.parent_message_id = None
|
||||
|
||||
mock_user_file = MagicMock()
|
||||
mock_user_prompt = UserPromptMessage(content="from build")
|
||||
mock_assistant_prompt = AssistantPromptMessage(content="answer")
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def scalars_side_effect(stmt):
|
||||
r = MagicMock()
|
||||
if call_count["n"] == 0:
|
||||
# messages query
|
||||
r.all.return_value = [msg]
|
||||
elif call_count["n"] == 1:
|
||||
# user files
|
||||
r.all.return_value = [mock_user_file]
|
||||
else:
|
||||
# assistant files
|
||||
r.all.return_value = []
|
||||
call_count["n"] += 1
|
||||
return r
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.extract_thread_messages",
|
||||
return_value=[msg],
|
||||
),
|
||||
patch.object(
|
||||
mem,
|
||||
"_build_prompt_message_with_files",
|
||||
side_effect=[mock_user_prompt, mock_assistant_prompt],
|
||||
) as mock_build,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_db.session.scalars.side_effect = scalars_side_effect
|
||||
result = mem.get_history_prompt_messages()
|
||||
|
||||
assert mock_build.call_count >= 1
|
||||
# First call should be user message
|
||||
first_call_kwargs = mock_build.call_args_list[0][1]
|
||||
assert first_call_kwargs["is_user_message"] is True
|
||||
|
||||
def test_assistant_files_cause_build_with_files_call(self):
|
||||
"""When assistant_files is non-empty, build is called with is_user_message=False."""
|
||||
mem = self._make_memory()
|
||||
msg = _make_message()
|
||||
msg.parent_message_id = None
|
||||
|
||||
mock_assistant_file = MagicMock()
|
||||
mock_user_prompt = UserPromptMessage(content="query")
|
||||
mock_assistant_prompt = AssistantPromptMessage(content="built")
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def scalars_side_effect(stmt):
|
||||
r = MagicMock()
|
||||
if call_count["n"] == 0:
|
||||
r.all.return_value = [msg]
|
||||
elif call_count["n"] == 1:
|
||||
r.all.return_value = [] # no user files
|
||||
else:
|
||||
r.all.return_value = [mock_assistant_file]
|
||||
call_count["n"] += 1
|
||||
return r
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.extract_thread_messages",
|
||||
return_value=[msg],
|
||||
),
|
||||
patch.object(
|
||||
mem,
|
||||
"_build_prompt_message_with_files",
|
||||
return_value=mock_assistant_prompt,
|
||||
) as mock_build,
|
||||
):
|
||||
mock_db.session.scalars.side_effect = scalars_side_effect
|
||||
result = mem.get_history_prompt_messages()
|
||||
|
||||
mock_build.assert_called_once()
|
||||
call_kwargs = mock_build.call_args[1]
|
||||
assert call_kwargs["is_user_message"] is False
|
||||
|
||||
def test_token_pruning_removes_oldest_messages(self):
|
||||
"""If tokens exceed limit, oldest messages are removed until within limit."""
|
||||
conv = _make_conversation()
|
||||
conv.app = MagicMock()
|
||||
|
||||
# Model returns tokens that decrease only after removing pairs
|
||||
token_values = [3000, 1500] # first call over limit, second within
|
||||
mi = MagicMock()
|
||||
mi.get_llm_num_tokens.side_effect = token_values
|
||||
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
|
||||
|
||||
msg = _make_message()
|
||||
msg.parent_message_id = None
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def scalars_side_effect(stmt):
|
||||
r = MagicMock()
|
||||
if call_count["n"] == 0:
|
||||
r.all.return_value = [msg]
|
||||
else:
|
||||
r.all.return_value = []
|
||||
call_count["n"] += 1
|
||||
return r
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.extract_thread_messages",
|
||||
return_value=[msg],
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_db.session.scalars.side_effect = scalars_side_effect
|
||||
result = mem.get_history_prompt_messages(max_token_limit=2000)
|
||||
|
||||
# After pruning, we should have fewer than the 2 initial messages
|
||||
assert len(result) <= 1
|
||||
|
||||
def test_token_pruning_stops_at_single_message(self):
|
||||
"""Pruning stops when only 1 message remains (to prevent empty list)."""
|
||||
conv = _make_conversation()
|
||||
conv.app = MagicMock()
|
||||
|
||||
# Always over limit
|
||||
mi = MagicMock()
|
||||
mi.get_llm_num_tokens.return_value = 99999
|
||||
|
||||
mem = TokenBufferMemory(conversation=conv, model_instance=mi)
|
||||
|
||||
msg = _make_message()
|
||||
msg.parent_message_id = None
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def scalars_side_effect(stmt):
|
||||
r = MagicMock()
|
||||
if call_count["n"] == 0:
|
||||
r.all.return_value = [msg]
|
||||
else:
|
||||
r.all.return_value = []
|
||||
call_count["n"] += 1
|
||||
return r
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.extract_thread_messages",
|
||||
return_value=[msg],
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_db.session.scalars.side_effect = scalars_side_effect
|
||||
result = mem.get_history_prompt_messages(max_token_limit=1)
|
||||
|
||||
# At least 1 message should remain
|
||||
assert len(result) >= 1
|
||||
|
||||
def test_no_pruning_when_within_limit(self):
|
||||
"""When tokens ≤ limit, no pruning occurs."""
|
||||
mem = self._make_memory()
|
||||
mem.model_instance.get_llm_num_tokens.return_value = 50 # well under default 2000
|
||||
|
||||
msg = _make_message()
|
||||
msg.parent_message_id = None
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def scalars_side_effect(stmt):
|
||||
r = MagicMock()
|
||||
if call_count["n"] == 0:
|
||||
r.all.return_value = [msg]
|
||||
else:
|
||||
r.all.return_value = []
|
||||
call_count["n"] += 1
|
||||
return r
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.extract_thread_messages",
|
||||
return_value=[msg],
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_db.session.scalars.side_effect = scalars_side_effect
|
||||
result = mem.get_history_prompt_messages(max_token_limit=2000)
|
||||
|
||||
assert len(result) == 2 # user + assistant
|
||||
|
||||
def test_plain_user_and_assistant_messages_returned(self):
|
||||
"""Without files, plain UserPromptMessage and AssistantPromptMessage appear."""
|
||||
mem = self._make_memory()
|
||||
|
||||
msg = _make_message(answer="My answer")
|
||||
msg.query = "My query"
|
||||
msg.parent_message_id = None
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def scalars_side_effect(stmt):
|
||||
r = MagicMock()
|
||||
if call_count["n"] == 0:
|
||||
r.all.return_value = [msg]
|
||||
else:
|
||||
r.all.return_value = []
|
||||
call_count["n"] += 1
|
||||
return r
|
||||
|
||||
with (
|
||||
patch("core.memory.token_buffer_memory.db") as mock_db,
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.extract_thread_messages",
|
||||
return_value=[msg],
|
||||
),
|
||||
patch(
|
||||
"core.memory.token_buffer_memory.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_db.session.scalars.side_effect = scalars_side_effect
|
||||
result = mem.get_history_prompt_messages()
|
||||
|
||||
assert len(result) == 2
|
||||
user_msg, ai_msg = result
|
||||
assert isinstance(user_msg, UserPromptMessage)
|
||||
assert user_msg.content == "My query"
|
||||
assert isinstance(ai_msg, AssistantPromptMessage)
|
||||
assert ai_msg.content == "My answer"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests for get_history_prompt_text
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestGetHistoryPromptText:
|
||||
"""Tests for get_history_prompt_text."""
|
||||
|
||||
def _make_memory(self) -> TokenBufferMemory:
|
||||
conv = _make_conversation()
|
||||
conv.app = MagicMock()
|
||||
return TokenBufferMemory(conversation=conv, model_instance=_make_model_instance())
|
||||
|
||||
def test_empty_messages_returns_empty_string(self):
|
||||
mem = self._make_memory()
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=[]):
|
||||
result = mem.get_history_prompt_text()
|
||||
assert result == ""
|
||||
|
||||
def test_user_and_assistant_messages_formatted(self):
|
||||
mem = self._make_memory()
|
||||
messages = [
|
||||
UserPromptMessage(content="Hello"),
|
||||
AssistantPromptMessage(content="World"),
|
||||
]
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
|
||||
result = mem.get_history_prompt_text(human_prefix="H", ai_prefix="A")
|
||||
assert result == "H: Hello\nA: World"
|
||||
|
||||
def test_custom_prefixes_applied(self):
|
||||
mem = self._make_memory()
|
||||
messages = [
|
||||
UserPromptMessage(content="Hi"),
|
||||
AssistantPromptMessage(content="Bye"),
|
||||
]
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
|
||||
result = mem.get_history_prompt_text(human_prefix="Human", ai_prefix="Bot")
|
||||
assert "Human: Hi" in result
|
||||
assert "Bot: Bye" in result
|
||||
|
||||
def test_list_content_with_text_and_image(self):
|
||||
"""List content: TextPromptMessageContent → text; ImagePromptMessageContent → [image]."""
|
||||
mem = self._make_memory()
|
||||
messages = [
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="caption"),
|
||||
ImagePromptMessageContent(url="http://img", format="png", mime_type="image/png"),
|
||||
]
|
||||
),
|
||||
]
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
|
||||
result = mem.get_history_prompt_text()
|
||||
assert "caption" in result
|
||||
assert "[image]" in result
|
||||
|
||||
def test_list_content_text_only(self):
|
||||
mem = self._make_memory()
|
||||
messages = [
|
||||
UserPromptMessage(content=[TextPromptMessageContent(data="just text")]),
|
||||
]
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
|
||||
result = mem.get_history_prompt_text()
|
||||
assert "just text" in result
|
||||
|
||||
def test_list_content_image_only(self):
|
||||
mem = self._make_memory()
|
||||
messages = [
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
ImagePromptMessageContent(url="http://img", format="jpg", mime_type="image/jpeg"),
|
||||
]
|
||||
),
|
||||
]
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
|
||||
result = mem.get_history_prompt_text()
|
||||
assert "[image]" in result
|
||||
|
||||
def test_unknown_role_skipped(self):
|
||||
"""Messages with a role that is not USER or ASSISTANT are skipped."""
|
||||
mem = self._make_memory()
|
||||
|
||||
# Create a mock message with a SYSTEM role
|
||||
system_msg = MagicMock()
|
||||
system_msg.role = PromptMessageRole.SYSTEM
|
||||
system_msg.content = "system instruction"
|
||||
|
||||
user_msg = UserPromptMessage(content="hi")
|
||||
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=[system_msg, user_msg]):
|
||||
result = mem.get_history_prompt_text()
|
||||
|
||||
assert "system instruction" not in result
|
||||
assert "Human: hi" in result
|
||||
|
||||
def test_passes_max_token_limit_and_message_limit(self):
|
||||
"""Parameters are forwarded to get_history_prompt_messages."""
|
||||
mem = self._make_memory()
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=[]) as mock_get:
|
||||
mem.get_history_prompt_text(max_token_limit=500, message_limit=10)
|
||||
mock_get.assert_called_once_with(max_token_limit=500, message_limit=10)
|
||||
|
||||
def test_multiple_messages_joined_by_newline(self):
|
||||
mem = self._make_memory()
|
||||
messages = [
|
||||
UserPromptMessage(content="Q1"),
|
||||
AssistantPromptMessage(content="A1"),
|
||||
UserPromptMessage(content="Q2"),
|
||||
AssistantPromptMessage(content="A2"),
|
||||
]
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
|
||||
result = mem.get_history_prompt_text()
|
||||
lines = result.split("\n")
|
||||
assert len(lines) == 4
|
||||
assert lines[0] == "Human: Q1"
|
||||
assert lines[1] == "Assistant: A1"
|
||||
assert lines[2] == "Human: Q2"
|
||||
assert lines[3] == "Assistant: A2"
|
||||
|
||||
def test_assistant_list_content_formatted(self):
|
||||
"""AssistantPromptMessage with list content is also handled."""
|
||||
mem = self._make_memory()
|
||||
messages = [
|
||||
AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="response text"),
|
||||
ImagePromptMessageContent(url="http://img2", format="png", mime_type="image/png"),
|
||||
]
|
||||
),
|
||||
]
|
||||
with patch.object(mem, "get_history_prompt_messages", return_value=messages):
|
||||
result = mem.get_history_prompt_text()
|
||||
assert "response text" in result
|
||||
assert "[image]" in result
|
||||
@ -1,114 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import _increase_tool_call
|
||||
|
||||
ToolCall = AssistantPromptMessage.ToolCall
|
||||
|
||||
# CASE 1: Single tool call
|
||||
INPUTS_CASE_1 = [
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_1 = [
|
||||
ToolCall(
|
||||
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||
),
|
||||
]
|
||||
|
||||
# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...)
|
||||
INPUTS_CASE_2 = [
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_2 = [
|
||||
ToolCall(
|
||||
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||
),
|
||||
ToolCall(
|
||||
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
|
||||
),
|
||||
]
|
||||
|
||||
# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...)
|
||||
INPUTS_CASE_3 = [
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_3 = [
|
||||
ToolCall(
|
||||
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||
),
|
||||
ToolCall(
|
||||
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
|
||||
),
|
||||
]
|
||||
|
||||
# CASE 4: Tool call sequences with no IDs
|
||||
INPUTS_CASE_4 = [
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_4 = [
|
||||
ToolCall(
|
||||
id="RANDOM_ID_1",
|
||||
type="function",
|
||||
function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'),
|
||||
),
|
||||
ToolCall(
|
||||
id="RANDOM_ID_2",
|
||||
type="function",
|
||||
function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _run_case(inputs: list[ToolCall], expected: list[ToolCall]):
|
||||
actual = []
|
||||
_increase_tool_call(inputs, actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test__increase_tool_call():
|
||||
# case 1:
|
||||
_run_case(INPUTS_CASE_1, EXPECTED_CASE_1)
|
||||
|
||||
# case 2:
|
||||
_run_case(INPUTS_CASE_2, EXPECTED_CASE_2)
|
||||
|
||||
# case 3:
|
||||
_run_case(INPUTS_CASE_3, EXPECTED_CASE_3)
|
||||
|
||||
# case 4:
|
||||
mock_id_generator = MagicMock()
|
||||
mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4]
|
||||
with patch(
|
||||
"dify_graph.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator
|
||||
):
|
||||
_run_case(INPUTS_CASE_4, EXPECTED_CASE_4)
|
||||
|
||||
|
||||
def test__increase_tool_call__no_id_no_name_first_delta_should_raise():
|
||||
inputs = [
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='"value"}')),
|
||||
]
|
||||
actual: list[ToolCall] = []
|
||||
with patch("dify_graph.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", MagicMock()):
|
||||
with pytest.raises(ValueError):
|
||||
_increase_tool_call(inputs, actual)
|
||||
@ -1,126 +0,0 @@
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_plugin_result
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
*,
|
||||
model: str = "test-model",
|
||||
content: str | list[TextPromptMessageContent] | None,
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
system_fingerprint: str | None = None,
|
||||
) -> LLMResultChunk:
|
||||
message = AssistantPromptMessage(content=content, tool_calls=tool_calls or [])
|
||||
delta = LLMResultChunkDelta(index=0, message=message, usage=usage)
|
||||
return LLMResultChunk(model=model, delta=delta, system_fingerprint=system_fingerprint)
|
||||
|
||||
|
||||
def test__normalize_non_stream_plugin_result__from_first_chunk_str_content_and_tool_calls():
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
|
||||
tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments=""),
|
||||
),
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id="",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='{"arg1": '),
|
||||
),
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id="",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='"value"}'),
|
||||
),
|
||||
]
|
||||
|
||||
usage = LLMUsage.empty_usage().model_copy(update={"prompt_tokens": 1, "total_tokens": 1})
|
||||
chunk = _make_chunk(content="hello", tool_calls=tool_calls, usage=usage, system_fingerprint="fp-1")
|
||||
|
||||
result = _normalize_non_stream_plugin_result(
|
||||
model="test-model", prompt_messages=prompt_messages, result=iter([chunk])
|
||||
)
|
||||
|
||||
assert result.model == "test-model"
|
||||
assert result.prompt_messages == prompt_messages
|
||||
assert result.message.content == "hello"
|
||||
assert result.usage.prompt_tokens == 1
|
||||
assert result.system_fingerprint == "fp-1"
|
||||
assert result.message.tool_calls == [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test__normalize_non_stream_plugin_result__from_first_chunk_list_content():
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
|
||||
content_list = [TextPromptMessageContent(data="a"), TextPromptMessageContent(data="b")]
|
||||
chunk = _make_chunk(content=content_list, usage=LLMUsage.empty_usage())
|
||||
|
||||
result = _normalize_non_stream_plugin_result(
|
||||
model="test-model", prompt_messages=prompt_messages, result=iter([chunk])
|
||||
)
|
||||
|
||||
assert result.message.content == content_list
|
||||
|
||||
|
||||
def test__normalize_non_stream_plugin_result__passthrough_llm_result():
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
llm_result = LLMResult(
|
||||
model="test-model",
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content="ok"),
|
||||
usage=LLMUsage.empty_usage(),
|
||||
)
|
||||
|
||||
assert (
|
||||
_normalize_non_stream_plugin_result(model="test-model", prompt_messages=prompt_messages, result=llm_result)
|
||||
== llm_result
|
||||
)
|
||||
|
||||
|
||||
def test__normalize_non_stream_plugin_result__empty_iterator_defaults():
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
|
||||
result = _normalize_non_stream_plugin_result(model="test-model", prompt_messages=prompt_messages, result=iter([]))
|
||||
|
||||
assert result.model == "test-model"
|
||||
assert result.prompt_messages == prompt_messages
|
||||
assert result.message.content == []
|
||||
assert result.message.tool_calls == []
|
||||
assert result.usage == LLMUsage.empty_usage()
|
||||
assert result.system_fingerprint is None
|
||||
|
||||
|
||||
def test__normalize_non_stream_plugin_result__accumulates_all_chunks():
|
||||
"""All chunks are accumulated from the iterator."""
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
|
||||
closed: list[bool] = []
|
||||
|
||||
def _chunk_iter():
|
||||
try:
|
||||
yield _make_chunk(content="hello", usage=LLMUsage.empty_usage())
|
||||
yield _make_chunk(content=" world", usage=LLMUsage.empty_usage())
|
||||
finally:
|
||||
closed.append(True)
|
||||
|
||||
result = _normalize_non_stream_plugin_result(
|
||||
model="test-model",
|
||||
prompt_messages=prompt_messages,
|
||||
result=_chunk_iter(),
|
||||
)
|
||||
|
||||
assert result.message.content == "hello world"
|
||||
assert closed == [True]
|
||||
@ -1,148 +0,0 @@
|
||||
"""Tests for LLMUsage entity."""
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
|
||||
|
||||
class TestLLMUsage:
|
||||
"""Test cases for LLMUsage class."""
|
||||
|
||||
def test_from_metadata_with_all_tokens(self):
|
||||
"""Test from_metadata when all token types are provided."""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
"prompt_unit_price": 0.001,
|
||||
"completion_unit_price": 0.002,
|
||||
"total_price": 0.2,
|
||||
"currency": "USD",
|
||||
"latency": 1.5,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 50
|
||||
assert usage.total_tokens == 150
|
||||
assert usage.prompt_unit_price == Decimal("0.001")
|
||||
assert usage.completion_unit_price == Decimal("0.002")
|
||||
assert usage.total_price == Decimal("0.2")
|
||||
assert usage.currency == "USD"
|
||||
assert usage.latency == 1.5
|
||||
|
||||
def test_from_metadata_with_prompt_tokens_only(self):
|
||||
"""Test from_metadata when only prompt_tokens is provided."""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 100,
|
||||
"total_tokens": 100,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 0
|
||||
assert usage.total_tokens == 100
|
||||
|
||||
def test_from_metadata_with_completion_tokens_only(self):
|
||||
"""Test from_metadata when only completion_tokens is provided."""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 50,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 0
|
||||
assert usage.completion_tokens == 50
|
||||
assert usage.total_tokens == 50
|
||||
|
||||
def test_from_metadata_calculates_total_when_missing(self):
|
||||
"""Test from_metadata calculates total_tokens when not provided."""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 50
|
||||
assert usage.total_tokens == 150 # Should be calculated
|
||||
|
||||
def test_from_metadata_with_total_but_no_completion(self):
|
||||
"""
|
||||
Test from_metadata when total_tokens is provided but completion_tokens is 0.
|
||||
This tests the fix for issue #24360 - prompt tokens should NOT be assigned to completion_tokens.
|
||||
"""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 479,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 521,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
# This is the key fix - prompt tokens should remain as prompt tokens
|
||||
assert usage.prompt_tokens == 479
|
||||
assert usage.completion_tokens == 0
|
||||
assert usage.total_tokens == 521
|
||||
|
||||
def test_from_metadata_with_empty_metadata(self):
|
||||
"""Test from_metadata with empty metadata."""
|
||||
metadata: LLMUsageMetadata = {}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 0
|
||||
assert usage.completion_tokens == 0
|
||||
assert usage.total_tokens == 0
|
||||
assert usage.currency == "USD"
|
||||
assert usage.latency == 0.0
|
||||
|
||||
def test_from_metadata_preserves_zero_completion_tokens(self):
|
||||
"""
|
||||
Test that zero completion_tokens are preserved when explicitly set.
|
||||
This is important for agent nodes that only use prompt tokens.
|
||||
"""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 1000,
|
||||
"prompt_unit_price": 0.15,
|
||||
"completion_unit_price": 0.60,
|
||||
"prompt_price": 0.00015,
|
||||
"completion_price": 0,
|
||||
"total_price": 0.00015,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 1000
|
||||
assert usage.completion_tokens == 0
|
||||
assert usage.total_tokens == 1000
|
||||
assert usage.prompt_price == Decimal("0.00015")
|
||||
assert usage.completion_price == Decimal(0)
|
||||
assert usage.total_price == Decimal("0.00015")
|
||||
|
||||
def test_from_metadata_with_decimal_values(self):
|
||||
"""Test from_metadata handles decimal values correctly."""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
"prompt_unit_price": "0.001",
|
||||
"completion_unit_price": "0.002",
|
||||
"prompt_price": "0.1",
|
||||
"completion_price": "0.1",
|
||||
"total_price": "0.2",
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_unit_price == Decimal("0.001")
|
||||
assert usage.completion_unit_price == Decimal("0.002")
|
||||
assert usage.prompt_price == Decimal("0.1")
|
||||
assert usage.completion_price == Decimal("0.1")
|
||||
assert usage.total_price == Decimal("0.2")
|
||||
Reference in New Issue
Block a user