mirror of
https://github.com/langgenius/dify.git
synced 2026-04-24 04:45:51 +08:00
fix: resolve remaining runtime decoupling CI failures
This commit is contained in:
@ -50,6 +50,7 @@ PreparedModelInstance: TypeAlias = PreparedLLMProtocol | _LegacyModelInstance
|
||||
|
||||
|
||||
def fetch_model_schema(*, model_instance: PreparedModelInstance) -> AIModelEntity:
|
||||
model_schema: AIModelEntity | None
|
||||
get_model_schema = getattr(model_instance, "get_model_schema", None)
|
||||
if callable(get_model_schema):
|
||||
model_schema = cast(PreparedLLMProtocol, model_instance).get_model_schema()
|
||||
|
||||
@ -364,28 +364,35 @@ class LLMNode(Node[LLMNodeData]):
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
model_parameters = model_instance.parameters
|
||||
invoke_model_parameters = dict(model_parameters)
|
||||
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None]
|
||||
if structured_output_enabled:
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=structured_output or {},
|
||||
)
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
invoke_result = model_instance.invoke_llm_with_structured_output(
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=output_schema,
|
||||
model_parameters=invoke_model_parameters,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
invoke_result = cast(
|
||||
LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
|
||||
model_instance.invoke_llm_with_structured_output(
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=output_schema,
|
||||
model_parameters=invoke_model_parameters,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=invoke_model_parameters,
|
||||
tools=None,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
invoke_result = cast(
|
||||
LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=invoke_model_parameters,
|
||||
tools=None,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
),
|
||||
)
|
||||
|
||||
return LLMNode.handle_invoke_result(
|
||||
|
||||
@ -758,6 +758,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
except ValueError as exc:
|
||||
raise ModelSchemaNotFoundError("Model schema not found") from exc
|
||||
|
||||
prompt_template: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
|
||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
else:
|
||||
|
||||
@ -55,7 +55,7 @@ class TypeMismatchError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
SEGMENT_TO_VARIABLE_MAP = {
|
||||
SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[Any]] = {
|
||||
ArrayAnySegment: ArrayAnyVariable,
|
||||
ArrayBooleanSegment: ArrayBooleanVariable,
|
||||
ArrayFileSegment: ArrayFileVariable,
|
||||
|
||||
@ -14,12 +14,19 @@ from werkzeug.http import parse_options_header
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers
|
||||
from dify_graph.file.file_factory import standardize_file_type
|
||||
from dify_graph.file.file_factory import (
|
||||
get_file_type_by_mime_type as _get_file_type_by_mime_type,
|
||||
)
|
||||
from dify_graph.file.file_factory import (
|
||||
standardize_file_type,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models import MessageFile, ToolFile, UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
get_file_type_by_mime_type = _get_file_type_by_mime_type
|
||||
|
||||
|
||||
def build_from_message_files(
|
||||
*,
|
||||
|
||||
@ -200,7 +200,8 @@ class SummaryIndexService:
|
||||
)
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens([summary_content])
|
||||
embedding_tokens = tokens_list[0] if tokens_list else 0
|
||||
raw_embedding_tokens = tokens_list[0] if tokens_list else 0
|
||||
embedding_tokens = raw_embedding_tokens if isinstance(raw_embedding_tokens, int) else 0
|
||||
except Exception as e:
|
||||
logger.warning("Failed to calculate embedding tokens for summary: %s", str(e))
|
||||
|
||||
|
||||
@ -156,7 +156,6 @@ class TestCacheEmbeddingDocuments:
|
||||
# Verify model was invoked with correct parameters
|
||||
mock_model_instance.invoke_text_embedding.assert_called_once_with(
|
||||
texts=texts,
|
||||
user="test-user",
|
||||
input_type=EmbeddingInputType.DOCUMENT,
|
||||
)
|
||||
|
||||
@ -651,7 +650,6 @@ class TestCacheEmbeddingQuery:
|
||||
# Verify model was invoked with QUERY input type
|
||||
mock_model_instance.invoke_text_embedding.assert_called_once_with(
|
||||
texts=[query],
|
||||
user="test-user",
|
||||
input_type=EmbeddingInputType.QUERY,
|
||||
)
|
||||
|
||||
@ -1623,7 +1621,6 @@ class TestEmbeddingEdgeCases:
|
||||
# Verify user parameter was passed to model
|
||||
mock_model_instance.invoke_text_embedding.assert_called_once_with(
|
||||
texts=[query],
|
||||
user=user_id,
|
||||
input_type=EmbeddingInputType.QUERY,
|
||||
)
|
||||
|
||||
@ -1676,7 +1673,6 @@ class TestEmbeddingEdgeCases:
|
||||
# Verify user parameter was passed
|
||||
mock_model_instance.invoke_text_embedding.assert_called_once()
|
||||
call_args = mock_model_instance.invoke_text_embedding.call_args
|
||||
assert call_args.kwargs["user"] == user_id
|
||||
assert call_args.kwargs["input_type"] == EmbeddingInputType.DOCUMENT
|
||||
|
||||
|
||||
|
||||
@ -399,7 +399,9 @@ class TestParagraphIndexProcessor:
|
||||
model_instance.invoke_llm.return_value = self._llm_result("text summary")
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls,
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager"
|
||||
) as mock_provider_manager,
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.ModelInstance",
|
||||
return_value=model_instance,
|
||||
@ -410,7 +412,7 @@ class TestParagraphIndexProcessor:
|
||||
),
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
|
||||
):
|
||||
mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock()
|
||||
mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock()
|
||||
summary, usage = ParagraphIndexProcessor.generate_summary(
|
||||
"tenant-1",
|
||||
"text content",
|
||||
@ -433,7 +435,9 @@ class TestParagraphIndexProcessor:
|
||||
image_content = ImagePromptMessageContent(format="url", mime_type="image/png", url="http://example.com/a.png")
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls,
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager"
|
||||
) as mock_provider_manager,
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.ModelInstance",
|
||||
return_value=model_instance,
|
||||
@ -448,7 +452,7 @@ class TestParagraphIndexProcessor:
|
||||
),
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota"),
|
||||
):
|
||||
mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock()
|
||||
mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock()
|
||||
summary, _ = ParagraphIndexProcessor.generate_summary(
|
||||
"tenant-1",
|
||||
"text content",
|
||||
@ -469,7 +473,9 @@ class TestParagraphIndexProcessor:
|
||||
image_file = SimpleNamespace()
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls,
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager"
|
||||
) as mock_provider_manager,
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.ModelInstance",
|
||||
return_value=model_instance,
|
||||
@ -486,7 +492,7 @@ class TestParagraphIndexProcessor:
|
||||
),
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
|
||||
):
|
||||
mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock()
|
||||
mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock()
|
||||
with pytest.raises(ValueError, match="Expected LLMResult"):
|
||||
ParagraphIndexProcessor.generate_summary(
|
||||
"tenant-1",
|
||||
|
||||
@ -482,7 +482,8 @@ class TestIndexingRunnerTransform:
|
||||
# Arrange
|
||||
runner = IndexingRunner()
|
||||
mock_embedding_instance = MagicMock()
|
||||
runner.model_manager.get_model_instance.return_value = mock_embedding_instance
|
||||
model_manager = mock_dependencies["model_manager"].return_value
|
||||
model_manager.get_model_instance.return_value = mock_embedding_instance
|
||||
|
||||
mock_processor = MagicMock()
|
||||
transformed_docs = [
|
||||
@ -509,7 +510,7 @@ class TestIndexingRunnerTransform:
|
||||
assert len(result) == 2
|
||||
assert result[0].page_content == "Chunk 1"
|
||||
assert result[1].page_content == "Chunk 2"
|
||||
runner.model_manager.get_model_instance.assert_called_once_with(
|
||||
model_manager.get_model_instance.assert_called_once_with(
|
||||
tenant_id=sample_dataset.tenant_id,
|
||||
provider=sample_dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
@ -522,6 +523,7 @@ class TestIndexingRunnerTransform:
|
||||
# Arrange
|
||||
runner = IndexingRunner()
|
||||
sample_dataset.indexing_technique = "economy"
|
||||
model_manager = mock_dependencies["model_manager"].return_value
|
||||
|
||||
mock_processor = MagicMock()
|
||||
transformed_docs = [
|
||||
@ -539,14 +541,15 @@ class TestIndexingRunnerTransform:
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
runner.model_manager.get_model_instance.assert_not_called()
|
||||
model_manager.get_model_instance.assert_not_called()
|
||||
|
||||
def test_transform_with_custom_segmentation(self, mock_dependencies, sample_dataset, sample_text_docs):
|
||||
"""Test transformation with custom segmentation rules."""
|
||||
# Arrange
|
||||
runner = IndexingRunner()
|
||||
mock_embedding_instance = MagicMock()
|
||||
runner.model_manager.get_model_instance.return_value = mock_embedding_instance
|
||||
model_manager = mock_dependencies["model_manager"].return_value
|
||||
model_manager.get_model_instance.return_value = mock_embedding_instance
|
||||
|
||||
mock_processor = MagicMock()
|
||||
transformed_docs = [Document(page_content="Custom chunk", metadata={"doc_id": "custom1", "doc_hash": "hash1"})]
|
||||
@ -645,7 +648,8 @@ class TestIndexingRunnerLoad:
|
||||
runner = IndexingRunner()
|
||||
mock_embedding_instance = MagicMock()
|
||||
mock_embedding_instance.get_text_embedding_num_tokens.return_value = 100
|
||||
runner.model_manager.get_model_instance.return_value = mock_embedding_instance
|
||||
model_manager = mock_dependencies["model_manager"].return_value
|
||||
model_manager.get_model_instance.return_value = mock_embedding_instance
|
||||
|
||||
mock_processor = MagicMock()
|
||||
|
||||
@ -664,7 +668,7 @@ class TestIndexingRunnerLoad:
|
||||
runner._load(mock_processor, sample_dataset, sample_dataset_document, sample_documents)
|
||||
|
||||
# Assert
|
||||
runner.model_manager.get_model_instance.assert_called_once()
|
||||
model_manager.get_model_instance.assert_called_once()
|
||||
# Verify executor was used for parallel processing
|
||||
assert mock_executor_instance.submit.called
|
||||
|
||||
@ -714,7 +718,8 @@ class TestIndexingRunnerLoad:
|
||||
|
||||
mock_embedding_instance = MagicMock()
|
||||
mock_embedding_instance.get_text_embedding_num_tokens.return_value = 50
|
||||
runner.model_manager.get_model_instance.return_value = mock_embedding_instance
|
||||
model_manager = mock_dependencies["model_manager"].return_value
|
||||
model_manager.get_model_instance.return_value = mock_embedding_instance
|
||||
|
||||
mock_processor = MagicMock()
|
||||
|
||||
|
||||
@ -352,7 +352,9 @@ class TestRerankModelRunner:
|
||||
# Assert: Empty result is returned
|
||||
assert len(result) == 0
|
||||
|
||||
def test_user_parameter_passed_to_model(self, rerank_runner, mock_model_instance, sample_documents):
|
||||
def test_user_parameter_passed_to_model(
|
||||
self, rerank_runner, mock_model_instance, sample_documents, mock_model_manager
|
||||
):
|
||||
"""Test that user parameter is passed to model invocation.
|
||||
|
||||
Verifies:
|
||||
@ -366,6 +368,7 @@ class TestRerankModelRunner:
|
||||
RerankDocument(index=0, text=sample_documents[0].page_content, score=0.90),
|
||||
],
|
||||
)
|
||||
mock_model_manager.return_value.get_model_instance.return_value = mock_model_instance
|
||||
mock_model_instance.invoke_rerank.return_value = mock_rerank_result
|
||||
|
||||
# Act: Run reranking with user parameter
|
||||
@ -375,9 +378,10 @@ class TestRerankModelRunner:
|
||||
user="user123",
|
||||
)
|
||||
|
||||
# Assert: User parameter is passed to model
|
||||
# Assert: User context is bound through ModelManager and the rebound instance is invoked.
|
||||
mock_model_manager.assert_any_call(tenant_id="test-tenant-id", user_id="user123")
|
||||
call_kwargs = mock_model_instance.invoke_rerank.call_args.kwargs
|
||||
assert call_kwargs["user"] == "user123"
|
||||
assert "user" not in call_kwargs
|
||||
|
||||
|
||||
class _ForwardingBaseRerankRunner(BaseRerankRunner):
|
||||
@ -539,10 +543,14 @@ class TestRerankModelRunnerMultimodal:
|
||||
)
|
||||
mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query_chain
|
||||
with (
|
||||
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain),
|
||||
patch("core.rag.rerank.rerank_model.ModelManager.for_tenant") as mock_model_manager,
|
||||
patch("core.rag.rerank.rerank_model.db.session", session),
|
||||
patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"),
|
||||
):
|
||||
mock_model_manager.return_value.get_model_instance.return_value = mock_model_instance
|
||||
result, unique_documents = rerank_runner.fetch_multimodal_rerank(
|
||||
query="query-upload-id",
|
||||
documents=[text_doc],
|
||||
@ -554,10 +562,11 @@ class TestRerankModelRunnerMultimodal:
|
||||
|
||||
assert result == rerank_result
|
||||
assert unique_documents == [text_doc]
|
||||
mock_model_manager.assert_any_call(tenant_id="test-tenant-id", user_id="user-1")
|
||||
invoke_kwargs = mock_model_instance.invoke_multimodal_rerank.call_args.kwargs
|
||||
assert invoke_kwargs["query"]["content_type"] == DocType.IMAGE
|
||||
assert invoke_kwargs["docs"][0]["content"] == "text-content"
|
||||
assert invoke_kwargs["user"] == "user-1"
|
||||
assert "user" not in invoke_kwargs
|
||||
|
||||
def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner):
|
||||
query_chain = Mock()
|
||||
@ -595,7 +604,7 @@ class TestWeightRerankRunner:
|
||||
@pytest.fixture
|
||||
def mock_model_manager(self):
|
||||
"""Mock ModelManager for embedding model."""
|
||||
with patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant", autospec=True) as mock_manager:
|
||||
with patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager:
|
||||
yield mock_manager
|
||||
|
||||
@pytest.fixture
|
||||
@ -1527,7 +1536,7 @@ class TestRerankEdgeCases:
|
||||
# Mock dependencies
|
||||
with (
|
||||
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba,
|
||||
patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant", autospec=True) as mock_manager,
|
||||
patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager,
|
||||
patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache,
|
||||
):
|
||||
mock_handler = MagicMock()
|
||||
@ -1673,7 +1682,7 @@ class TestRerankPerformance:
|
||||
|
||||
with (
|
||||
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba,
|
||||
patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant", autospec=True) as mock_manager,
|
||||
patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager,
|
||||
patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache,
|
||||
):
|
||||
mock_handler = MagicMock()
|
||||
@ -1824,7 +1833,7 @@ class TestRerankErrorHandling:
|
||||
|
||||
with (
|
||||
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba,
|
||||
patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant", autospec=True) as mock_manager,
|
||||
patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager,
|
||||
patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache,
|
||||
):
|
||||
mock_handler = MagicMock()
|
||||
|
||||
@ -162,7 +162,11 @@ class TestReactMultiDatasetRouter:
|
||||
model_instance = Mock()
|
||||
model_instance.invoke_llm.return_value = iter([chunk])
|
||||
|
||||
with patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct:
|
||||
with (
|
||||
patch("core.rag.retrieval.router.multi_dataset_react_route.ModelManager.for_tenant") as mock_manager,
|
||||
patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct,
|
||||
):
|
||||
mock_manager.return_value.get_model_instance.return_value = model_instance
|
||||
text, returned_usage = router._invoke_llm(
|
||||
completion_param={"temperature": 0.1},
|
||||
model_instance=model_instance,
|
||||
@ -174,6 +178,7 @@ class TestReactMultiDatasetRouter:
|
||||
|
||||
assert text == "part"
|
||||
assert returned_usage == usage
|
||||
mock_manager.assert_any_call(tenant_id="t1", user_id="u1")
|
||||
mock_deduct.assert_called_once()
|
||||
|
||||
def test_handle_invoke_result_with_empty_usage(self) -> None:
|
||||
|
||||
@ -111,7 +111,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
mock_config=self.mock_config,
|
||||
http_request_config=self._http_request_config,
|
||||
http_client=self._http_request_http_client,
|
||||
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
|
||||
tool_file_manager_factory=self._bound_tool_file_manager_factory,
|
||||
file_manager=self._http_request_file_manager,
|
||||
)
|
||||
elif node_type in {
|
||||
|
||||
@ -6,17 +6,14 @@ from unittest.mock import MagicMock
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools import signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from dify_graph.file import FileTransferMethod, FileType, models
|
||||
from dify_graph.file import FileTransferMethod, FileType
|
||||
from dify_graph.nodes.llm.file_saver import (
|
||||
FileSaverImpl,
|
||||
_extract_content_type_and_extension,
|
||||
_get_extension,
|
||||
_validate_extension_override,
|
||||
)
|
||||
from models import ToolFile
|
||||
|
||||
_PNG_DATA = b"\x89PNG\r\n\x1a\n"
|
||||
|
||||
@ -27,58 +24,46 @@ def _gen_id():
|
||||
|
||||
class TestFileSaverImpl:
|
||||
def test_save_binary_string(self, monkeypatch: pytest.MonkeyPatch):
|
||||
user_id = _gen_id()
|
||||
tenant_id = _gen_id()
|
||||
file_type = FileType.IMAGE
|
||||
mime_type = "image/png"
|
||||
mock_signed_url = "https://example.com/image.png"
|
||||
mock_tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_key="test-file-key",
|
||||
mimetype=mime_type,
|
||||
original_url=None,
|
||||
name=f"{_gen_id()}.png",
|
||||
size=len(_PNG_DATA),
|
||||
)
|
||||
mock_tool_file = MagicMock()
|
||||
mock_tool_file.id = _gen_id()
|
||||
mock_tool_file.name = f"{_gen_id()}.png"
|
||||
mock_tool_file.file_key = "test-file-key"
|
||||
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
|
||||
|
||||
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
|
||||
monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
|
||||
# Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here.
|
||||
mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file)
|
||||
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
|
||||
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
|
||||
mocked_sign_file.return_value = mock_signed_url
|
||||
file_reference = MagicMock()
|
||||
file_reference_factory = MagicMock()
|
||||
file_reference_factory.build_from_mapping.return_value = file_reference
|
||||
http_client = MagicMock()
|
||||
|
||||
storage_file_manager = FileSaverImpl(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_saver = FileSaverImpl(
|
||||
tool_file_manager=mocked_tool_file_manager,
|
||||
file_reference_factory=file_reference_factory,
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
|
||||
assert file.tenant_id == tenant_id
|
||||
assert file.type == file_type
|
||||
assert file.transfer_method == FileTransferMethod.TOOL_FILE
|
||||
assert file.extension == ".png"
|
||||
assert file.mime_type == mime_type
|
||||
assert file.size == len(_PNG_DATA)
|
||||
assert file.related_id == mock_tool_file.id
|
||||
|
||||
assert file.generate_url() == mock_signed_url
|
||||
file = file_saver.save_binary_string(_PNG_DATA, mime_type, file_type)
|
||||
assert file is file_reference
|
||||
|
||||
mocked_tool_file_manager.create_file_by_raw.assert_called_once_with(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=_PNG_DATA,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
mocked_sign_file.assert_called_once_with(tool_file_id=mock_tool_file.id, extension=".png", for_external=True)
|
||||
file_reference_factory.build_from_mapping.assert_called_once_with(
|
||||
mapping={
|
||||
"type": file_type,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
"filename": mock_tool_file.name,
|
||||
"extension": ".png",
|
||||
"mime_type": mime_type,
|
||||
"size": len(_PNG_DATA),
|
||||
"tool_file_id": mock_tool_file.id,
|
||||
"related_id": mock_tool_file.id,
|
||||
"storage_key": mock_tool_file.file_key,
|
||||
}
|
||||
)
|
||||
|
||||
def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch):
|
||||
_TEST_URL = "https://example.com/image.png"
|
||||
@ -91,8 +76,8 @@ class TestFileSaverImpl:
|
||||
http_client.get.return_value = mock_response
|
||||
|
||||
file_saver = FileSaverImpl(
|
||||
user_id=_gen_id(),
|
||||
tenant_id=_gen_id(),
|
||||
tool_file_manager=MagicMock(),
|
||||
file_reference_factory=MagicMock(),
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
@ -104,8 +89,6 @@ class TestFileSaverImpl:
|
||||
def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch):
|
||||
_TEST_URL = "https://example.com/image.png"
|
||||
mime_type = "image/png"
|
||||
user_id = _gen_id()
|
||||
tenant_id = _gen_id()
|
||||
|
||||
mock_request = httpx.Request("GET", _TEST_URL)
|
||||
mock_response = httpx.Response(
|
||||
@ -117,21 +100,13 @@ class TestFileSaverImpl:
|
||||
http_client = MagicMock()
|
||||
http_client.get.return_value = mock_response
|
||||
|
||||
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id, http_client=http_client)
|
||||
mock_tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_key="test-file-key",
|
||||
mimetype=mime_type,
|
||||
original_url=None,
|
||||
name=f"{_gen_id()}.png",
|
||||
size=len(_PNG_DATA),
|
||||
file_saver = FileSaverImpl(
|
||||
tool_file_manager=MagicMock(),
|
||||
file_reference_factory=MagicMock(),
|
||||
http_client=http_client,
|
||||
)
|
||||
mock_tool_file.id = _gen_id()
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)
|
||||
expected_file = MagicMock()
|
||||
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=expected_file)
|
||||
monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string)
|
||||
|
||||
file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||
@ -141,7 +116,7 @@ class TestFileSaverImpl:
|
||||
FileType.IMAGE,
|
||||
extension_override=".png",
|
||||
)
|
||||
assert file == mock_tool_file
|
||||
assert file is expected_file
|
||||
|
||||
|
||||
def test_validate_extension_override():
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Generator, Iterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
@ -12,7 +11,6 @@ import pytest
|
||||
import dify_graph.model_runtime.model_providers.__base.large_language_model as llm_module
|
||||
|
||||
# Access large_language_model members via llm_module to avoid partial import issues in CI
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.callbacks.base_callback import Callback
|
||||
from dify_graph.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
@ -116,24 +114,10 @@ class _TestLLM(llm_module.LargeLanguageModel):
|
||||
|
||||
@pytest.fixture
|
||||
def llm() -> _TestLLM:
|
||||
plugin_provider = PluginModelProviderEntity.model_construct(
|
||||
id="provider-id",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
provider="provider",
|
||||
tenant_id="tenant",
|
||||
plugin_unique_identifier="plugin-uid",
|
||||
plugin_id="plugin-id",
|
||||
declaration=MagicMock(),
|
||||
)
|
||||
return _TestLLM.model_construct(
|
||||
tenant_id="tenant",
|
||||
model_type=ModelType.LLM,
|
||||
plugin_id="plugin-id",
|
||||
provider_name="provider",
|
||||
plugin_model_provider=plugin_provider,
|
||||
started_at=1.0,
|
||||
)
|
||||
provider_schema = SimpleNamespace(provider="provider", label=SimpleNamespace(en_US="Provider"))
|
||||
model_runtime = MagicMock()
|
||||
model_runtime.get_llm_num_tokens.return_value = 0
|
||||
return _TestLLM(provider_schema=provider_schema, model_runtime=model_runtime, started_at=1.0)
|
||||
|
||||
|
||||
def test_gen_tool_call_id_is_uuid_based(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -280,23 +264,11 @@ def test_build_llm_result_from_chunks_accumulates_all_chunks() -> None:
|
||||
assert result.message.content == "firstsecond"
|
||||
|
||||
|
||||
def test_invoke_llm_via_plugin_passes_list_converted_stop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
invoked: dict[str, Any] = {}
|
||||
|
||||
class FakePluginModelClient:
|
||||
def invoke_llm(self, **kwargs: Any) -> str:
|
||||
invoked.update(kwargs)
|
||||
return "ok"
|
||||
|
||||
import core.plugin.impl.model as plugin_model_module
|
||||
|
||||
monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient)
|
||||
|
||||
def test_invoke_llm_via_runtime_passes_list_converted_stop(llm: _TestLLM) -> None:
|
||||
llm.model_runtime = MagicMock()
|
||||
prompt_messages: Sequence[PromptMessage] = (UserPromptMessage(content="hi"),)
|
||||
result = llm_module._invoke_llm_via_plugin(
|
||||
tenant_id="t",
|
||||
user_id="u",
|
||||
plugin_id="p",
|
||||
result = llm_module._invoke_llm_via_runtime(
|
||||
llm_model=llm,
|
||||
provider="prov",
|
||||
model="m",
|
||||
credentials={"k": "v"},
|
||||
@ -307,21 +279,29 @@ def test_invoke_llm_via_plugin_passes_list_converted_stop(monkeypatch: pytest.Mo
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assert result == "ok"
|
||||
assert invoked["prompt_messages"] == list(prompt_messages)
|
||||
assert invoked["stop"] == ["a", "b"]
|
||||
llm.model_runtime.invoke_llm.assert_called_once_with(
|
||||
provider="prov",
|
||||
model="m",
|
||||
credentials={"k": "v"},
|
||||
model_parameters={"temp": 1},
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=None,
|
||||
stop=("a", "b"),
|
||||
stream=True,
|
||||
)
|
||||
assert result is llm.model_runtime.invoke_llm.return_value
|
||||
|
||||
|
||||
def test_normalize_non_stream_plugin_result_passthrough_llmresult() -> None:
|
||||
def test_normalize_non_stream_runtime_result_passthrough_llmresult() -> None:
|
||||
llm_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage())
|
||||
assert (
|
||||
llm_module._normalize_non_stream_plugin_result(model="m", prompt_messages=[], result=llm_result) is llm_result
|
||||
llm_module._normalize_non_stream_runtime_result(model="m", prompt_messages=[], result=llm_result) is llm_result
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_non_stream_plugin_result_builds_from_chunks() -> None:
|
||||
def test_normalize_non_stream_runtime_result_builds_from_chunks() -> None:
|
||||
chunks = iter([_chunk(content="hello", usage=_usage(1, 1))])
|
||||
result = llm_module._normalize_non_stream_plugin_result(
|
||||
result = llm_module._normalize_non_stream_runtime_result(
|
||||
model="m", prompt_messages=[UserPromptMessage(content="u")], result=chunks
|
||||
)
|
||||
assert isinstance(result, LLMResult)
|
||||
@ -331,7 +311,7 @@ def test_normalize_non_stream_plugin_result_builds_from_chunks() -> None:
|
||||
def test_invoke_non_stream_normalizes_and_sets_prompt_messages(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
plugin_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage())
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin",
|
||||
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime",
|
||||
lambda **_: plugin_result,
|
||||
)
|
||||
cb = SpyCallback()
|
||||
@ -355,7 +335,7 @@ def test_invoke_stream_wraps_generator_and_triggers_callbacks(llm: _TestLLM, mon
|
||||
]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin",
|
||||
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime",
|
||||
lambda **_: plugin_chunks,
|
||||
)
|
||||
|
||||
@ -383,7 +363,7 @@ def test_invoke_triggers_error_callbacks_and_raises_transformed(llm: _TestLLM, m
|
||||
raise ValueError("plugin down")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin", boom
|
||||
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime", boom
|
||||
)
|
||||
cb = SpyCallback()
|
||||
with pytest.raises(RuntimeError, match="transformed: plugin down"):
|
||||
@ -397,8 +377,8 @@ def test_invoke_triggers_error_callbacks_and_raises_transformed(llm: _TestLLM, m
|
||||
def test_invoke_raises_not_implemented_for_unsupported_result_type(
|
||||
llm: _TestLLM, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setattr(llm_module, "_invoke_llm_via_plugin", lambda **_: "not-a-result")
|
||||
monkeypatch.setattr(llm_module, "_normalize_non_stream_plugin_result", lambda **_: "not-a-result")
|
||||
monkeypatch.setattr(llm_module, "_invoke_llm_via_runtime", lambda **_: "not-a-result")
|
||||
monkeypatch.setattr(llm_module, "_normalize_non_stream_runtime_result", lambda **_: "not-a-result")
|
||||
with pytest.raises(NotImplementedError, match="unsupported invoke result type"):
|
||||
llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False)
|
||||
|
||||
@ -410,9 +390,9 @@ def test_invoke_appends_logging_callback_in_debug(llm: _TestLLM, monkeypatch: py
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(llm_module, "LoggingCallback", FakeLoggingCallback)
|
||||
monkeypatch.setattr(llm_module.dify_config, "DEBUG", True)
|
||||
monkeypatch.setattr(llm_module.logger, "isEnabledFor", lambda level: level == logging.DEBUG)
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin",
|
||||
"dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime",
|
||||
lambda **_: LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()),
|
||||
)
|
||||
|
||||
@ -427,26 +407,22 @@ def test_invoke_appends_logging_callback_in_debug(llm: _TestLLM, monkeypatch: py
|
||||
assert any(isinstance(cb, FakeLoggingCallback) for cb in captured_callbacks[0])
|
||||
|
||||
|
||||
def test_get_num_tokens_returns_0_when_plugin_disabled(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(llm_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", False)
|
||||
def test_get_num_tokens_returns_0_when_runtime_returns_0(llm: _TestLLM) -> None:
|
||||
llm.model_runtime.get_llm_num_tokens.return_value = 0
|
||||
assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 0
|
||||
|
||||
|
||||
def test_get_num_tokens_uses_plugin_when_enabled(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(llm_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", True)
|
||||
|
||||
class FakePluginModelClient:
|
||||
def get_llm_num_tokens(self, **kwargs: Any) -> int:
|
||||
assert kwargs["tenant_id"] == "tenant"
|
||||
assert kwargs["plugin_id"] == "plugin-id"
|
||||
assert kwargs["provider"] == "provider"
|
||||
assert kwargs["model_type"] == "llm"
|
||||
return 42
|
||||
|
||||
import core.plugin.impl.model as plugin_model_module
|
||||
|
||||
monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient)
|
||||
def test_get_num_tokens_uses_runtime(llm: _TestLLM) -> None:
|
||||
llm.model_runtime.get_llm_num_tokens.return_value = 42
|
||||
assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 42
|
||||
llm.model_runtime.get_llm_num_tokens.assert_called_once_with(
|
||||
provider="provider",
|
||||
model_type=ModelType.LLM,
|
||||
model="m",
|
||||
credentials={},
|
||||
prompt_messages=[UserPromptMessage(content="x")],
|
||||
tools=None,
|
||||
)
|
||||
|
||||
|
||||
def test_calc_response_usage_uses_prices_and_latency(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
@ -237,8 +237,7 @@ class TestAudioServiceASR:
|
||||
# Assert
|
||||
assert result == {"text": "Transcribed text"}
|
||||
mock_model_instance.invoke_speech2text.assert_called_once()
|
||||
call_args = mock_model_instance.invoke_speech2text.call_args
|
||||
assert call_args.kwargs["user"] == "user-123"
|
||||
mock_model_manager_class.assert_called_once_with(tenant_id=app.tenant_id, user_id="user-123")
|
||||
|
||||
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
|
||||
def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory):
|
||||
@ -398,10 +397,9 @@ class TestAudioServiceTTS:
|
||||
|
||||
# Assert
|
||||
assert result == b"audio data"
|
||||
mock_model_manager_class.assert_called_once_with(tenant_id=app.tenant_id, user_id="user-123")
|
||||
mock_model_instance.invoke_tts.assert_called_once_with(
|
||||
content_text="Hello world",
|
||||
user="user-123",
|
||||
tenant_id=app.tenant_id,
|
||||
voice="en-US-Neural",
|
||||
)
|
||||
|
||||
|
||||
@ -189,7 +189,7 @@ def test_vectorize_summary_retries_connection_errors_then_succeeds(monkeypatch:
|
||||
embedding_model.get_text_embedding_num_tokens.return_value = [5]
|
||||
model_manager = MagicMock()
|
||||
model_manager.get_model_instance.return_value = embedding_model
|
||||
monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager))
|
||||
monkeypatch.setattr(summary_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager))
|
||||
|
||||
vector_instance = MagicMock()
|
||||
vector_instance.add_texts.side_effect = [RuntimeError("connection timeout"), None]
|
||||
@ -228,7 +228,7 @@ def test_vectorize_summary_without_session_creates_record_when_missing(monkeypat
|
||||
|
||||
model_manager = MagicMock()
|
||||
model_manager.get_model_instance.side_effect = RuntimeError("no model")
|
||||
monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager))
|
||||
monkeypatch.setattr(summary_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager))
|
||||
|
||||
# New session used after vectorization succeeds (record not found by id nor chunk_id).
|
||||
session = MagicMock(name="session")
|
||||
@ -405,8 +405,8 @@ def test_vectorize_summary_updates_existing_record_found_by_chunk_id(monkeypatch
|
||||
vector_instance.add_texts.return_value = None
|
||||
monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"ModelManager",
|
||||
summary_module.ModelManager,
|
||||
"for_tenant",
|
||||
MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))),
|
||||
)
|
||||
|
||||
@ -439,8 +439,8 @@ def test_vectorize_summary_updates_existing_record_found_by_id(monkeypatch: pyte
|
||||
summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None)))
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"ModelManager",
|
||||
summary_module.ModelManager,
|
||||
"for_tenant",
|
||||
MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))),
|
||||
)
|
||||
|
||||
@ -472,8 +472,8 @@ def test_vectorize_summary_session_enter_returns_none_triggers_runtime_error(mon
|
||||
summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None)))
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"ModelManager",
|
||||
summary_module.ModelManager,
|
||||
"for_tenant",
|
||||
MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))),
|
||||
)
|
||||
|
||||
@ -508,8 +508,8 @@ def test_vectorize_summary_created_record_becomes_none_triggers_guard(monkeypatc
|
||||
summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None)))
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"ModelManager",
|
||||
summary_module.ModelManager,
|
||||
"for_tenant",
|
||||
MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))),
|
||||
)
|
||||
|
||||
|
||||
@ -213,7 +213,9 @@ def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_ex
|
||||
embedding_model_instance = MagicMock(name="embedding_model_instance")
|
||||
model_manager_instance = MagicMock(name="model_manager_instance")
|
||||
model_manager_instance.get_model_instance.return_value = embedding_model_instance
|
||||
monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance))
|
||||
monkeypatch.setattr(
|
||||
vector_service_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager_instance)
|
||||
)
|
||||
|
||||
generate_child_chunks_mock = MagicMock()
|
||||
monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock)
|
||||
@ -261,7 +263,9 @@ def test_create_segments_vector_parent_child_uses_default_embedding_model_when_p
|
||||
embedding_model_instance = MagicMock()
|
||||
model_manager_instance = MagicMock()
|
||||
model_manager_instance.get_default_model_instance.return_value = embedding_model_instance
|
||||
monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance))
|
||||
monkeypatch.setattr(
|
||||
vector_service_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager_instance)
|
||||
)
|
||||
|
||||
generate_child_chunks_mock = MagicMock()
|
||||
monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock)
|
||||
|
||||
Reference in New Issue
Block a user