refactor:Decouple Domain Models from Direct Database Access (#27316)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
heyszt
2025-10-28 09:59:30 +08:00
committed by GitHub
parent 341b3ae7c9
commit 543c5236e7
7 changed files with 595 additions and 264 deletions

View File

@ -1,6 +1,7 @@
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file import file_manager
@ -18,7 +19,9 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from factories import file_factory
from models.model import AppMode, Conversation, Message, MessageFile
from models.workflow import Workflow, WorkflowRun
from models.workflow import Workflow
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
class TokenBufferMemory:
@ -29,6 +32,14 @@ class TokenBufferMemory:
):
self.conversation = conversation
self.model_instance = model_instance
self._workflow_run_repo: APIWorkflowRunRepository | None = None
@property
def workflow_run_repo(self) -> APIWorkflowRunRepository:
if self._workflow_run_repo is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return self._workflow_run_repo
def _build_prompt_message_with_files(
self,
@ -50,7 +61,16 @@ class TokenBufferMemory:
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_run = db.session.scalar(select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id))
app = self.conversation.app
if not app:
raise ValueError("App not found for conversation")
if not message.workflow_run_id:
raise ValueError("Workflow run ID not found")
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
)
if not workflow_run:
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))

View File

@ -12,7 +12,7 @@ from uuid import UUID, uuid4
from cachetools import LRUCache
from flask import current_app
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
@ -34,7 +34,8 @@ from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog, WorkflowRun
from models.workflow import WorkflowAppLog
from repositories.factory import DifyAPIRepositoryFactory
from tasks.ops_trace_task import process_trace_tasks
if TYPE_CHECKING:
@ -419,6 +420,18 @@ class OpsTraceManager:
class TraceTask:
_workflow_run_repo = None
_repo_lock = threading.Lock()
@classmethod
def _get_workflow_run_repo(cls):
if cls._workflow_run_repo is None:
with cls._repo_lock:
if cls._workflow_run_repo is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return cls._workflow_run_repo
def __init__(
self,
trace_type: Any,
@ -486,27 +499,27 @@ class TraceTask:
if not workflow_run_id:
return {}
workflow_run_repo = self._get_workflow_run_repo()
workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(run_id=workflow_run_id)
if not workflow_run:
raise ValueError("Workflow run not found")
workflow_id = workflow_run.workflow_id
tenant_id = workflow_run.tenant_id
workflow_run_id = workflow_run.id
workflow_run_elapsed_time = workflow_run.elapsed_time
workflow_run_status = workflow_run.status
workflow_run_inputs = workflow_run.inputs_dict
workflow_run_outputs = workflow_run.outputs_dict
workflow_run_version = workflow_run.version
error = workflow_run.error or ""
total_tokens = workflow_run.total_tokens
file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
with Session(db.engine) as session:
workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalars(workflow_run_stmt).first()
if not workflow_run:
raise ValueError("Workflow run not found")
workflow_id = workflow_run.workflow_id
tenant_id = workflow_run.tenant_id
workflow_run_id = workflow_run.id
workflow_run_elapsed_time = workflow_run.elapsed_time
workflow_run_status = workflow_run.status
workflow_run_inputs = workflow_run.inputs_dict
workflow_run_outputs = workflow_run.outputs_dict
workflow_run_version = workflow_run.version
error = workflow_run.error or ""
total_tokens = workflow_run.total_tokens
file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
# get workflow_app_log_id
workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
WorkflowAppLog.tenant_id == tenant_id,
@ -523,43 +536,43 @@ class TraceTask:
)
message_id = session.scalar(message_data_stmt)
metadata = {
"workflow_id": workflow_id,
"conversation_id": conversation_id,
"workflow_run_id": workflow_run_id,
"tenant_id": tenant_id,
"elapsed_time": workflow_run_elapsed_time,
"status": workflow_run_status,
"version": workflow_run_version,
"total_tokens": total_tokens,
"file_list": file_list,
"triggered_from": workflow_run.triggered_from,
"user_id": user_id,
"app_id": workflow_run.app_id,
}
metadata = {
"workflow_id": workflow_id,
"conversation_id": conversation_id,
"workflow_run_id": workflow_run_id,
"tenant_id": tenant_id,
"elapsed_time": workflow_run_elapsed_time,
"status": workflow_run_status,
"version": workflow_run_version,
"total_tokens": total_tokens,
"file_list": file_list,
"triggered_from": workflow_run.triggered_from,
"user_id": user_id,
"app_id": workflow_run.app_id,
}
workflow_trace_info = WorkflowTraceInfo(
trace_id=self.trace_id,
workflow_data=workflow_run.to_dict(),
conversation_id=conversation_id,
workflow_id=workflow_id,
tenant_id=tenant_id,
workflow_run_id=workflow_run_id,
workflow_run_elapsed_time=workflow_run_elapsed_time,
workflow_run_status=workflow_run_status,
workflow_run_inputs=workflow_run_inputs,
workflow_run_outputs=workflow_run_outputs,
workflow_run_version=workflow_run_version,
error=error,
total_tokens=total_tokens,
file_list=file_list,
query=query,
metadata=metadata,
workflow_app_log_id=workflow_app_log_id,
message_id=message_id,
start_time=workflow_run.created_at,
end_time=workflow_run.finished_at,
)
workflow_trace_info = WorkflowTraceInfo(
trace_id=self.trace_id,
workflow_data=workflow_run.to_dict(),
conversation_id=conversation_id,
workflow_id=workflow_id,
tenant_id=tenant_id,
workflow_run_id=workflow_run_id,
workflow_run_elapsed_time=workflow_run_elapsed_time,
workflow_run_status=workflow_run_status,
workflow_run_inputs=workflow_run_inputs,
workflow_run_outputs=workflow_run_outputs,
workflow_run_version=workflow_run_version,
error=error,
total_tokens=total_tokens,
file_list=file_list,
query=query,
metadata=metadata,
workflow_app_log_id=workflow_app_log_id,
message_id=message_id,
start_time=workflow_run.created_at,
end_time=workflow_run.finished_at,
)
return workflow_trace_info
def message_trace(self, message_id: str | None):