diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 06555e5842..c1c255f206 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -467,7 +467,8 @@ class AppListApi(Resource): @login_required @account_initialization_required @enterprise_license_required - def get(self): + @with_session(write=False) + def get(self, session: Session): """Get app list""" current_user, current_tenant_id = current_account_with_tenant() @@ -504,7 +505,7 @@ class AppListApi(Resource): draft_trigger_app_ids: set[str] = set() if workflow_capable_app_ids: draft_workflows = ( - db.session.execute( + session.execute( select(Workflow).where( Workflow.version == Workflow.VERSION_DRAFT, Workflow.app_id.in_(workflow_capable_app_ids), diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 47a9b8aedb..67367cbe99 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from flask_restx import Resource from pydantic import BaseModel, Field +from sqlalchemy.orm import Session from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns @@ -11,6 +12,7 @@ from controllers.console.app.error import ( ProviderNotInitializeError, ProviderQuotaExceededError, ) +from controllers.console.app.wraps import with_session from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id from core.app.app_config.entities import ModelConfig from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError @@ -19,7 +21,6 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator -from extensions.ext_database import db from graphon.model_runtime.entities.llm_entities import LLMMode from graphon.model_runtime.errors.invoke import InvokeError from libs.login import login_required @@ -158,7 +159,8 @@ class InstructionGenerateApi(Resource): @login_required @account_initialization_required @with_current_tenant_id - def post(self, current_tenant_id: str): + @with_session(write=False) + def post(self, session: Session, current_tenant_id: str): args = InstructionGeneratePayload.model_validate(console_ns.payload) providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] code_provider: type[CodeNodeProvider] | None = next( @@ -168,10 +170,10 @@ class InstructionGenerateApi(Resource): try: # Generate from nothing for a workflow node if (args.current in (code_template, "")) and args.node_id != "": - app = db.session.get(App, args.flow_id) + app = session.get(App, args.flow_id) if not app: return {"error": f"app {args.flow_id} not found"}, 400 - workflow = WorkflowService().get_draft_workflow(app_model=app) + workflow = WorkflowService().get_draft_workflow(app_model=app, session=session) if not workflow: return {"error": f"workflow {args.flow_id} not found"}, 400 nodes: Sequence = workflow.graph_dict["nodes"] diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 6d9ee97fa4..36b760d37a 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -140,14 +140,21 @@ class WorkflowService: ) return db.session.execute(stmt).scalar_one() - def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None: + def get_draft_workflow( + self, app_model: App, workflow_id: str | None = None, session: Session | None = None + ) -> Workflow | None: """ Get draft workflow + + When ``session`` is provided, reuse it so callers that already hold a + Session avoid checking out an extra request-scoped ``db.session`` + connection. Falls back to ``db.session`` for backward compatibility. """ if workflow_id: - return self.get_published_workflow_by_id(app_model, workflow_id) + return self.get_published_workflow_by_id(app_model, workflow_id, session=session) # fetch draft workflow by app_model - workflow = db.session.scalar( + bind = session if session is not None else db.session + workflow = bind.scalar( select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py index 80e7c41a9e..f9d3f0ad87 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -7,6 +7,7 @@ from importlib import util from pathlib import Path from types import ModuleType, SimpleNamespace from typing import Any +from unittest.mock import MagicMock import pytest from flask.views import MethodView @@ -18,6 +19,15 @@ if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + @pytest.fixture(scope="module") def app_module(): module_name = "controllers.console.app.app" @@ -395,3 +405,46 @@ def test_app_pagination_aliases_per_page_and_has_next(app_models): assert len(serialized["data"]) == 2 assert serialized["data"][0]["icon_url"] == "signed:first-icon" assert serialized["data"][1]["icon_url"] is None + + +def test_app_list_uses_injected_session_for_draft_workflows(app, app_module, monkeypatch): + api = app_module.AppListApi() + method = _unwrap(api.get) + current_user = SimpleNamespace(id="user-1") + app_item = SimpleNamespace( + id="app-1", + name="Workflow App", + desc_or_prompt="Summary", + mode="workflow", + mode_compatible_with_agent="workflow", + ) + app_pagination = SimpleNamespace(page=1, per_page=20, total=1, has_next=False, items=[app_item]) + workflow = SimpleNamespace( + id="workflow-1", + app_id="app-1", + walk_nodes=lambda: iter([("trigger-1", {"type": "trigger-webhook"})]), + ) + session = MagicMock() + session.execute.return_value.scalars.return_value.all.return_value = [workflow] + scoped_session = SimpleNamespace(execute=MagicMock(side_effect=AssertionError("db.session should not be used"))) + + monkeypatch.setattr(app_module, "current_account_with_tenant", lambda: (current_user, "tenant-1")) + monkeypatch.setattr( + app_module, + "AppService", + lambda: SimpleNamespace(get_paginate_apps=lambda *_args, **_kwargs: app_pagination), + ) + monkeypatch.setattr( + app_module, + "FeatureService", + SimpleNamespace(get_system_features=lambda: SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))), + ) + monkeypatch.setattr(app_module, "db", SimpleNamespace(session=scoped_session)) + + with app.test_request_context("/console/api/apps?page=1&limit=20", method="GET"): + response, status = method(session) + + assert status == 200 + assert response["data"][0]["has_draft_trigger"] is True + session.execute.assert_called_once() + scoped_session.execute.assert_not_called() diff --git a/api/tests/unit_tests/controllers/console/app/test_generator_api.py b/api/tests/unit_tests/controllers/console/app/test_generator_api.py index 11c6acfcc1..b5cf867455 100644 --- a/api/tests/unit_tests/controllers/console/app/test_generator_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_generator_api.py @@ -1,6 +1,7 @@ from __future__ import annotations from types import SimpleNamespace +from unittest.mock import MagicMock import pytest @@ -24,10 +25,17 @@ def _model_config_payload(): def _install_workflow_service(monkeypatch: pytest.MonkeyPatch, workflow): class _Service: - def get_draft_workflow(self, app_model): + app_model = None + session = None + + def get_draft_workflow(self, app_model, session=None): + self.app_model = app_model + self.session = session return workflow - monkeypatch.setattr(generator_module, "WorkflowService", lambda: _Service()) + service = _Service() + monkeypatch.setattr(generator_module, "WorkflowService", lambda: service) + return service def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -68,7 +76,8 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch api = generator_module.InstructionGenerateApi() method = _unwrap(api.post) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None)) + session = MagicMock() + session.get.return_value = None with app.test_request_context( "/console/api/instruction-generate", @@ -80,10 +89,11 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch "model_config": _model_config_payload(), }, ): - response, status = method("t1") + response, status = method(session, "t1") assert status == 400 assert response["error"] == "app app-1 not found" + session.get.assert_called_once_with(generator_module.App, "app-1") def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -91,7 +101,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey method = _unwrap(api.post) app_model = SimpleNamespace(id="app-1") - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) + session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model) _install_workflow_service(monkeypatch, workflow=None) with app.test_request_context( @@ -104,7 +114,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey "model_config": _model_config_payload(), }, ): - response, status = method("t1") + response, status = method(session, "t1") assert status == 400 assert response["error"] == "workflow app-1 not found" @@ -115,7 +125,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) method = _unwrap(api.post) app_model = SimpleNamespace(id="app-1") - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) + session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model) workflow = SimpleNamespace(graph_dict={"nodes": []}) _install_workflow_service(monkeypatch, workflow=workflow) @@ -130,7 +140,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) "model_config": _model_config_payload(), }, ): - response, status = method("t1") + response, status = method(session, "t1") assert status == 400 assert response["error"] == "node node-1 not found" @@ -141,7 +151,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> method = _unwrap(api.post) app_model = SimpleNamespace(id="app-1") - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) + session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model) workflow = SimpleNamespace( graph_dict={ @@ -150,7 +160,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> ] } ) - _install_workflow_service(monkeypatch, workflow=workflow) + workflow_service = _install_workflow_service(monkeypatch, workflow=workflow) monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", lambda **_kwargs: {"code": "x"}) with app.test_request_context( @@ -163,14 +173,17 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> "model_config": _model_config_payload(), }, ): - response = method("t1") + response = method(session, "t1") assert response == {"code": "x"} + assert workflow_service.app_model is app_model + assert workflow_service.session is session def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch) -> None: api = generator_module.InstructionGenerateApi() method = _unwrap(api.post) + session = SimpleNamespace() monkeypatch.setattr( generator_module.LLMGenerator, @@ -189,7 +202,7 @@ def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch "model_config": _model_config_payload(), }, ): - response = method("t1") + response = method(session, "t1") assert response == {"instruction": "ok"} @@ -197,6 +210,7 @@ def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.MonkeyPatch) -> None: api = generator_module.InstructionGenerateApi() method = _unwrap(api.post) + session = SimpleNamespace() with app.test_request_context( "/console/api/instruction-generate", @@ -209,7 +223,7 @@ def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.Monke "model_config": _model_config_payload(), }, ): - response, status = method("t1") + response, status = method(session, "t1") assert status == 400 assert response["error"] == "incompatible parameters" diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index f105364094..d384c5a83b 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -346,6 +346,19 @@ class TestWorkflowService: assert result == mock_workflow + def test_get_draft_workflow_uses_provided_session(self, workflow_service, mock_db_session): + """Test get_draft_workflow can reuse an injected SQLAlchemy session.""" + app = TestWorkflowAssociatedDataFactory.create_app_mock() + mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock() + session = MagicMock() + session.scalar.return_value = mock_workflow + + result = workflow_service.get_draft_workflow(app, session=session) + + assert result == mock_workflow + session.scalar.assert_called_once() + mock_db_session.session.scalar.assert_not_called() + def test_get_draft_workflow_returns_none(self, workflow_service, mock_db_session): """Test get_draft_workflow returns None when no draft exists.""" app = TestWorkflowAssociatedDataFactory.create_app_mock() @@ -370,6 +383,21 @@ class TestWorkflowService: assert result == mock_workflow + def test_get_draft_workflow_with_workflow_id_reuses_provided_session(self, workflow_service): + """Test get_draft_workflow passes an injected session to published workflow lookup.""" + app = TestWorkflowAssociatedDataFactory.create_app_mock() + workflow_id = "workflow-123" + session = MagicMock() + mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(version="v1") + + with patch.object( + workflow_service, "get_published_workflow_by_id", return_value=mock_workflow + ) as mock_get_published: + result = workflow_service.get_draft_workflow(app, workflow_id=workflow_id, session=session) + + assert result == mock_workflow + mock_get_published.assert_called_once_with(app, workflow_id, session=session) + # ==================== Get Published Workflow Tests ==================== # These tests verify retrieval of published workflows (versioned snapshots)