mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
Merge branch 'feat/queue-based-graph-engine' into feat/rag-2
This commit is contained in:
@ -1,4 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
@ -6,6 +8,8 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
@ -14,9 +18,9 @@ from extensions.ext_database import db
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
|
||||
def admin_required(view):
|
||||
def admin_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
|
||||
@ -87,7 +87,7 @@ class BaseApiKeyListResource(Resource):
|
||||
custom="max_keys_exceeded",
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||
api_token = ApiToken()
|
||||
setattr(api_token, self.resource_id_field, resource_id)
|
||||
api_token.tenant_id = current_user.current_tenant_id
|
||||
|
||||
@ -207,7 +207,7 @@ class InstructionGenerationTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self) -> dict:
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, NoReturn
|
||||
from typing import NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
@ -31,7 +31,7 @@ from services.workflow_service import WorkflowService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||
def _convert_values_to_json_serializable_object(value: Segment):
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
@ -42,8 +42,7 @@ def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||
"""Serialize variable value. If variable is truncated, return the truncated value."""
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
value = variable.get_value()
|
||||
# create a copy of the value to avoid affecting the model cache.
|
||||
value = value.model_copy(deep=True)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import cast
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
import flask_login
|
||||
from flask import jsonify, request
|
||||
@ -15,10 +16,14 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
|
||||
|
||||
from .. import api
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
def oauth_server_client_id_required(view):
|
||||
|
||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
@ -30,18 +35,15 @@ def oauth_server_client_id_required(view):
|
||||
if not oauth_provider_app:
|
||||
raise NotFound("client_id is invalid")
|
||||
|
||||
kwargs["oauth_provider_app"] = oauth_provider_app
|
||||
|
||||
return view(*args, **kwargs)
|
||||
return view(self, oauth_provider_app, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def oauth_server_access_token_required(view):
|
||||
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
oauth_provider_app = kwargs.get("oauth_provider_app")
|
||||
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
|
||||
if not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
raise BadRequest("Invalid oauth_provider_app")
|
||||
|
||||
authorization_header = request.headers.get("Authorization")
|
||||
@ -79,9 +81,7 @@ def oauth_server_access_token_required(view):
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
|
||||
kwargs["account"] = account
|
||||
|
||||
return view(*args, **kwargs)
|
||||
return view(self, oauth_provider_app, account, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@ -17,9 +17,10 @@ class Subscription(Resource):
|
||||
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
||||
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
args = parser.parse_args()
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
|
||||
assert current_user.current_tenant_id is not None
|
||||
return BillingService.get_subscription(
|
||||
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
|
||||
)
|
||||
@ -31,7 +32,9 @@ class Invoices(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
assert current_user.current_tenant_id is not None
|
||||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
||||
|
||||
|
||||
|
||||
@ -477,6 +477,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
data_source_info = document.data_source_info_dict
|
||||
|
||||
if document.data_source_type == "upload_file":
|
||||
if not data_source_info:
|
||||
continue
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
@ -493,6 +495,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
elif document.data_source_type == "notion_import":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
@ -506,6 +510,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif document.data_source_type == "website_crawl":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
|
||||
@ -43,6 +43,8 @@ class ExploreAppMetaApi(InstalledAppResource):
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Get app meta"""
|
||||
app_model = installed_app.app
|
||||
if not app_model:
|
||||
raise ValueError("App not found")
|
||||
return AppService().get_app_meta(app_model)
|
||||
|
||||
|
||||
|
||||
@ -36,6 +36,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
Run workflow
|
||||
"""
|
||||
app_model = installed_app.app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
@ -74,6 +76,8 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||
Stop workflow task
|
||||
"""
|
||||
app_model = installed_app.app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, Optional, ParamSpec, TypeVar
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
@ -13,19 +15,15 @@ from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
def installed_app_required(view=None):
|
||||
def decorator(view):
|
||||
|
||||
def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
|
||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not kwargs.get("installed_app_id"):
|
||||
raise ValueError("missing installed_app_id in path parameters")
|
||||
|
||||
installed_app_id = kwargs.get("installed_app_id")
|
||||
installed_app_id = str(installed_app_id)
|
||||
|
||||
del kwargs["installed_app_id"]
|
||||
|
||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
.where(
|
||||
@ -52,10 +50,10 @@ def installed_app_required(view=None):
|
||||
return decorator
|
||||
|
||||
|
||||
def user_allowed_to_access_app(view=None):
|
||||
def decorator(view):
|
||||
def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
|
||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(installed_app: InstalledApp, *args, **kwargs):
|
||||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||
feature = FeatureService.get_system_features()
|
||||
if feature.webapp_auth.enabled:
|
||||
app_id = installed_app.app_id
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy.orm import Session
|
||||
@ -7,14 +9,17 @@ from werkzeug.exceptions import Forbidden
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def plugin_permission_required(
|
||||
install_required: bool = False,
|
||||
debug_required: bool = False,
|
||||
):
|
||||
def interceptor(view):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
|
||||
@ -2,7 +2,9 @@ import contextlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, request
|
||||
from flask_login import current_user
|
||||
@ -19,10 +21,13 @@ from services.operation_service import OperationService
|
||||
|
||||
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
def account_initialization_required(view):
|
||||
|
||||
def account_initialization_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
# check account initialization
|
||||
account = current_user
|
||||
|
||||
@ -34,9 +39,9 @@ def account_initialization_required(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_cloud(view):
|
||||
def only_edition_cloud(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if dify_config.EDITION != "CLOUD":
|
||||
abort(404)
|
||||
|
||||
@ -45,9 +50,9 @@ def only_edition_cloud(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_enterprise(view):
|
||||
def only_edition_enterprise(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
abort(404)
|
||||
|
||||
@ -56,9 +61,9 @@ def only_edition_enterprise(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_self_hosted(view):
|
||||
def only_edition_self_hosted(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
abort(404)
|
||||
|
||||
@ -67,9 +72,9 @@ def only_edition_self_hosted(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def cloud_edition_billing_enabled(view):
|
||||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if not features.billing.enabled:
|
||||
abort(403, "Billing feature is not enabled.")
|
||||
@ -79,9 +84,9 @@ def cloud_edition_billing_enabled(view):
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str):
|
||||
def interceptor(view):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
@ -120,9 +125,9 @@ def cloud_edition_billing_resource_check(resource: str):
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
def interceptor(view):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
@ -142,9 +147,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check(resource: str):
|
||||
def interceptor(view):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if resource == "knowledge":
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
@ -176,9 +181,9 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_utm_record(view):
|
||||
def cloud_utm_record(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
with contextlib.suppress(Exception):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
@ -194,9 +199,9 @@ def cloud_utm_record(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def setup_required(view):
|
||||
def setup_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
# check setup
|
||||
if (
|
||||
dify_config.EDITION == "SELF_HOSTED"
|
||||
@ -212,9 +217,9 @@ def setup_required(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_license_required(view):
|
||||
def enterprise_license_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
settings = FeatureService.get_system_features()
|
||||
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
|
||||
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
|
||||
@ -224,9 +229,9 @@ def enterprise_license_required(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def email_password_login_enabled(view):
|
||||
def email_password_login_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if features.enable_email_password_login:
|
||||
return view(*args, **kwargs)
|
||||
@ -237,9 +242,9 @@ def email_password_login_enabled(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def enable_change_email(view):
|
||||
def enable_change_email(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if features.enable_change_email:
|
||||
return view(*args, **kwargs)
|
||||
@ -250,9 +255,9 @@ def enable_change_email(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def is_allow_transfer_owner(view):
|
||||
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if features.is_allow_transfer_workspace:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -99,7 +99,7 @@ class MCPAppApi(Resource):
|
||||
|
||||
return mcp_server, app
|
||||
|
||||
def _validate_server_status(self, mcp_server: AppMCPServer) -> None:
|
||||
def _validate_server_status(self, mcp_server: AppMCPServer):
|
||||
"""Validate MCP server status"""
|
||||
if mcp_server.status != AppMCPServerStatus.ACTIVE:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")
|
||||
|
||||
@ -440,7 +440,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate segment belongs to the specified document
|
||||
if segment.document_id != document_id:
|
||||
if str(segment.document_id) != str(document_id):
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
# check child chunk
|
||||
@ -451,7 +451,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate child chunk belongs to the specified segment
|
||||
if child_chunk.segment_id != segment.id:
|
||||
if str(child_chunk.segment_id) != str(segment.id):
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
try:
|
||||
@ -500,7 +500,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate segment belongs to the specified document
|
||||
if segment.document_id != document_id:
|
||||
if str(segment.document_id) != str(document_id):
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# get child chunk
|
||||
@ -511,7 +511,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate child chunk belongs to the specified segment
|
||||
if child_chunk.segment_id != segment.id:
|
||||
if str(child_chunk.segment_id) != str(segment.id):
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate args
|
||||
|
||||
@ -3,7 +3,7 @@ from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
from typing import Optional, ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
@ -22,6 +22,9 @@ from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class WhereisUserArg(StrEnum):
|
||||
"""
|
||||
@ -60,27 +63,6 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
||||
if tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("The workspace's status is archived.")
|
||||
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == api_token.tenant_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.role.in_(["owner"]))
|
||||
.where(Tenant.status == TenantStatus.NORMAL)
|
||||
.one_or_none()
|
||||
) # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.query(Account).where(Account.id == ta.account_id).first()
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
else:
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
|
||||
kwargs["app_model"] = app_model
|
||||
|
||||
if fetch_user_arg:
|
||||
@ -118,8 +100,8 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
def interceptor(view):
|
||||
def decorated(*args, **kwargs):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
|
||||
@ -148,9 +130,9 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
|
||||
def interceptor(view):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
if features.billing.enabled:
|
||||
@ -170,9 +152,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
||||
def interceptor(view):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
|
||||
if resource == "knowledge":
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from datetime import UTC, datetime
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -15,6 +16,9 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def validate_jwt_token(view=None):
|
||||
def decorator(view):
|
||||
|
||||
Reference in New Issue
Block a user