mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
chore(refactor): queries in service and auth components
This commit is contained in:
@ -98,7 +98,7 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
|
||||
full_state_data.code_verifier,
|
||||
full_state_data.redirect_uri,
|
||||
)
|
||||
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id)
|
||||
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
|
||||
provider.save_tokens(tokens)
|
||||
return full_state_data
|
||||
|
||||
@ -275,6 +275,7 @@ def auth(
|
||||
server_url: str,
|
||||
authorization_code: Optional[str] = None,
|
||||
state_param: Optional[str] = None,
|
||||
for_list: bool = False,
|
||||
) -> dict[str, str]:
|
||||
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
||||
metadata = discover_oauth_metadata(server_url)
|
||||
@ -337,8 +338,8 @@ def auth(
|
||||
metadata,
|
||||
client_information,
|
||||
provider.redirect_url,
|
||||
provider.provider_id,
|
||||
provider.tenant_id,
|
||||
provider.mcp_provider.id,
|
||||
provider.mcp_provider.tenant_id,
|
||||
)
|
||||
|
||||
provider.save_code_verifier(code_verifier)
|
||||
|
||||
@ -7,18 +7,20 @@ from core.mcp.types import (
|
||||
OAuthClientMetadata,
|
||||
OAuthTokens,
|
||||
)
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
|
||||
|
||||
class OAuthClientProvider:
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
mcp_provider: MCPToolProvider
|
||||
|
||||
def __init__(self, provider_id: str, tenant_id: str):
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
|
||||
if for_list:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
else:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
|
||||
|
||||
@property
|
||||
def redirect_url(self) -> str:
|
||||
@ -39,12 +41,7 @@ class OAuthClientProvider:
|
||||
|
||||
def client_information(self) -> Optional[OAuthClientInformation]:
|
||||
"""Loads information about this OAuth client."""
|
||||
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
|
||||
if not mcp_provider:
|
||||
return None
|
||||
client_information = MCPToolManageService.get_mcp_provider_decrypted_credentials(
|
||||
self.tenant_id, self.provider_id
|
||||
).get("client_information", {})
|
||||
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
|
||||
if not client_information:
|
||||
return None
|
||||
return OAuthClientInformation.model_validate(client_information)
|
||||
@ -52,15 +49,13 @@ class OAuthClientProvider:
|
||||
def save_client_information(self, client_information: OAuthClientInformationFull) -> None:
|
||||
"""Saves client information after dynamic registration."""
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
self.tenant_id, self.provider_id, {"client_information": client_information.model_dump()}
|
||||
self.mcp_provider,
|
||||
{"client_information": client_information.model_dump()},
|
||||
)
|
||||
|
||||
def tokens(self) -> Optional[OAuthTokens]:
|
||||
"""Loads any existing OAuth tokens for the current session."""
|
||||
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
|
||||
if not mcp_provider:
|
||||
return None
|
||||
credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id)
|
||||
credentials = self.mcp_provider.decrypted_credentials
|
||||
if not credentials:
|
||||
return None
|
||||
return OAuthTokens(
|
||||
@ -74,20 +69,13 @@ class OAuthClientProvider:
|
||||
"""Stores new OAuth tokens for the current session."""
|
||||
# update mcp provider credentials
|
||||
token_dict = tokens.model_dump()
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.tenant_id, self.provider_id, token_dict, authed=True)
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
|
||||
|
||||
def save_code_verifier(self, code_verifier: str) -> None:
|
||||
"""Saves a PKCE code verifier for the current session."""
|
||||
# update mcp provider credentials
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
self.tenant_id, self.provider_id, {"code_verifier": code_verifier}
|
||||
)
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
|
||||
|
||||
def code_verifier(self) -> str:
|
||||
"""Loads the PKCE code verifier for the current session."""
|
||||
# get code verifier from mcp provider credentials
|
||||
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
|
||||
if not mcp_provider:
|
||||
return ""
|
||||
credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id)
|
||||
return str(credentials.get("code_verifier", ""))
|
||||
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))
|
||||
|
||||
@ -22,6 +22,7 @@ class MCPClient:
|
||||
tenant_id: str,
|
||||
authed: bool = True,
|
||||
authorization_code: Optional[str] = None,
|
||||
for_list: bool = False,
|
||||
):
|
||||
# Initialize info
|
||||
self.provider_id = provider_id
|
||||
@ -35,7 +36,7 @@ class MCPClient:
|
||||
if authed:
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
|
||||
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id)
|
||||
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
|
||||
self.token = self.provider.tokens()
|
||||
|
||||
# Initialize session and client objects
|
||||
|
||||
Reference in New Issue
Block a user