refactor: consolidate LLM runtime model state on ModelInstance (#32746)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-01 02:29:32 +08:00
committed by GitHub
parent 48d8667c4f
commit 962df17a15
20 changed files with 375 additions and 324 deletions

View File

@ -48,3 +48,19 @@ def get_mocked_fetch_model_config(
)
return MagicMock(return_value=(model_instance, model_config))
def get_mocked_fetch_model_instance(
provider: str,
model: str,
mode: str,
credentials: dict,
):
mock_fetch_model_config = get_mocked_fetch_model_config(
provider=provider,
model=model,
mode=mode,
credentials=credentials,
)
model_instance, _ = mock_fetch_model_config()
return MagicMock(return_value=model_instance)

View File

@ -5,13 +5,13 @@ from collections.abc import Generator
from unittest.mock import MagicMock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.node_factory import DifyNodeFactory
from core.llm_generator.output_parser.structured_output import _parse_structured_output
from core.model_manager import ModelInstance
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import StreamCompletedEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
@ -67,21 +67,14 @@ def init_llm_node(config: dict) -> LLMNode:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# Create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
node = LLMNode(
id=str(uuid.uuid4()),
config=config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(),
model_factory=MagicMock(),
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(spec=ModelInstance),
)
return node
@ -116,8 +109,7 @@ def test_execute_llm():
db.session.close = MagicMock()
# Mock the _fetch_model_config to avoid database calls
def mock_fetch_model_config(*_args, **_kwargs):
def build_mock_model_instance() -> MagicMock:
from decimal import Decimal
from unittest.mock import MagicMock
@ -125,7 +117,20 @@ def test_execute_llm():
from core.model_runtime.entities.message_entities import AssistantPromptMessage
# Create mock model instance
mock_model_instance = MagicMock()
mock_model_instance = MagicMock(spec=ModelInstance)
mock_model_instance.provider = "openai"
mock_model_instance.model_name = "gpt-3.5-turbo"
mock_model_instance.credentials = {}
mock_model_instance.parameters = {}
mock_model_instance.stop = []
mock_model_instance.model_type_instance = MagicMock()
mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock(
model_properties={},
parameter_rules=[],
features=[],
)
mock_model_instance.provider_model_bundle = MagicMock()
mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom"
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
@ -149,14 +154,7 @@ def test_execute_llm():
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create mock model config
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.parameters = {}
return mock_model_instance, mock_model_config
return mock_model_instance
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_1(**_kwargs):
@ -167,10 +165,9 @@ def test_execute_llm():
UserPromptMessage(content="what's the weather today?"),
], []
with (
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1),
):
node._model_instance = build_mock_model_instance()
with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1):
# execute node
result = node._run()
assert isinstance(result, Generator)
@ -228,8 +225,7 @@ def test_execute_llm_with_jinja2():
# Mock db.session.close()
db.session.close = MagicMock()
# Mock the _fetch_model_config method
def mock_fetch_model_config(*_args, **_kwargs):
def build_mock_model_instance() -> MagicMock:
from decimal import Decimal
from unittest.mock import MagicMock
@ -237,7 +233,20 @@ def test_execute_llm_with_jinja2():
from core.model_runtime.entities.message_entities import AssistantPromptMessage
# Create mock model instance
mock_model_instance = MagicMock()
mock_model_instance = MagicMock(spec=ModelInstance)
mock_model_instance.provider = "openai"
mock_model_instance.model_name = "gpt-3.5-turbo"
mock_model_instance.credentials = {}
mock_model_instance.parameters = {}
mock_model_instance.stop = []
mock_model_instance.model_type_instance = MagicMock()
mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock(
model_properties={},
parameter_rules=[],
features=[],
)
mock_model_instance.provider_model_bundle = MagicMock()
mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom"
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
@ -261,14 +270,7 @@ def test_execute_llm_with_jinja2():
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create mock model config
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.parameters = {}
return mock_model_instance, mock_model_config
return mock_model_instance
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_2(**_kwargs):
@ -279,10 +281,9 @@ def test_execute_llm_with_jinja2():
UserPromptMessage(content="what's the weather today?"),
], []
with (
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2),
):
node._model_instance = build_mock_model_instance()
with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2):
# execute node
result = node._run()

View File

@ -4,18 +4,17 @@ import uuid
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.node_factory import DifyNodeFactory
from core.model_manager import ModelInstance
from core.model_runtime.entities import AssistantPromptMessage
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
@ -72,14 +71,6 @@ def init_parameter_extractor_node(config: dict):
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# Create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
node = ParameterExtractorNode(
id=str(uuid.uuid4()),
config=config,
@ -87,6 +78,7 @@ def init_parameter_extractor_node(config: dict):
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(spec=ModelInstance),
)
return node
@ -116,12 +108,12 @@ def test_function_calling_parameter_extractor(setup_model_mock):
}
)
node._fetch_model_config = get_mocked_fetch_model_config(
node._model_instance = get_mocked_fetch_model_instance(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
)()
db.session.close = MagicMock()
result = node._run()
@ -157,12 +149,12 @@ def test_instructions(setup_model_mock):
},
)
node._fetch_model_config = get_mocked_fetch_model_config(
node._model_instance = get_mocked_fetch_model_instance(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
)()
db.session.close = MagicMock()
result = node._run()
@ -207,12 +199,12 @@ def test_chat_parameter_extractor(setup_model_mock):
},
)
node._fetch_model_config = get_mocked_fetch_model_config(
node._model_instance = get_mocked_fetch_model_instance(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
)()
db.session.close = MagicMock()
result = node._run()
@ -258,12 +250,12 @@ def test_completion_parameter_extractor(setup_model_mock):
},
)
node._fetch_model_config = get_mocked_fetch_model_config(
node._model_instance = get_mocked_fetch_model_instance(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo-instruct",
mode="completion",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
)()
db.session.close = MagicMock()
result = node._run()
@ -383,12 +375,12 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
},
)
node._fetch_model_config = get_mocked_fetch_model_config(
node._model_instance = get_mocked_fetch_model_instance(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
)()
# Test the mock before running the actual test
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
db.session.close = MagicMock()

View File

@ -1391,10 +1391,20 @@ class TestWorkflowService:
workflow_service = WorkflowService()
from unittest.mock import patch
from core.app.workflow.node_factory import DifyNodeFactory
from core.model_manager import ModelInstance
# Act
result = workflow_service.run_free_workflow_node(
node_data=node_data, tenant_id=tenant_id, user_id=user_id, node_id=node_id, user_inputs=user_inputs
)
with patch.object(
DifyNodeFactory,
"_build_model_instance_for_llm_node",
return_value=MagicMock(spec=ModelInstance),
):
result = workflow_service.run_free_workflow_node(
node_data=node_data, tenant_id=tenant_id, user_id=user_id, node_id=node_id, user_inputs=user_inputs
)
# Assert
assert result is not None

View File

@ -10,6 +10,7 @@ from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, Any, Optional
from unittest.mock import MagicMock
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
@ -44,9 +45,10 @@ class MockNodeMixin:
mock_config: Optional["MockConfig"] = None,
**kwargs: Any,
):
if isinstance(self, (LLMNode, QuestionClassifierNode)):
if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)):
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
super().__init__(
id=id,

View File

@ -9,11 +9,12 @@ This test validates that:
"""
import time
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.node_factory import DifyNodeFactory
from core.model_manager import ModelInstance
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
@ -115,7 +116,12 @@ def test_parallel_streaming_workflow():
# Create node factory and graph
node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
with patch.object(
DifyNodeFactory,
"_build_model_instance_for_llm_node",
return_value=MagicMock(spec=ModelInstance),
):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
# Create the graph engine
engine = GraphEngine(

View File

@ -547,8 +547,22 @@ class TableTestRunner:
"""Run tests in parallel."""
results = []
flask_app: Any = None
try:
from flask import current_app
flask_app = current_app._get_current_object() # type: ignore[attr-defined]
except RuntimeError:
flask_app = None
def _run_test_case_with_context(test_case: WorkflowTestCase) -> WorkflowTestResult:
if flask_app is None:
return self.run_test_case(test_case)
with flask_app.app_context():
return self.run_test_case(test_case)
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_to_test = {executor.submit(self.run_test_case, tc): tc for tc in test_cases}
future_to_test = {executor.submit(_run_test_case_with_context, tc): tc for tc in test_cases}
for future in as_completed(future_to_test):
test_case = future_to_test[future]

View File

@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre
from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
from core.model_manager import ModelInstance
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
@ -115,6 +116,7 @@ def llm_node(
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
model_factory=mock_model_factory,
model_instance=mock.MagicMock(spec=ModelInstance),
llm_file_saver=mock_file_saver,
)
return node
@ -601,6 +603,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
model_factory=mock_model_factory,
model_instance=mock.MagicMock(spec=ModelInstance),
llm_file_saver=mock_file_saver,
)
return node, mock_file_saver