refactor(tool oauth): update api implementation

This commit is contained in:
Harry
2025-06-23 16:51:28 +08:00
parent 7f292dc261
commit 5e7c5863ef
16 changed files with 393 additions and 738 deletions

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.tool_entities import ToolProviderCredentialType, ToolProviderType
class ToolApiEntity(BaseModel):
@ -70,3 +70,14 @@ class ToolProviderApiEntity(BaseModel):
"tools": tools,
"labels": self.labels,
}
class ToolProviderCredentialApiEntity(BaseModel):
id: str = Field(description="The unique id of the credential")
name: str = Field(description="The name of the credential")
provider: str = Field(description="The provider of the credential")
credential_type: ToolProviderCredentialType = Field(description="The type of the credential")
is_default: bool = Field(
default=False, description="Whether the credential is the default credential for the provider in the workspace"
)
credentials: dict = Field(description="The credentials of the provider")

View File

@ -434,3 +434,36 @@ class ToolSelector(BaseModel):
def to_plugin_parameter(self) -> dict[str, Any]:
return self.model_dump()
class ToolProviderCredentialType(enum.StrEnum):
API_KEY = "api_key"
OAUTH2 = "oauth2"
def get_name(self):
if self == ToolProviderCredentialType.API_KEY:
return "API KEY"
elif self == ToolProviderCredentialType.OAUTH2:
return "AUTH"
else:
return self.value.replace("_", " ").upper()
def is_editable(self):
return self == ToolProviderCredentialType.API_KEY
def is_validate_allowed(self):
return self == ToolProviderCredentialType.API_KEY
@classmethod
def values(cls):
return [item.value for item in cls]
@classmethod
def of(cls, credential_type: str) -> "ToolProviderCredentialType":
type_name = credential_type.lower()
if type_name == "api_key":
return cls.API_KEY
elif type_name == "oauth2":
return cls.OAUTH2
else:
raise ValueError(f"Invalid credential type: {credential_type}")

View File

@ -200,6 +200,7 @@ class ToolManager:
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
@ -209,6 +210,7 @@ class ToolManager:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
@ -575,18 +577,27 @@ class ToolManager:
with db.session.no_autoflush:
if "builtin" in filters:
# get builtin providers
def get_builtin_providers(tenant_id):
# according to multi credentials, select the one with is_default=True first, then created_at oldest
# for compatibility with old version
sql = """
SELECT DISTINCT ON (tenant_id, provider) id
FROM tool_builtin_providers
WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
builtin_providers = cls.list_builtin_providers(tenant_id)
# get db builtin providers
db_builtin_providers: list[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
)
# get builtin providers
db_builtin_providers = get_builtin_providers(tenant_id)
# rewrite db_builtin_providers
for db_provider in db_builtin_providers:
tool_provider_id = str(ToolProviderID(db_provider.provider))
db_provider.provider = tool_provider_id
db_provider.provider = str(ToolProviderID(db_provider.provider))
def find_db_builtin_provider(provider):
return next((x for x in db_builtin_providers if x.provider == provider), None)