This commit is contained in:
jyong
2025-06-03 19:02:57 +08:00
parent 309fffd1e4
commit 9cdd2cbb27
35 changed files with 229 additions and 300 deletions

View File

@ -109,8 +109,6 @@ class OAuthDataSourceSync(Resource):
return {"result": "success"}, 200
api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")

View File

@ -4,24 +4,19 @@ from typing import Optional
import requests
from flask import current_app, redirect, request
from flask_login import current_user
from flask_restful import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from werkzeug.exceptions import Unauthorized
from configs import dify_config
from constants.languages import languages
from controllers.console.wraps import account_initialization_required, setup_required
from core.plugin.impl.oauth import OAuthHandler
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import extract_remote_ip
from libs.login import login_required
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account
from models.account import AccountStatus
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountNotFoundError, AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
@ -186,6 +181,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
return account
api.add_resource(OAuthLogin, "/oauth/login/<provider>")
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")

View File

@ -1,12 +1,9 @@
from flask import redirect, request
from flask_login import current_user # type: ignore
from flask_restful import ( # type: ignore
Resource, # type: ignore
marshal_with,
reqparse,
)
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
@ -16,7 +13,6 @@ from controllers.console.wraps import (
setup_required,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.impl.datasource import PluginDatasourceManager
from core.plugin.impl.oauth import OAuthHandler
from extensions.ext_database import db
from libs.login import login_required
@ -33,10 +29,9 @@ class DatasourcePluginOauthApi(Resource):
if not current_user.is_editor:
raise Forbidden()
# get all plugin oauth configs
plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by(
provider=provider,
plugin_id=plugin_id
).first()
plugin_oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
)
if not plugin_oauth_config:
raise NotFound()
oauth_handler = OAuthHandler()
@ -45,24 +40,20 @@ class DatasourcePluginOauthApi(Resource):
if system_credentials:
system_credentials["redirect_url"] = redirect_url
response = oauth_handler.get_authorization_url(
current_user.current_tenant.id,
current_user.id,
plugin_id,
provider,
system_credentials=system_credentials
current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials
)
return response.model_dump()
class DatasourceOauthCallback(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, plugin_id):
oauth_handler = OAuthHandler()
plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by(
provider=provider,
plugin_id=plugin_id
).first()
plugin_oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
)
if not plugin_oauth_config:
raise NotFound()
credentials = oauth_handler.get_credentials(
@ -71,18 +62,16 @@ class DatasourceOauthCallback(Resource):
plugin_id,
provider,
system_credentials=plugin_oauth_config.system_credentials,
request=request
request=request,
)
datasource_provider = DatasourceProvider(
plugin_id=plugin_id,
provider=provider,
auth_type="oauth",
encrypted_credentials=credentials
plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials
)
db.session.add(datasource_provider)
db.session.commit()
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
class DatasourceAuth(Resource):
@setup_required
@login_required
@ -99,28 +88,27 @@ class DatasourceAuth(Resource):
try:
datasource_provider_service.datasource_provider_credentials_validate(
tenant_id=current_user.current_tenant_id,
provider=provider,
plugin_id=plugin_id,
credentials=args["credentials"]
tenant_id=current_user.current_tenant_id,
provider=provider,
plugin_id=plugin_id,
credentials=args["credentials"],
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}, 201
@setup_required
@login_required
@account_initialization_required
def get(self, provider, plugin_id):
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
provider=provider,
plugin_id=plugin_id
tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id
)
return {"result": datasources}, 200
class DatasourceAuthDeleteApi(Resource):
@setup_required
@login_required
@ -130,12 +118,11 @@ class DatasourceAuthDeleteApi(Resource):
raise Forbidden()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id,
provider=provider,
plugin_id=plugin_id
tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id
)
return {"result": "success"}, 200
# Import Rag Pipeline
api.add_resource(
DatasourcePluginOauthApi,
@ -149,4 +136,3 @@ api.add_resource(
DatasourceAuth,
"/auth/datasource/provider/<string:provider>/plugin/<string:plugin_id>",
)

View File

@ -110,6 +110,7 @@ class CustomizedPipelineTemplateApi(Resource):
dsl = yaml.safe_load(template.yaml_content)
return {"data": dsl}, 200
class CustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@ -142,6 +143,7 @@ class CustomizedPipelineTemplateApi(Resource):
RagPipelineService.publish_customized_pipeline_template(pipeline_id, args)
return 200
api.add_resource(
PipelineTemplateListApi,
"/rag/pipeline/templates",

View File

@ -540,7 +540,6 @@ class RagPipelineConfigApi(Resource):
@login_required
@account_initialization_required
def get(self, pipeline_id):
return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
}