feat: compat tool provider credentials to updated data

This commit is contained in:
Yeuoly
2024-09-30 23:22:03 +08:00
parent e12f4009d3
commit 56b7853afe
7 changed files with 111 additions and 51 deletions

View File

@ -221,7 +221,7 @@ class ApiToolManageService:
labels = ToolLabelManager.get_tool_labels(controller)
return [
ToolTransformService.tool_to_user_tool(
ToolTransformService.convert_tool_entity_to_api_entity(
tool_bundle,
tenant_id=tenant_id,
labels=labels,
@ -465,7 +465,7 @@ class ApiToolManageService:
for tool in tools:
user_provider.tools.append(
ToolTransformService.tool_to_user_tool(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
)
)

View File

@ -7,6 +7,7 @@ from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.tool_entities import ToolProviderID
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
@ -40,14 +41,7 @@ class BuiltinToolManageService:
provider_identity=provider_controller.entity.identity.name,
)
# check if user has added the provider
builtin_provider: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
.first()
)
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
credentials = {}
if builtin_provider is not None:
@ -58,7 +52,7 @@ class BuiltinToolManageService:
result = []
for tool in tools:
result.append(
ToolTransformService.tool_to_user_tool(
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool,
credentials=credentials,
tenant_id=tenant_id,
@ -86,14 +80,7 @@ class BuiltinToolManageService:
update builtin tool provider
"""
# get if the provider exists
provider: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
)
.first()
)
provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
try:
# get provider
@ -149,14 +136,7 @@ class BuiltinToolManageService:
"""
get builtin tool provider credentials
"""
provider_obj: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
.first()
)
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
if provider_obj is None:
return {}
@ -177,14 +157,7 @@ class BuiltinToolManageService:
"""
delete tool provider
"""
provider_obj: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
)
.first()
)
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
if provider_obj is None:
raise ValueError(f"you have not added provider {provider_name}")
@ -227,6 +200,13 @@ class BuiltinToolManageService:
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
)
# rewrite db_providers
for db_provider in db_providers:
try:
ToolProviderID(db_provider.provider)
except Exception:
db_provider.provider = f"langgenius/{db_provider.provider}/{db_provider.provider}"
# find provider
find_provider = lambda provider: next(
filter(lambda db_provider: db_provider.provider == provider, db_providers), None
@ -258,7 +238,7 @@ class BuiltinToolManageService:
tools = provider_controller.get_tools()
for tool in tools:
user_builtin_provider.tools.append(
ToolTransformService.tool_to_user_tool(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
credentials=user_builtin_provider.original_credentials,
@ -271,3 +251,40 @@ class BuiltinToolManageService:
raise e
return BuiltinToolProviderSort.sort(result)
@staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
try:
provider_id_entity = ToolProviderID(provider_name)
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
return None
provider_obj = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name) | (BuiltinToolProvider.provider == provider_name),
)
.first()
)
if provider_obj is None:
return None
try:
ToolProviderID(provider_obj.provider)
except Exception:
provider_obj.provider = f"langgenius/{provider_obj.provider}/{provider_obj.provider}"
return provider_obj
except Exception:
# it's an old provider without organization
return (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name),
)
.first()
)

View File

@ -223,7 +223,7 @@ class ToolTransformService:
return result
@staticmethod
def tool_to_user_tool(
def convert_tool_entity_to_api_entity(
tool: Union[ApiToolBundle, WorkflowTool, Tool],
tenant_id: str,
credentials: dict | None = None,

View File

@ -210,7 +210,7 @@ class WorkflowToolManageService:
)
ToolTransformService.repack_provider(user_tool_provider)
user_tool_provider.tools = [
ToolTransformService.tool_to_user_tool(
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool.get_tools(user_id, tenant_id)[0],
labels=labels.get(tool.provider_id, []),
tenant_id=tenant_id,
@ -299,7 +299,7 @@ class WorkflowToolManageService:
"icon": json.loads(db_tool.icon),
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool(
"tool": ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool.get_tools(db_tool.tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool),
tenant_id=tenant_id,
@ -329,7 +329,7 @@ class WorkflowToolManageService:
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
return [
ToolTransformService.tool_to_user_tool(
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool.get_tools(db_tool.tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool),
tenant_id=tenant_id,