feat: refactor OAuth provider handling and improve provider name generation

This commit is contained in:
Harry
2025-07-18 12:47:32 +08:00
parent 9f2a9ad271
commit 0ac5c0bf3e
4 changed files with 256 additions and 155 deletions

View File

@ -1,6 +1,5 @@
import json
import logging
import re
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional
@ -11,6 +10,7 @@ from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.helper.provider_name_generator import generate_provider_name
from core.plugin.entities.plugin import ToolProviderID
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
@ -299,42 +299,18 @@ class BuiltinToolManageService:
def generate_builtin_tool_provider_name(
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
) -> str:
try:
db_providers = (
session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type.value,
)
.order_by(BuiltinToolProvider.created_at.desc())
.all()
db_providers = (
session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type.value,
)
# Get the default name pattern
default_pattern = f"{credential_type.get_name()}"
# Find all names that match the default pattern: "{default_pattern} {number}"
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for db_provider in db_providers:
if db_provider.name:
match = re.match(pattern, db_provider.name.strip())
if match:
numbers.append(int(match.group(1)))
# If no default pattern names found, start with 1
if not numbers:
return f"{default_pattern} 1"
# Find the next number
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"
except Exception as e:
logger.warning(f"Error generating next provider name for {provider}: {str(e)}")
# fallback
return f"{credential_type.get_name()} 1"
.order_by(BuiltinToolProvider.created_at.desc())
.all()
)
return generate_provider_name(db_providers, credential_type, f"builtin tool provider {provider}")
@staticmethod
def get_builtin_tool_provider_credentials(