mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
refactor: consolidate LLM runtime model state on ModelInstance (#32746)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@ -48,3 +48,19 @@ def get_mocked_fetch_model_config(
|
||||
)
|
||||
|
||||
return MagicMock(return_value=(model_instance, model_config))
|
||||
|
||||
|
||||
def get_mocked_fetch_model_instance(
|
||||
provider: str,
|
||||
model: str,
|
||||
mode: str,
|
||||
credentials: dict,
|
||||
):
|
||||
mock_fetch_model_config = get_mocked_fetch_model_config(
|
||||
provider=provider,
|
||||
model=model,
|
||||
mode=mode,
|
||||
credentials=credentials,
|
||||
)
|
||||
model_instance, _ = mock_fetch_model_config()
|
||||
return MagicMock(return_value=model_instance)
|
||||
|
||||
@ -5,13 +5,13 @@ from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.llm_generator.output_parser.structured_output import _parse_structured_output
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
@ -67,21 +67,14 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = LLMNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=MagicMock(),
|
||||
model_factory=MagicMock(),
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(spec=ModelInstance),
|
||||
)
|
||||
|
||||
return node
|
||||
@ -116,8 +109,7 @@ def test_execute_llm():
|
||||
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# Mock the _fetch_model_config to avoid database calls
|
||||
def mock_fetch_model_config(*_args, **_kwargs):
|
||||
def build_mock_model_instance() -> MagicMock:
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -125,7 +117,20 @@ def test_execute_llm():
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
# Create mock model instance
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance = MagicMock(spec=ModelInstance)
|
||||
mock_model_instance.provider = "openai"
|
||||
mock_model_instance.model_name = "gpt-3.5-turbo"
|
||||
mock_model_instance.credentials = {}
|
||||
mock_model_instance.parameters = {}
|
||||
mock_model_instance.stop = []
|
||||
mock_model_instance.model_type_instance = MagicMock()
|
||||
mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock(
|
||||
model_properties={},
|
||||
parameter_rules=[],
|
||||
features=[],
|
||||
)
|
||||
mock_model_instance.provider_model_bundle = MagicMock()
|
||||
mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom"
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
@ -149,14 +154,7 @@ def test_execute_llm():
|
||||
)
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create mock model config
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.parameters = {}
|
||||
|
||||
return mock_model_instance, mock_model_config
|
||||
return mock_model_instance
|
||||
|
||||
# Mock fetch_prompt_messages to avoid database calls
|
||||
def mock_fetch_prompt_messages_1(**_kwargs):
|
||||
@ -167,10 +165,9 @@ def test_execute_llm():
|
||||
UserPromptMessage(content="what's the weather today?"),
|
||||
], []
|
||||
|
||||
with (
|
||||
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
|
||||
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1),
|
||||
):
|
||||
node._model_instance = build_mock_model_instance()
|
||||
|
||||
with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1):
|
||||
# execute node
|
||||
result = node._run()
|
||||
assert isinstance(result, Generator)
|
||||
@ -228,8 +225,7 @@ def test_execute_llm_with_jinja2():
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# Mock the _fetch_model_config method
|
||||
def mock_fetch_model_config(*_args, **_kwargs):
|
||||
def build_mock_model_instance() -> MagicMock:
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -237,7 +233,20 @@ def test_execute_llm_with_jinja2():
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
# Create mock model instance
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance = MagicMock(spec=ModelInstance)
|
||||
mock_model_instance.provider = "openai"
|
||||
mock_model_instance.model_name = "gpt-3.5-turbo"
|
||||
mock_model_instance.credentials = {}
|
||||
mock_model_instance.parameters = {}
|
||||
mock_model_instance.stop = []
|
||||
mock_model_instance.model_type_instance = MagicMock()
|
||||
mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock(
|
||||
model_properties={},
|
||||
parameter_rules=[],
|
||||
features=[],
|
||||
)
|
||||
mock_model_instance.provider_model_bundle = MagicMock()
|
||||
mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom"
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
@ -261,14 +270,7 @@ def test_execute_llm_with_jinja2():
|
||||
)
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create mock model config
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.parameters = {}
|
||||
|
||||
return mock_model_instance, mock_model_config
|
||||
return mock_model_instance
|
||||
|
||||
# Mock fetch_prompt_messages to avoid database calls
|
||||
def mock_fetch_prompt_messages_2(**_kwargs):
|
||||
@ -279,10 +281,9 @@ def test_execute_llm_with_jinja2():
|
||||
UserPromptMessage(content="what's the weather today?"),
|
||||
], []
|
||||
|
||||
with (
|
||||
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
|
||||
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2),
|
||||
):
|
||||
node._model_instance = build_mock_model_instance()
|
||||
|
||||
with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2):
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
|
||||
@ -4,18 +4,17 @@ import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import AssistantPromptMessage
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
|
||||
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
|
||||
@ -72,14 +71,6 @@ def init_parameter_extractor_node(config: dict):
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = ParameterExtractorNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
@ -87,6 +78,7 @@ def init_parameter_extractor_node(config: dict):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(spec=ModelInstance),
|
||||
)
|
||||
return node
|
||||
|
||||
@ -116,12 +108,12 @@ def test_function_calling_parameter_extractor(setup_model_mock):
|
||||
}
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
node._model_instance = get_mocked_fetch_model_instance(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
@ -157,12 +149,12 @@ def test_instructions(setup_model_mock):
|
||||
},
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
node._model_instance = get_mocked_fetch_model_instance(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
@ -207,12 +199,12 @@ def test_chat_parameter_extractor(setup_model_mock):
|
||||
},
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
node._model_instance = get_mocked_fetch_model_instance(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
@ -258,12 +250,12 @@ def test_completion_parameter_extractor(setup_model_mock):
|
||||
},
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
node._model_instance = get_mocked_fetch_model_instance(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
mode="completion",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
@ -383,12 +375,12 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
|
||||
},
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
node._model_instance = get_mocked_fetch_model_instance(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
)()
|
||||
# Test the mock before running the actual test
|
||||
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
|
||||
db.session.close = MagicMock()
|
||||
|
||||
Reference in New Issue
Block a user