chore(refactor): queries in service and auth components

This commit is contained in:
Novice
2025-06-25 14:09:19 +08:00
parent 01922f2d02
commit f783ad68e4
9 changed files with 135 additions and 148 deletions

View File

@ -98,7 +98,7 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
full_state_data.code_verifier,
full_state_data.redirect_uri,
)
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id)
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
provider.save_tokens(tokens)
return full_state_data
@ -275,6 +275,7 @@ def auth(
server_url: str,
authorization_code: Optional[str] = None,
state_param: Optional[str] = None,
for_list: bool = False,
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
metadata = discover_oauth_metadata(server_url)
@ -337,8 +338,8 @@ def auth(
metadata,
client_information,
provider.redirect_url,
provider.provider_id,
provider.tenant_id,
provider.mcp_provider.id,
provider.mcp_provider.tenant_id,
)
provider.save_code_verifier(code_verifier)

View File

@ -7,18 +7,20 @@ from core.mcp.types import (
OAuthClientMetadata,
OAuthTokens,
)
from models.tools import MCPToolProvider
from services.tools.mcp_tools_mange_service import MCPToolManageService
LATEST_PROTOCOL_VERSION = "1.0"
class OAuthClientProvider:
provider_id: str
tenant_id: str
mcp_provider: MCPToolProvider
def __init__(self, provider_id: str, tenant_id: str):
self.provider_id = provider_id
self.tenant_id = tenant_id
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
if for_list:
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
else:
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
@property
def redirect_url(self) -> str:
@ -39,12 +41,7 @@ class OAuthClientProvider:
def client_information(self) -> Optional[OAuthClientInformation]:
"""Loads information about this OAuth client."""
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
if not mcp_provider:
return None
client_information = MCPToolManageService.get_mcp_provider_decrypted_credentials(
self.tenant_id, self.provider_id
).get("client_information", {})
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
if not client_information:
return None
return OAuthClientInformation.model_validate(client_information)
@ -52,15 +49,13 @@ class OAuthClientProvider:
def save_client_information(self, client_information: OAuthClientInformationFull) -> None:
"""Saves client information after dynamic registration."""
MCPToolManageService.update_mcp_provider_credentials(
self.tenant_id, self.provider_id, {"client_information": client_information.model_dump()}
self.mcp_provider,
{"client_information": client_information.model_dump()},
)
def tokens(self) -> Optional[OAuthTokens]:
"""Loads any existing OAuth tokens for the current session."""
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
if not mcp_provider:
return None
credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id)
credentials = self.mcp_provider.decrypted_credentials
if not credentials:
return None
return OAuthTokens(
@ -74,20 +69,13 @@ class OAuthClientProvider:
"""Stores new OAuth tokens for the current session."""
# update mcp provider credentials
token_dict = tokens.model_dump()
MCPToolManageService.update_mcp_provider_credentials(self.tenant_id, self.provider_id, token_dict, authed=True)
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
def save_code_verifier(self, code_verifier: str) -> None:
"""Saves a PKCE code verifier for the current session."""
# update mcp provider credentials
MCPToolManageService.update_mcp_provider_credentials(
self.tenant_id, self.provider_id, {"code_verifier": code_verifier}
)
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
def code_verifier(self) -> str:
"""Loads the PKCE code verifier for the current session."""
# get code verifier from mcp provider credentials
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
if not mcp_provider:
return ""
credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id)
return str(credentials.get("code_verifier", ""))
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))

View File

@ -22,6 +22,7 @@ class MCPClient:
tenant_id: str,
authed: bool = True,
authorization_code: Optional[str] = None,
for_list: bool = False,
):
# Initialize info
self.provider_id = provider_id
@ -35,7 +36,7 @@ class MCPClient:
if authed:
from core.mcp.auth.auth_provider import OAuthClientProvider
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id)
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
self.token = self.provider.tokens()
# Initialize session and client objects