chore: fix review issues

This commit is contained in:
Novice
2025-10-14 20:36:13 +08:00
parent d5a7a537e5
commit 5c6a2af448
11 changed files with 296 additions and 257 deletions

View File

@ -16,7 +16,7 @@ from controllers.console.wraps import (
enterprise_license_required,
setup_required,
)
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPSupportGrantType
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
@ -44,7 +44,9 @@ def is_valid_url(url: str) -> bool:
try:
parsed = urlparse(url)
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
except Exception:
except (ValueError, TypeError):
# ValueError: Invalid URL format
# TypeError: url is not a string
return False
@ -886,7 +888,7 @@ class ToolProviderMCPApi(Resource):
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
# Create provider
with Session(db.engine) as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
result = service.create_provider(
tenant_id=tenant_id,
@ -897,14 +899,10 @@ class ToolProviderMCPApi(Resource):
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
timeout=configuration.timeout,
sse_read_timeout=configuration.sse_read_timeout,
headers=args["headers"],
client_id=authentication.client_id if authentication else None,
client_secret=authentication.client_secret if authentication else None,
grant_type=authentication.grant_type if authentication else MCPSupportGrantType.AUTHORIZATION_CODE,
configuration=configuration,
authentication=authentication,
)
session.commit()
return jsonable_encoder(result)
@setup_required
@ -932,7 +930,7 @@ class ToolProviderMCPApi(Resource):
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider(
tenant_id=current_tenant_id,
@ -943,14 +941,10 @@ class ToolProviderMCPApi(Resource):
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
timeout=configuration.timeout,
sse_read_timeout=configuration.sse_read_timeout,
headers=args["headers"],
client_id=authentication.client_id if authentication else None,
client_secret=authentication.client_secret if authentication else None,
grant_type=authentication.grant_type if authentication else MCPSupportGrantType.AUTHORIZATION_CODE,
configuration=configuration,
authentication=authentication,
)
session.commit()
return {"result": "success"}
@setup_required
@ -962,10 +956,9 @@ class ToolProviderMCPApi(Resource):
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id.current_tenant_id, provider_id=args["provider_id"])
session.commit()
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
@ -983,23 +976,18 @@ class ToolMCPAuthApi(Resource):
_, tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
raise ValueError("provider not found")
with session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
raise ValueError("provider not found")
# Convert to entity
provider_entity = db_provider.to_entity()
server_url = provider_entity.decrypt_server_url()
# Convert to entity
provider_entity = db_provider.to_entity()
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_authentication()
# Option 1: if headers is provided, use it and don't need to get token
headers = provider_entity.decrypt_headers()
# Option 2: Add OAuth token if authed and no headers provided
if not provider_entity.headers and provider_entity.authed:
token = provider_entity.retrieve_tokens()
if token:
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
# Try to connect without active transaction
try:
# Use MCPClientWithAuthRetry to handle authentication automatically
with MCPClient(
@ -1008,18 +996,20 @@ class ToolMCPAuthApi(Resource):
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
):
service.update_provider_credentials(
provider=db_provider,
credentials=provider_entity.credentials,
authed=True,
)
session.commit()
# Create new transaction for update
with session.begin():
service.update_provider_credentials(
provider=db_provider,
credentials=provider_entity.credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError as e:
service = MCPToolManageService(session=session)
return auth(provider_entity, service, args.get("authorization_code"))
except MCPError as e:
service.clear_provider_credentials(provider=db_provider)
session.commit()
with session.begin():
service.clear_provider_credentials(provider=db_provider)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
@ -1044,7 +1034,7 @@ class ToolMCPListAllApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
service = MCPToolManageService(session=session)
tools = service.list_providers(tenant_id=tenant_id)
@ -1058,7 +1048,7 @@ class ToolMCPUpdateApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
tools = service.list_provider_tools(
tenant_id=tenant_id,
@ -1078,9 +1068,8 @@ class ToolMCPCallbackApi(Resource):
authorization_code = args["code"]
# Create service instance for handle_callback
with Session(db.engine) as session:
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
handle_callback(state_key, authorization_code, mcp_service)
session.commit()
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")