mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
fix: resolve test failures and lint errors after segment 5 merge
- Add login_manager mock to controller test fixtures (6 files) - Remove duplicate MemoryConfig import in llm_utils.py - Fix line-too-long in test_workflow_draft_variable.py Made-with: Cursor
This commit is contained in:
@ -18,7 +18,7 @@ from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_att
|
||||
|
||||
from configs import dify_config
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
@ -460,7 +460,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
|
||||
# Save LLMGenerationDetail for LLM nodes with successful execution
|
||||
if (
|
||||
domain_model.node_type == NodeType.LLM
|
||||
domain_model.node_type == BuiltinNodeTypes.LLM
|
||||
and domain_model.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
and domain_model.outputs is not None
|
||||
):
|
||||
|
||||
@ -56,9 +56,13 @@ from dify_graph.enums import (
|
||||
)
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
@ -94,6 +98,7 @@ from dify_graph.variables import (
|
||||
ArrayFileSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
from flask import Flask, g, request
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
@ -33,6 +33,9 @@ def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
flask_app.login_manager = mock_lm
|
||||
return flask_app
|
||||
|
||||
|
||||
@ -110,6 +113,7 @@ def setup_test_context(
|
||||
patch("controllers.console.app.message.attach_message_extra_contents", return_value=None),
|
||||
):
|
||||
with test_app.test_request_context(full_path, method=method, json=payload):
|
||||
g._login_user = mock_account
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
|
||||
if "suggested-questions" in route_path:
|
||||
@ -202,7 +206,7 @@ class TestMessageEndpoints:
|
||||
q_mock = mock_db.data_query
|
||||
q_mock.where.return_value.first.side_effect = [mock_conv]
|
||||
q_mock.where.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_msg]
|
||||
mock_db.session.scalar.return_value = False
|
||||
mock_db.session.scalar.side_effect = [MagicMock(), False]
|
||||
|
||||
resp = api.get(**v_args)
|
||||
assert resp["limit"] == 1
|
||||
|
||||
@ -2,7 +2,7 @@ from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
from flask import Flask, g, request
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.statistic import (
|
||||
@ -22,6 +22,9 @@ from models import App, AppMode
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
flask_app.login_manager = mock_lm
|
||||
return flask_app
|
||||
|
||||
|
||||
@ -85,6 +88,7 @@ def setup_test_context(
|
||||
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
with test_app.test_request_context(route_path, method="GET"):
|
||||
g._login_user = mock_account
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
api_instance = endpoint_class()
|
||||
response = api_instance.get(app_id="app_123")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
from flask import Flask, g, request
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
@ -24,6 +24,9 @@ def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
flask_app.login_manager = mock_lm
|
||||
return flask_app
|
||||
|
||||
|
||||
@ -75,6 +78,7 @@ def setup_test_context(test_app, endpoint_class, route_path, method, mock_accoun
|
||||
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
with test_app.test_request_context(route_path, method=method, json=payload):
|
||||
g._login_user = mock_account
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
# extract node_id or variable_id from path manually since view_args overrides
|
||||
if "nodes/" in route_path:
|
||||
@ -102,6 +106,7 @@ class TestWorkflowDraftVariableEndpoints:
|
||||
mock_var = MagicMock()
|
||||
mock_var.app_id = "app_123"
|
||||
mock_var.id = "var_123"
|
||||
mock_var.user_id = "user_123"
|
||||
mock_var.name = "test_var"
|
||||
mock_var.description = ""
|
||||
mock_var.get_variable_type.return_value = variable_type
|
||||
@ -151,8 +156,11 @@ class TestWorkflowDraftVariableEndpoints:
|
||||
mock_app_model,
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.SandboxService")
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_workflow_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
def test_workflow_variable_collection_delete(
|
||||
self, mock_draft_srv, mock_sandbox_srv, app, mock_account, mock_app_model,
|
||||
):
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
WorkflowVariableCollectionApi,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask import Flask, g
|
||||
|
||||
from controllers.console.auth.data_source_bearer_auth import (
|
||||
ApiKeyAuthDataSource,
|
||||
@ -17,6 +17,9 @@ class TestApiKeyAuthDataSource:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@ -49,6 +52,7 @@ class TestApiKeyAuthDataSource:
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
|
||||
g._login_user = mock_account
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
@ -81,6 +85,7 @@ class TestApiKeyAuthDataSource:
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
|
||||
g._login_user = mock_account
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
@ -97,6 +102,9 @@ class TestApiKeyAuthDataSourceBinding:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@ -124,6 +132,7 @@ class TestApiKeyAuthDataSourceBinding:
|
||||
method="POST",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
):
|
||||
g._login_user = mock_account
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
@ -162,6 +171,7 @@ class TestApiKeyAuthDataSourceBinding:
|
||||
method="POST",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
):
|
||||
g._login_user = mock_account
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
@ -176,6 +186,9 @@ class TestApiKeyAuthDataSourceBindingDelete:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@ -198,6 +211,7 @@ class TestApiKeyAuthDataSourceBindingDelete:
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"):
|
||||
g._login_user = mock_account
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask import Flask, g
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.auth.data_source_oauth import (
|
||||
@ -17,6 +17,9 @@ class TestOAuthDataSource:
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@ -85,6 +88,9 @@ class TestOAuthDataSourceCallback:
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@ -128,6 +134,9 @@ class TestOAuthDataSourceBinding:
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@ -161,6 +170,9 @@ class TestOAuthDataSourceSync:
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@ -181,6 +193,7 @@ class TestOAuthDataSourceSync:
|
||||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"):
|
||||
g._login_user = mock_account
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = OAuthDataSourceSync()
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask import Flask, g
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.auth.oauth_server import (
|
||||
@ -17,6 +17,9 @@ class TestOAuthServerAppApi:
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
@ -85,6 +88,9 @@ class TestOAuthServerUserAuthorizeApi:
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
@ -117,6 +123,7 @@ class TestOAuthServerUserAuthorizeApi:
|
||||
mock_sign.return_value = "auth_code_123"
|
||||
|
||||
with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}):
|
||||
g._login_user = mock_account
|
||||
with patch("libs.login.current_user", mock_account):
|
||||
api_instance = OAuthServerUserAuthorizeApi()
|
||||
response = api_instance.post()
|
||||
@ -130,6 +137,9 @@ class TestOAuthServerUserTokenApi:
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
@ -291,6 +301,9 @@ class TestOAuthServerUserAccountApi:
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -42,6 +42,7 @@ def test_update_node_execution_prefers_event_finished_at(monkeypatch: pytest.Mon
|
||||
predecessor_node_id=None,
|
||||
iteration_id="iter-1",
|
||||
loop_id=None,
|
||||
parent_node_id=None,
|
||||
created_at=node_execution.created_at,
|
||||
)
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
)
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
BuiltinNodeTypes,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
@ -67,7 +67,7 @@ def _execution(
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_id="node-id",
|
||||
node_type=NodeType.LLM,
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
title="Title",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
@ -387,7 +387,7 @@ def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch)
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node"
|
||||
db_model.node_type = NodeType.LLM
|
||||
db_model.node_type = BuiltinNodeTypes.LLM
|
||||
db_model.title = "t"
|
||||
db_model.inputs = json.dumps({"trunc": "i"})
|
||||
db_model.process_data = json.dumps({"trunc": "p"})
|
||||
@ -441,7 +441,7 @@ def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node"
|
||||
db_model.node_type = NodeType.LLM
|
||||
db_model.node_type = BuiltinNodeTypes.LLM
|
||||
db_model.title = "t"
|
||||
db_model.inputs = json.dumps({"i": 1})
|
||||
db_model.process_data = json.dumps({"p": 2})
|
||||
|
||||
@ -593,7 +593,6 @@ def test_handle_list_messages_basic(llm_node):
|
||||
|
||||
|
||||
def test_handle_list_messages_jinja2_uses_template_renderer(llm_node):
|
||||
llm_node._template_renderer.render_jinja2.return_value = "Hello, world"
|
||||
messages = [
|
||||
LLMNodeChatModelMessage(
|
||||
text="",
|
||||
@ -603,20 +602,16 @@ def test_handle_list_messages_jinja2_uses_template_renderer(llm_node):
|
||||
)
|
||||
]
|
||||
|
||||
result = llm_node.handle_list_messages(
|
||||
messages=messages,
|
||||
context=None,
|
||||
jinja2_variables=[],
|
||||
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||
vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
template_renderer=llm_node._template_renderer,
|
||||
)
|
||||
with mock.patch("dify_graph.nodes.llm.node._render_jinja2_message", return_value="Hello, world"):
|
||||
result = llm_node.handle_list_messages(
|
||||
messages=messages,
|
||||
context=None,
|
||||
jinja2_variables=[],
|
||||
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||
vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
)
|
||||
|
||||
assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])]
|
||||
llm_node._template_renderer.render_jinja2.assert_called_once_with(
|
||||
template="Hello, {{ name }}",
|
||||
inputs={},
|
||||
)
|
||||
|
||||
|
||||
def test_handle_memory_completion_mode_uses_prompt_message_interface():
|
||||
|
||||
@ -728,7 +728,7 @@ def test_get_oauth_client_should_return_decrypted_system_client_when_verified(
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
|
||||
mocker.patch(
|
||||
"services.trigger.trigger_provider_service.decrypt_system_oauth_params",
|
||||
"services.trigger.trigger_provider_service.decrypt_system_params",
|
||||
return_value={"client_id": "system"},
|
||||
)
|
||||
|
||||
@ -754,7 +754,7 @@ def test_get_oauth_client_should_raise_error_when_system_decryption_fails(
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
|
||||
mocker.patch(
|
||||
"services.trigger.trigger_provider_service.decrypt_system_oauth_params",
|
||||
"services.trigger.trigger_provider_service.decrypt_system_params",
|
||||
side_effect=RuntimeError("bad data"),
|
||||
)
|
||||
|
||||
|
||||
@ -2444,6 +2444,8 @@ class TestWorkflowServiceDraftExecution:
|
||||
patch("services.workflow_service.DifyCoreRepositoryFactory") as mock_repo_factory,
|
||||
patch("services.workflow_service.DraftVariableSaver") as mock_saver_cls,
|
||||
patch("services.workflow_service.storage"),
|
||||
patch("services.workflow_service.SandboxProviderService"),
|
||||
patch("services.workflow_service.SandboxService"),
|
||||
):
|
||||
mock_node = MagicMock()
|
||||
mock_node.node_type = BuiltinNodeTypes.START
|
||||
@ -2513,6 +2515,8 @@ class TestWorkflowServiceDraftExecution:
|
||||
patch("services.workflow_service.DifyCoreRepositoryFactory"),
|
||||
patch("services.workflow_service.DraftVariableSaver"),
|
||||
patch("services.workflow_service.storage"),
|
||||
patch("services.workflow_service.SandboxProviderService"),
|
||||
patch("services.workflow_service.SandboxService"),
|
||||
):
|
||||
mock_node = MagicMock()
|
||||
mock_node.node_type = BuiltinNodeTypes.LLM
|
||||
|
||||
Reference in New Issue
Block a user