feat: plugin OAuth with stateful

This commit is contained in:
Harry
2025-06-20 10:34:57 +08:00
parent 366ddb05ae
commit 12c20ec7f6
15 changed files with 809 additions and 72 deletions

View File

@ -1,18 +1,27 @@
import io
from flask import send_file
from flask import redirect, request, send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
from flask_restful import (
Resource,
reqparse,
)
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
)
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value
from libs.login import login_required
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.tool_labels_service import ToolLabelsService
@ -108,17 +117,19 @@ class ToolBuiltinProviderUpdateApi(Resource):
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
with Session(db.engine) as session:
result = BuiltinToolManageService.update_builtin_tool_provider(
session=session,
user_id=user_id,
tenant_id=tenant_id,
provider_name=provider,
credentials=args["credentials"],
credential_id=args["credential_id"],
name=args["name"]
)
session.commit()
return result
@ -555,9 +566,9 @@ class ToolBuiltinListApi(Resource):
[
provider.to_dict()
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
user_id,
tenant_id,
)
]
)
@ -576,9 +587,9 @@ class ToolApiListApi(Resource):
[
provider.to_dict()
for provider in ApiToolManageService.list_api_tools(
user_id,
tenant_id,
)
user_id,
tenant_id,
)
]
)
@ -597,9 +608,9 @@ class ToolWorkflowListApi(Resource):
[
provider.to_dict()
for provider in WorkflowToolManageService.list_tenant_workflow_tools(
user_id,
tenant_id,
)
user_id,
tenant_id,
)
]
)
@ -613,6 +624,121 @@ class ToolLabelsApi(Resource):
return jsonable_encoder(ToolLabelsService.list_tool_labels())
class ToolPluginOAuthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
provider = args["provider"]
plugin_id = args["plugin_id"]
# todo check permission
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
# check if user client is configured and enabled then using user client
# if user client is not configured then using system client
tenant_id = user.current_tenant_id
user_id = user.id
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_provider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
plugin_id=plugin_id,
)
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(user_id=current_user.id,
tenant_id=tenant_id,
plugin_id=plugin_id,
provider=provider)
# todo decrypt oauth params
oauth_params = plugin_oauth_config.oauth_params
oauth_params[
'redirect_uri'] = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}"
response = oauth_handler.get_authorization_url(
tenant_id,
user.id,
plugin_id,
provider,
system_credentials=oauth_params,
)
return response.model_dump()
class ToolOAuthCallback(Resource):
@setup_required
def get(self):
args = (reqparse
.RequestParser()
.add_argument("context_id", type=str, required=True, nullable=False, location="args")
.parse_args()
)
context_id = args["context_id"]
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
user_id, tenant_id, plugin_id, provider = (
context.get("user_id"),
context.get("tenant_id"),
context.get("plugin_id"),
context.get("provider"),
)
oauth_handler = OAuthHandler()
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_provider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
plugin_id=plugin_id,
)
oauth_params = plugin_oauth_config.oauth_params
oauth_params['redirect_uri'] = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}"
credentials = oauth_handler.get_credentials(
tenant_id,
user_id,
plugin_id,
provider,
system_credentials=oauth_params,
request=request,
)
if not credentials:
raise Exception("no credentials found for this plugin")
#TODO add credentials to database
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
class ToolBuiltinProviderSetDefaultApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return BuiltinToolManageService.set_default_provider(
tenant_id=current_user.current_tenant_id,
user_id=current_user.id,
provider=provider,
id=args["id"])
# tool oauth
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/tool")
api.add_resource(ToolOAuthCallback, "/oauth/plugin/tool/callback")
# tool provider
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
@ -621,6 +747,8 @@ api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-prov
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource(ToolBuiltinProviderSetDefaultApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/set-default")
api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
)