mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
refactor(workflow): inject credential/model access ports into LLM nodes (#32569)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@ -82,7 +82,7 @@ class TestCacheEmbeddingDocuments:
|
||||
Mock: Configured ModelInstance with text embedding capabilities
|
||||
"""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "text-embedding-ada-002"
|
||||
model_instance.model_name = "text-embedding-ada-002"
|
||||
model_instance.provider = "openai"
|
||||
model_instance.credentials = {"api_key": "test-key"}
|
||||
|
||||
@ -597,7 +597,7 @@ class TestCacheEmbeddingQuery:
|
||||
def mock_model_instance(self):
|
||||
"""Create a mock ModelInstance for testing."""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "text-embedding-ada-002"
|
||||
model_instance.model_name = "text-embedding-ada-002"
|
||||
model_instance.provider = "openai"
|
||||
model_instance.credentials = {"api_key": "test-key"}
|
||||
return model_instance
|
||||
@ -830,7 +830,7 @@ class TestEmbeddingModelSwitching:
|
||||
"""
|
||||
# Arrange
|
||||
model_instance_ada = Mock()
|
||||
model_instance_ada.model = "text-embedding-ada-002"
|
||||
model_instance_ada.model_name = "text-embedding-ada-002"
|
||||
model_instance_ada.provider = "openai"
|
||||
|
||||
# Mock model type instance for ada
|
||||
@ -841,7 +841,7 @@ class TestEmbeddingModelSwitching:
|
||||
model_type_instance_ada.get_model_schema.return_value = model_schema_ada
|
||||
|
||||
model_instance_3_small = Mock()
|
||||
model_instance_3_small.model = "text-embedding-3-small"
|
||||
model_instance_3_small.model_name = "text-embedding-3-small"
|
||||
model_instance_3_small.provider = "openai"
|
||||
|
||||
# Mock model type instance for 3-small
|
||||
@ -914,11 +914,11 @@ class TestEmbeddingModelSwitching:
|
||||
"""
|
||||
# Arrange
|
||||
model_instance_openai = Mock()
|
||||
model_instance_openai.model = "text-embedding-ada-002"
|
||||
model_instance_openai.model_name = "text-embedding-ada-002"
|
||||
model_instance_openai.provider = "openai"
|
||||
|
||||
model_instance_cohere = Mock()
|
||||
model_instance_cohere.model = "embed-english-v3.0"
|
||||
model_instance_cohere.model_name = "embed-english-v3.0"
|
||||
model_instance_cohere.provider = "cohere"
|
||||
|
||||
cache_openai = CacheEmbedding(model_instance_openai)
|
||||
@ -1001,7 +1001,7 @@ class TestEmbeddingDimensionValidation:
|
||||
def mock_model_instance(self):
|
||||
"""Create a mock ModelInstance for testing."""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "text-embedding-ada-002"
|
||||
model_instance.model_name = "text-embedding-ada-002"
|
||||
model_instance.provider = "openai"
|
||||
model_instance.credentials = {"api_key": "test-key"}
|
||||
|
||||
@ -1123,7 +1123,7 @@ class TestEmbeddingDimensionValidation:
|
||||
"""
|
||||
# Arrange - OpenAI ada-002 (1536 dimensions)
|
||||
model_instance_ada = Mock()
|
||||
model_instance_ada.model = "text-embedding-ada-002"
|
||||
model_instance_ada.model_name = "text-embedding-ada-002"
|
||||
model_instance_ada.provider = "openai"
|
||||
|
||||
# Mock model type instance for ada
|
||||
@ -1156,7 +1156,7 @@ class TestEmbeddingDimensionValidation:
|
||||
|
||||
# Arrange - Cohere embed-english-v3.0 (1024 dimensions)
|
||||
model_instance_cohere = Mock()
|
||||
model_instance_cohere.model = "embed-english-v3.0"
|
||||
model_instance_cohere.model_name = "embed-english-v3.0"
|
||||
model_instance_cohere.provider = "cohere"
|
||||
|
||||
# Mock model type instance for cohere
|
||||
@ -1225,7 +1225,7 @@ class TestEmbeddingEdgeCases:
|
||||
- MAX_CHUNKS: 10
|
||||
"""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "text-embedding-ada-002"
|
||||
model_instance.model_name = "text-embedding-ada-002"
|
||||
model_instance.provider = "openai"
|
||||
|
||||
model_type_instance = Mock()
|
||||
@ -1702,7 +1702,7 @@ class TestEmbeddingCachePerformance:
|
||||
- MAX_CHUNKS: 10
|
||||
"""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "text-embedding-ada-002"
|
||||
model_instance.model_name = "text-embedding-ada-002"
|
||||
model_instance.provider = "openai"
|
||||
|
||||
model_type_instance = Mock()
|
||||
|
||||
@ -34,7 +34,7 @@ def create_mock_model_instance():
|
||||
mock_instance.provider_model_bundle.configuration = Mock()
|
||||
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
|
||||
mock_instance.provider = "test-provider"
|
||||
mock_instance.model = "test-model"
|
||||
mock_instance.model_name = "test-model"
|
||||
return mock_instance
|
||||
|
||||
|
||||
@ -65,7 +65,7 @@ class TestRerankModelRunner:
|
||||
mock_instance.provider_model_bundle.configuration = Mock()
|
||||
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
|
||||
mock_instance.provider = "test-provider"
|
||||
mock_instance.model = "test-model"
|
||||
mock_instance.model_name = "test-model"
|
||||
return mock_instance
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -199,11 +199,32 @@ def test_mock_config_builder():
|
||||
|
||||
def test_mock_factory_node_type_detection():
|
||||
"""Test that MockNodeFactory correctly identifies nodes to mock."""
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
|
||||
from .test_mock_factory import MockNodeFactory
|
||||
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test",
|
||||
app_id="test",
|
||||
workflow_id="test",
|
||||
graph_config={},
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||
start_at=0,
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=None, # Will be set by test
|
||||
graph_runtime_state=None, # Will be set by test
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=None,
|
||||
)
|
||||
|
||||
@ -288,7 +309,11 @@ def test_workflow_without_auto_mock():
|
||||
|
||||
def test_register_custom_mock_node():
|
||||
"""Test registering a custom mock implementation for a node type."""
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
|
||||
from .test_mock_factory import MockNodeFactory
|
||||
|
||||
@ -298,9 +323,25 @@ def test_register_custom_mock_node():
|
||||
# Custom mock implementation
|
||||
pass
|
||||
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test",
|
||||
app_id="test",
|
||||
workflow_id="test",
|
||||
graph_config={},
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||
start_at=0,
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=None,
|
||||
graph_runtime_state=None,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=None,
|
||||
)
|
||||
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import datetime
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
@ -82,7 +82,7 @@ def _build_branching_graph(
|
||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
title=title,
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=prompt_text,
|
||||
@ -101,6 +101,8 @@ def _build_branching_graph(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
credentials_provider=mock.Mock(),
|
||||
model_factory=mock.Mock(),
|
||||
)
|
||||
return llm_node
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import datetime
|
||||
import time
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
@ -78,7 +78,7 @@ def _build_llm_human_llm_graph(
|
||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
title=title,
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=prompt_text,
|
||||
@ -97,6 +97,8 @@ def _build_llm_human_llm_graph(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
credentials_provider=mock.Mock(),
|
||||
model_factory=mock.Mock(),
|
||||
)
|
||||
return llm_node
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import time
|
||||
from unittest import mock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
@ -85,6 +86,8 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
credentials_provider=mock.Mock(),
|
||||
model_factory=mock.Mock(),
|
||||
)
|
||||
return llm_node
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ This module provides a MockNodeFactory that automatically detects and mocks node
|
||||
requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
@ -74,7 +75,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
NodeType.CODE: MockCodeNode,
|
||||
}
|
||||
|
||||
def create_node(self, node_config: dict[str, Any]) -> Node:
|
||||
def create_node(self, node_config: Mapping[str, Any]) -> Node:
|
||||
"""
|
||||
Create a node instance, using mock implementations for third-party service nodes.
|
||||
|
||||
@ -123,6 +124,16 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
mock_config=self.mock_config,
|
||||
http_request_config=self._http_request_config,
|
||||
)
|
||||
elif node_type in {NodeType.LLM, NodeType.QUESTION_CLASSIFIER, NodeType.PARAMETER_EXTRACTOR}:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
)
|
||||
else:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
|
||||
@ -16,9 +16,33 @@ from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNo
|
||||
|
||||
def test_mock_factory_registers_iteration_node():
|
||||
"""Test that MockNodeFactory has iteration node registered."""
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
|
||||
# Create a MockNodeFactory instance
|
||||
factory = MockNodeFactory(graph_init_params=None, graph_runtime_state=None, mock_config=None)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test",
|
||||
app_id="test",
|
||||
workflow_id="test",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||
start_at=0,
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=None,
|
||||
)
|
||||
|
||||
# Check that iteration node is registered
|
||||
assert NodeType.ITERATION in factory._mock_node_types
|
||||
|
||||
@ -8,6 +8,7 @@ allowing tests to run without external dependencies.
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
@ -18,6 +19,7 @@ from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
||||
from core.workflow.nodes.http_request import HttpRequestNode
|
||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
@ -42,6 +44,10 @@ class MockNodeMixin:
|
||||
mock_config: Optional["MockConfig"] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if isinstance(self, (LLMNode, QuestionClassifierNode)):
|
||||
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
|
||||
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
|
||||
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
|
||||
@ -101,11 +101,32 @@ def test_node_mock_config():
|
||||
|
||||
def test_mock_factory_detection():
|
||||
"""Test MockNodeFactory node type detection."""
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
|
||||
print("Testing MockNodeFactory detection...")
|
||||
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test",
|
||||
app_id="test",
|
||||
workflow_id="test",
|
||||
graph_config={},
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||
start_at=0,
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=None,
|
||||
graph_runtime_state=None,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=None,
|
||||
)
|
||||
|
||||
@ -133,11 +154,32 @@ def test_mock_factory_detection():
|
||||
|
||||
def test_mock_factory_registration():
|
||||
"""Test registering and unregistering mock node types."""
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
|
||||
print("Testing MockNodeFactory registration...")
|
||||
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test",
|
||||
app_id="test",
|
||||
workflow_id="test",
|
||||
graph_config={},
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||
start_at=0,
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=None,
|
||||
graph_runtime_state=None,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=None,
|
||||
)
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from unittest import mock
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
@ -32,6 +33,7 @@ from core.workflow.nodes.llm.entities import (
|
||||
)
|
||||
from core.workflow.nodes.llm.file_saver import LLMFileSaver
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
@ -100,6 +102,8 @@ def llm_node(
|
||||
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState
|
||||
) -> LLMNode:
|
||||
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
|
||||
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
@ -109,13 +113,29 @@ def llm_node(
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=mock_credentials_provider,
|
||||
model_factory=mock_model_factory,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_config():
|
||||
def model_config(monkeypatch):
|
||||
from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass
|
||||
|
||||
def mock_plugin_model_providers(_self):
|
||||
providers = MockModelClass().fetch_model_providers("test")
|
||||
for provider in providers:
|
||||
provider.declaration.provider = f"{provider.plugin_id}/{provider.declaration.provider}"
|
||||
return providers
|
||||
|
||||
monkeypatch.setattr(
|
||||
ModelProviderFactory,
|
||||
"get_plugin_model_providers",
|
||||
mock_plugin_model_providers,
|
||||
)
|
||||
|
||||
# Create actual provider and model type instances
|
||||
model_provider_factory = ModelProviderFactory(tenant_id="test")
|
||||
provider_instance = model_provider_factory.get_plugin_model_provider("openai")
|
||||
@ -125,7 +145,7 @@ def model_config():
|
||||
provider_model_bundle = ProviderModelBundle(
|
||||
configuration=ProviderConfiguration(
|
||||
tenant_id="1",
|
||||
provider=provider_instance,
|
||||
provider=provider_instance.declaration,
|
||||
preferred_provider_type=ProviderType.CUSTOM,
|
||||
using_provider_type=ProviderType.CUSTOM,
|
||||
system_configuration=SystemConfiguration(enabled=False),
|
||||
@ -153,6 +173,89 @@ def model_config():
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsEntity):
|
||||
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||
|
||||
provider_model_bundle = model_config.provider_model_bundle
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
provider_model = mock.MagicMock()
|
||||
|
||||
model_instance = mock.MagicMock(
|
||||
model_type_instance=model_type_instance,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
)
|
||||
|
||||
mock_credentials_provider.fetch.return_value = {"api_key": "test"}
|
||||
mock_model_factory.init_model_instance.return_value = model_instance
|
||||
|
||||
with (
|
||||
mock.patch.object(
|
||||
provider_model_bundle.configuration.__class__,
|
||||
"get_provider_model",
|
||||
return_value=provider_model,
|
||||
),
|
||||
mock.patch.object(
|
||||
model_type_instance.__class__,
|
||||
"get_model_schema",
|
||||
return_value=model_config.model_schema,
|
||||
),
|
||||
):
|
||||
fetch_model_config(
|
||||
node_data_model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||
credentials_provider=mock_credentials_provider,
|
||||
model_factory=mock_model_factory,
|
||||
)
|
||||
|
||||
mock_credentials_provider.fetch.assert_called_once_with("openai", "gpt-3.5-turbo")
|
||||
mock_model_factory.init_model_instance.assert_called_once_with("openai", "gpt-3.5-turbo")
|
||||
provider_model.raise_for_status.assert_called_once()
|
||||
|
||||
|
||||
def test_dify_model_access_adapters_call_managers():
|
||||
mock_provider_manager = mock.MagicMock()
|
||||
mock_model_manager = mock.MagicMock()
|
||||
mock_configurations = mock.MagicMock()
|
||||
mock_provider_configuration = mock.MagicMock()
|
||||
mock_provider_model = mock.MagicMock()
|
||||
|
||||
mock_configurations.get.return_value = mock_provider_configuration
|
||||
mock_provider_configuration.get_provider_model.return_value = mock_provider_model
|
||||
mock_provider_configuration.get_current_credentials.return_value = {"api_key": "test"}
|
||||
|
||||
credentials_provider = DifyCredentialsProvider(
|
||||
tenant_id="tenant",
|
||||
provider_manager=mock_provider_manager,
|
||||
)
|
||||
model_factory = DifyModelFactory(
|
||||
tenant_id="tenant",
|
||||
model_manager=mock_model_manager,
|
||||
)
|
||||
|
||||
mock_provider_manager.get_configurations.return_value = mock_configurations
|
||||
|
||||
credentials_provider.fetch("openai", "gpt-3.5-turbo")
|
||||
model_factory.init_model_instance("openai", "gpt-3.5-turbo")
|
||||
|
||||
mock_provider_manager.get_configurations.assert_called_once_with("tenant")
|
||||
mock_configurations.get.assert_called_once_with("openai")
|
||||
mock_provider_configuration.get_provider_model.assert_called_once_with(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
mock_provider_configuration.get_current_credentials.assert_called_once_with(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
mock_provider_model.raise_for_status.assert_called_once()
|
||||
mock_model_manager.get_model_instance.assert_called_once_with(
|
||||
tenant_id="tenant",
|
||||
provider="openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_files_with_file_segment():
|
||||
file = File(
|
||||
id="1",
|
||||
@ -485,6 +588,8 @@ def test_handle_list_messages_basic(llm_node):
|
||||
@pytest.fixture
|
||||
def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_state) -> tuple[LLMNode, LLMFileSaver]:
|
||||
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
|
||||
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
@ -494,6 +599,8 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=mock_credentials_provider,
|
||||
model_factory=mock_model_factory,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
return node, mock_file_saver
|
||||
|
||||
@ -642,8 +642,16 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
|
||||
|
||||
# Mock embedding model
|
||||
mock_embedding_model = Mock()
|
||||
mock_embedding_model.model = "text-embedding-ada-002"
|
||||
mock_embedding_model.model_name = "text-embedding-ada-002"
|
||||
mock_embedding_model.provider = "openai"
|
||||
mock_embedding_model.credentials = {}
|
||||
|
||||
mock_model_schema = Mock()
|
||||
mock_model_schema.features = []
|
||||
|
||||
mock_text_embedding_model = Mock()
|
||||
mock_text_embedding_model.get_model_schema.return_value = mock_model_schema
|
||||
mock_embedding_model.model_type_instance = mock_text_embedding_model
|
||||
|
||||
mock_model_instance = Mock()
|
||||
mock_model_instance.get_model_instance.return_value = mock_embedding_model
|
||||
|
||||
@ -174,7 +174,7 @@ class DatasetServiceTestDataFactory:
|
||||
Mock: Embedding model mock with model and provider attributes
|
||||
"""
|
||||
embedding_model = Mock()
|
||||
embedding_model.model = model
|
||||
embedding_model.model_name = model
|
||||
embedding_model.provider = provider
|
||||
return embedding_model
|
||||
|
||||
@ -434,7 +434,7 @@ class TestDatasetServiceCreateDataset:
|
||||
# Assert
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.embedding_model_provider == embedding_model.provider
|
||||
assert result.embedding_model == embedding_model.model
|
||||
assert result.embedding_model == embedding_model.model_name
|
||||
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
|
||||
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
|
||||
@ -46,7 +46,7 @@ class DatasetCreateTestDataFactory:
|
||||
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
|
||||
"""Create a mock embedding model."""
|
||||
embedding_model = Mock()
|
||||
embedding_model.model = model
|
||||
embedding_model.model_name = model
|
||||
embedding_model.provider = provider
|
||||
return embedding_model
|
||||
|
||||
@ -244,7 +244,7 @@ class TestDatasetServiceCreateEmptyDataset:
|
||||
# Assert
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.embedding_model_provider == embedding_model.provider
|
||||
assert result.embedding_model == embedding_model.model
|
||||
assert result.embedding_model == embedding_model.model_name
|
||||
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
|
||||
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user