mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
refactor: use libs.login current_user in console controllers (#26745)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
@ -1,5 +1,4 @@
|
|||||||
import flask_restx
|
import flask_restx
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal_with
|
from flask_restx import Resource, fields, marshal_with
|
||||||
from flask_restx._http import HTTPStatus
|
from flask_restx._http import HTTPStatus
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@ -8,7 +7,8 @@ from werkzeug.exceptions import Forbidden
|
|||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
|
from models.account import Account
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from models.model import ApiToken, App
|
from models.model import ApiToken, App
|
||||||
|
|
||||||
@ -57,6 +57,8 @@ class BaseApiKeyListResource(Resource):
|
|||||||
def get(self, resource_id):
|
def get(self, resource_id):
|
||||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||||
keys = db.session.scalars(
|
keys = db.session.scalars(
|
||||||
select(ApiToken).where(
|
select(ApiToken).where(
|
||||||
@ -69,8 +71,10 @@ class BaseApiKeyListResource(Resource):
|
|||||||
def post(self, resource_id):
|
def post(self, resource_id):
|
||||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||||
if not current_user.is_editor:
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
current_key_count = (
|
current_key_count = (
|
||||||
@ -108,6 +112,8 @@ class BaseApiKeyResource(Resource):
|
|||||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
api_key_id = str(api_key_id)
|
api_key_id = str(api_key_id)
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
|
|
||||||
from libs.helper import extract_remote_ip
|
from libs.helper import extract_remote_ip
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
|
from models.account import Account
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
from .. import console_ns
|
from .. import console_ns
|
||||||
@ -17,6 +17,8 @@ class ComplianceApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
def get(self):
|
def get(self):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("doc_name", type=str, required=True, location="args")
|
parser.add_argument("doc_name", type=str, required=True, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import marshal, reqparse
|
from flask_restx import marshal, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
@ -21,6 +19,7 @@ from core.errors.error import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from fields.hit_testing_fields import hit_testing_record_fields
|
from fields.hit_testing_fields import hit_testing_record_fields
|
||||||
|
from libs.login import current_user
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from services.dataset_service import DatasetService
|
from services.dataset_service import DatasetService
|
||||||
from services.hit_testing_service import HitTestingService
|
from services.hit_testing_service import HitTestingService
|
||||||
@ -31,6 +30,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class DatasetsHitTestingBase:
|
class DatasetsHitTestingBase:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_and_validate_dataset(dataset_id: str):
|
def get_and_validate_dataset(dataset_id: str):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
@ -57,11 +57,12 @@ class DatasetsHitTestingBase:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def perform_hit_testing(dataset, args):
|
def perform_hit_testing(dataset, args):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
try:
|
try:
|
||||||
response = HitTestingService.retrieve(
|
response = HitTestingService.retrieve(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
query=args["query"],
|
query=args["query"],
|
||||||
account=cast(Account, current_user),
|
account=current_user,
|
||||||
retrieval_model=args["retrieval_model"],
|
retrieval_model=args["retrieval_model"],
|
||||||
external_retrieval_model=args["external_retrieval_model"],
|
external_retrieval_model=args["external_retrieval_model"],
|
||||||
limit=10,
|
limit=10,
|
||||||
|
|||||||
@ -2,15 +2,15 @@ from collections.abc import Callable
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Concatenate, ParamSpec, TypeVar
|
from typing import Concatenate, ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console.explore.error import AppAccessDeniedError
|
from controllers.console.explore.error import AppAccessDeniedError
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
from models import InstalledApp
|
from models import InstalledApp
|
||||||
|
from models.account import Account
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
@ -24,6 +24,8 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
|
|||||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
installed_app = (
|
installed_app = (
|
||||||
db.session.query(InstalledApp)
|
db.session.query(InstalledApp)
|
||||||
.where(
|
.where(
|
||||||
@ -56,6 +58,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
|||||||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||||
feature = FeatureService.get_system_features()
|
feature = FeatureService.get_system_features()
|
||||||
if feature.webapp_auth.enabled:
|
if feature.webapp_auth.enabled:
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
app_id = installed_app.app_id
|
app_id = installed_app.app_id
|
||||||
app_code = AppService.get_app_code_by_id(app_id)
|
app_code = AppService.get_app_code_by_id(app_id)
|
||||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||||
|
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from fields.api_based_extension_fields import api_based_extension_fields
|
from fields.api_based_extension_fields import api_based_extension_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
|
from models.account import Account
|
||||||
from models.api_based_extension import APIBasedExtension
|
from models.api_based_extension import APIBasedExtension
|
||||||
from services.api_based_extension_service import APIBasedExtensionService
|
from services.api_based_extension_service import APIBasedExtensionService
|
||||||
from services.code_based_extension_service import CodeBasedExtensionService
|
from services.code_based_extension_service import CodeBasedExtensionService
|
||||||
@ -47,6 +47,8 @@ class APIBasedExtensionAPI(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_based_extension_fields)
|
@marshal_with(api_based_extension_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||||
|
|
||||||
@ -68,6 +70,8 @@ class APIBasedExtensionAPI(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_based_extension_fields)
|
@marshal_with(api_based_extension_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
parser.add_argument("name", type=str, required=True, location="json")
|
||||||
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||||
@ -95,6 +99,8 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_based_extension_fields)
|
@marshal_with(api_based_extension_fields)
|
||||||
def get(self, id):
|
def get(self, id):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
api_based_extension_id = str(id)
|
api_based_extension_id = str(id)
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
@ -119,6 +125,8 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_based_extension_fields)
|
@marshal_with(api_based_extension_fields)
|
||||||
def post(self, id):
|
def post(self, id):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
api_based_extension_id = str(id)
|
api_based_extension_id = str(id)
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
@ -146,6 +154,8 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, id):
|
def delete(self, id):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
api_based_extension_id = str(id)
|
api_based_extension_id = str(id)
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource, fields
|
||||||
|
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
|
from models.account import Account
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
from . import api, console_ns
|
from . import api, console_ns
|
||||||
@ -23,6 +23,8 @@ class FeatureApi(Resource):
|
|||||||
@cloud_utm_record
|
@cloud_utm_record
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Get feature configuration for current tenant"""
|
"""Get feature configuration for current tenant"""
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
|
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
|
|
||||||
import services
|
import services
|
||||||
@ -16,6 +14,7 @@ from core.file import helpers as file_helpers
|
|||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||||
|
from libs.login import current_user
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
@ -65,7 +64,8 @@ class RemoteFileUploadApi(Resource):
|
|||||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user = cast(Account, current_user)
|
assert isinstance(current_user, Account)
|
||||||
|
user = current_user
|
||||||
upload_file = FileService(db.engine).upload_file(
|
upload_file = FileService(db.engine).upload_file(
|
||||||
filename=file_info.filename,
|
filename=file_info.filename,
|
||||||
content=content,
|
content=content,
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from fields.tag_fields import dataset_tag_fields
|
from fields.tag_fields import dataset_tag_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
|
from models.account import Account
|
||||||
from models.model import Tag
|
from models.model import Tag
|
||||||
from services.tag_service import TagService
|
from services.tag_service import TagService
|
||||||
|
|
||||||
@ -24,6 +24,8 @@ class TagListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(dataset_tag_fields)
|
@marshal_with(dataset_tag_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
tag_type = request.args.get("type", type=str, default="")
|
tag_type = request.args.get("type", type=str, default="")
|
||||||
keyword = request.args.get("keyword", default=None, type=str)
|
keyword = request.args.get("keyword", default=None, type=str)
|
||||||
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
|
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
|
||||||
@ -34,8 +36,10 @@ class TagListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -59,9 +63,11 @@ class TagUpdateDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def patch(self, tag_id):
|
def patch(self, tag_id):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
tag_id = str(tag_id)
|
tag_id = str(tag_id)
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -81,9 +87,11 @@ class TagUpdateDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, tag_id):
|
def delete(self, tag_id):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
tag_id = str(tag_id)
|
tag_id = str(tag_id)
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
TagService.delete_tag(tag_id)
|
TagService.delete_tag(tag_id)
|
||||||
@ -97,8 +105,10 @@ class TagBindingCreateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -123,8 +133,10 @@ class TagBindingDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource, fields
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
|
from models.account import Account
|
||||||
from services.agent_service import AgentService
|
from services.agent_service import AgentService
|
||||||
|
|
||||||
|
|
||||||
@ -21,7 +21,9 @@ class AgentProviderListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
user = current_user
|
user = current_user
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = user.current_tenant_id
|
||||||
@ -43,7 +45,9 @@ class AgentProviderApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider_name: str):
|
def get(self, provider_name: str):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
user = current_user
|
user = current_user
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = user.current_tenant_id
|
||||||
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
|
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
@ -6,10 +5,18 @@ from controllers.console import api, console_ns
|
|||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
|
from models.account import Account
|
||||||
from services.plugin.endpoint_service import EndpointService
|
from services.plugin.endpoint_service import EndpointService
|
||||||
|
|
||||||
|
|
||||||
|
def _current_account_with_tenant() -> tuple[Account, str]:
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
assert tenant_id is not None
|
||||||
|
return current_user, tenant_id
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/endpoints/create")
|
@console_ns.route("/workspaces/current/endpoints/create")
|
||||||
class EndpointCreateApi(Resource):
|
class EndpointCreateApi(Resource):
|
||||||
@api.doc("create_endpoint")
|
@api.doc("create_endpoint")
|
||||||
@ -34,7 +41,7 @@ class EndpointCreateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = _current_account_with_tenant()
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
@ -51,7 +58,7 @@ class EndpointCreateApi(Resource):
|
|||||||
try:
|
try:
|
||||||
return {
|
return {
|
||||||
"success": EndpointService.create_endpoint(
|
"success": EndpointService.create_endpoint(
|
||||||
tenant_id=user.current_tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
plugin_unique_identifier=plugin_unique_identifier,
|
plugin_unique_identifier=plugin_unique_identifier,
|
||||||
name=name,
|
name=name,
|
||||||
@ -80,7 +87,7 @@ class EndpointListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = _current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("page", type=int, required=True, location="args")
|
parser.add_argument("page", type=int, required=True, location="args")
|
||||||
@ -93,7 +100,7 @@ class EndpointListApi(Resource):
|
|||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
{
|
{
|
||||||
"endpoints": EndpointService.list_endpoints(
|
"endpoints": EndpointService.list_endpoints(
|
||||||
tenant_id=user.current_tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
page=page,
|
page=page,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
@ -123,7 +130,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = _current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("page", type=int, required=True, location="args")
|
parser.add_argument("page", type=int, required=True, location="args")
|
||||||
@ -138,7 +145,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
{
|
{
|
||||||
"endpoints": EndpointService.list_endpoints_for_single_plugin(
|
"endpoints": EndpointService.list_endpoints_for_single_plugin(
|
||||||
tenant_id=user.current_tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
page=page,
|
page=page,
|
||||||
@ -165,7 +172,7 @@ class EndpointDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = _current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("endpoint_id", type=str, required=True)
|
parser.add_argument("endpoint_id", type=str, required=True)
|
||||||
@ -177,9 +184,7 @@ class EndpointDeleteApi(Resource):
|
|||||||
endpoint_id = args["endpoint_id"]
|
endpoint_id = args["endpoint_id"]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": EndpointService.delete_endpoint(
|
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -207,7 +212,7 @@ class EndpointUpdateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = _current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("endpoint_id", type=str, required=True)
|
parser.add_argument("endpoint_id", type=str, required=True)
|
||||||
@ -224,7 +229,7 @@ class EndpointUpdateApi(Resource):
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"success": EndpointService.update_endpoint(
|
"success": EndpointService.update_endpoint(
|
||||||
tenant_id=user.current_tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
endpoint_id=endpoint_id,
|
endpoint_id=endpoint_id,
|
||||||
name=name,
|
name=name,
|
||||||
@ -250,7 +255,7 @@ class EndpointEnableApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = _current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("endpoint_id", type=str, required=True)
|
parser.add_argument("endpoint_id", type=str, required=True)
|
||||||
@ -262,9 +267,7 @@ class EndpointEnableApi(Resource):
|
|||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": EndpointService.enable_endpoint(
|
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -285,7 +288,7 @@ class EndpointDisableApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = _current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("endpoint_id", type=str, required=True)
|
parser.add_argument("endpoint_id", type=str, required=True)
|
||||||
@ -297,7 +300,5 @@ class EndpointDisableApi(Resource):
|
|||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": EndpointService.disable_endpoint(
|
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from urllib import parse
|
from urllib import parse
|
||||||
|
|
||||||
from flask import abort, request
|
from flask import abort, request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
|
|
||||||
import services
|
import services
|
||||||
@ -26,7 +25,7 @@ from controllers.console.wraps import (
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.member_fields import account_with_role_list_fields
|
from fields.member_fields import account_with_role_list_fields
|
||||||
from libs.helper import extract_remote_ip
|
from libs.helper import extract_remote_ip
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
from models.account import Account, TenantAccountRole
|
from models.account import Account, TenantAccountRole
|
||||||
from services.account_service import AccountService, RegisterService, TenantService
|
from services.account_service import AccountService, RegisterService, TenantService
|
||||||
from services.errors.account import AccountAlreadyInTenantError
|
from services.errors.account import AccountAlreadyInTenantError
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from werkzeug.exceptions import Unauthorized
|
from werkzeug.exceptions import Unauthorized
|
||||||
@ -24,7 +23,7 @@ from controllers.console.wraps import (
|
|||||||
)
|
)
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
from models.account import Account, Tenant, TenantStatus
|
from models.account import Account, Tenant, TenantStatus
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|||||||
@ -7,13 +7,13 @@ from functools import wraps
|
|||||||
from typing import ParamSpec, TypeVar
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask import abort, request
|
from flask import abort, request
|
||||||
from flask_login import current_user
|
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from controllers.console.workspace.error import AccountNotInitializedError
|
from controllers.console.workspace.error import AccountNotInitializedError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.account import AccountStatus
|
from libs.login import current_user
|
||||||
|
from models.account import Account, AccountStatus
|
||||||
from models.dataset import RateLimitLog
|
from models.dataset import RateLimitLog
|
||||||
from models.model import DifySetup
|
from models.model import DifySetup
|
||||||
from services.feature_service import FeatureService, LicenseStatus
|
from services.feature_service import FeatureService, LicenseStatus
|
||||||
@ -25,11 +25,16 @@ P = ParamSpec("P")
|
|||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
def _current_account() -> Account:
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
def account_initialization_required(view: Callable[P, R]):
|
def account_initialization_required(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
# check account initialization
|
# check account initialization
|
||||||
account = current_user
|
account = _current_account()
|
||||||
|
|
||||||
if account.status == AccountStatus.UNINITIALIZED:
|
if account.status == AccountStatus.UNINITIALIZED:
|
||||||
raise AccountNotInitializedError()
|
raise AccountNotInitializedError()
|
||||||
@ -75,7 +80,9 @@ def only_edition_self_hosted(view: Callable[P, R]):
|
|||||||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
account = _current_account()
|
||||||
|
assert account.current_tenant_id is not None
|
||||||
|
features = FeatureService.get_features(account.current_tenant_id)
|
||||||
if not features.billing.enabled:
|
if not features.billing.enabled:
|
||||||
abort(403, "Billing feature is not enabled.")
|
abort(403, "Billing feature is not enabled.")
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
@ -87,7 +94,10 @@ def cloud_edition_billing_resource_check(resource: str):
|
|||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
account = _current_account()
|
||||||
|
assert account.current_tenant_id is not None
|
||||||
|
tenant_id = account.current_tenant_id
|
||||||
|
features = FeatureService.get_features(tenant_id)
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
members = features.members
|
members = features.members
|
||||||
apps = features.apps
|
apps = features.apps
|
||||||
@ -128,7 +138,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
|||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
account = _current_account()
|
||||||
|
assert account.current_tenant_id is not None
|
||||||
|
features = FeatureService.get_features(account.current_tenant_id)
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
if resource == "add_segment":
|
if resource == "add_segment":
|
||||||
if features.billing.subscription.plan == "sandbox":
|
if features.billing.subscription.plan == "sandbox":
|
||||||
@ -151,10 +163,13 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
if resource == "knowledge":
|
if resource == "knowledge":
|
||||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
account = _current_account()
|
||||||
|
assert account.current_tenant_id is not None
|
||||||
|
tenant_id = account.current_tenant_id
|
||||||
|
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
|
||||||
if knowledge_rate_limit.enabled:
|
if knowledge_rate_limit.enabled:
|
||||||
current_time = int(time.time() * 1000)
|
current_time = int(time.time() * 1000)
|
||||||
key = f"rate_limit_{current_user.current_tenant_id}"
|
key = f"rate_limit_{tenant_id}"
|
||||||
|
|
||||||
redis_client.zadd(key, {current_time: current_time})
|
redis_client.zadd(key, {current_time: current_time})
|
||||||
|
|
||||||
@ -165,7 +180,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||||||
if request_count > knowledge_rate_limit.limit:
|
if request_count > knowledge_rate_limit.limit:
|
||||||
# add ratelimit record
|
# add ratelimit record
|
||||||
rate_limit_log = RateLimitLog(
|
rate_limit_log = RateLimitLog(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=tenant_id,
|
||||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||||
operation="knowledge",
|
operation="knowledge",
|
||||||
)
|
)
|
||||||
@ -185,14 +200,17 @@ def cloud_utm_record(view: Callable[P, R]):
|
|||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
account = _current_account()
|
||||||
|
assert account.current_tenant_id is not None
|
||||||
|
tenant_id = account.current_tenant_id
|
||||||
|
features = FeatureService.get_features(tenant_id)
|
||||||
|
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
utm_info = request.cookies.get("utm_info")
|
utm_info = request.cookies.get("utm_info")
|
||||||
|
|
||||||
if utm_info:
|
if utm_info:
|
||||||
utm_info_dict: dict = json.loads(utm_info)
|
utm_info_dict: dict = json.loads(utm_info)
|
||||||
OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
|
OperationService.record_utm(tenant_id, utm_info_dict)
|
||||||
|
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
@ -271,7 +289,9 @@ def enable_change_email(view: Callable[P, R]):
|
|||||||
def is_allow_transfer_owner(view: Callable[P, R]):
|
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
account = _current_account()
|
||||||
|
assert account.current_tenant_id is not None
|
||||||
|
features = FeatureService.get_features(account.current_tenant_id)
|
||||||
if features.is_allow_transfer_workspace:
|
if features.is_allow_transfer_workspace:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
@ -284,7 +304,9 @@ def is_allow_transfer_owner(view: Callable[P, R]):
|
|||||||
def knowledge_pipeline_publish_enabled(view):
|
def knowledge_pipeline_publish_enabled(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
account = _current_account()
|
||||||
|
assert account.current_tenant_id is not None
|
||||||
|
features = FeatureService.get_features(account.current_tenant_id)
|
||||||
if features.knowledge_pipeline.publish_enabled:
|
if features.knowledge_pipeline.publish_enabled:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
abort(403)
|
abort(403)
|
||||||
|
|||||||
@ -60,7 +60,7 @@ class TestAccountInitialization:
|
|||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with patch("controllers.console.wraps.current_user", mock_user):
|
with patch("controllers.console.wraps._current_account", return_value=mock_user):
|
||||||
result = protected_view()
|
result = protected_view()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
@ -77,7 +77,7 @@ class TestAccountInitialization:
|
|||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with patch("controllers.console.wraps.current_user", mock_user):
|
with patch("controllers.console.wraps._current_account", return_value=mock_user):
|
||||||
with pytest.raises(AccountNotInitializedError):
|
with pytest.raises(AccountNotInitializedError):
|
||||||
protected_view()
|
protected_view()
|
||||||
|
|
||||||
@ -163,7 +163,7 @@ class TestBillingResourceLimits:
|
|||||||
return "member_added"
|
return "member_added"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with patch("controllers.console.wraps.current_user"):
|
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||||
result = add_member()
|
result = add_member()
|
||||||
|
|
||||||
@ -185,7 +185,7 @@ class TestBillingResourceLimits:
|
|||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with app.test_request_context():
|
with app.test_request_context():
|
||||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||||
with pytest.raises(Exception) as exc_info:
|
with pytest.raises(Exception) as exc_info:
|
||||||
add_member()
|
add_member()
|
||||||
@ -207,7 +207,7 @@ class TestBillingResourceLimits:
|
|||||||
|
|
||||||
# Test 1: Should reject when source is datasets
|
# Test 1: Should reject when source is datasets
|
||||||
with app.test_request_context("/?source=datasets"):
|
with app.test_request_context("/?source=datasets"):
|
||||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||||
with pytest.raises(Exception) as exc_info:
|
with pytest.raises(Exception) as exc_info:
|
||||||
upload_document()
|
upload_document()
|
||||||
@ -215,7 +215,7 @@ class TestBillingResourceLimits:
|
|||||||
|
|
||||||
# Test 2: Should allow when source is not datasets
|
# Test 2: Should allow when source is not datasets
|
||||||
with app.test_request_context("/?source=other"):
|
with app.test_request_context("/?source=other"):
|
||||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||||
result = upload_document()
|
result = upload_document()
|
||||||
assert result == "document_uploaded"
|
assert result == "document_uploaded"
|
||||||
@ -239,7 +239,7 @@ class TestRateLimiting:
|
|||||||
return "knowledge_success"
|
return "knowledge_success"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with patch("controllers.console.wraps.current_user"):
|
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||||
with patch(
|
with patch(
|
||||||
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
||||||
):
|
):
|
||||||
@ -271,7 +271,7 @@ class TestRateLimiting:
|
|||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with app.test_request_context():
|
with app.test_request_context():
|
||||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||||
with patch(
|
with patch(
|
||||||
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
||||||
):
|
):
|
||||||
|
|||||||
Reference in New Issue
Block a user