mirror of
https://github.com/langgenius/dify.git
synced 2026-05-31 06:06:20 +08:00
chore: reuse injected SQLAlchemy sessions in app read paths (#36798)
This commit is contained in:
@ -467,7 +467,8 @@ class AppListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
@with_session(write=False)
|
||||
def get(self, session: Session):
|
||||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -504,7 +505,7 @@ class AppListApi(Resource):
|
||||
draft_trigger_app_ids: set[str] = set()
|
||||
if workflow_capable_app_ids:
|
||||
draft_workflows = (
|
||||
db.session.execute(
|
||||
session.execute(
|
||||
select(Workflow).where(
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
Workflow.app_id.in_(workflow_capable_app_ids),
|
||||
|
||||
@ -2,6 +2,7 @@ from collections.abc import Sequence
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
@ -11,6 +12,7 @@ from controllers.console.app.error import (
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.wraps import with_session
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
@ -19,7 +21,6 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
@ -158,7 +159,8 @@ class InstructionGenerateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
@with_session(write=False)
|
||||
def post(self, session: Session, current_tenant_id: str):
|
||||
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
@ -168,10 +170,10 @@ class InstructionGenerateApi(Resource):
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.get(App, args.flow_id)
|
||||
app = session.get(App, args.flow_id)
|
||||
if not app:
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app, session=session)
|
||||
if not workflow:
|
||||
return {"error": f"workflow {args.flow_id} not found"}, 400
|
||||
nodes: Sequence = workflow.graph_dict["nodes"]
|
||||
|
||||
@ -140,14 +140,21 @@ class WorkflowService:
|
||||
)
|
||||
return db.session.execute(stmt).scalar_one()
|
||||
|
||||
def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None:
|
||||
def get_draft_workflow(
|
||||
self, app_model: App, workflow_id: str | None = None, session: Session | None = None
|
||||
) -> Workflow | None:
|
||||
"""
|
||||
Get draft workflow
|
||||
|
||||
When ``session`` is provided, reuse it so callers that already hold a
|
||||
Session avoid checking out an extra request-scoped ``db.session``
|
||||
connection. Falls back to ``db.session`` for backward compatibility.
|
||||
"""
|
||||
if workflow_id:
|
||||
return self.get_published_workflow_by_id(app_model, workflow_id)
|
||||
return self.get_published_workflow_by_id(app_model, workflow_id, session=session)
|
||||
# fetch draft workflow by app_model
|
||||
workflow = db.session.scalar(
|
||||
bind = session if session is not None else db.session
|
||||
workflow = bind.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
|
||||
@ -7,6 +7,7 @@ from importlib import util
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask.views import MethodView
|
||||
@ -18,6 +19,15 @@ if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def app_module():
|
||||
module_name = "controllers.console.app.app"
|
||||
@ -395,3 +405,46 @@ def test_app_pagination_aliases_per_page_and_has_next(app_models):
|
||||
assert len(serialized["data"]) == 2
|
||||
assert serialized["data"][0]["icon_url"] == "signed:first-icon"
|
||||
assert serialized["data"][1]["icon_url"] is None
|
||||
|
||||
|
||||
def test_app_list_uses_injected_session_for_draft_workflows(app, app_module, monkeypatch):
|
||||
api = app_module.AppListApi()
|
||||
method = _unwrap(api.get)
|
||||
current_user = SimpleNamespace(id="user-1")
|
||||
app_item = SimpleNamespace(
|
||||
id="app-1",
|
||||
name="Workflow App",
|
||||
desc_or_prompt="Summary",
|
||||
mode="workflow",
|
||||
mode_compatible_with_agent="workflow",
|
||||
)
|
||||
app_pagination = SimpleNamespace(page=1, per_page=20, total=1, has_next=False, items=[app_item])
|
||||
workflow = SimpleNamespace(
|
||||
id="workflow-1",
|
||||
app_id="app-1",
|
||||
walk_nodes=lambda: iter([("trigger-1", {"type": "trigger-webhook"})]),
|
||||
)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalars.return_value.all.return_value = [workflow]
|
||||
scoped_session = SimpleNamespace(execute=MagicMock(side_effect=AssertionError("db.session should not be used")))
|
||||
|
||||
monkeypatch.setattr(app_module, "current_account_with_tenant", lambda: (current_user, "tenant-1"))
|
||||
monkeypatch.setattr(
|
||||
app_module,
|
||||
"AppService",
|
||||
lambda: SimpleNamespace(get_paginate_apps=lambda *_args, **_kwargs: app_pagination),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app_module,
|
||||
"FeatureService",
|
||||
SimpleNamespace(get_system_features=lambda: SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))),
|
||||
)
|
||||
monkeypatch.setattr(app_module, "db", SimpleNamespace(session=scoped_session))
|
||||
|
||||
with app.test_request_context("/console/api/apps?page=1&limit=20", method="GET"):
|
||||
response, status = method(session)
|
||||
|
||||
assert status == 200
|
||||
assert response["data"][0]["has_draft_trigger"] is True
|
||||
session.execute.assert_called_once()
|
||||
scoped_session.execute.assert_not_called()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@ -24,10 +25,17 @@ def _model_config_payload():
|
||||
|
||||
def _install_workflow_service(monkeypatch: pytest.MonkeyPatch, workflow):
|
||||
class _Service:
|
||||
def get_draft_workflow(self, app_model):
|
||||
app_model = None
|
||||
session = None
|
||||
|
||||
def get_draft_workflow(self, app_model, session=None):
|
||||
self.app_model = app_model
|
||||
self.session = session
|
||||
return workflow
|
||||
|
||||
monkeypatch.setattr(generator_module, "WorkflowService", lambda: _Service())
|
||||
service = _Service()
|
||||
monkeypatch.setattr(generator_module, "WorkflowService", lambda: service)
|
||||
return service
|
||||
|
||||
|
||||
def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -68,7 +76,8 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None))
|
||||
session = MagicMock()
|
||||
session.get.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
@ -80,10 +89,11 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method("t1")
|
||||
response, status = method(session, "t1")
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "app app-1 not found"
|
||||
session.get.assert_called_once_with(generator_module.App, "app-1")
|
||||
|
||||
|
||||
def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -91,7 +101,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey
|
||||
method = _unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
|
||||
_install_workflow_service(monkeypatch, workflow=None)
|
||||
|
||||
with app.test_request_context(
|
||||
@ -104,7 +114,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method("t1")
|
||||
response, status = method(session, "t1")
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "workflow app-1 not found"
|
||||
@ -115,7 +125,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch)
|
||||
method = _unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
|
||||
|
||||
workflow = SimpleNamespace(graph_dict={"nodes": []})
|
||||
_install_workflow_service(monkeypatch, workflow=workflow)
|
||||
@ -130,7 +140,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch)
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method("t1")
|
||||
response, status = method(session, "t1")
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "node node-1 not found"
|
||||
@ -141,7 +151,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
|
||||
method = _unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
|
||||
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={
|
||||
@ -150,7 +160,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
|
||||
]
|
||||
}
|
||||
)
|
||||
_install_workflow_service(monkeypatch, workflow=workflow)
|
||||
workflow_service = _install_workflow_service(monkeypatch, workflow=workflow)
|
||||
monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", lambda **_kwargs: {"code": "x"})
|
||||
|
||||
with app.test_request_context(
|
||||
@ -163,14 +173,17 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response = method("t1")
|
||||
response = method(session, "t1")
|
||||
|
||||
assert response == {"code": "x"}
|
||||
assert workflow_service.app_model is app_model
|
||||
assert workflow_service.session is session
|
||||
|
||||
|
||||
def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
session = SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(
|
||||
generator_module.LLMGenerator,
|
||||
@ -189,7 +202,7 @@ def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response = method("t1")
|
||||
response = method(session, "t1")
|
||||
|
||||
assert response == {"instruction": "ok"}
|
||||
|
||||
@ -197,6 +210,7 @@ def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch
|
||||
def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
session = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
@ -209,7 +223,7 @@ def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.Monke
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method("t1")
|
||||
response, status = method(session, "t1")
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "incompatible parameters"
|
||||
|
||||
@ -346,6 +346,19 @@ class TestWorkflowService:
|
||||
|
||||
assert result == mock_workflow
|
||||
|
||||
def test_get_draft_workflow_uses_provided_session(self, workflow_service, mock_db_session):
|
||||
"""Test get_draft_workflow can reuse an injected SQLAlchemy session."""
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock()
|
||||
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock()
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = mock_workflow
|
||||
|
||||
result = workflow_service.get_draft_workflow(app, session=session)
|
||||
|
||||
assert result == mock_workflow
|
||||
session.scalar.assert_called_once()
|
||||
mock_db_session.session.scalar.assert_not_called()
|
||||
|
||||
def test_get_draft_workflow_returns_none(self, workflow_service, mock_db_session):
|
||||
"""Test get_draft_workflow returns None when no draft exists."""
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock()
|
||||
@ -370,6 +383,21 @@ class TestWorkflowService:
|
||||
|
||||
assert result == mock_workflow
|
||||
|
||||
def test_get_draft_workflow_with_workflow_id_reuses_provided_session(self, workflow_service):
|
||||
"""Test get_draft_workflow passes an injected session to published workflow lookup."""
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock()
|
||||
workflow_id = "workflow-123"
|
||||
session = MagicMock()
|
||||
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(version="v1")
|
||||
|
||||
with patch.object(
|
||||
workflow_service, "get_published_workflow_by_id", return_value=mock_workflow
|
||||
) as mock_get_published:
|
||||
result = workflow_service.get_draft_workflow(app, workflow_id=workflow_id, session=session)
|
||||
|
||||
assert result == mock_workflow
|
||||
mock_get_published.assert_called_once_with(app, workflow_id, session=session)
|
||||
|
||||
# ==================== Get Published Workflow Tests ====================
|
||||
# These tests verify retrieval of published workflows (versioned snapshots)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user