mirror of
https://github.com/langgenius/dify.git
synced 2026-02-27 21:17:13 +08:00
Signed-off-by: majiayu000 <1835304752@qq.com> Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Signed-off-by: -LAN- <laipz8200@outlook.com> Signed-off-by: yihong0618 <zouzou0208@gmail.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com> Co-authored-by: 盐粒 Yanli <yanli@dify.ai> Co-authored-by: wangxiaolei <fatelei@gmail.com> Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cursx <33718736+Cursx@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: lif <1835304752@qq.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: fenglin <790872612@qq.com> Co-authored-by: qiaofenglin <qiaofenglin@baidu.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: TomoOkuyama <49631611+TomoOkuyama@users.noreply.github.com> Co-authored-by: Tomo Okuyama <tomo.okuyama@intersystems.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: zyssyz123 <916125788@qq.com> Co-authored-by: hj24 <mambahj24@gmail.com> Co-authored-by: Coding On Star <447357187@qq.com> Co-authored-by: CodingOnStar <hanxujiang@dify.ai> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: Xiangxuan Qu <fghpdf@outlook.com> Co-authored-by: fghpdf <fghpdf@users.noreply.github.com> Co-authored-by: coopercoder <whitetiger0127@163.com> Co-authored-by: zhaiguangpeng <zhaiguangpeng@didiglobal.com> Co-authored-by: Junyan Qin (Chin) <rockchinq@gmail.com> Co-authored-by: E.G <146701565+GlobalStar117@users.noreply.github.com> Co-authored-by: GlobalStar117 <GlobalStar117@users.noreply.github.com> Co-authored-by: Claude Haiku 4.5 <noreply@anthropic.com> Co-authored-by: CodingOnStar <hanxujiang@dify.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: heyszt <270985384@qq.com> Co-authored-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Co-authored-by: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: moonpanda <chuanzegao@163.com> Co-authored-by: warlocgao <warlocgao@tencent.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: KVOJJJin <jzongcode@gmail.com> Co-authored-by: eux <euxx@users.noreply.github.com> Co-authored-by: bangjiehan <bangjiehan@gmail.com> Co-authored-by: FFXN <31929997+FFXN@users.noreply.github.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: Nie Ronghua <nieronghua@sf-express.com> Co-authored-by: JQSevenMiao <141806521+JQSevenMiao@users.noreply.github.com> Co-authored-by: jiasiqi <jiasiqi3@tal.com> Co-authored-by: Seokrin Taron Sung <sungsjade@gmail.com> Co-authored-by: CrabSAMA <40541269+CrabSAMA@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: yessenia <yessenia.contact@gmail.com> Co-authored-by: Jax <anobaka@qq.com> Co-authored-by: niveshdandyan <155956228+niveshdandyan@users.noreply.github.com> Co-authored-by: OSS Contributor <oss-contributor@example.com> Co-authored-by: niveshdandyan <niveshdandyan@users.noreply.github.com> Co-authored-by: Sean Kenneth Doherty <Smaster7772@gmail.com>
212 lines
7.5 KiB
Python
212 lines
7.5 KiB
Python
from collections.abc import Callable
|
|
from functools import wraps
|
|
from typing import Concatenate, ParamSpec, TypeVar
|
|
|
|
from flask import jsonify, request
|
|
from flask_restx import Resource
|
|
from pydantic import BaseModel
|
|
from werkzeug.exceptions import BadRequest, NotFound
|
|
|
|
from controllers.console.wraps import account_initialization_required, setup_required
|
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
from libs.login import current_account_with_tenant, login_required
|
|
from models import Account
|
|
from models.model import OAuthProviderApp
|
|
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
|
|
|
from .. import console_ns
|
|
|
|
P = ParamSpec("P")
|
|
R = TypeVar("R")
|
|
T = TypeVar("T")
|
|
|
|
|
|
class OAuthClientPayload(BaseModel):
|
|
client_id: str
|
|
|
|
|
|
class OAuthProviderRequest(BaseModel):
|
|
client_id: str
|
|
redirect_uri: str
|
|
|
|
|
|
class OAuthTokenRequest(BaseModel):
|
|
client_id: str
|
|
grant_type: str
|
|
code: str | None = None
|
|
client_secret: str | None = None
|
|
redirect_uri: str | None = None
|
|
refresh_token: str | None = None
|
|
|
|
|
|
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
|
@wraps(view)
|
|
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
|
json_data = request.get_json()
|
|
if json_data is None:
|
|
raise BadRequest("client_id is required")
|
|
|
|
payload = OAuthClientPayload.model_validate(json_data)
|
|
client_id = payload.client_id
|
|
|
|
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
|
|
if not oauth_provider_app:
|
|
raise NotFound("client_id is invalid")
|
|
|
|
return view(self, oauth_provider_app, *args, **kwargs)
|
|
|
|
return decorated
|
|
|
|
|
|
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
|
|
@wraps(view)
|
|
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
|
|
if not isinstance(oauth_provider_app, OAuthProviderApp):
|
|
raise BadRequest("Invalid oauth_provider_app")
|
|
|
|
authorization_header = request.headers.get("Authorization")
|
|
if not authorization_header:
|
|
response = jsonify({"error": "Authorization header is required"})
|
|
response.status_code = 401
|
|
response.headers["WWW-Authenticate"] = "Bearer"
|
|
return response
|
|
|
|
parts = authorization_header.strip().split(None, 1)
|
|
if len(parts) != 2:
|
|
response = jsonify({"error": "Invalid Authorization header format"})
|
|
response.status_code = 401
|
|
response.headers["WWW-Authenticate"] = "Bearer"
|
|
return response
|
|
|
|
token_type = parts[0].strip()
|
|
if token_type.lower() != "bearer":
|
|
response = jsonify({"error": "token_type is invalid"})
|
|
response.status_code = 401
|
|
response.headers["WWW-Authenticate"] = "Bearer"
|
|
return response
|
|
|
|
access_token = parts[1].strip()
|
|
if not access_token:
|
|
response = jsonify({"error": "access_token is required"})
|
|
response.status_code = 401
|
|
response.headers["WWW-Authenticate"] = "Bearer"
|
|
return response
|
|
|
|
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
|
|
if not account:
|
|
response = jsonify({"error": "access_token or client_id is invalid"})
|
|
response.status_code = 401
|
|
response.headers["WWW-Authenticate"] = "Bearer"
|
|
return response
|
|
|
|
return view(self, oauth_provider_app, account, *args, **kwargs)
|
|
|
|
return decorated
|
|
|
|
|
|
@console_ns.route("/oauth/provider")
|
|
class OAuthServerAppApi(Resource):
|
|
@setup_required
|
|
@oauth_server_client_id_required
|
|
def post(self, oauth_provider_app: OAuthProviderApp):
|
|
payload = OAuthProviderRequest.model_validate(request.get_json())
|
|
redirect_uri = payload.redirect_uri
|
|
|
|
# check if redirect_uri is valid
|
|
if redirect_uri not in oauth_provider_app.redirect_uris:
|
|
raise BadRequest("redirect_uri is invalid")
|
|
|
|
return jsonable_encoder(
|
|
{
|
|
"app_icon": oauth_provider_app.app_icon,
|
|
"app_label": oauth_provider_app.app_label,
|
|
"scope": oauth_provider_app.scope,
|
|
}
|
|
)
|
|
|
|
|
|
@console_ns.route("/oauth/provider/authorize")
|
|
class OAuthServerUserAuthorizeApi(Resource):
|
|
@setup_required
|
|
@login_required
|
|
@account_initialization_required
|
|
@oauth_server_client_id_required
|
|
def post(self, oauth_provider_app: OAuthProviderApp):
|
|
current_user, _ = current_account_with_tenant()
|
|
account = current_user
|
|
user_account_id = account.id
|
|
|
|
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
|
return jsonable_encoder(
|
|
{
|
|
"code": code,
|
|
}
|
|
)
|
|
|
|
|
|
@console_ns.route("/oauth/provider/token")
|
|
class OAuthServerUserTokenApi(Resource):
|
|
@setup_required
|
|
@oauth_server_client_id_required
|
|
def post(self, oauth_provider_app: OAuthProviderApp):
|
|
payload = OAuthTokenRequest.model_validate(request.get_json())
|
|
|
|
try:
|
|
grant_type = OAuthGrantType(payload.grant_type)
|
|
except ValueError:
|
|
raise BadRequest("invalid grant_type")
|
|
match grant_type:
|
|
case OAuthGrantType.AUTHORIZATION_CODE:
|
|
if not payload.code:
|
|
raise BadRequest("code is required")
|
|
|
|
if payload.client_secret != oauth_provider_app.client_secret:
|
|
raise BadRequest("client_secret is invalid")
|
|
|
|
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
|
|
raise BadRequest("redirect_uri is invalid")
|
|
|
|
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
|
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
|
|
)
|
|
return jsonable_encoder(
|
|
{
|
|
"access_token": access_token,
|
|
"token_type": "Bearer",
|
|
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
|
"refresh_token": refresh_token,
|
|
}
|
|
)
|
|
case OAuthGrantType.REFRESH_TOKEN:
|
|
if not payload.refresh_token:
|
|
raise BadRequest("refresh_token is required")
|
|
|
|
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
|
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
|
|
)
|
|
return jsonable_encoder(
|
|
{
|
|
"access_token": access_token,
|
|
"token_type": "Bearer",
|
|
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
|
"refresh_token": refresh_token,
|
|
}
|
|
)
|
|
|
|
|
|
@console_ns.route("/oauth/provider/account")
|
|
class OAuthServerUserAccountApi(Resource):
|
|
@setup_required
|
|
@oauth_server_client_id_required
|
|
@oauth_server_access_token_required
|
|
def post(self, oauth_provider_app: OAuthProviderApp, account: Account):
|
|
return jsonable_encoder(
|
|
{
|
|
"name": account.name,
|
|
"email": account.email,
|
|
"avatar": account.avatar,
|
|
"interface_language": account.interface_language,
|
|
"timezone": account.timezone,
|
|
}
|
|
)
|