Merge branch 'main' into feat/agent-node-v2

This commit is contained in:
Novice
2025-12-17 15:55:27 +08:00
110 changed files with 12457 additions and 2403 deletions

View File

@ -22,7 +22,12 @@ from controllers.console.error import (
NotAllowedCreateWorkspace,
WorkspacesLimitExceeded,
)
from controllers.console.wraps import email_password_login_enabled, setup_required
from controllers.console.wraps import (
decrypt_code_field,
decrypt_password_field,
email_password_login_enabled,
setup_required,
)
from events.tenant_event import tenant_was_created
from libs.helper import EmailStr, extract_remote_ip
from libs.login import current_account_with_tenant
@ -79,6 +84,7 @@ class LoginApi(Resource):
@setup_required
@email_password_login_enabled
@console_ns.expect(console_ns.models[LoginPayload.__name__])
@decrypt_password_field
def post(self):
"""Authenticate user and login."""
args = LoginPayload.model_validate(console_ns.payload)
@ -218,6 +224,7 @@ class EmailCodeLoginSendEmailApi(Resource):
class EmailCodeLoginApi(Resource):
@setup_required
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
@decrypt_code_field
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)

View File

@ -140,6 +140,18 @@ class DataSourceNotionListApi(Resource):
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str)
datasource_parameters = {}
if datasource_parameters_str:
try:
datasource_parameters = json.loads(datasource_parameters_str)
if not isinstance(datasource_parameters, dict):
raise ValueError("datasource_parameters must be a JSON object.")
except json.JSONDecodeError:
raise ValueError("Invalid datasource_parameters JSON format.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
@ -187,7 +199,7 @@ class DataSourceNotionListApi(Resource):
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=current_user.id,
datasource_parameters={},
datasource_parameters=datasource_parameters,
provider_type=datasource_runtime.datasource_provider_type(),
)
)
@ -218,14 +230,14 @@ class DataSourceNotionListApi(Resource):
@console_ns.route(
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
"/notion/pages/<uuid:page_id>/<string:page_type>/preview",
"/datasets/notion-indexing-estimate",
)
class DataSourceNotionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, workspace_id, page_id, page_type):
def get(self, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
credential_id = request.args.get("credential_id", default=None, type=str)
@ -239,11 +251,10 @@ class DataSourceNotionApi(Resource):
plugin_id="langgenius/notion_datasource",
)
workspace_id = str(workspace_id)
page_id = str(page_id)
extractor = NotionExtractor(
notion_workspace_id=workspace_id,
notion_workspace_id="",
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),

View File

@ -4,7 +4,7 @@ from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request
from flask_restx import Resource, marshal_with # type: ignore
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -975,6 +975,11 @@ class RagPipelineRecommendedPluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, location="args", required=False, default="all")
args = parser.parse_args()
type = args["type"]
rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins()
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
return recommended_plugins

View File

@ -9,10 +9,12 @@ from typing import ParamSpec, TypeVar
from flask import abort, request
from configs import dify_config
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
from controllers.console.workspace.error import AccountNotInitializedError
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.encryption import FieldEncryption
from libs.login import current_account_with_tenant
from models.account import AccountStatus
from models.dataset import RateLimitLog
@ -25,6 +27,14 @@ from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogo
P = ParamSpec("P")
R = TypeVar("R")
# Field names for decryption
FIELD_NAME_PASSWORD = "password"
FIELD_NAME_CODE = "code"
# Error messages for decryption failures
ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
def account_initialization_required(view: Callable[P, R]):
@wraps(view)
@ -419,3 +429,75 @@ def annotation_import_concurrency_limit(view: Callable[P, R]):
return view(*args, **kwargs)
return decorated
def _decrypt_field(field_name: str, error_class: type[Exception], error_message: str) -> None:
"""
Helper to decode a Base64 encoded field in the request payload.
Args:
field_name: Name of the field to decode
error_class: Exception class to raise on decoding failure
error_message: Error message to include in the exception
"""
if not request or not request.is_json:
return
# Get the payload dict - it's cached and mutable
payload = request.get_json()
if not payload or field_name not in payload:
return
encoded_value = payload[field_name]
decoded_value = FieldEncryption.decrypt_field(encoded_value)
# If decoding failed, raise error immediately
if decoded_value is None:
raise error_class(error_message)
# Update payload dict in-place with decoded value
# Since payload is a mutable dict and get_json() returns the cached reference,
# modifying it will affect all subsequent accesses including console_ns.payload
payload[field_name] = decoded_value
def decrypt_password_field(view: Callable[P, R]):
"""
Decorator to decrypt password field in request payload.
Automatically decrypts the 'password' field if encryption is enabled.
If decryption fails, raises AuthenticationFailedError.
Usage:
@decrypt_password_field
def post(self):
args = LoginPayload.model_validate(console_ns.payload)
# args.password is now decrypted
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_decrypt_field(FIELD_NAME_PASSWORD, AuthenticationFailedError, ERROR_MSG_INVALID_ENCRYPTED_DATA)
return view(*args, **kwargs)
return decorated
def decrypt_code_field(view: Callable[P, R]):
"""
Decorator to decrypt verification code field in request payload.
Automatically decrypts the 'code' field if encryption is enabled.
If decryption fails, raises EmailCodeError.
Usage:
@decrypt_code_field
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
# args.code is now decrypted
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_decrypt_field(FIELD_NAME_CODE, EmailCodeError, ERROR_MSG_INVALID_ENCRYPTED_CODE)
return view(*args, **kwargs)
return decorated