fix: resolve remaining runtime decoupling CI failures

This commit is contained in:
-LAN-
2026-03-16 06:13:29 +08:00
parent b9324bbf18
commit c9e3e84e61
17 changed files with 175 additions and 184 deletions

View File

@ -50,6 +50,7 @@ PreparedModelInstance: TypeAlias = PreparedLLMProtocol | _LegacyModelInstance
def fetch_model_schema(*, model_instance: PreparedModelInstance) -> AIModelEntity:
model_schema: AIModelEntity | None
get_model_schema = getattr(model_instance, "get_model_schema", None)
if callable(get_model_schema):
model_schema = cast(PreparedLLMProtocol, model_instance).get_model_schema()

View File

@ -364,28 +364,35 @@ class LLMNode(Node[LLMNodeData]):
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
model_parameters = model_instance.parameters
invoke_model_parameters = dict(model_parameters)
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None]
if structured_output_enabled:
output_schema = LLMNode.fetch_structured_output_schema(
structured_output=structured_output or {},
)
request_start_time = time.perf_counter()
invoke_result = model_instance.invoke_llm_with_structured_output(
prompt_messages=prompt_messages,
json_schema=output_schema,
model_parameters=invoke_model_parameters,
stop=stop,
stream=True,
invoke_result = cast(
LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
model_instance.invoke_llm_with_structured_output(
prompt_messages=prompt_messages,
json_schema=output_schema,
model_parameters=invoke_model_parameters,
stop=stop,
stream=True,
),
)
else:
request_start_time = time.perf_counter()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=invoke_model_parameters,
tools=None,
stop=stop,
stream=True,
invoke_result = cast(
LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=invoke_model_parameters,
tools=None,
stop=stop,
stream=True,
),
)
return LLMNode.handle_invoke_result(

View File

@ -758,6 +758,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
except ValueError as exc:
raise ModelSchemaNotFoundError("Model schema not found") from exc
prompt_template: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
else:

View File

@ -55,7 +55,7 @@ class TypeMismatchError(Exception):
pass
SEGMENT_TO_VARIABLE_MAP = {
SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[Any]] = {
ArrayAnySegment: ArrayAnyVariable,
ArrayBooleanSegment: ArrayBooleanVariable,
ArrayFileSegment: ArrayFileVariable,

View File

@ -14,12 +14,19 @@ from werkzeug.http import parse_options_header
from core.helper import ssrf_proxy
from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers
from dify_graph.file.file_factory import standardize_file_type
from dify_graph.file.file_factory import (
get_file_type_by_mime_type as _get_file_type_by_mime_type,
)
from dify_graph.file.file_factory import (
standardize_file_type,
)
from extensions.ext_database import db
from models import MessageFile, ToolFile, UploadFile
logger = logging.getLogger(__name__)
get_file_type_by_mime_type = _get_file_type_by_mime_type
def build_from_message_files(
*,

View File

@ -200,7 +200,8 @@ class SummaryIndexService:
)
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens([summary_content])
embedding_tokens = tokens_list[0] if tokens_list else 0
raw_embedding_tokens = tokens_list[0] if tokens_list else 0
embedding_tokens = raw_embedding_tokens if isinstance(raw_embedding_tokens, int) else 0
except Exception as e:
logger.warning("Failed to calculate embedding tokens for summary: %s", str(e))

View File

@ -156,7 +156,6 @@ class TestCacheEmbeddingDocuments:
# Verify model was invoked with correct parameters
mock_model_instance.invoke_text_embedding.assert_called_once_with(
texts=texts,
user="test-user",
input_type=EmbeddingInputType.DOCUMENT,
)
@ -651,7 +650,6 @@ class TestCacheEmbeddingQuery:
# Verify model was invoked with QUERY input type
mock_model_instance.invoke_text_embedding.assert_called_once_with(
texts=[query],
user="test-user",
input_type=EmbeddingInputType.QUERY,
)
@ -1623,7 +1621,6 @@ class TestEmbeddingEdgeCases:
# Verify user parameter was passed to model
mock_model_instance.invoke_text_embedding.assert_called_once_with(
texts=[query],
user=user_id,
input_type=EmbeddingInputType.QUERY,
)
@ -1676,7 +1673,6 @@ class TestEmbeddingEdgeCases:
# Verify user parameter was passed
mock_model_instance.invoke_text_embedding.assert_called_once()
call_args = mock_model_instance.invoke_text_embedding.call_args
assert call_args.kwargs["user"] == user_id
assert call_args.kwargs["input_type"] == EmbeddingInputType.DOCUMENT

View File

@ -399,7 +399,9 @@ class TestParagraphIndexProcessor:
model_instance.invoke_llm.return_value = self._llm_result("text summary")
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls,
patch(
"core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager"
) as mock_provider_manager,
patch(
"core.rag.index_processor.processor.paragraph_index_processor.ModelInstance",
return_value=model_instance,
@ -410,7 +412,7 @@ class TestParagraphIndexProcessor:
),
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
):
mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock()
mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock()
summary, usage = ParagraphIndexProcessor.generate_summary(
"tenant-1",
"text content",
@ -433,7 +435,9 @@ class TestParagraphIndexProcessor:
image_content = ImagePromptMessageContent(format="url", mime_type="image/png", url="http://example.com/a.png")
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls,
patch(
"core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager"
) as mock_provider_manager,
patch(
"core.rag.index_processor.processor.paragraph_index_processor.ModelInstance",
return_value=model_instance,
@ -448,7 +452,7 @@ class TestParagraphIndexProcessor:
),
patch("core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota"),
):
mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock()
mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock()
summary, _ = ParagraphIndexProcessor.generate_summary(
"tenant-1",
"text content",
@ -469,7 +473,9 @@ class TestParagraphIndexProcessor:
image_file = SimpleNamespace()
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls,
patch(
"core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager"
) as mock_provider_manager,
patch(
"core.rag.index_processor.processor.paragraph_index_processor.ModelInstance",
return_value=model_instance,
@ -486,7 +492,7 @@ class TestParagraphIndexProcessor:
),
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
):
mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock()
mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock()
with pytest.raises(ValueError, match="Expected LLMResult"):
ParagraphIndexProcessor.generate_summary(
"tenant-1",

View File

@ -482,7 +482,8 @@ class TestIndexingRunnerTransform:
# Arrange
runner = IndexingRunner()
mock_embedding_instance = MagicMock()
runner.model_manager.get_model_instance.return_value = mock_embedding_instance
model_manager = mock_dependencies["model_manager"].return_value
model_manager.get_model_instance.return_value = mock_embedding_instance
mock_processor = MagicMock()
transformed_docs = [
@ -509,7 +510,7 @@ class TestIndexingRunnerTransform:
assert len(result) == 2
assert result[0].page_content == "Chunk 1"
assert result[1].page_content == "Chunk 2"
runner.model_manager.get_model_instance.assert_called_once_with(
model_manager.get_model_instance.assert_called_once_with(
tenant_id=sample_dataset.tenant_id,
provider=sample_dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
@ -522,6 +523,7 @@ class TestIndexingRunnerTransform:
# Arrange
runner = IndexingRunner()
sample_dataset.indexing_technique = "economy"
model_manager = mock_dependencies["model_manager"].return_value
mock_processor = MagicMock()
transformed_docs = [
@ -539,14 +541,15 @@ class TestIndexingRunnerTransform:
# Assert
assert len(result) == 1
runner.model_manager.get_model_instance.assert_not_called()
model_manager.get_model_instance.assert_not_called()
def test_transform_with_custom_segmentation(self, mock_dependencies, sample_dataset, sample_text_docs):
"""Test transformation with custom segmentation rules."""
# Arrange
runner = IndexingRunner()
mock_embedding_instance = MagicMock()
runner.model_manager.get_model_instance.return_value = mock_embedding_instance
model_manager = mock_dependencies["model_manager"].return_value
model_manager.get_model_instance.return_value = mock_embedding_instance
mock_processor = MagicMock()
transformed_docs = [Document(page_content="Custom chunk", metadata={"doc_id": "custom1", "doc_hash": "hash1"})]
@ -645,7 +648,8 @@ class TestIndexingRunnerLoad:
runner = IndexingRunner()
mock_embedding_instance = MagicMock()
mock_embedding_instance.get_text_embedding_num_tokens.return_value = 100
runner.model_manager.get_model_instance.return_value = mock_embedding_instance
model_manager = mock_dependencies["model_manager"].return_value
model_manager.get_model_instance.return_value = mock_embedding_instance
mock_processor = MagicMock()
@ -664,7 +668,7 @@ class TestIndexingRunnerLoad:
runner._load(mock_processor, sample_dataset, sample_dataset_document, sample_documents)
# Assert
runner.model_manager.get_model_instance.assert_called_once()
model_manager.get_model_instance.assert_called_once()
# Verify executor was used for parallel processing
assert mock_executor_instance.submit.called
@ -714,7 +718,8 @@ class TestIndexingRunnerLoad:
mock_embedding_instance = MagicMock()
mock_embedding_instance.get_text_embedding_num_tokens.return_value = 50
runner.model_manager.get_model_instance.return_value = mock_embedding_instance
model_manager = mock_dependencies["model_manager"].return_value
model_manager.get_model_instance.return_value = mock_embedding_instance
mock_processor = MagicMock()

View File

@ -352,7 +352,9 @@ class TestRerankModelRunner:
# Assert: Empty result is returned
assert len(result) == 0
def test_user_parameter_passed_to_model(self, rerank_runner, mock_model_instance, sample_documents):
def test_user_parameter_passed_to_model(
self, rerank_runner, mock_model_instance, sample_documents, mock_model_manager
):
"""Test that user parameter is passed to model invocation.
Verifies:
@ -366,6 +368,7 @@ class TestRerankModelRunner:
RerankDocument(index=0, text=sample_documents[0].page_content, score=0.90),
],
)
mock_model_manager.return_value.get_model_instance.return_value = mock_model_instance
mock_model_instance.invoke_rerank.return_value = mock_rerank_result
# Act: Run reranking with user parameter
@ -375,9 +378,10 @@ class TestRerankModelRunner:
user="user123",
)
# Assert: User parameter is passed to model
# Assert: User context is bound through ModelManager and the rebound instance is invoked.
mock_model_manager.assert_any_call(tenant_id="test-tenant-id", user_id="user123")
call_kwargs = mock_model_instance.invoke_rerank.call_args.kwargs
assert call_kwargs["user"] == "user123"
assert "user" not in call_kwargs
class _ForwardingBaseRerankRunner(BaseRerankRunner):
@ -539,10 +543,14 @@ class TestRerankModelRunnerMultimodal:
)
mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result
session = MagicMock()
session.query.return_value = query_chain
with (
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain),
patch("core.rag.rerank.rerank_model.ModelManager.for_tenant") as mock_model_manager,
patch("core.rag.rerank.rerank_model.db.session", session),
patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"),
):
mock_model_manager.return_value.get_model_instance.return_value = mock_model_instance
result, unique_documents = rerank_runner.fetch_multimodal_rerank(
query="query-upload-id",
documents=[text_doc],
@ -554,10 +562,11 @@ class TestRerankModelRunnerMultimodal:
assert result == rerank_result
assert unique_documents == [text_doc]
mock_model_manager.assert_any_call(tenant_id="test-tenant-id", user_id="user-1")
invoke_kwargs = mock_model_instance.invoke_multimodal_rerank.call_args.kwargs
assert invoke_kwargs["query"]["content_type"] == DocType.IMAGE
assert invoke_kwargs["docs"][0]["content"] == "text-content"
assert invoke_kwargs["user"] == "user-1"
assert "user" not in invoke_kwargs
def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner):
query_chain = Mock()
@ -595,7 +604,7 @@ class TestWeightRerankRunner:
@pytest.fixture
def mock_model_manager(self):
"""Mock ModelManager for embedding model."""
with patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant", autospec=True) as mock_manager:
with patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager:
yield mock_manager
@pytest.fixture
@ -1527,7 +1536,7 @@ class TestRerankEdgeCases:
# Mock dependencies
with (
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba,
patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant", autospec=True) as mock_manager,
patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager,
patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache,
):
mock_handler = MagicMock()
@ -1673,7 +1682,7 @@ class TestRerankPerformance:
with (
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba,
patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant", autospec=True) as mock_manager,
patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager,
patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache,
):
mock_handler = MagicMock()
@ -1824,7 +1833,7 @@ class TestRerankErrorHandling:
with (
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba,
patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant", autospec=True) as mock_manager,
patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager,
patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache,
):
mock_handler = MagicMock()

View File

@ -162,7 +162,11 @@ class TestReactMultiDatasetRouter:
model_instance = Mock()
model_instance.invoke_llm.return_value = iter([chunk])
with patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct:
with (
patch("core.rag.retrieval.router.multi_dataset_react_route.ModelManager.for_tenant") as mock_manager,
patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct,
):
mock_manager.return_value.get_model_instance.return_value = model_instance
text, returned_usage = router._invoke_llm(
completion_param={"temperature": 0.1},
model_instance=model_instance,
@ -174,6 +178,7 @@ class TestReactMultiDatasetRouter:
assert text == "part"
assert returned_usage == usage
mock_manager.assert_any_call(tenant_id="t1", user_id="u1")
mock_deduct.assert_called_once()
def test_handle_invoke_result_with_empty_usage(self) -> None:

View File

@ -111,7 +111,7 @@ class MockNodeFactory(DifyNodeFactory):
mock_config=self.mock_config,
http_request_config=self._http_request_config,
http_client=self._http_request_http_client,
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
tool_file_manager_factory=self._bound_tool_file_manager_factory,
file_manager=self._http_request_file_manager,
)
elif node_type in {

View File

@ -6,17 +6,14 @@ from unittest.mock import MagicMock
import httpx
import pytest
from core.helper import ssrf_proxy
from core.tools import signature
from core.tools.tool_file_manager import ToolFileManager
from dify_graph.file import FileTransferMethod, FileType, models
from dify_graph.file import FileTransferMethod, FileType
from dify_graph.nodes.llm.file_saver import (
FileSaverImpl,
_extract_content_type_and_extension,
_get_extension,
_validate_extension_override,
)
from models import ToolFile
_PNG_DATA = b"\x89PNG\r\n\x1a\n"
@ -27,58 +24,46 @@ def _gen_id():
class TestFileSaverImpl:
def test_save_binary_string(self, monkeypatch: pytest.MonkeyPatch):
user_id = _gen_id()
tenant_id = _gen_id()
file_type = FileType.IMAGE
mime_type = "image/png"
mock_signed_url = "https://example.com/image.png"
mock_tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_key="test-file-key",
mimetype=mime_type,
original_url=None,
name=f"{_gen_id()}.png",
size=len(_PNG_DATA),
)
mock_tool_file = MagicMock()
mock_tool_file.id = _gen_id()
mock_tool_file.name = f"{_gen_id()}.png"
mock_tool_file.file_key = "test-file-key"
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
# Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here.
mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file)
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
mocked_sign_file.return_value = mock_signed_url
file_reference = MagicMock()
file_reference_factory = MagicMock()
file_reference_factory.build_from_mapping.return_value = file_reference
http_client = MagicMock()
storage_file_manager = FileSaverImpl(
user_id=user_id,
tenant_id=tenant_id,
file_saver = FileSaverImpl(
tool_file_manager=mocked_tool_file_manager,
file_reference_factory=file_reference_factory,
http_client=http_client,
)
file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
assert file.tenant_id == tenant_id
assert file.type == file_type
assert file.transfer_method == FileTransferMethod.TOOL_FILE
assert file.extension == ".png"
assert file.mime_type == mime_type
assert file.size == len(_PNG_DATA)
assert file.related_id == mock_tool_file.id
assert file.generate_url() == mock_signed_url
file = file_saver.save_binary_string(_PNG_DATA, mime_type, file_type)
assert file is file_reference
mocked_tool_file_manager.create_file_by_raw.assert_called_once_with(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_binary=_PNG_DATA,
mimetype=mime_type,
)
mocked_sign_file.assert_called_once_with(tool_file_id=mock_tool_file.id, extension=".png", for_external=True)
file_reference_factory.build_from_mapping.assert_called_once_with(
mapping={
"type": file_type,
"transfer_method": FileTransferMethod.TOOL_FILE,
"filename": mock_tool_file.name,
"extension": ".png",
"mime_type": mime_type,
"size": len(_PNG_DATA),
"tool_file_id": mock_tool_file.id,
"related_id": mock_tool_file.id,
"storage_key": mock_tool_file.file_key,
}
)
def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch):
_TEST_URL = "https://example.com/image.png"
@ -91,8 +76,8 @@ class TestFileSaverImpl:
http_client.get.return_value = mock_response
file_saver = FileSaverImpl(
user_id=_gen_id(),
tenant_id=_gen_id(),
tool_file_manager=MagicMock(),
file_reference_factory=MagicMock(),
http_client=http_client,
)
@ -104,8 +89,6 @@ class TestFileSaverImpl:
def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch):
_TEST_URL = "https://example.com/image.png"
mime_type = "image/png"
user_id = _gen_id()
tenant_id = _gen_id()
mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response(
@ -117,21 +100,13 @@ class TestFileSaverImpl:
http_client = MagicMock()
http_client.get.return_value = mock_response
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id, http_client=http_client)
mock_tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_key="test-file-key",
mimetype=mime_type,
original_url=None,
name=f"{_gen_id()}.png",
size=len(_PNG_DATA),
file_saver = FileSaverImpl(
tool_file_manager=MagicMock(),
file_reference_factory=MagicMock(),
http_client=http_client,
)
mock_tool_file.id = _gen_id()
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)
expected_file = MagicMock()
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=expected_file)
monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string)
file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
@ -141,7 +116,7 @@ class TestFileSaverImpl:
FileType.IMAGE,
extension_override=".png",
)
assert file == mock_tool_file
assert file is expected_file
def test_validate_extension_override():

View File

@ -1,7 +1,6 @@
import logging
from collections.abc import Generator, Iterator, Sequence
from dataclasses import dataclass, field
from datetime import datetime
from decimal import Decimal
from types import SimpleNamespace
from typing import Any
@ -12,7 +11,6 @@ import pytest
import dify_graph.model_runtime.model_providers.__base.large_language_model as llm_module
# Access large_language_model members via llm_module to avoid partial import issues in CI
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.callbacks.base_callback import Callback
from dify_graph.model_runtime.entities.llm_entities import (
LLMResult,
@ -116,24 +114,10 @@ class _TestLLM(llm_module.LargeLanguageModel):
@pytest.fixture
def llm() -> _TestLLM:
plugin_provider = PluginModelProviderEntity.model_construct(
id="provider-id",
created_at=datetime.now(),
updated_at=datetime.now(),
provider="provider",
tenant_id="tenant",
plugin_unique_identifier="plugin-uid",
plugin_id="plugin-id",
declaration=MagicMock(),
)
return _TestLLM.model_construct(
tenant_id="tenant",
model_type=ModelType.LLM,
plugin_id="plugin-id",
provider_name="provider",
plugin_model_provider=plugin_provider,
started_at=1.0,
)
provider_schema = SimpleNamespace(provider="provider", label=SimpleNamespace(en_US="Provider"))
model_runtime = MagicMock()
model_runtime.get_llm_num_tokens.return_value = 0
return _TestLLM(provider_schema=provider_schema, model_runtime=model_runtime, started_at=1.0)
def test_gen_tool_call_id_is_uuid_based(monkeypatch: pytest.MonkeyPatch) -> None:
@ -280,23 +264,11 @@ def test_build_llm_result_from_chunks_accumulates_all_chunks() -> None:
assert result.message.content == "firstsecond"
def test_invoke_llm_via_plugin_passes_list_converted_stop(monkeypatch: pytest.MonkeyPatch) -> None:
invoked: dict[str, Any] = {}
class FakePluginModelClient:
def invoke_llm(self, **kwargs: Any) -> str:
invoked.update(kwargs)
return "ok"
import core.plugin.impl.model as plugin_model_module
monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient)
def test_invoke_llm_via_runtime_passes_list_converted_stop(llm: _TestLLM) -> None:
llm.model_runtime = MagicMock()
prompt_messages: Sequence[PromptMessage] = (UserPromptMessage(content="hi"),)
result = llm_module._invoke_llm_via_plugin(
tenant_id="t",
user_id="u",
plugin_id="p",
result = llm_module._invoke_llm_via_runtime(
llm_model=llm,
provider="prov",
model="m",
credentials={"k": "v"},
@ -307,21 +279,29 @@ def test_invoke_llm_via_plugin_passes_list_converted_stop(monkeypatch: pytest.Mo
stream=True,
)
assert result == "ok"
assert invoked["prompt_messages"] == list(prompt_messages)
assert invoked["stop"] == ["a", "b"]
llm.model_runtime.invoke_llm.assert_called_once_with(
provider="prov",
model="m",
credentials={"k": "v"},
model_parameters={"temp": 1},
prompt_messages=list(prompt_messages),
tools=None,
stop=("a", "b"),
stream=True,
)
assert result is llm.model_runtime.invoke_llm.return_value
def test_normalize_non_stream_plugin_result_passthrough_llmresult() -> None:
def test_normalize_non_stream_runtime_result_passthrough_llmresult() -> None:
llm_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage())
assert (
llm_module._normalize_non_stream_plugin_result(model="m", prompt_messages=[], result=llm_result) is llm_result
llm_module._normalize_non_stream_runtime_result(model="m", prompt_messages=[], result=llm_result) is llm_result
)
def test_normalize_non_stream_plugin_result_builds_from_chunks() -> None:
def test_normalize_non_stream_runtime_result_builds_from_chunks() -> None:
chunks = iter([_chunk(content="hello", usage=_usage(1, 1))])
result = llm_module._normalize_non_stream_plugin_result(
result = llm_module._normalize_non_stream_runtime_result(
model="m", prompt_messages=[UserPromptMessage(content="u")], result=chunks
)
assert isinstance(result, LLMResult)
@ -331,7 +311,7 @@ def test_normalize_non_stream_plugin_result_builds_from_chunks() -> None:
def test_invoke_non_stream_normalizes_and_sets_prompt_messages(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
plugin_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage())
monkeypatch.setattr(
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin",
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime",
lambda **_: plugin_result,
)
cb = SpyCallback()
@ -355,7 +335,7 @@ def test_invoke_stream_wraps_generator_and_triggers_callbacks(llm: _TestLLM, mon
]
)
monkeypatch.setattr(
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin",
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime",
lambda **_: plugin_chunks,
)
@ -383,7 +363,7 @@ def test_invoke_triggers_error_callbacks_and_raises_transformed(llm: _TestLLM, m
raise ValueError("plugin down")
monkeypatch.setattr(
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin", boom
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime", boom
)
cb = SpyCallback()
with pytest.raises(RuntimeError, match="transformed: plugin down"):
@ -397,8 +377,8 @@ def test_invoke_triggers_error_callbacks_and_raises_transformed(llm: _TestLLM, m
def test_invoke_raises_not_implemented_for_unsupported_result_type(
llm: _TestLLM, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr(llm_module, "_invoke_llm_via_plugin", lambda **_: "not-a-result")
monkeypatch.setattr(llm_module, "_normalize_non_stream_plugin_result", lambda **_: "not-a-result")
monkeypatch.setattr(llm_module, "_invoke_llm_via_runtime", lambda **_: "not-a-result")
monkeypatch.setattr(llm_module, "_normalize_non_stream_runtime_result", lambda **_: "not-a-result")
with pytest.raises(NotImplementedError, match="unsupported invoke result type"):
llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False)
@ -410,9 +390,9 @@ def test_invoke_appends_logging_callback_in_debug(llm: _TestLLM, monkeypatch: py
pass
monkeypatch.setattr(llm_module, "LoggingCallback", FakeLoggingCallback)
monkeypatch.setattr(llm_module.dify_config, "DEBUG", True)
monkeypatch.setattr(llm_module.logger, "isEnabledFor", lambda level: level == logging.DEBUG)
monkeypatch.setattr(
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin",
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime",
lambda **_: LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()),
)
@ -427,26 +407,22 @@ def test_invoke_appends_logging_callback_in_debug(llm: _TestLLM, monkeypatch: py
assert any(isinstance(cb, FakeLoggingCallback) for cb in captured_callbacks[0])
def test_get_num_tokens_returns_0_when_plugin_disabled(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(llm_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", False)
def test_get_num_tokens_returns_0_when_runtime_returns_0(llm: _TestLLM) -> None:
llm.model_runtime.get_llm_num_tokens.return_value = 0
assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 0
def test_get_num_tokens_uses_plugin_when_enabled(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(llm_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", True)
class FakePluginModelClient:
def get_llm_num_tokens(self, **kwargs: Any) -> int:
assert kwargs["tenant_id"] == "tenant"
assert kwargs["plugin_id"] == "plugin-id"
assert kwargs["provider"] == "provider"
assert kwargs["model_type"] == "llm"
return 42
import core.plugin.impl.model as plugin_model_module
monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient)
def test_get_num_tokens_uses_runtime(llm: _TestLLM) -> None:
llm.model_runtime.get_llm_num_tokens.return_value = 42
assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 42
llm.model_runtime.get_llm_num_tokens.assert_called_once_with(
provider="provider",
model_type=ModelType.LLM,
model="m",
credentials={},
prompt_messages=[UserPromptMessage(content="x")],
tools=None,
)
def test_calc_response_usage_uses_prices_and_latency(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:

View File

@ -237,8 +237,7 @@ class TestAudioServiceASR:
# Assert
assert result == {"text": "Transcribed text"}
mock_model_instance.invoke_speech2text.assert_called_once()
call_args = mock_model_instance.invoke_speech2text.call_args
assert call_args.kwargs["user"] == "user-123"
mock_model_manager_class.assert_called_once_with(tenant_id=app.tenant_id, user_id="user-123")
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory):
@ -398,10 +397,9 @@ class TestAudioServiceTTS:
# Assert
assert result == b"audio data"
mock_model_manager_class.assert_called_once_with(tenant_id=app.tenant_id, user_id="user-123")
mock_model_instance.invoke_tts.assert_called_once_with(
content_text="Hello world",
user="user-123",
tenant_id=app.tenant_id,
voice="en-US-Neural",
)

View File

@ -189,7 +189,7 @@ def test_vectorize_summary_retries_connection_errors_then_succeeds(monkeypatch:
embedding_model.get_text_embedding_num_tokens.return_value = [5]
model_manager = MagicMock()
model_manager.get_model_instance.return_value = embedding_model
monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager))
monkeypatch.setattr(summary_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager))
vector_instance = MagicMock()
vector_instance.add_texts.side_effect = [RuntimeError("connection timeout"), None]
@ -228,7 +228,7 @@ def test_vectorize_summary_without_session_creates_record_when_missing(monkeypat
model_manager = MagicMock()
model_manager.get_model_instance.side_effect = RuntimeError("no model")
monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager))
monkeypatch.setattr(summary_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager))
# New session used after vectorization succeeds (record not found by id nor chunk_id).
session = MagicMock(name="session")
@ -405,8 +405,8 @@ def test_vectorize_summary_updates_existing_record_found_by_chunk_id(monkeypatch
vector_instance.add_texts.return_value = None
monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance))
monkeypatch.setattr(
summary_module,
"ModelManager",
summary_module.ModelManager,
"for_tenant",
MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))),
)
@ -439,8 +439,8 @@ def test_vectorize_summary_updates_existing_record_found_by_id(monkeypatch: pyte
summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None)))
)
monkeypatch.setattr(
summary_module,
"ModelManager",
summary_module.ModelManager,
"for_tenant",
MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))),
)
@ -472,8 +472,8 @@ def test_vectorize_summary_session_enter_returns_none_triggers_runtime_error(mon
summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None)))
)
monkeypatch.setattr(
summary_module,
"ModelManager",
summary_module.ModelManager,
"for_tenant",
MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))),
)
@ -508,8 +508,8 @@ def test_vectorize_summary_created_record_becomes_none_triggers_guard(monkeypatc
summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None)))
)
monkeypatch.setattr(
summary_module,
"ModelManager",
summary_module.ModelManager,
"for_tenant",
MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))),
)

View File

@ -213,7 +213,9 @@ def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_ex
embedding_model_instance = MagicMock(name="embedding_model_instance")
model_manager_instance = MagicMock(name="model_manager_instance")
model_manager_instance.get_model_instance.return_value = embedding_model_instance
monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance))
monkeypatch.setattr(
vector_service_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager_instance)
)
generate_child_chunks_mock = MagicMock()
monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock)
@ -261,7 +263,9 @@ def test_create_segments_vector_parent_child_uses_default_embedding_model_when_p
embedding_model_instance = MagicMock()
model_manager_instance = MagicMock()
model_manager_instance.get_default_model_instance.return_value = embedding_model_instance
monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance))
monkeypatch.setattr(
vector_service_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager_instance)
)
generate_child_chunks_mock = MagicMock()
monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock)