mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
refactor: use sessionmaker in tool_label_manager.py (#34895)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
@ -19,10 +20,18 @@ class ToolLabelManager:
|
||||
return list(set(tool_labels))
|
||||
|
||||
@classmethod
|
||||
def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
|
||||
def update_tool_labels(
|
||||
cls, controller: ToolProviderController, labels: list[str], session: Session | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Update tool labels
|
||||
|
||||
:param controller: tool provider controller
|
||||
:param labels: list of tool labels
|
||||
:param session: database session, if None, a new session will be created
|
||||
:return: None
|
||||
"""
|
||||
|
||||
labels = cls.filter_tool_labels(labels)
|
||||
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
@ -30,26 +39,46 @@ class ToolLabelManager:
|
||||
else:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
if session is not None:
|
||||
cls._update_tool_labels_logics(session, provider_id, controller, labels)
|
||||
else:
|
||||
with sessionmaker(db.engine).begin() as _session:
|
||||
cls._update_tool_labels_logics(_session, provider_id, controller, labels)
|
||||
|
||||
@classmethod
|
||||
def _update_tool_labels_logics(
|
||||
cls, session: Session, provider_id: str, controller: ToolProviderController, labels: list[str]
|
||||
) -> None:
|
||||
"""
|
||||
Update tool labels logics
|
||||
|
||||
:param session: database session
|
||||
:param provider_id: tool provider ID
|
||||
:param controller: tool provider controller
|
||||
:param labels: list of tool labels
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# delete old labels
|
||||
db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id))
|
||||
_ = session.execute(
|
||||
delete(ToolLabelBinding).where(
|
||||
ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_type == controller.provider_type
|
||||
)
|
||||
)
|
||||
|
||||
# insert new labels
|
||||
for label in labels:
|
||||
db.session.add(
|
||||
ToolLabelBinding(
|
||||
tool_id=provider_id,
|
||||
tool_type=controller.provider_type,
|
||||
label_name=label,
|
||||
)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
session.add(ToolLabelBinding(tool_id=provider_id, tool_type=controller.provider_type, label_name=label))
|
||||
|
||||
@classmethod
|
||||
def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
|
||||
"""
|
||||
Get tool labels
|
||||
|
||||
:param controller: tool provider controller
|
||||
:return: list of tool labels (str)
|
||||
"""
|
||||
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id
|
||||
elif isinstance(controller, BuiltinToolProviderController):
|
||||
@ -60,9 +89,11 @@ class ToolLabelManager:
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type,
|
||||
)
|
||||
labels = db.session.scalars(stmt).all()
|
||||
|
||||
return list(labels)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
labels: list[str] = list(_session.scalars(stmt).all())
|
||||
|
||||
return labels
|
||||
|
||||
@classmethod
|
||||
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
|
||||
@ -78,16 +109,22 @@ class ToolLabelManager:
|
||||
if not tool_providers:
|
||||
return {}
|
||||
|
||||
provider_ids: list[str] = []
|
||||
provider_types: set[str] = set()
|
||||
|
||||
for controller in tool_providers:
|
||||
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
provider_ids = []
|
||||
for controller in tool_providers:
|
||||
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
|
||||
provider_ids.append(controller.provider_id)
|
||||
provider_types.add(controller.provider_type)
|
||||
|
||||
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
|
||||
labels: list[ToolLabelBinding] = []
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
stmt = select(ToolLabelBinding).where(
|
||||
ToolLabelBinding.tool_id.in_(provider_ids), ToolLabelBinding.tool_type.in_(list(provider_types))
|
||||
)
|
||||
labels = list(_session.scalars(stmt).all())
|
||||
|
||||
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user