refactor(workflow): inject credential/model access ports into LLM nodes (#32569)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-02-27 14:36:41 +08:00
committed by GitHub
parent d20880d102
commit a694533fc9
38 changed files with 676 additions and 179 deletions

View File

@ -82,7 +82,7 @@ class TestCacheEmbeddingDocuments:
Mock: Configured ModelInstance with text embedding capabilities
"""
model_instance = Mock()
model_instance.model = "text-embedding-ada-002"
model_instance.model_name = "text-embedding-ada-002"
model_instance.provider = "openai"
model_instance.credentials = {"api_key": "test-key"}
@ -597,7 +597,7 @@ class TestCacheEmbeddingQuery:
def mock_model_instance(self):
"""Create a mock ModelInstance for testing."""
model_instance = Mock()
model_instance.model = "text-embedding-ada-002"
model_instance.model_name = "text-embedding-ada-002"
model_instance.provider = "openai"
model_instance.credentials = {"api_key": "test-key"}
return model_instance
@ -830,7 +830,7 @@ class TestEmbeddingModelSwitching:
"""
# Arrange
model_instance_ada = Mock()
model_instance_ada.model = "text-embedding-ada-002"
model_instance_ada.model_name = "text-embedding-ada-002"
model_instance_ada.provider = "openai"
# Mock model type instance for ada
@ -841,7 +841,7 @@ class TestEmbeddingModelSwitching:
model_type_instance_ada.get_model_schema.return_value = model_schema_ada
model_instance_3_small = Mock()
model_instance_3_small.model = "text-embedding-3-small"
model_instance_3_small.model_name = "text-embedding-3-small"
model_instance_3_small.provider = "openai"
# Mock model type instance for 3-small
@ -914,11 +914,11 @@ class TestEmbeddingModelSwitching:
"""
# Arrange
model_instance_openai = Mock()
model_instance_openai.model = "text-embedding-ada-002"
model_instance_openai.model_name = "text-embedding-ada-002"
model_instance_openai.provider = "openai"
model_instance_cohere = Mock()
model_instance_cohere.model = "embed-english-v3.0"
model_instance_cohere.model_name = "embed-english-v3.0"
model_instance_cohere.provider = "cohere"
cache_openai = CacheEmbedding(model_instance_openai)
@ -1001,7 +1001,7 @@ class TestEmbeddingDimensionValidation:
def mock_model_instance(self):
"""Create a mock ModelInstance for testing."""
model_instance = Mock()
model_instance.model = "text-embedding-ada-002"
model_instance.model_name = "text-embedding-ada-002"
model_instance.provider = "openai"
model_instance.credentials = {"api_key": "test-key"}
@ -1123,7 +1123,7 @@ class TestEmbeddingDimensionValidation:
"""
# Arrange - OpenAI ada-002 (1536 dimensions)
model_instance_ada = Mock()
model_instance_ada.model = "text-embedding-ada-002"
model_instance_ada.model_name = "text-embedding-ada-002"
model_instance_ada.provider = "openai"
# Mock model type instance for ada
@ -1156,7 +1156,7 @@ class TestEmbeddingDimensionValidation:
# Arrange - Cohere embed-english-v3.0 (1024 dimensions)
model_instance_cohere = Mock()
model_instance_cohere.model = "embed-english-v3.0"
model_instance_cohere.model_name = "embed-english-v3.0"
model_instance_cohere.provider = "cohere"
# Mock model type instance for cohere
@ -1225,7 +1225,7 @@ class TestEmbeddingEdgeCases:
- MAX_CHUNKS: 10
"""
model_instance = Mock()
model_instance.model = "text-embedding-ada-002"
model_instance.model_name = "text-embedding-ada-002"
model_instance.provider = "openai"
model_type_instance = Mock()
@ -1702,7 +1702,7 @@ class TestEmbeddingCachePerformance:
- MAX_CHUNKS: 10
"""
model_instance = Mock()
model_instance.model = "text-embedding-ada-002"
model_instance.model_name = "text-embedding-ada-002"
model_instance.provider = "openai"
model_type_instance = Mock()

View File

@ -34,7 +34,7 @@ def create_mock_model_instance():
mock_instance.provider_model_bundle.configuration = Mock()
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
mock_instance.provider = "test-provider"
mock_instance.model = "test-model"
mock_instance.model_name = "test-model"
return mock_instance
@ -65,7 +65,7 @@ class TestRerankModelRunner:
mock_instance.provider_model_bundle.configuration = Mock()
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
mock_instance.provider = "test-provider"
mock_instance.model = "test-model"
mock_instance.model_name = "test-model"
return mock_instance
@pytest.fixture

View File

@ -199,11 +199,32 @@ def test_mock_config_builder():
def test_mock_factory_node_type_detection():
"""Test that MockNodeFactory correctly identifies nodes to mock."""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
from .test_mock_factory import MockNodeFactory
graph_init_params = GraphInitParams(
tenant_id="test",
app_id="test",
workflow_id="test",
graph_config={},
user_id="test",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
start_at=0,
total_tokens=0,
node_run_steps=0,
)
factory = MockNodeFactory(
graph_init_params=None, # Will be set by test
graph_runtime_state=None, # Will be set by test
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=None,
)
@ -288,7 +309,11 @@ def test_workflow_without_auto_mock():
def test_register_custom_mock_node():
"""Test registering a custom mock implementation for a node type."""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.nodes.template_transform import TemplateTransformNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
from .test_mock_factory import MockNodeFactory
@ -298,9 +323,25 @@ def test_register_custom_mock_node():
# Custom mock implementation
pass
graph_init_params = GraphInitParams(
tenant_id="test",
app_id="test",
workflow_id="test",
graph_config={},
user_id="test",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
start_at=0,
total_tokens=0,
node_run_steps=0,
)
factory = MockNodeFactory(
graph_init_params=None,
graph_runtime_state=None,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=None,
)

View File

@ -1,9 +1,9 @@
import datetime
import time
from collections.abc import Iterable
from unittest import mock
from unittest.mock import MagicMock
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
@ -82,7 +82,7 @@ def _build_branching_graph(
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
llm_data = LLMNodeData(
title=title,
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text=prompt_text,
@ -101,6 +101,8 @@ def _build_branching_graph(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
credentials_provider=mock.Mock(),
model_factory=mock.Mock(),
)
return llm_node

View File

@ -1,8 +1,8 @@
import datetime
import time
from unittest import mock
from unittest.mock import MagicMock
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
@ -78,7 +78,7 @@ def _build_llm_human_llm_graph(
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
llm_data = LLMNodeData(
title=title,
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text=prompt_text,
@ -97,6 +97,8 @@ def _build_llm_human_llm_graph(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
credentials_provider=mock.Mock(),
model_factory=mock.Mock(),
)
return llm_node

View File

@ -1,4 +1,5 @@
import time
from unittest import mock
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
@ -85,6 +86,8 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
credentials_provider=mock.Mock(),
model_factory=mock.Mock(),
)
return llm_node

View File

@ -5,6 +5,7 @@ This module provides a MockNodeFactory that automatically detects and mocks node
requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
"""
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from core.app.workflow.node_factory import DifyNodeFactory
@ -74,7 +75,7 @@ class MockNodeFactory(DifyNodeFactory):
NodeType.CODE: MockCodeNode,
}
def create_node(self, node_config: dict[str, Any]) -> Node:
def create_node(self, node_config: Mapping[str, Any]) -> Node:
"""
Create a node instance, using mock implementations for third-party service nodes.
@ -123,6 +124,16 @@ class MockNodeFactory(DifyNodeFactory):
mock_config=self.mock_config,
http_request_config=self._http_request_config,
)
elif node_type in {NodeType.LLM, NodeType.QUESTION_CLASSIFIER, NodeType.PARAMETER_EXTRACTOR}:
mock_instance = mock_class(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
)
else:
mock_instance = mock_class(
id=node_id,

View File

@ -16,9 +16,33 @@ from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNo
def test_mock_factory_registers_iteration_node():
"""Test that MockNodeFactory has iteration node registered."""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
# Create a MockNodeFactory instance
factory = MockNodeFactory(graph_init_params=None, graph_runtime_state=None, mock_config=None)
graph_init_params = GraphInitParams(
tenant_id="test",
app_id="test",
workflow_id="test",
graph_config={"nodes": [], "edges": []},
user_id="test",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
start_at=0,
total_tokens=0,
node_run_steps=0,
)
factory = MockNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=None,
)
# Check that iteration node is registered
assert NodeType.ITERATION in factory._mock_node_types

View File

@ -8,6 +8,7 @@ allowing tests to run without external dependencies.
import time
from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, Any, Optional
from unittest.mock import MagicMock
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -18,6 +19,7 @@ from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from core.workflow.nodes.llm import LLMNode
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from core.workflow.nodes.question_classifier import QuestionClassifierNode
from core.workflow.nodes.template_transform import TemplateTransformNode
@ -42,6 +44,10 @@ class MockNodeMixin:
mock_config: Optional["MockConfig"] = None,
**kwargs: Any,
):
if isinstance(self, (LLMNode, QuestionClassifierNode)):
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
super().__init__(
id=id,
config=config,

View File

@ -101,11 +101,32 @@ def test_node_mock_config():
def test_mock_factory_detection():
"""Test MockNodeFactory node type detection."""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
print("Testing MockNodeFactory detection...")
graph_init_params = GraphInitParams(
tenant_id="test",
app_id="test",
workflow_id="test",
graph_config={},
user_id="test",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
start_at=0,
total_tokens=0,
node_run_steps=0,
)
factory = MockNodeFactory(
graph_init_params=None,
graph_runtime_state=None,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=None,
)
@ -133,11 +154,32 @@ def test_mock_factory_detection():
def test_mock_factory_registration():
"""Test registering and unregistering mock node types."""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
print("Testing MockNodeFactory registration...")
graph_init_params = GraphInitParams(
tenant_id="test",
app_id="test",
workflow_id="test",
graph_config={},
user_id="test",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
start_at=0,
total_tokens=0,
node_run_steps=0,
)
factory = MockNodeFactory(
graph_init_params=None,
graph_runtime_state=None,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=None,
)

View File

@ -6,6 +6,7 @@ from unittest import mock
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
from core.model_runtime.entities.common_entities import I18nObject
@ -32,6 +33,7 @@ from core.workflow.nodes.llm.entities import (
)
from core.workflow.nodes.llm.file_saver import LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
@ -100,6 +102,8 @@ def llm_node(
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState
) -> LLMNode:
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
@ -109,13 +113,29 @@ def llm_node(
config=node_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
model_factory=mock_model_factory,
llm_file_saver=mock_file_saver,
)
return node
@pytest.fixture
def model_config():
def model_config(monkeypatch):
from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass
def mock_plugin_model_providers(_self):
providers = MockModelClass().fetch_model_providers("test")
for provider in providers:
provider.declaration.provider = f"{provider.plugin_id}/{provider.declaration.provider}"
return providers
monkeypatch.setattr(
ModelProviderFactory,
"get_plugin_model_providers",
mock_plugin_model_providers,
)
# Create actual provider and model type instances
model_provider_factory = ModelProviderFactory(tenant_id="test")
provider_instance = model_provider_factory.get_plugin_model_provider("openai")
@ -125,7 +145,7 @@ def model_config():
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
tenant_id="1",
provider=provider_instance,
provider=provider_instance.declaration,
preferred_provider_type=ProviderType.CUSTOM,
using_provider_type=ProviderType.CUSTOM,
system_configuration=SystemConfiguration(enabled=False),
@ -153,6 +173,89 @@ def model_config():
)
def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsEntity):
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
provider_model_bundle = model_config.provider_model_bundle
model_type_instance = provider_model_bundle.model_type_instance
provider_model = mock.MagicMock()
model_instance = mock.MagicMock(
model_type_instance=model_type_instance,
provider_model_bundle=provider_model_bundle,
)
mock_credentials_provider.fetch.return_value = {"api_key": "test"}
mock_model_factory.init_model_instance.return_value = model_instance
with (
mock.patch.object(
provider_model_bundle.configuration.__class__,
"get_provider_model",
return_value=provider_model,
),
mock.patch.object(
model_type_instance.__class__,
"get_model_schema",
return_value=model_config.model_schema,
),
):
fetch_model_config(
node_data_model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
credentials_provider=mock_credentials_provider,
model_factory=mock_model_factory,
)
mock_credentials_provider.fetch.assert_called_once_with("openai", "gpt-3.5-turbo")
mock_model_factory.init_model_instance.assert_called_once_with("openai", "gpt-3.5-turbo")
provider_model.raise_for_status.assert_called_once()
def test_dify_model_access_adapters_call_managers():
mock_provider_manager = mock.MagicMock()
mock_model_manager = mock.MagicMock()
mock_configurations = mock.MagicMock()
mock_provider_configuration = mock.MagicMock()
mock_provider_model = mock.MagicMock()
mock_configurations.get.return_value = mock_provider_configuration
mock_provider_configuration.get_provider_model.return_value = mock_provider_model
mock_provider_configuration.get_current_credentials.return_value = {"api_key": "test"}
credentials_provider = DifyCredentialsProvider(
tenant_id="tenant",
provider_manager=mock_provider_manager,
)
model_factory = DifyModelFactory(
tenant_id="tenant",
model_manager=mock_model_manager,
)
mock_provider_manager.get_configurations.return_value = mock_configurations
credentials_provider.fetch("openai", "gpt-3.5-turbo")
model_factory.init_model_instance("openai", "gpt-3.5-turbo")
mock_provider_manager.get_configurations.assert_called_once_with("tenant")
mock_configurations.get.assert_called_once_with("openai")
mock_provider_configuration.get_provider_model.assert_called_once_with(
model_type=ModelType.LLM,
model="gpt-3.5-turbo",
)
mock_provider_configuration.get_current_credentials.assert_called_once_with(
model_type=ModelType.LLM,
model="gpt-3.5-turbo",
)
mock_provider_model.raise_for_status.assert_called_once()
mock_model_manager.get_model_instance.assert_called_once_with(
tenant_id="tenant",
provider="openai",
model_type=ModelType.LLM,
model="gpt-3.5-turbo",
)
def test_fetch_files_with_file_segment():
file = File(
id="1",
@ -485,6 +588,8 @@ def test_handle_list_messages_basic(llm_node):
@pytest.fixture
def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_state) -> tuple[LLMNode, LLMFileSaver]:
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
@ -494,6 +599,8 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
config=node_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
model_factory=mock_model_factory,
llm_file_saver=mock_file_saver,
)
return node, mock_file_saver

View File

@ -642,8 +642,16 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
# Mock embedding model
mock_embedding_model = Mock()
mock_embedding_model.model = "text-embedding-ada-002"
mock_embedding_model.model_name = "text-embedding-ada-002"
mock_embedding_model.provider = "openai"
mock_embedding_model.credentials = {}
mock_model_schema = Mock()
mock_model_schema.features = []
mock_text_embedding_model = Mock()
mock_text_embedding_model.get_model_schema.return_value = mock_model_schema
mock_embedding_model.model_type_instance = mock_text_embedding_model
mock_model_instance = Mock()
mock_model_instance.get_model_instance.return_value = mock_embedding_model

View File

@ -174,7 +174,7 @@ class DatasetServiceTestDataFactory:
Mock: Embedding model mock with model and provider attributes
"""
embedding_model = Mock()
embedding_model.model = model
embedding_model.model_name = model
embedding_model.provider = provider
return embedding_model
@ -434,7 +434,7 @@ class TestDatasetServiceCreateDataset:
# Assert
assert result.indexing_technique == "high_quality"
assert result.embedding_model_provider == embedding_model.provider
assert result.embedding_model == embedding_model.model
assert result.embedding_model == embedding_model.model_name
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
)

View File

@ -46,7 +46,7 @@ class DatasetCreateTestDataFactory:
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
"""Create a mock embedding model."""
embedding_model = Mock()
embedding_model.model = model
embedding_model.model_name = model
embedding_model.provider = provider
return embedding_model
@ -244,7 +244,7 @@ class TestDatasetServiceCreateEmptyDataset:
# Assert
assert result.indexing_technique == "high_quality"
assert result.embedding_model_provider == embedding_model.provider
assert result.embedding_model == embedding_model.model
assert result.embedding_model == embedding_model.model_name
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
)