mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 22:48:07 +08:00
feat: plugin OAuth with stateful
This commit is contained in:
@ -1,18 +1,27 @@
|
||||
import io
|
||||
|
||||
from flask import send_file
|
||||
from flask import redirect, request, send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_restful import (
|
||||
Resource,
|
||||
reqparse,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.tools.tool_labels_service import ToolLabelsService
|
||||
@ -108,17 +117,19 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
result = BuiltinToolManageService.update_builtin_tool_provider(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
credentials=args["credentials"],
|
||||
credential_id=args["credential_id"],
|
||||
name=args["name"]
|
||||
)
|
||||
session.commit()
|
||||
return result
|
||||
@ -555,9 +566,9 @@ class ToolBuiltinListApi(Resource):
|
||||
[
|
||||
provider.to_dict()
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@ -576,9 +587,9 @@ class ToolApiListApi(Resource):
|
||||
[
|
||||
provider.to_dict()
|
||||
for provider in ApiToolManageService.list_api_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@ -597,9 +608,9 @@ class ToolWorkflowListApi(Resource):
|
||||
[
|
||||
provider.to_dict()
|
||||
for provider in WorkflowToolManageService.list_tenant_workflow_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@ -613,6 +624,121 @@ class ToolLabelsApi(Resource):
|
||||
return jsonable_encoder(ToolLabelsService.list_tool_labels())
|
||||
|
||||
|
||||
class ToolPluginOAuthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
provider = args["provider"]
|
||||
plugin_id = args["plugin_id"]
|
||||
|
||||
# todo check permission
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
# check if user client is configured and enabled then using user client
|
||||
# if user client is not configured then using system client
|
||||
tenant_id = user.current_tenant_id
|
||||
user_id = user.id
|
||||
|
||||
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_provider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
context_id = OAuthProxyService.create_proxy_context(user_id=current_user.id,
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider)
|
||||
# todo decrypt oauth params
|
||||
oauth_params = plugin_oauth_config.oauth_params
|
||||
oauth_params[
|
||||
'redirect_uri'] = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}"
|
||||
|
||||
response = oauth_handler.get_authorization_url(
|
||||
tenant_id,
|
||||
user.id,
|
||||
plugin_id,
|
||||
provider,
|
||||
system_credentials=oauth_params,
|
||||
)
|
||||
return response.model_dump()
|
||||
|
||||
|
||||
class ToolOAuthCallback(Resource):
|
||||
|
||||
@setup_required
|
||||
def get(self):
|
||||
args = (reqparse
|
||||
.RequestParser()
|
||||
.add_argument("context_id", type=str, required=True, nullable=False, location="args")
|
||||
.parse_args()
|
||||
)
|
||||
context_id = args["context_id"]
|
||||
context = OAuthProxyService.use_proxy_context(context_id)
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
user_id, tenant_id, plugin_id, provider = (
|
||||
context.get("user_id"),
|
||||
context.get("tenant_id"),
|
||||
context.get("plugin_id"),
|
||||
context.get("provider"),
|
||||
)
|
||||
oauth_handler = OAuthHandler()
|
||||
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_provider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
oauth_params = plugin_oauth_config.oauth_params
|
||||
oauth_params['redirect_uri'] = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}"
|
||||
|
||||
credentials = oauth_handler.get_credentials(
|
||||
tenant_id,
|
||||
user_id,
|
||||
plugin_id,
|
||||
provider,
|
||||
system_credentials=oauth_params,
|
||||
request=request,
|
||||
)
|
||||
|
||||
if not credentials:
|
||||
raise Exception("no credentials found for this plugin")
|
||||
|
||||
#TODO add credentials to database
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||
|
||||
|
||||
class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
return BuiltinToolManageService.set_default_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
provider=provider,
|
||||
id=args["id"])
|
||||
|
||||
|
||||
# tool oauth
|
||||
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/tool")
|
||||
api.add_resource(ToolOAuthCallback, "/oauth/plugin/tool/callback")
|
||||
|
||||
# tool provider
|
||||
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
|
||||
|
||||
@ -621,6 +747,8 @@ api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-prov
|
||||
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
|
||||
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
||||
api.add_resource(ToolBuiltinProviderSetDefaultApi,
|
||||
"/workspaces/current/tool-provider/builtin/<path:provider>/set-default")
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user