refactor: use sessionmaker in workflow_tools_manage_service.py (#34896)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
HeYinKazune
2026-04-13 12:47:29 +09:00
committed by GitHub
parent 8436470fcb
commit a111d56ea3

View File

@ -4,7 +4,7 @@ from datetime import datetime
from graphon.model_runtime.utils.encoders import jsonable_encoder
from sqlalchemy import delete, or_, select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
@ -42,32 +42,43 @@ class WorkflowToolManageService:
labels: list[str] | None = None,
):
# check if the name is unique
existing_workflow_tool_provider = db.session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
existing_workflow_tool_provider: WorkflowToolProvider | None = None
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
# query if the name or app_id exists
existing_workflow_tool_provider = _session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
)
.limit(1)
)
.limit(1)
)
# if the name or app_id exists raise error
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
app: App | None = db.session.scalar(
select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1)
)
# query the app
app: App | None = None
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
app = _session.scalar(select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1))
# if not found raise error
if app is None:
raise ValueError(f"App {workflow_app_id} not found")
# query the workflow
workflow: Workflow | None = app.workflow
# if not found raise error
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_app_id}")
# check if workflow configuration is synced
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict)
# create workflow tool provider
workflow_tool_provider = WorkflowToolProvider(
tenant_id=tenant_id,
user_id=user_id,
@ -87,13 +98,15 @@ class WorkflowToolManageService:
logger.warning(e, exc_info=True)
raise ValueError(str(e))
with Session(db.engine, expire_on_commit=False) as session, session.begin():
session.add(workflow_tool_provider)
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
_session.add(workflow_tool_provider)
# keep the session open to make orm instances in the same session
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
return {"result": "success"}
@classmethod
@ -112,6 +125,7 @@ class WorkflowToolManageService:
):
"""
Update a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_tool_id: workflow tool id
@ -187,28 +201,32 @@ class WorkflowToolManageService:
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
"""
List workflow tools.
:param user_id: the user id
:param tenant_id: the tenant id
:return: the list of tools
"""
db_tools = db.session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()
providers: list[WorkflowToolProvider] = []
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
providers = list(
_session.scalars(select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)).all()
)
# Create a mapping from provider_id to app_id
provider_id_to_app_id = {provider.id: provider.app_id for provider in db_tools}
provider_id_to_app_id = {provider.id: provider.app_id for provider in providers}
tools: list[WorkflowToolProviderController] = []
for provider in db_tools:
for provider in providers:
try:
tools.append(ToolTransformService.workflow_provider_to_controller(provider))
except Exception:
# skip deleted tools
logger.exception("Failed to load workflow tool provider %s", provider.id)
labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)])
labels = ToolLabelManager.get_tools_labels([tool for tool in tools if isinstance(tool, ToolProviderController)])
result = []
result: list[ToolProviderApiEntity] = []
for tool in tools:
workflow_app_id = provider_id_to_app_id.get(tool.provider_id)
@ -233,17 +251,18 @@ class WorkflowToolManageService:
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
"""
Delete a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_tool_id: the workflow tool id
"""
db.session.execute(
delete(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
)
)
db.session.commit()
with sessionmaker(db.engine).begin() as _session:
_ = _session.execute(
delete(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
)
)
return {"result": "success"}
@ -251,47 +270,59 @@ class WorkflowToolManageService:
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
"""
Get a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_tool_id: the workflow tool id
:return: the tool
"""
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.limit(1)
)
return cls._get_workflow_tool(tenant_id, db_tool)
tool_provider: WorkflowToolProvider | None = None
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
tool_provider = _session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.limit(1)
)
return cls._get_workflow_tool(tenant_id, tool_provider)
@classmethod
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str):
"""
Get a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.limit(1)
)
return cls._get_workflow_tool(tenant_id, db_tool)
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
tool_provider: WorkflowToolProvider | None = _session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.limit(1)
)
return cls._get_workflow_tool(tenant_id, tool_provider)
@classmethod
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
"""
Get a workflow tool.
:db_tool: the database tool
:return: the tool
"""
if db_tool is None:
raise ValueError("Tool not found")
workflow_app: App | None = db.session.scalar(
select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1)
)
workflow_app: App | None = None
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
workflow_app = _session.scalar(
select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1)
)
if workflow_app is None:
raise ValueError(f"App {db_tool.app_id} not found")
@ -331,28 +362,32 @@ class WorkflowToolManageService:
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[ToolApiEntity]:
"""
List workflow tool provider tools.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_tool_id: the workflow tool id
:return: the list of tools
"""
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.limit(1)
)
if db_tool is None:
provider: WorkflowToolProvider | None = None
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
provider = _session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.limit(1)
)
if provider is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
tool = ToolTransformService.workflow_provider_to_controller(provider)
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
if len(workflow_tools) == 0:
raise ValueError(f"Tool {workflow_tool_id} not found")
return [
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool.get_tools(db_tool.tenant_id)[0],
tool=tool.get_tools(provider.tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool),
tenant_id=tenant_id,
)