refactor(api): inject sessionmaker into conversation variable updater (#30609)

This commit is contained in:
-LAN-
2026-01-06 14:52:59 +08:00
committed by GitHub
parent f3ca8be9f9
commit d12b91a01a
3 changed files with 13 additions and 12 deletions

View File

@ -17,7 +17,7 @@ from libs.datetime_utils import naive_utc_now
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account, ConversationVariable
from models.model import App, Conversation, EndUser, Message
from services.conversation_variable_updater import conversation_variable_updater_factory
from services.conversation_variable_updater import ConversationVariableUpdater
from services.errors.conversation import (
ConversationNotExistsError,
ConversationVariableNotExistsError,
@ -337,7 +337,7 @@ class ConversationService:
updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict)
# Use the conversation variable updater to persist the changes
updater = conversation_variable_updater_factory()
updater = ConversationVariableUpdater(session_factory.get_session_maker())
updater.update(conversation_id, updated_variable)
updater.flush()

View File

@ -1,8 +1,7 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from core.variables.variables import Variable
from extensions.ext_database import db
from models import ConversationVariable
@ -10,12 +9,15 @@ class ConversationVariableNotFoundError(Exception):
pass
class ConversationVariableUpdaterImpl:
class ConversationVariableUpdater:
def __init__(self, session_maker: sessionmaker[Session]) -> None:
self._session_maker: sessionmaker[Session] = session_maker
def update(self, conversation_id: str, variable: Variable) -> None:
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
with self._session_maker() as session:
row = session.scalar(stmt)
if not row:
raise ConversationVariableNotFoundError("conversation variable not found in the database")
@ -24,7 +26,3 @@ class ConversationVariableUpdaterImpl:
def flush(self) -> None:
pass
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
return ConversationVariableUpdaterImpl()