mirror of
https://github.com/langgenius/dify.git
synced 2026-03-29 09:59:59 +08:00
refactor(mcp): clean the auth code
This commit is contained in:
@ -30,7 +30,7 @@ from models.provider_ids import ToolProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService, OAuthDataType
|
||||
from services.tools.tool_labels_service import ToolLabelsService
|
||||
from services.tools.tools_manage_service import ToolCommonService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
@ -897,10 +897,6 @@ class ToolProviderMCPApi(Resource):
|
||||
args = parser.parse_args()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
# Validate server URL
|
||||
if not is_valid_url(args["server_url"]):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
|
||||
# Parse and validate models
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
@ -941,15 +937,21 @@ class ToolProviderMCPApi(Resource):
|
||||
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if not is_valid_url(args["server_url"]):
|
||||
if "[__HIDDEN__]" in args["server_url"]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Server URL is not valid.")
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
|
||||
validation_result = None
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
validation_result = service.validate_server_url_change(
|
||||
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
|
||||
)
|
||||
|
||||
# No need to check for errors here, exceptions will be raised directly
|
||||
|
||||
# Step 2: Perform database update in a transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider(
|
||||
@ -964,6 +966,7 @@ class ToolProviderMCPApi(Resource):
|
||||
headers=args["headers"],
|
||||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
validation_result=validation_result,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
@ -998,47 +1001,49 @@ class ToolMCPAuthApi(Resource):
|
||||
provider_id = args["provider_id"]
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
if not db_provider:
|
||||
raise ValueError("provider not found")
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
if not db_provider:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
# Convert to entity
|
||||
provider_entity = db_provider.to_entity()
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
headers = provider_entity.decrypt_authentication()
|
||||
# Convert to entity
|
||||
provider_entity = db_provider.to_entity()
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
headers = provider_entity.decrypt_authentication()
|
||||
|
||||
# Try to connect without active transaction
|
||||
# Try to connect without active transaction
|
||||
try:
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
with MCPClient(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
):
|
||||
# Create new transaction for update
|
||||
with session.begin():
|
||||
service.update_provider_credentials(
|
||||
provider=db_provider,
|
||||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
service = MCPToolManageService(session=session)
|
||||
try:
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
with MCPClient(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
):
|
||||
# Create new transaction for update
|
||||
with session.begin():
|
||||
service.update_provider_credentials(
|
||||
provider=db_provider,
|
||||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
service = MCPToolManageService(session=session)
|
||||
try:
|
||||
return auth(provider_entity, service, args.get("authorization_code"))
|
||||
except MCPRefreshTokenError as e:
|
||||
with session.begin():
|
||||
service.clear_provider_credentials(provider=db_provider)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except MCPError as e:
|
||||
auth_result = auth(provider_entity, args.get("authorization_code"))
|
||||
with session.begin():
|
||||
response = service.execute_auth_actions(auth_result)
|
||||
return response
|
||||
except MCPRefreshTokenError as e:
|
||||
with session.begin():
|
||||
service.clear_provider_credentials(provider=db_provider)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except MCPError as e:
|
||||
with session.begin():
|
||||
service.clear_provider_credentials(provider=db_provider)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
|
||||
@ -1048,7 +1053,7 @@ class ToolMCPDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
@ -1062,7 +1067,7 @@ class ToolMCPListAllApi(Resource):
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
tools = service.list_providers(tenant_id=tenant_id)
|
||||
|
||||
@ -1100,6 +1105,11 @@ class ToolMCPCallbackApi(Resource):
|
||||
# Create service instance for handle_callback
|
||||
with Session(db.engine) as session, session.begin():
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
handle_callback(state_key, authorization_code, mcp_service)
|
||||
# handle_callback now returns state data and tokens
|
||||
state_data, tokens = handle_callback(state_key, authorization_code)
|
||||
# Save tokens using the service layer
|
||||
mcp_service.save_oauth_data(
|
||||
state_data.provider_id, state_data.tenant_id, tokens.model_dump(), OAuthDataType.TOKENS
|
||||
)
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
Reference in New Issue
Block a user