refactor(mcp): clean the auth code

This commit is contained in:
Novice
2025-10-23 17:00:02 +08:00
parent 8cf4a0d3ad
commit ffd3a461f6
9 changed files with 521 additions and 781 deletions

View File

@ -30,7 +30,7 @@ from models.provider_ids import ToolProviderID
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService, OAuthDataType
from services.tools.tool_labels_service import ToolLabelsService
from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService
@ -897,10 +897,6 @@ class ToolProviderMCPApi(Resource):
args = parser.parse_args()
user, tenant_id = current_account_with_tenant()
# Validate server URL
if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.")
# Parse and validate models
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
@ -941,15 +937,21 @@ class ToolProviderMCPApi(Resource):
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
args = parser.parse_args()
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
pass
else:
raise ValueError("Server URL is not valid.")
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant()
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
validation_result = None
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
validation_result = service.validate_server_url_change(
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
)
# No need to check for errors here, exceptions will be raised directly
# Step 2: Perform database update in a transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider(
@ -964,6 +966,7 @@ class ToolProviderMCPApi(Resource):
headers=args["headers"],
configuration=configuration,
authentication=authentication,
validation_result=validation_result,
)
return {"result": "success"}
@ -998,47 +1001,49 @@ class ToolMCPAuthApi(Resource):
provider_id = args["provider_id"]
_, tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
with session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
raise ValueError("provider not found")
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
raise ValueError("provider not found")
# Convert to entity
provider_entity = db_provider.to_entity()
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_authentication()
# Convert to entity
provider_entity = db_provider.to_entity()
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_authentication()
# Try to connect without active transaction
# Try to connect without active transaction
try:
# Use MCPClientWithAuthRetry to handle authentication automatically
with MCPClient(
server_url=server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
):
# Create new transaction for update
with session.begin():
service.update_provider_credentials(
provider=db_provider,
credentials=provider_entity.credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError as e:
service = MCPToolManageService(session=session)
try:
# Use MCPClientWithAuthRetry to handle authentication automatically
with MCPClient(
server_url=server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
):
# Create new transaction for update
with session.begin():
service.update_provider_credentials(
provider=db_provider,
credentials=provider_entity.credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError as e:
service = MCPToolManageService(session=session)
try:
return auth(provider_entity, service, args.get("authorization_code"))
except MCPRefreshTokenError as e:
with session.begin():
service.clear_provider_credentials(provider=db_provider)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except MCPError as e:
auth_result = auth(provider_entity, args.get("authorization_code"))
with session.begin():
response = service.execute_auth_actions(auth_result)
return response
except MCPRefreshTokenError as e:
with session.begin():
service.clear_provider_credentials(provider=db_provider)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except MCPError as e:
with session.begin():
service.clear_provider_credentials(provider=db_provider)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
@console_ns.route("/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
@ -1048,7 +1053,7 @@ class ToolMCPDetailApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@ -1062,7 +1067,7 @@ class ToolMCPListAllApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
tools = service.list_providers(tenant_id=tenant_id)
@ -1100,6 +1105,11 @@ class ToolMCPCallbackApi(Resource):
# Create service instance for handle_callback
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
handle_callback(state_key, authorization_code, mcp_service)
# handle_callback now returns state data and tokens
state_data, tokens = handle_callback(state_key, authorization_code)
# Save tokens using the service layer
mcp_service.save_oauth_data(
state_data.provider_id, state_data.tenant_id, tokens.model_dump(), OAuthDataType.TOKENS
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")