chore: reuse injected SQLAlchemy sessions in app read paths (#36798)

This commit is contained in:
Myshkin451
2026-05-30 08:23:58 +08:00
committed by GitHub
parent 91ac465982
commit 0b60338ad5
6 changed files with 127 additions and 22 deletions

View File

@ -467,7 +467,8 @@ class AppListApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
@with_session(write=False)
def get(self, session: Session):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
@ -504,7 +505,7 @@ class AppListApi(Resource):
draft_trigger_app_ids: set[str] = set()
if workflow_capable_app_ids:
draft_workflows = (
db.session.execute(
session.execute(
select(Workflow).where(
Workflow.version == Workflow.VERSION_DRAFT,
Workflow.app_id.in_(workflow_capable_app_ids),

View File

@ -2,6 +2,7 @@ from collections.abc import Sequence
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
@ -11,6 +12,7 @@ from controllers.console.app.error import (
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.app.wraps import with_session
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
from core.app.app_config.entities import ModelConfig
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
@ -19,7 +21,6 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.llm_generator import LLMGenerator
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
@ -158,7 +159,8 @@ class InstructionGenerateApi(Resource):
@login_required
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str):
@with_session(write=False)
def post(self, session: Session, current_tenant_id: str):
args = InstructionGeneratePayload.model_validate(console_ns.payload)
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
@ -168,10 +170,10 @@ class InstructionGenerateApi(Resource):
try:
# Generate from nothing for a workflow node
if (args.current in (code_template, "")) and args.node_id != "":
app = db.session.get(App, args.flow_id)
app = session.get(App, args.flow_id)
if not app:
return {"error": f"app {args.flow_id} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app)
workflow = WorkflowService().get_draft_workflow(app_model=app, session=session)
if not workflow:
return {"error": f"workflow {args.flow_id} not found"}, 400
nodes: Sequence = workflow.graph_dict["nodes"]

View File

@ -140,14 +140,21 @@ class WorkflowService:
)
return db.session.execute(stmt).scalar_one()
def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None:
def get_draft_workflow(
self, app_model: App, workflow_id: str | None = None, session: Session | None = None
) -> Workflow | None:
"""
Get draft workflow
When ``session`` is provided, reuse it so callers that already hold a
Session avoid checking out an extra request-scoped ``db.session``
connection. Falls back to ``db.session`` for backward compatibility.
"""
if workflow_id:
return self.get_published_workflow_by_id(app_model, workflow_id)
return self.get_published_workflow_by_id(app_model, workflow_id, session=session)
# fetch draft workflow by app_model
workflow = db.session.scalar(
bind = session if session is not None else db.session
workflow = bind.scalar(
select(Workflow)
.where(
Workflow.tenant_id == app_model.tenant_id,

View File

@ -7,6 +7,7 @@ from importlib import util
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
from flask.views import MethodView
@ -18,6 +19,15 @@ if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
@pytest.fixture(scope="module")
def app_module():
module_name = "controllers.console.app.app"
@ -395,3 +405,46 @@ def test_app_pagination_aliases_per_page_and_has_next(app_models):
assert len(serialized["data"]) == 2
assert serialized["data"][0]["icon_url"] == "signed:first-icon"
assert serialized["data"][1]["icon_url"] is None
def test_app_list_uses_injected_session_for_draft_workflows(app, app_module, monkeypatch):
api = app_module.AppListApi()
method = _unwrap(api.get)
current_user = SimpleNamespace(id="user-1")
app_item = SimpleNamespace(
id="app-1",
name="Workflow App",
desc_or_prompt="Summary",
mode="workflow",
mode_compatible_with_agent="workflow",
)
app_pagination = SimpleNamespace(page=1, per_page=20, total=1, has_next=False, items=[app_item])
workflow = SimpleNamespace(
id="workflow-1",
app_id="app-1",
walk_nodes=lambda: iter([("trigger-1", {"type": "trigger-webhook"})]),
)
session = MagicMock()
session.execute.return_value.scalars.return_value.all.return_value = [workflow]
scoped_session = SimpleNamespace(execute=MagicMock(side_effect=AssertionError("db.session should not be used")))
monkeypatch.setattr(app_module, "current_account_with_tenant", lambda: (current_user, "tenant-1"))
monkeypatch.setattr(
app_module,
"AppService",
lambda: SimpleNamespace(get_paginate_apps=lambda *_args, **_kwargs: app_pagination),
)
monkeypatch.setattr(
app_module,
"FeatureService",
SimpleNamespace(get_system_features=lambda: SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))),
)
monkeypatch.setattr(app_module, "db", SimpleNamespace(session=scoped_session))
with app.test_request_context("/console/api/apps?page=1&limit=20", method="GET"):
response, status = method(session)
assert status == 200
assert response["data"][0]["has_draft_trigger"] is True
session.execute.assert_called_once()
scoped_session.execute.assert_not_called()

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
@ -24,10 +25,17 @@ def _model_config_payload():
def _install_workflow_service(monkeypatch: pytest.MonkeyPatch, workflow):
class _Service:
def get_draft_workflow(self, app_model):
app_model = None
session = None
def get_draft_workflow(self, app_model, session=None):
self.app_model = app_model
self.session = session
return workflow
monkeypatch.setattr(generator_module, "WorkflowService", lambda: _Service())
service = _Service()
monkeypatch.setattr(generator_module, "WorkflowService", lambda: service)
return service
def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
@ -68,7 +76,8 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None))
session = MagicMock()
session.get.return_value = None
with app.test_request_context(
"/console/api/instruction-generate",
@ -80,10 +89,11 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch
"model_config": _model_config_payload(),
},
):
response, status = method("t1")
response, status = method(session, "t1")
assert status == 400
assert response["error"] == "app app-1 not found"
session.get.assert_called_once_with(generator_module.App, "app-1")
def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
@ -91,7 +101,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey
method = _unwrap(api.post)
app_model = SimpleNamespace(id="app-1")
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
_install_workflow_service(monkeypatch, workflow=None)
with app.test_request_context(
@ -104,7 +114,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey
"model_config": _model_config_payload(),
},
):
response, status = method("t1")
response, status = method(session, "t1")
assert status == 400
assert response["error"] == "workflow app-1 not found"
@ -115,7 +125,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch)
method = _unwrap(api.post)
app_model = SimpleNamespace(id="app-1")
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
workflow = SimpleNamespace(graph_dict={"nodes": []})
_install_workflow_service(monkeypatch, workflow=workflow)
@ -130,7 +140,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch)
"model_config": _model_config_payload(),
},
):
response, status = method("t1")
response, status = method(session, "t1")
assert status == 400
assert response["error"] == "node node-1 not found"
@ -141,7 +151,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
method = _unwrap(api.post)
app_model = SimpleNamespace(id="app-1")
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
workflow = SimpleNamespace(
graph_dict={
@ -150,7 +160,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
]
}
)
_install_workflow_service(monkeypatch, workflow=workflow)
workflow_service = _install_workflow_service(monkeypatch, workflow=workflow)
monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", lambda **_kwargs: {"code": "x"})
with app.test_request_context(
@ -163,14 +173,17 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
"model_config": _model_config_payload(),
},
):
response = method("t1")
response = method(session, "t1")
assert response == {"code": "x"}
assert workflow_service.app_model is app_model
assert workflow_service.session is session
def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
session = SimpleNamespace()
monkeypatch.setattr(
generator_module.LLMGenerator,
@ -189,7 +202,7 @@ def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch
"model_config": _model_config_payload(),
},
):
response = method("t1")
response = method(session, "t1")
assert response == {"instruction": "ok"}
@ -197,6 +210,7 @@ def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch
def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
session = SimpleNamespace()
with app.test_request_context(
"/console/api/instruction-generate",
@ -209,7 +223,7 @@ def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.Monke
"model_config": _model_config_payload(),
},
):
response, status = method("t1")
response, status = method(session, "t1")
assert status == 400
assert response["error"] == "incompatible parameters"

View File

@ -346,6 +346,19 @@ class TestWorkflowService:
assert result == mock_workflow
def test_get_draft_workflow_uses_provided_session(self, workflow_service, mock_db_session):
"""Test get_draft_workflow can reuse an injected SQLAlchemy session."""
app = TestWorkflowAssociatedDataFactory.create_app_mock()
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock()
session = MagicMock()
session.scalar.return_value = mock_workflow
result = workflow_service.get_draft_workflow(app, session=session)
assert result == mock_workflow
session.scalar.assert_called_once()
mock_db_session.session.scalar.assert_not_called()
def test_get_draft_workflow_returns_none(self, workflow_service, mock_db_session):
"""Test get_draft_workflow returns None when no draft exists."""
app = TestWorkflowAssociatedDataFactory.create_app_mock()
@ -370,6 +383,21 @@ class TestWorkflowService:
assert result == mock_workflow
def test_get_draft_workflow_with_workflow_id_reuses_provided_session(self, workflow_service):
"""Test get_draft_workflow passes an injected session to published workflow lookup."""
app = TestWorkflowAssociatedDataFactory.create_app_mock()
workflow_id = "workflow-123"
session = MagicMock()
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(version="v1")
with patch.object(
workflow_service, "get_published_workflow_by_id", return_value=mock_workflow
) as mock_get_published:
result = workflow_service.get_draft_workflow(app, workflow_id=workflow_id, session=session)
assert result == mock_workflow
mock_get_published.assert_called_once_with(app, workflow_id, session=session)
# ==================== Get Published Workflow Tests ====================
# These tests verify retrieval of published workflows (versioned snapshots)