mirror of
https://github.com/langgenius/dify.git
synced 2026-03-08 00:55:57 +08:00
feat(oauth): refactor tool provider methods and enhance credential handling
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
import io
|
||||
|
||||
from flask import redirect, request, send_file
|
||||
from flask import make_response, redirect, request, send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import (
|
||||
Resource,
|
||||
@ -17,6 +17,7 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentialType
|
||||
from extensions.ext_database import db
|
||||
@ -127,7 +128,7 @@ class ToolBuiltinProviderAddApi(Resource):
|
||||
return BuiltinToolManageService.add_builtin_tool_provider(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
provider=provider,
|
||||
credentials=args["credentials"],
|
||||
name=args["name"],
|
||||
api_type=ToolProviderCredentialType.of(args["type"]),
|
||||
@ -373,10 +374,11 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider, credential_type):
|
||||
user = current_user
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, credential_type, tenant_id)
|
||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(
|
||||
provider, ToolProviderCredentialType.of(credential_type), tenant_id
|
||||
)
|
||||
|
||||
|
||||
class ToolApiProviderSchemaApi(Resource):
|
||||
@ -613,15 +615,12 @@ class ToolApiListApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
provider.to_dict()
|
||||
for provider in ApiToolManageService.list_api_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
]
|
||||
@ -662,13 +661,10 @@ class ToolPluginOAuthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
provider = args["provider"]
|
||||
plugin_id = args["plugin_id"]
|
||||
def get(self, provider):
|
||||
tool_provider = ToolProviderID(provider)
|
||||
plugin_id = tool_provider.plugin_id
|
||||
provider_name = tool_provider.provider_name
|
||||
|
||||
# todo check permission
|
||||
user = current_user
|
||||
@ -679,63 +675,66 @@ class ToolPluginOAuthApi(Resource):
|
||||
tenant_id = user.current_tenant_id
|
||||
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider
|
||||
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
|
||||
)
|
||||
# todo decrypt oauth params
|
||||
# TODO decrypt oauth params
|
||||
oauth_params = plugin_oauth_config.oauth_params
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}"
|
||||
oauth_params["redirect_uri"] = redirect_uri
|
||||
|
||||
response = oauth_handler.get_authorization_url(
|
||||
tenant_id,
|
||||
user.id,
|
||||
plugin_id,
|
||||
provider,
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_params,
|
||||
)
|
||||
return response.model_dump()
|
||||
response = make_response(jsonable_encoder(authorization_url_response))
|
||||
response.set_cookie(
|
||||
"context_id",
|
||||
context_id,
|
||||
httponly=True,
|
||||
samesite="Lax",
|
||||
max_age=OAuthProxyService.__MAX_AGE__,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class ToolOAuthCallback(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
args = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("context_id", type=str, required=True, nullable=False, location="args")
|
||||
.parse_args()
|
||||
)
|
||||
context_id = args["context_id"]
|
||||
def get(self, provider):
|
||||
context_id = request.cookies.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
|
||||
context = OAuthProxyService.use_proxy_context(context_id)
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
user_id, tenant_id, plugin_id, provider = (
|
||||
context.get("user_id"),
|
||||
context.get("tenant_id"),
|
||||
context.get("plugin_id"),
|
||||
context.get("provider"),
|
||||
)
|
||||
tool_provider = ToolProviderID(provider)
|
||||
plugin_id = tool_provider.plugin_id
|
||||
provider_name = tool_provider.provider_name
|
||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
oauth_params = plugin_oauth_config.oauth_params
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}"
|
||||
oauth_params["redirect_uri"] = redirect_uri
|
||||
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
||||
credentials = oauth_handler.get_credentials(
|
||||
tenant_id,
|
||||
user_id,
|
||||
plugin_id,
|
||||
provider,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_params,
|
||||
request=request,
|
||||
).credentials
|
||||
@ -747,12 +746,11 @@ class ToolOAuthCallback(Resource):
|
||||
BuiltinToolManageService.add_builtin_tool_provider(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
provider=provider,
|
||||
credentials=dict(credentials),
|
||||
name=provider,
|
||||
api_type=ToolProviderCredentialType.OAUTH2,
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth/plugin/{provider}/tool/success")
|
||||
|
||||
|
||||
class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@ -768,9 +766,41 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
class ToolOAuthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return BuiltinToolManageService.setup_oauth_custom_client(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider=provider,
|
||||
client_params=args["client_params"],
|
||||
)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||
tenant_id=current_user.current_tenant_id, provider_name=provider
|
||||
)
|
||||
|
||||
|
||||
# tool oauth
|
||||
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/tool")
|
||||
api.add_resource(ToolOAuthCallback, "/oauth/plugin/tool/callback")
|
||||
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/<path:provider>/tool/authorization-url")
|
||||
api.add_resource(ToolOAuthCallback, "/oauth/plugin/<path:provider>/tool/callback")
|
||||
|
||||
api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
|
||||
|
||||
# tool provider
|
||||
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
|
||||
@ -782,14 +812,14 @@ api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/b
|
||||
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/set-default"
|
||||
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/default-credential"
|
||||
)
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
|
||||
)
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderCredentialsSchemaApi,
|
||||
"/workspaces/current/tool-provider/builtin/<path:provider>/<path:credential_type>/credentials_schema",
|
||||
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema/<path:credential_type>",
|
||||
)
|
||||
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user