refactor: consolidate LLM runtime model state on ModelInstance (#32746)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-01 02:29:32 +08:00
committed by GitHub
parent 48d8667c4f
commit 962df17a15
20 changed files with 375 additions and 324 deletions

View File

@ -48,3 +48,19 @@ def get_mocked_fetch_model_config(
)
return MagicMock(return_value=(model_instance, model_config))
def get_mocked_fetch_model_instance(
provider: str,
model: str,
mode: str,
credentials: dict,
):
mock_fetch_model_config = get_mocked_fetch_model_config(
provider=provider,
model=model,
mode=mode,
credentials=credentials,
)
model_instance, _ = mock_fetch_model_config()
return MagicMock(return_value=model_instance)

View File

@ -5,13 +5,13 @@ from collections.abc import Generator
from unittest.mock import MagicMock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.node_factory import DifyNodeFactory
from core.llm_generator.output_parser.structured_output import _parse_structured_output
from core.model_manager import ModelInstance
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import StreamCompletedEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
@ -67,21 +67,14 @@ def init_llm_node(config: dict) -> LLMNode:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# Create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
node = LLMNode(
id=str(uuid.uuid4()),
config=config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(),
model_factory=MagicMock(),
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(spec=ModelInstance),
)
return node
@ -116,8 +109,7 @@ def test_execute_llm():
db.session.close = MagicMock()
# Mock the _fetch_model_config to avoid database calls
def mock_fetch_model_config(*_args, **_kwargs):
def build_mock_model_instance() -> MagicMock:
from decimal import Decimal
from unittest.mock import MagicMock
@ -125,7 +117,20 @@ def test_execute_llm():
from core.model_runtime.entities.message_entities import AssistantPromptMessage
# Create mock model instance
mock_model_instance = MagicMock()
mock_model_instance = MagicMock(spec=ModelInstance)
mock_model_instance.provider = "openai"
mock_model_instance.model_name = "gpt-3.5-turbo"
mock_model_instance.credentials = {}
mock_model_instance.parameters = {}
mock_model_instance.stop = []
mock_model_instance.model_type_instance = MagicMock()
mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock(
model_properties={},
parameter_rules=[],
features=[],
)
mock_model_instance.provider_model_bundle = MagicMock()
mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom"
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
@ -149,14 +154,7 @@ def test_execute_llm():
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create mock model config
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.parameters = {}
return mock_model_instance, mock_model_config
return mock_model_instance
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_1(**_kwargs):
@ -167,10 +165,9 @@ def test_execute_llm():
UserPromptMessage(content="what's the weather today?"),
], []
with (
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1),
):
node._model_instance = build_mock_model_instance()
with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1):
# execute node
result = node._run()
assert isinstance(result, Generator)
@ -228,8 +225,7 @@ def test_execute_llm_with_jinja2():
# Mock db.session.close()
db.session.close = MagicMock()
# Mock the _fetch_model_config method
def mock_fetch_model_config(*_args, **_kwargs):
def build_mock_model_instance() -> MagicMock:
from decimal import Decimal
from unittest.mock import MagicMock
@ -237,7 +233,20 @@ def test_execute_llm_with_jinja2():
from core.model_runtime.entities.message_entities import AssistantPromptMessage
# Create mock model instance
mock_model_instance = MagicMock()
mock_model_instance = MagicMock(spec=ModelInstance)
mock_model_instance.provider = "openai"
mock_model_instance.model_name = "gpt-3.5-turbo"
mock_model_instance.credentials = {}
mock_model_instance.parameters = {}
mock_model_instance.stop = []
mock_model_instance.model_type_instance = MagicMock()
mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock(
model_properties={},
parameter_rules=[],
features=[],
)
mock_model_instance.provider_model_bundle = MagicMock()
mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom"
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
@ -261,14 +270,7 @@ def test_execute_llm_with_jinja2():
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create mock model config
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.parameters = {}
return mock_model_instance, mock_model_config
return mock_model_instance
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_2(**_kwargs):
@ -279,10 +281,9 @@ def test_execute_llm_with_jinja2():
UserPromptMessage(content="what's the weather today?"),
], []
with (
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2),
):
node._model_instance = build_mock_model_instance()
with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2):
# execute node
result = node._run()

View File

@ -4,18 +4,17 @@ import uuid
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.node_factory import DifyNodeFactory
from core.model_manager import ModelInstance
from core.model_runtime.entities import AssistantPromptMessage
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
@ -72,14 +71,6 @@ def init_parameter_extractor_node(config: dict):
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# Create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
node = ParameterExtractorNode(
id=str(uuid.uuid4()),
config=config,
@ -87,6 +78,7 @@ def init_parameter_extractor_node(config: dict):
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(spec=ModelInstance),
)
return node
@ -116,12 +108,12 @@ def test_function_calling_parameter_extractor(setup_model_mock):
}
)
node._fetch_model_config = get_mocked_fetch_model_config(
node._model_instance = get_mocked_fetch_model_instance(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
)()
db.session.close = MagicMock()
result = node._run()
@ -157,12 +149,12 @@ def test_instructions(setup_model_mock):
},
)
node._fetch_model_config = get_mocked_fetch_model_config(
node._model_instance = get_mocked_fetch_model_instance(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
)()
db.session.close = MagicMock()
result = node._run()
@ -207,12 +199,12 @@ def test_chat_parameter_extractor(setup_model_mock):
},
)
node._fetch_model_config = get_mocked_fetch_model_config(
node._model_instance = get_mocked_fetch_model_instance(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
)()
db.session.close = MagicMock()
result = node._run()
@ -258,12 +250,12 @@ def test_completion_parameter_extractor(setup_model_mock):
},
)
node._fetch_model_config = get_mocked_fetch_model_config(
node._model_instance = get_mocked_fetch_model_instance(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo-instruct",
mode="completion",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
)()
db.session.close = MagicMock()
result = node._run()
@ -383,12 +375,12 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
},
)
node._fetch_model_config = get_mocked_fetch_model_config(
node._model_instance = get_mocked_fetch_model_instance(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
)()
# Test the mock before running the actual test
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
db.session.close = MagicMock()