mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
refactor: consolidate LLM runtime model state on ModelInstance (#32746)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@ -48,3 +48,19 @@ def get_mocked_fetch_model_config(
|
||||
)
|
||||
|
||||
return MagicMock(return_value=(model_instance, model_config))
|
||||
|
||||
|
||||
def get_mocked_fetch_model_instance(
|
||||
provider: str,
|
||||
model: str,
|
||||
mode: str,
|
||||
credentials: dict,
|
||||
):
|
||||
mock_fetch_model_config = get_mocked_fetch_model_config(
|
||||
provider=provider,
|
||||
model=model,
|
||||
mode=mode,
|
||||
credentials=credentials,
|
||||
)
|
||||
model_instance, _ = mock_fetch_model_config()
|
||||
return MagicMock(return_value=model_instance)
|
||||
|
||||
@ -5,13 +5,13 @@ from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.llm_generator.output_parser.structured_output import _parse_structured_output
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
@ -67,21 +67,14 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = LLMNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=MagicMock(),
|
||||
model_factory=MagicMock(),
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(spec=ModelInstance),
|
||||
)
|
||||
|
||||
return node
|
||||
@ -116,8 +109,7 @@ def test_execute_llm():
|
||||
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# Mock the _fetch_model_config to avoid database calls
|
||||
def mock_fetch_model_config(*_args, **_kwargs):
|
||||
def build_mock_model_instance() -> MagicMock:
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -125,7 +117,20 @@ def test_execute_llm():
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
# Create mock model instance
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance = MagicMock(spec=ModelInstance)
|
||||
mock_model_instance.provider = "openai"
|
||||
mock_model_instance.model_name = "gpt-3.5-turbo"
|
||||
mock_model_instance.credentials = {}
|
||||
mock_model_instance.parameters = {}
|
||||
mock_model_instance.stop = []
|
||||
mock_model_instance.model_type_instance = MagicMock()
|
||||
mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock(
|
||||
model_properties={},
|
||||
parameter_rules=[],
|
||||
features=[],
|
||||
)
|
||||
mock_model_instance.provider_model_bundle = MagicMock()
|
||||
mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom"
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
@ -149,14 +154,7 @@ def test_execute_llm():
|
||||
)
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create mock model config
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.parameters = {}
|
||||
|
||||
return mock_model_instance, mock_model_config
|
||||
return mock_model_instance
|
||||
|
||||
# Mock fetch_prompt_messages to avoid database calls
|
||||
def mock_fetch_prompt_messages_1(**_kwargs):
|
||||
@ -167,10 +165,9 @@ def test_execute_llm():
|
||||
UserPromptMessage(content="what's the weather today?"),
|
||||
], []
|
||||
|
||||
with (
|
||||
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
|
||||
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1),
|
||||
):
|
||||
node._model_instance = build_mock_model_instance()
|
||||
|
||||
with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1):
|
||||
# execute node
|
||||
result = node._run()
|
||||
assert isinstance(result, Generator)
|
||||
@ -228,8 +225,7 @@ def test_execute_llm_with_jinja2():
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# Mock the _fetch_model_config method
|
||||
def mock_fetch_model_config(*_args, **_kwargs):
|
||||
def build_mock_model_instance() -> MagicMock:
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -237,7 +233,20 @@ def test_execute_llm_with_jinja2():
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
# Create mock model instance
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance = MagicMock(spec=ModelInstance)
|
||||
mock_model_instance.provider = "openai"
|
||||
mock_model_instance.model_name = "gpt-3.5-turbo"
|
||||
mock_model_instance.credentials = {}
|
||||
mock_model_instance.parameters = {}
|
||||
mock_model_instance.stop = []
|
||||
mock_model_instance.model_type_instance = MagicMock()
|
||||
mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock(
|
||||
model_properties={},
|
||||
parameter_rules=[],
|
||||
features=[],
|
||||
)
|
||||
mock_model_instance.provider_model_bundle = MagicMock()
|
||||
mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom"
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
@ -261,14 +270,7 @@ def test_execute_llm_with_jinja2():
|
||||
)
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create mock model config
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.parameters = {}
|
||||
|
||||
return mock_model_instance, mock_model_config
|
||||
return mock_model_instance
|
||||
|
||||
# Mock fetch_prompt_messages to avoid database calls
|
||||
def mock_fetch_prompt_messages_2(**_kwargs):
|
||||
@ -279,10 +281,9 @@ def test_execute_llm_with_jinja2():
|
||||
UserPromptMessage(content="what's the weather today?"),
|
||||
], []
|
||||
|
||||
with (
|
||||
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
|
||||
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2),
|
||||
):
|
||||
node._model_instance = build_mock_model_instance()
|
||||
|
||||
with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2):
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
|
||||
@ -4,18 +4,17 @@ import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import AssistantPromptMessage
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
|
||||
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
|
||||
@ -72,14 +71,6 @@ def init_parameter_extractor_node(config: dict):
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = ParameterExtractorNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
@ -87,6 +78,7 @@ def init_parameter_extractor_node(config: dict):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(spec=ModelInstance),
|
||||
)
|
||||
return node
|
||||
|
||||
@ -116,12 +108,12 @@ def test_function_calling_parameter_extractor(setup_model_mock):
|
||||
}
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
node._model_instance = get_mocked_fetch_model_instance(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
@ -157,12 +149,12 @@ def test_instructions(setup_model_mock):
|
||||
},
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
node._model_instance = get_mocked_fetch_model_instance(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
@ -207,12 +199,12 @@ def test_chat_parameter_extractor(setup_model_mock):
|
||||
},
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
node._model_instance = get_mocked_fetch_model_instance(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
@ -258,12 +250,12 @@ def test_completion_parameter_extractor(setup_model_mock):
|
||||
},
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
node._model_instance = get_mocked_fetch_model_instance(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
mode="completion",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
@ -383,12 +375,12 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
|
||||
},
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
node._model_instance = get_mocked_fetch_model_instance(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
)()
|
||||
# Test the mock before running the actual test
|
||||
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
|
||||
db.session.close = MagicMock()
|
||||
|
||||
@ -1391,10 +1391,20 @@ class TestWorkflowService:
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.model_manager import ModelInstance
|
||||
|
||||
# Act
|
||||
result = workflow_service.run_free_workflow_node(
|
||||
node_data=node_data, tenant_id=tenant_id, user_id=user_id, node_id=node_id, user_inputs=user_inputs
|
||||
)
|
||||
with patch.object(
|
||||
DifyNodeFactory,
|
||||
"_build_model_instance_for_llm_node",
|
||||
return_value=MagicMock(spec=ModelInstance),
|
||||
):
|
||||
result = workflow_service.run_free_workflow_node(
|
||||
node_data=node_data, tenant_id=tenant_id, user_id=user_id, node_id=node_id, user_inputs=user_inputs
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
|
||||
@ -10,6 +10,7 @@ from collections.abc import Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
@ -44,9 +45,10 @@ class MockNodeMixin:
|
||||
mock_config: Optional["MockConfig"] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if isinstance(self, (LLMNode, QuestionClassifierNode)):
|
||||
if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)):
|
||||
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
|
||||
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
|
||||
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
|
||||
|
||||
super().__init__(
|
||||
id=id,
|
||||
|
||||
@ -9,11 +9,12 @@ This test validates that:
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
@ -115,7 +116,12 @@ def test_parallel_streaming_workflow():
|
||||
|
||||
# Create node factory and graph
|
||||
node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
with patch.object(
|
||||
DifyNodeFactory,
|
||||
"_build_model_instance_for_llm_node",
|
||||
return_value=MagicMock(spec=ModelInstance),
|
||||
):
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
# Create the graph engine
|
||||
engine = GraphEngine(
|
||||
|
||||
@ -547,8 +547,22 @@ class TableTestRunner:
|
||||
"""Run tests in parallel."""
|
||||
results = []
|
||||
|
||||
flask_app: Any = None
|
||||
try:
|
||||
from flask import current_app
|
||||
|
||||
flask_app = current_app._get_current_object() # type: ignore[attr-defined]
|
||||
except RuntimeError:
|
||||
flask_app = None
|
||||
|
||||
def _run_test_case_with_context(test_case: WorkflowTestCase) -> WorkflowTestResult:
|
||||
if flask_app is None:
|
||||
return self.run_test_case(test_case)
|
||||
with flask_app.app_context():
|
||||
return self.run_test_case(test_case)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
future_to_test = {executor.submit(self.run_test_case, tc): tc for tc in test_cases}
|
||||
future_to_test = {executor.submit(_run_test_case_with_context, tc): tc for tc in test_cases}
|
||||
|
||||
for future in as_completed(future_to_test):
|
||||
test_case = future_to_test[future]
|
||||
|
||||
@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre
|
||||
from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
@ -115,6 +116,7 @@ def llm_node(
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=mock_credentials_provider,
|
||||
model_factory=mock_model_factory,
|
||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
return node
|
||||
@ -601,6 +603,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=mock_credentials_provider,
|
||||
model_factory=mock_model_factory,
|
||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
return node, mock_file_saver
|
||||
|
||||
Reference in New Issue
Block a user