Merge remote-tracking branch 'origin/main' into feat/trigger

This commit is contained in:
lyzno1
2025-10-15 09:53:03 +08:00
167 changed files with 4679 additions and 2534 deletions

View File

@ -1,6 +1,7 @@
#!/bin/bash #!/bin/bash
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml

View File

@ -200,6 +200,11 @@ class PluginConfig(BaseSettings):
default="plugin-api-key", default="plugin-api-key",
) )
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
default=300.0,
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key") INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
PLUGIN_REMOTE_INSTALL_HOST: str = Field( PLUGIN_REMOTE_INSTALL_HOST: str = Field(

View File

@ -7,8 +7,7 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField from libs.helper import TimestampField
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from models.dataset import Dataset from models.dataset import Dataset
from models.model import ApiToken, App from models.model import ApiToken, App
@ -57,9 +56,9 @@ class BaseApiKeyListResource(Resource):
def get(self, resource_id): def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _get_resource(resource_id, current_tenant_id, self.resource_model)
keys = db.session.scalars( keys = db.session.scalars(
select(ApiToken).where( select(ApiToken).where(
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
@ -71,9 +70,8 @@ class BaseApiKeyListResource(Resource):
def post(self, resource_id): def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None _get_resource(resource_id, current_tenant_id, self.resource_model)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.has_edit_permission: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
@ -93,7 +91,7 @@ class BaseApiKeyListResource(Resource):
key = ApiToken.generate_api_key(self.token_prefix or "", 24) key = ApiToken.generate_api_key(self.token_prefix or "", 24)
api_token = ApiToken() api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id) setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id api_token.tenant_id = current_tenant_id
api_token.token = key api_token.token = key
api_token.type = self.resource_type api_token.type = self.resource_type
db.session.add(api_token) db.session.add(api_token)
@ -112,9 +110,8 @@ class BaseApiKeyResource(Resource):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
api_key_id = str(api_key_id) api_key_id = str(api_key_id)
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None _get_resource(resource_id, current_tenant_id, self.resource_model)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
@ -158,11 +155,6 @@ class AppApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for an app""" """Create a new API key for an app"""
return super().post(resource_id) return super().post(resource_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app" resource_type = "app"
resource_model = App resource_model = App
resource_id_field = "app_id" resource_id_field = "app_id"
@ -179,11 +171,6 @@ class AppApiKeyResource(BaseApiKeyResource):
"""Delete an API key for an app""" """Delete an API key for an app"""
return super().delete(resource_id, api_key_id) return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app" resource_type = "app"
resource_model = App resource_model = App
resource_id_field = "app_id" resource_id_field = "app_id"
@ -208,11 +195,6 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for a dataset""" """Create a new API key for a dataset"""
return super().post(resource_id) return super().post(resource_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset" resource_type = "dataset"
resource_model = Dataset resource_model = Dataset
resource_id_field = "dataset_id" resource_id_field = "dataset_id"
@ -229,11 +211,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
"""Delete an API key for a dataset""" """Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id) return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset" resource_type = "dataset"
resource_model = Dataset resource_model = Dataset
resource_id_field = "dataset_id" resource_id_field = "dataset_id"

View File

@ -1,7 +1,6 @@
from typing import Literal from typing import Literal
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -17,7 +16,7 @@ from fields.annotation_fields import (
annotation_fields, annotation_fields,
annotation_hit_history_fields, annotation_hit_history_fields,
) )
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
@ -43,7 +42,9 @@ class AnnotationReplyActionApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
def post(self, app_id, action: Literal["enable", "disable"]): def post(self, app_id, action: Literal["enable", "disable"]):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
@ -70,7 +71,9 @@ class AppAnnotationSettingDetailApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
@ -99,7 +102,9 @@ class AppAnnotationSettingUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, app_id, annotation_setting_id): def post(self, app_id, annotation_setting_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
@ -125,7 +130,9 @@ class AnnotationReplyActionStatusApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
def get(self, app_id, job_id, action): def get(self, app_id, job_id, action):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
job_id = str(job_id) job_id = str(job_id)
@ -160,7 +167,9 @@ class AnnotationApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
@ -199,7 +208,9 @@ class AnnotationApi(Resource):
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
def post(self, app_id): def post(self, app_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
@ -214,7 +225,9 @@ class AnnotationApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, app_id): def delete(self, app_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
@ -250,7 +263,9 @@ class AnnotationExportApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
@ -273,7 +288,9 @@ class AnnotationUpdateDeleteApi(Resource):
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
def post(self, app_id, annotation_id): def post(self, app_id, annotation_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
@ -289,7 +306,9 @@ class AnnotationUpdateDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, app_id, annotation_id): def delete(self, app_id, annotation_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
@ -311,7 +330,9 @@ class AnnotationBatchImportApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
def post(self, app_id): def post(self, app_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
@ -342,7 +363,9 @@ class AnnotationBatchImportStatusApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
def get(self, app_id, job_id): def get(self, app_id, job_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
job_id = str(job_id) job_id = str(job_id)
@ -377,7 +400,9 @@ class AnnotationHitHistoryListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id, annotation_id): def get(self, app_id, annotation_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)

View File

@ -1,6 +1,3 @@
from typing import cast
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -13,8 +10,7 @@ from controllers.console.wraps import (
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Account
from models.model import App from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus from services.app_dsl_service import AppDslService, ImportStatus
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
@ -32,7 +28,8 @@ class AppImportApi(Resource):
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check("apps")
def post(self): def post(self):
# Check user role first # Check user role first
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -51,7 +48,7 @@ class AppImportApi(Resource):
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = AppDslService(session) import_service = AppDslService(session)
# Import app # Import app
account = cast(Account, current_user) account = current_user
result = import_service.import_app( result = import_service.import_app(
account=account, account=account,
import_mode=args["mode"], import_mode=args["mode"],
@ -85,14 +82,15 @@ class AppImportConfirmApi(Resource):
@marshal_with(app_import_fields) @marshal_with(app_import_fields)
def post(self, import_id): def post(self, import_id):
# Check user role first # Check user role first
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = AppDslService(session) import_service = AppDslService(session)
# Confirm import # Confirm import
account = cast(Account, current_user) account = current_user
result = import_service.confirm_import(import_id=import_id, account=account) result = import_service.confirm_import(import_id=import_id, account=account)
session.commit() session.commit()
@ -110,7 +108,8 @@ class AppImportCheckDependenciesApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(app_import_check_dependencies_fields) @marshal_with(app_import_check_dependencies_fields)
def get(self, app_model: App): def get(self, app_model: App):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
with Session(db.engine) as session: with Session(db.engine) as session:

View File

@ -1,6 +1,5 @@
from collections.abc import Sequence from collections.abc import Sequence
from flask_login import current_user
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns from controllers.console import api, console_ns
@ -18,7 +17,7 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP
from core.llm_generator.llm_generator import LLMGenerator from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import App from models import App
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
@ -49,11 +48,11 @@ class RuleGenerateApi(Resource):
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try: try:
rules = LLMGenerator.generate_rule_config( rules = LLMGenerator.generate_rule_config(
tenant_id=account.current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
no_variable=args["no_variable"], no_variable=args["no_variable"],
@ -100,11 +99,11 @@ class RuleCodeGenerateApi(Resource):
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json") parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try: try:
code_result = LLMGenerator.generate_code( code_result = LLMGenerator.generate_code(
tenant_id=account.current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
code_language=args["code_language"], code_language=args["code_language"],
@ -145,11 +144,11 @@ class RuleStructuredOutputGenerateApi(Resource):
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try: try:
structured_output = LLMGenerator.generate_structured_output( structured_output = LLMGenerator.generate_structured_output(
tenant_id=account.current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
) )
@ -199,6 +198,7 @@ class InstructionGenerateApi(Resource):
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("ideal_output", type=str, required=False, default="", location="json") parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next( code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args["language"])), None (p for p in providers if p.is_accept_language(args["language"])), None
@ -221,21 +221,21 @@ class InstructionGenerateApi(Resource):
match node_type: match node_type:
case "llm": case "llm":
return LLMGenerator.generate_rule_config( return LLMGenerator.generate_rule_config(
current_user.current_tenant_id, current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
no_variable=True, no_variable=True,
) )
case "agent": case "agent":
return LLMGenerator.generate_rule_config( return LLMGenerator.generate_rule_config(
current_user.current_tenant_id, current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
no_variable=True, no_variable=True,
) )
case "code": case "code":
return LLMGenerator.generate_code( return LLMGenerator.generate_code(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
code_language=args["language"], code_language=args["language"],
@ -244,7 +244,7 @@ class InstructionGenerateApi(Resource):
return {"error": f"invalid node type: {node_type}"} return {"error": f"invalid node type: {node_type}"}
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy( return LLMGenerator.instruction_modify_legacy(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
flow_id=args["flow_id"], flow_id=args["flow_id"],
current=args["current"], current=args["current"],
instruction=args["instruction"], instruction=args["instruction"],
@ -253,7 +253,7 @@ class InstructionGenerateApi(Resource):
) )
if args["node_id"] != "" and args["current"] != "": # For workflow node if args["node_id"] != "" and args["current"] != "": # For workflow node
return LLMGenerator.instruction_modify_workflow( return LLMGenerator.instruction_modify_workflow(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
flow_id=args["flow_id"], flow_id=args["flow_id"],
node_id=args["node_id"], node_id=args["node_id"],
current=args["current"], current=args["current"],

View File

@ -2,7 +2,6 @@ import json
from typing import cast from typing import cast
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -15,8 +14,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated from events.app_event import app_model_config_was_updated
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from models.model import AppMode, AppModelConfig from models.model import AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService from services.app_model_config_service import AppModelConfigService
@ -54,16 +52,14 @@ class ModelConfigResource(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model): def post(self, app_model):
"""Modify app model config""" """Modify app model config"""
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise Forbidden()
if not current_user.has_edit_permission: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
# validate config # validate config
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
config=cast(dict, request.json), config=cast(dict, request.json),
app_mode=AppMode.value_of(app_model.mode), app_mode=AppMode.value_of(app_model.mode),
) )
@ -95,12 +91,12 @@ class ModelConfigResource(Resource):
# get tool # get tool
try: try:
tool_runtime = ToolManager.get_agent_tool_runtime( tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
app_id=app_model.id, app_id=app_model.id,
agent_tool=agent_tool_entity, agent_tool=agent_tool_entity,
) )
manager = ToolParameterConfigurationManager( manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
tool_runtime=tool_runtime, tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id, provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type, provider_type=agent_tool_entity.provider_type,
@ -134,7 +130,7 @@ class ModelConfigResource(Resource):
else: else:
try: try:
tool_runtime = ToolManager.get_agent_tool_runtime( tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
app_id=app_model.id, app_id=app_model.id,
agent_tool=agent_tool_entity, agent_tool=agent_tool_entity,
) )
@ -142,7 +138,7 @@ class ModelConfigResource(Resource):
continue continue
manager = ToolParameterConfigurationManager( manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
tool_runtime=tool_runtime, tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id, provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type, provider_type=agent_tool_entity.provider_type,

View File

@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -9,7 +8,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_site_fields from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Account, Site from models import Account, Site
@ -76,9 +75,10 @@ class AppSite(Resource):
@marshal_with(app_site_fields) @marshal_with(app_site_fields)
def post(self, app_model): def post(self, app_model):
args = parse_app_site_args() args = parse_app_site_args()
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be editor, admin, or owner # The role of the current user in the ta table must be editor, admin, or owner
if not current_user.is_editor: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
site = db.session.query(Site).where(Site.app_id == app_model.id).first() site = db.session.query(Site).where(Site.app_id == app_model.id).first()
@ -131,6 +131,8 @@ class AppSiteAccessTokenReset(Resource):
@marshal_with(app_site_fields) @marshal_with(app_site_fields)
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()

View File

@ -1,10 +1,9 @@
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError from controllers.console.auth.error import ApiKeyAuthFailedError
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService from services.auth.api_key_auth_service import ApiKeyAuthService
from ..wraps import account_initialization_required, setup_required from ..wraps import account_initialization_required, setup_required
@ -16,7 +15,8 @@ class ApiKeyAuthDataSource(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id) _, current_tenant_id = current_account_with_tenant()
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
if data_source_api_key_bindings: if data_source_api_key_bindings:
return { return {
"sources": [ "sources": [
@ -41,6 +41,8 @@ class ApiKeyAuthDataSourceBinding(Resource):
@account_initialization_required @account_initialization_required
def post(self): def post(self):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -50,7 +52,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
args = parser.parse_args() args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args) ApiKeyAuthService.validate_api_key_auth_args(args)
try: try:
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args) ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
except Exception as e: except Exception as e:
raise ApiKeyAuthFailedError(str(e)) raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -63,9 +65,11 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
@account_initialization_required @account_initialization_required
def delete(self, binding_id): def delete(self, binding_id):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
return {"result": "success"}, 204 return {"result": "success"}, 204

View File

@ -3,7 +3,6 @@ from collections.abc import Generator
from typing import cast from typing import cast
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -20,7 +19,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import DataSourceOauthBinding, Document from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService from services.datasource_provider_service import DatasourceProviderService
@ -37,10 +36,12 @@ class DataSourceApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(integrate_list_fields) @marshal_with(integrate_list_fields)
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant()
# get workspace data source integrates # get workspace data source integrates
data_source_integrates = db.session.scalars( data_source_integrates = db.session.scalars(
select(DataSourceOauthBinding).where( select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_tenant_id,
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
) )
).all() ).all()
@ -120,13 +121,15 @@ class DataSourceNotionListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(integrate_notion_info_list_fields) @marshal_with(integrate_notion_info_list_fields)
def get(self): def get(self):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = request.args.get("dataset_id", default=None, type=str) dataset_id = request.args.get("dataset_id", default=None, type=str)
credential_id = request.args.get("credential_id", default=None, type=str) credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id: if not credential_id:
raise ValueError("Credential id is required.") raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials( credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
credential_id=credential_id, credential_id=credential_id,
provider="notion_datasource", provider="notion_datasource",
plugin_id="langgenius/notion_datasource", plugin_id="langgenius/notion_datasource",
@ -146,7 +149,7 @@ class DataSourceNotionListApi(Resource):
documents = session.scalars( documents = session.scalars(
select(Document).filter_by( select(Document).filter_by(
dataset_id=dataset_id, dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
data_source_type="notion_import", data_source_type="notion_import",
enabled=True, enabled=True,
) )
@ -161,7 +164,7 @@ class DataSourceNotionListApi(Resource):
datasource_runtime = DatasourceManager.get_datasource_runtime( datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id="langgenius/notion_datasource/notion_datasource", provider_id="langgenius/notion_datasource/notion_datasource",
datasource_name="notion_datasource", datasource_name="notion_datasource",
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
) )
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
@ -210,12 +213,14 @@ class DataSourceNotionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, workspace_id, page_id, page_type): def get(self, workspace_id, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
credential_id = request.args.get("credential_id", default=None, type=str) credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id: if not credential_id:
raise ValueError("Credential id is required.") raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials( credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
credential_id=credential_id, credential_id=credential_id,
provider="notion_datasource", provider="notion_datasource",
plugin_id="langgenius/notion_datasource", plugin_id="langgenius/notion_datasource",
@ -229,7 +234,7 @@ class DataSourceNotionApi(Resource):
notion_obj_id=page_id, notion_obj_id=page_id,
notion_page_type=page_type, notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"), notion_access_token=credential.get("integration_secret"),
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
) )
text_docs = extractor.extract() text_docs = extractor.extract()
@ -239,6 +244,8 @@ class DataSourceNotionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json") parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
@ -263,7 +270,7 @@ class DataSourceNotionApi(Resource):
"notion_workspace_id": workspace_id, "notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"], "notion_obj_id": page["page_id"],
"notion_page_type": page["type"], "notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id, "tenant_id": current_tenant_id,
} }
), ),
document_model=args["doc_form"], document_model=args["doc_form"],
@ -271,7 +278,7 @@ class DataSourceNotionApi(Resource):
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate( response = indexing_runner.indexing_estimate(
current_user.current_tenant_id, current_tenant_id,
extract_settings, extract_settings,
args["process_rule"], args["process_rule"],
args["doc_form"], args["doc_form"],

View File

@ -1,7 +1,6 @@
import uuid import uuid
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal, reqparse from flask_restx import Resource, marshal, reqparse
from sqlalchemy import select from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -27,7 +26,7 @@ from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService, SegmentService from services.dataset_service import DatasetService, DocumentService, SegmentService
@ -43,6 +42,8 @@ class DatasetDocumentSegmentListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id, document_id): def get(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -79,7 +80,7 @@ class DatasetDocumentSegmentListApi(Resource):
select(DocumentSegment) select(DocumentSegment)
.where( .where(
DocumentSegment.document_id == str(document_id), DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id, DocumentSegment.tenant_id == current_tenant_id,
) )
.order_by(DocumentSegment.position.asc()) .order_by(DocumentSegment.position.asc())
) )
@ -115,6 +116,8 @@ class DatasetDocumentSegmentListApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id): def delete(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -148,6 +151,8 @@ class DatasetDocumentSegmentApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action): def patch(self, dataset_id, document_id, action):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
@ -171,7 +176,7 @@ class DatasetDocumentSegmentApi(Resource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
@ -204,6 +209,8 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -221,7 +228,7 @@ class DatasetDocumentSegmentAddApi(Resource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
@ -255,6 +262,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id): def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -272,7 +281,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
@ -287,7 +296,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -317,6 +326,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id): def delete(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -333,7 +344,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -361,6 +372,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -396,7 +409,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
upload_file_id, upload_file_id,
dataset_id, dataset_id,
document_id, document_id,
current_user.current_tenant_id, current_tenant_id,
current_user.id, current_user.id,
) )
except Exception as e: except Exception as e:
@ -427,6 +440,8 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id, segment_id): def post(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -441,7 +456,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -453,7 +468,7 @@ class ChildChunkAddApi(Resource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
@ -483,6 +498,8 @@ class ChildChunkAddApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id, document_id, segment_id): def get(self, dataset_id, document_id, segment_id):
_, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -499,7 +516,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -530,6 +547,8 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id): def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -546,7 +565,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -580,6 +599,8 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id, child_chunk_id): def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -596,7 +617,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -607,7 +628,7 @@ class ChildChunkUpdateApi(Resource):
db.session.query(ChildChunk) db.session.query(ChildChunk)
.where( .where(
ChildChunk.id == str(child_chunk_id), ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_user.current_tenant_id, ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id, ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id, ChildChunk.document_id == document_id,
) )
@ -634,6 +655,8 @@ class ChildChunkUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id): def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -650,7 +673,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -661,7 +684,7 @@ class ChildChunkUpdateApi(Resource):
db.session.query(ChildChunk) db.session.query(ChildChunk)
.where( .where(
ChildChunk.id == str(child_chunk_id), ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_user.current_tenant_id, ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id, ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id, ChildChunk.document_id == document_id,
) )

View File

@ -1,5 +1,4 @@
from flask import make_response, redirect, request from flask import make_response, redirect, request
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -13,7 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen from libs.helper import StrLen
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.provider_ids import DatasourceProviderID from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService from services.plugin.oauth_service import OAuthProxyService
@ -25,9 +24,10 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider_id: str): def get(self, provider_id: str):
user = current_user current_user, current_tenant_id = current_account_with_tenant()
tenant_id = user.current_tenant_id
if not current_user.is_editor: tenant_id = current_tenant_id
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
credential_id = request.args.get("credential_id") credential_id = request.args.get("credential_id")
@ -52,7 +52,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback" redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
authorization_url_response = oauth_handler.get_authorization_url( authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user.id, user_id=current_user.id,
plugin_id=plugin_id, plugin_id=plugin_id,
provider=provider_name, provider=provider_name,
redirect_uri=redirect_uri, redirect_uri=redirect_uri,
@ -131,7 +131,9 @@ class DatasourceAuth(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider_id: str): def post(self, provider_id: str):
if not current_user.is_editor: current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -145,7 +147,7 @@ class DatasourceAuth(Resource):
try: try:
datasource_provider_service.add_datasource_api_key_provider( datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider_id=datasource_provider_id, provider_id=datasource_provider_id,
credentials=args["credentials"], credentials=args["credentials"],
name=args["name"], name=args["name"],
@ -160,8 +162,10 @@ class DatasourceAuth(Resource):
def get(self, provider_id: str): def get(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
_, current_tenant_id = current_account_with_tenant()
datasources = datasource_provider_service.list_datasource_credentials( datasources = datasource_provider_service.list_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=datasource_provider_id.provider_name, provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id, plugin_id=datasource_provider_id.plugin_id,
) )
@ -174,17 +178,19 @@ class DatasourceAuthDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider_id: str): def post(self, provider_id: str):
current_user, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name provider_name = datasource_provider_id.provider_name
if not current_user.is_editor: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials( datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
auth_id=args["credential_id"], auth_id=args["credential_id"],
provider=provider_name, provider=provider_name,
plugin_id=plugin_id, plugin_id=plugin_id,
@ -198,17 +204,19 @@ class DatasourceAuthUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider_id: str): def post(self, provider_id: str):
current_user, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json") parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.is_editor: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials( datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
auth_id=args["credential_id"], auth_id=args["credential_id"],
provider=datasource_provider_id.provider_name, provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id, plugin_id=datasource_provider_id.plugin_id,
@ -224,10 +232,10 @@ class DatasourceAuthListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_all_datasource_credentials( datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
tenant_id=current_user.current_tenant_id
)
return {"result": jsonable_encoder(datasources)}, 200 return {"result": jsonable_encoder(datasources)}, 200
@ -237,10 +245,10 @@ class DatasourceHardCodeAuthListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_hard_code_datasource_credentials( datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
tenant_id=current_user.current_tenant_id
)
return {"result": jsonable_encoder(datasources)}, 200 return {"result": jsonable_encoder(datasources)}, 200
@ -250,7 +258,9 @@ class DatasourceAuthOauthCustomClient(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider_id: str): def post(self, provider_id: str):
if not current_user.is_editor: current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
@ -259,7 +269,7 @@ class DatasourceAuthOauthCustomClient(Resource):
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params( datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
client_params=args.get("client_params", {}), client_params=args.get("client_params", {}),
enabled=args.get("enable_oauth_custom_client", False), enabled=args.get("enable_oauth_custom_client", False),
@ -270,10 +280,12 @@ class DatasourceAuthOauthCustomClient(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider_id: str): def delete(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_oauth_custom_client_params( datasource_provider_service.remove_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -285,7 +297,9 @@ class DatasourceAuthDefaultApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider_id: str): def post(self, provider_id: str):
if not current_user.is_editor: current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json") parser.add_argument("id", type=str, required=True, nullable=False, location="json")
@ -293,7 +307,7 @@ class DatasourceAuthDefaultApi(Resource):
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider( datasource_provider_service.set_default_datasource_provider(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
credential_id=args["id"], credential_id=args["id"],
) )
@ -306,7 +320,9 @@ class DatasourceUpdateProviderNameApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider_id: str): def post(self, provider_id: str):
if not current_user.is_editor: current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json") parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
@ -315,7 +331,7 @@ class DatasourceUpdateProviderNameApi(Resource):
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name( datasource_provider_service.update_datasource_provider_name(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
name=args["name"], name=args["name"],
credential_id=args["credential_id"], credential_id=args["credential_id"],

View File

@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, marshal, reqparse from flask_restx import Resource, marshal, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -13,7 +12,7 @@ from controllers.console.wraps import (
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.dataset import DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
@ -38,7 +37,7 @@ class CreateRagPipelineDatasetApi(Resource):
) )
args = parser.parse_args() args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
@ -58,12 +57,12 @@ class CreateRagPipelineDatasetApi(Resource):
with Session(db.engine) as session: with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session) rag_pipeline_dsl_service = RagPipelineDslService(session)
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset( import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
) )
if rag_pipeline_dataset_create_entity.permission == "partial_members": if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list( DatasetPermissionService.update_partial_member_list(
current_user.current_tenant_id, current_tenant_id,
import_info["dataset_id"], import_info["dataset_id"],
rag_pipeline_dataset_create_entity.partial_member_list, rag_pipeline_dataset_create_entity.partial_member_list,
) )
@ -81,10 +80,12 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
dataset = DatasetService.create_empty_rag_pipeline_dataset( dataset = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity( rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
name="", name="",
description="", description="",

View File

@ -12,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
from extensions.ext_database import db from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields from fields.installed_app_fields import installed_app_list_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models import Account, App, InstalledApp, RecommendedApp from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService from services.account_service import TenantService
from services.app_service import AppService from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
@ -29,9 +29,7 @@ class InstalledAppsListApi(Resource):
@marshal_with(installed_app_list_fields) @marshal_with(installed_app_list_fields)
def get(self): def get(self):
app_id = request.args.get("app_id", default=None, type=str) app_id = request.args.get("app_id", default=None, type=str)
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id
if app_id: if app_id:
installed_apps = db.session.scalars( installed_apps = db.session.scalars(
@ -121,9 +119,8 @@ class InstalledAppsListApi(Resource):
if recommended_app is None: if recommended_app is None:
raise NotFound("App not found") raise NotFound("App not found")
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id
app = db.session.query(App).where(App.id == args["app_id"]).first() app = db.session.query(App).where(App.id == args["app_id"]).first()
if app is None: if app is None:
@ -163,9 +160,8 @@ class InstalledAppApi(InstalledAppResource):
""" """
def delete(self, installed_app): def delete(self, installed_app):
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("current_user must be an Account instance") if installed_app.app_owner_tenant_id == current_tenant_id:
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant") raise BadRequest("You can't uninstall an app owned by the current tenant")
db.session.delete(installed_app) db.session.delete(installed_app)

View File

@ -4,7 +4,7 @@ from constants import HIDDEN_VALUE
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, current_user, login_required
from models.account import Account from models.account import Account
from models.api_based_extension import APIBasedExtension from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService from services.api_based_extension_service import APIBasedExtensionService
@ -47,9 +47,7 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(api_based_extension_fields) @marshal_with(api_based_extension_fields)
def get(self): def get(self):
assert isinstance(current_user, Account) _, tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
tenant_id = current_user.current_tenant_id
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
@api.doc("create_api_based_extension") @api.doc("create_api_based_extension")
@ -77,9 +75,10 @@ class APIBasedExtensionAPI(Resource):
parser.add_argument("api_endpoint", type=str, required=True, location="json") parser.add_argument("api_endpoint", type=str, required=True, location="json")
parser.add_argument("api_key", type=str, required=True, location="json") parser.add_argument("api_key", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension( extension_data = APIBasedExtension(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
name=args["name"], name=args["name"],
api_endpoint=args["api_endpoint"], api_endpoint=args["api_endpoint"],
api_key=args["api_key"], api_key=args["api_key"],
@ -102,7 +101,7 @@ class APIBasedExtensionDetailAPI(Resource):
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None assert current_user.current_tenant_id is not None
api_based_extension_id = str(id) api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
@ -128,9 +127,9 @@ class APIBasedExtensionDetailAPI(Resource):
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None assert current_user.current_tenant_id is not None
api_based_extension_id = str(id) api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id _, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("name", type=str, required=True, location="json")
@ -157,9 +156,9 @@ class APIBasedExtensionDetailAPI(Resource):
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None assert current_user.current_tenant_id is not None
api_based_extension_id = str(id) api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id _, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
APIBasedExtensionService.delete(extension_data_from_db) APIBasedExtensionService.delete(extension_data_from_db)

View File

@ -1,7 +1,6 @@
from flask_restx import Resource, fields from flask_restx import Resource, fields
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from services.feature_service import FeatureService from services.feature_service import FeatureService
from . import api, console_ns from . import api, console_ns
@ -23,9 +22,9 @@ class FeatureApi(Resource):
@cloud_utm_record @cloud_utm_record
def get(self): def get(self):
"""Get feature configuration for current tenant""" """Get feature configuration for current tenant"""
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
return FeatureService.get_features(current_user.current_tenant_id).model_dump() return FeatureService.get_features(current_tenant_id).model_dump()
@console_ns.route("/system-features") @console_ns.route("/system-features")

View File

@ -108,4 +108,4 @@ class FileSupportTypeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
return {"allowed_extensions": DOCUMENT_EXTENSIONS} return {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}

View File

@ -14,8 +14,7 @@ from core.file import helpers as file_helpers
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from extensions.ext_database import db from extensions.ext_database import db
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from libs.login import current_user from libs.login import current_account_with_tenant
from models.account import Account
from services.file_service import FileService from services.file_service import FileService
from . import console_ns from . import console_ns
@ -64,8 +63,7 @@ class RemoteFileUploadApi(Resource):
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try: try:
assert isinstance(current_user, Account) user, _ = current_account_with_tenant()
user = current_user
upload_file = FileService(db.engine).upload_file( upload_file = FileService(db.engine).upload_file(
filename=file_info.filename, filename=file_info.filename,
content=content, content=content,

View File

@ -5,18 +5,10 @@ from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginPermissionDeniedError from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from services.plugin.endpoint_service import EndpointService from services.plugin.endpoint_service import EndpointService
def _current_account_with_tenant() -> tuple[Account, str]:
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
return current_user, tenant_id
@console_ns.route("/workspaces/current/endpoints/create") @console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource): class EndpointCreateApi(Resource):
@api.doc("create_endpoint") @api.doc("create_endpoint")
@ -41,7 +33,7 @@ class EndpointCreateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = _current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
@ -87,7 +79,7 @@ class EndpointListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user, tenant_id = _current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args") parser.add_argument("page", type=int, required=True, location="args")
@ -130,7 +122,7 @@ class EndpointListForSinglePluginApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user, tenant_id = _current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args") parser.add_argument("page", type=int, required=True, location="args")
@ -172,7 +164,7 @@ class EndpointDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = _current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True) parser.add_argument("endpoint_id", type=str, required=True)
@ -212,7 +204,7 @@ class EndpointUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = _current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True) parser.add_argument("endpoint_id", type=str, required=True)
@ -255,7 +247,7 @@ class EndpointEnableApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = _current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True) parser.add_argument("endpoint_id", type=str, required=True)
@ -288,7 +280,7 @@ class EndpointDisableApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = _current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True) parser.add_argument("endpoint_id", type=str, required=True)

View File

@ -25,7 +25,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db from extensions.ext_database import db
from fields.member_fields import account_with_role_list_fields from fields.member_fields import account_with_role_list_fields
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole from models.account import Account, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError from services.errors.account import AccountAlreadyInTenantError
@ -41,8 +41,7 @@ class MemberListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_with_role_list_fields) @marshal_with(account_with_role_list_fields)
def get(self): def get(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.current_tenant: if not current_user.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant) members = TenantService.get_tenant_members(current_user.current_tenant)
@ -69,9 +68,7 @@ class MemberInviteEmailApi(Resource):
interface_language = args["language"] interface_language = args["language"]
if not TenantAccountRole.is_non_owner_role(invitee_role): if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400 return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
inviter = current_user inviter = current_user
if not inviter.current_tenant: if not inviter.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
@ -120,8 +117,7 @@ class MemberCancelInviteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, member_id): def delete(self, member_id):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.current_tenant: if not current_user.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
member = db.session.query(Account).where(Account.id == str(member_id)).first() member = db.session.query(Account).where(Account.id == str(member_id)).first()
@ -160,9 +156,7 @@ class MemberUpdateRoleApi(Resource):
if not TenantAccountRole.is_valid_role(new_role): if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400 return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant: if not current_user.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
member = db.session.get(Account, str(member_id)) member = db.session.get(Account, str(member_id))
@ -189,8 +183,7 @@ class DatasetOperatorMemberListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_with_role_list_fields) @marshal_with(account_with_role_list_fields)
def get(self): def get(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.current_tenant: if not current_user.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant) members = TenantService.get_dataset_operator_members(current_user.current_tenant)
@ -212,10 +205,8 @@ class SendOwnerTransferEmailApi(Resource):
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
current_user, _ = current_account_with_tenant()
# check if the current user is the owner of the workspace # check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant: if not current_user.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant): if not TenantService.is_owner(current_user, current_user.current_tenant):
@ -250,8 +241,7 @@ class OwnerTransferCheckApi(Resource):
parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
# check if the current user is the owner of the workspace # check if the current user is the owner of the workspace
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.current_tenant: if not current_user.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant): if not TenantService.is_owner(current_user, current_user.current_tenant):
@ -296,8 +286,7 @@ class OwnerTransfer(Resource):
args = parser.parse_args() args = parser.parse_args()
# check if the current user is the owner of the workspace # check if the current user is the owner of the workspace
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.current_tenant: if not current_user.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant): if not TenantService.is_owner(current_user, current_user.current_tenant):

View File

@ -1,7 +1,6 @@
import io import io
from flask import send_file from flask import send_file
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -11,8 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value from libs.helper import StrLen, uuid_value
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from services.billing_service import BillingService from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService from services.model_provider_service import ModelProviderService
@ -23,11 +21,8 @@ class ModelProviderListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account") tenant_id = current_tenant_id
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument( parser.add_argument(
@ -52,11 +47,8 @@ class ModelProviderCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider: str): def get(self, provider: str):
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account") tenant_id = current_tenant_id
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
# if credential_id is not provided, return current used credential # if credential_id is not provided, return current used credential
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
@ -73,8 +65,7 @@ class ModelProviderCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
@ -85,11 +76,9 @@ class ModelProviderCredentialApi(Resource):
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try: try:
model_provider_service.create_provider_credential( model_provider_service.create_provider_credential(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
credentials=args["credentials"], credentials=args["credentials"],
credential_name=args["name"], credential_name=args["name"],
@ -103,8 +92,7 @@ class ModelProviderCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def put(self, provider: str): def put(self, provider: str):
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
@ -116,11 +104,9 @@ class ModelProviderCredentialApi(Resource):
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try: try:
model_provider_service.update_provider_credential( model_provider_service.update_provider_credential(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
credentials=args["credentials"], credentials=args["credentials"],
credential_id=args["credential_id"], credential_id=args["credential_id"],
@ -135,19 +121,16 @@ class ModelProviderCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider: str): def delete(self, provider: str):
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential( model_provider_service.remove_provider_credential(
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
) )
return {"result": "success"}, 204 return {"result": "success"}, 204
@ -159,19 +142,16 @@ class ModelProviderCredentialSwitchApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
service = ModelProviderService() service = ModelProviderService()
service.switch_active_provider_credential( service.switch_active_provider_credential(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
credential_id=args["credential_id"], credential_id=args["credential_id"],
) )
@ -184,15 +164,12 @@ class ModelProviderValidateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id: tenant_id = current_tenant_id
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -240,14 +217,11 @@ class PreferredProviderTypeUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
if not current_user.current_tenant_id: tenant_id = current_tenant_id
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument( parser.add_argument(
@ -276,14 +250,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
def get(self, provider: str): def get(self, provider: str):
if provider != "anthropic": if provider != "anthropic":
raise ValueError(f"provider name {provider} is invalid") raise ValueError(f"provider name {provider} is invalid")
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account")
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
data = BillingService.get_model_provider_payment_link( data = BillingService.get_model_provider_payment_link(
provider_name=provider, provider_name=provider,
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
account_id=current_user.id, account_id=current_user.id,
prefilled_email=current_user.email, prefilled_email=current_user.email,
) )

View File

@ -1,6 +1,5 @@
import logging import logging
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -10,7 +9,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value from libs.helper import StrLen, uuid_value
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from services.model_load_balancing_service import ModelLoadBalancingService from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService from services.model_provider_service import ModelProviderService
@ -23,6 +22,8 @@ class DefaultModelApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument( parser.add_argument(
"model_type", "model_type",
@ -34,8 +35,6 @@ class DefaultModelApi(Resource):
) )
args = parser.parse_args() args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type( default_model_entity = model_provider_service.get_default_model_of_model_type(
tenant_id=tenant_id, model_type=args["model_type"] tenant_id=tenant_id, model_type=args["model_type"]
@ -47,15 +46,14 @@ class DefaultModelApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json") parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
model_settings = args["model_settings"] model_settings = args["model_settings"]
for model_setting in model_settings: for model_setting in model_settings:
@ -92,7 +90,7 @@ class ModelProviderModelApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider) models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
@ -104,11 +102,11 @@ class ModelProviderModelApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
# To save the model's load balance configs # To save the model's load balance configs
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument( parser.add_argument(
@ -129,7 +127,7 @@ class ModelProviderModelApi(Resource):
raise ValueError("credential_id is required when configuring a custom-model") raise ValueError("credential_id is required when configuring a custom-model")
service = ModelProviderService() service = ModelProviderService()
service.switch_active_custom_model_credential( service.switch_active_custom_model_credential(
tenant_id=current_user.current_tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
model_type=args["model_type"], model_type=args["model_type"],
model=args["model"], model=args["model"],
@ -164,11 +162,11 @@ class ModelProviderModelApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider: str): def delete(self, provider: str):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument( parser.add_argument(
@ -195,7 +193,7 @@ class ModelProviderModelCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider: str): def get(self, provider: str):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="args") parser.add_argument("model", type=str, required=True, nullable=False, location="args")
@ -257,6 +255,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
@ -274,7 +274,6 @@ class ModelProviderModelCredentialApi(Resource):
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
try: try:
@ -301,6 +300,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def put(self, provider: str): def put(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
@ -323,7 +324,7 @@ class ModelProviderModelCredentialApi(Resource):
try: try:
model_provider_service.update_model_credential( model_provider_service.update_model_credential(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
model_type=args["model_type"], model_type=args["model_type"],
model=args["model"], model=args["model"],
@ -340,6 +341,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider: str): def delete(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -357,7 +360,7 @@ class ModelProviderModelCredentialApi(Resource):
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
model_provider_service.remove_model_credential( model_provider_service.remove_model_credential(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
model_type=args["model_type"], model_type=args["model_type"],
model=args["model"], model=args["model"],
@ -373,6 +376,8 @@ class ModelProviderModelCredentialSwitchApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -390,7 +395,7 @@ class ModelProviderModelCredentialSwitchApi(Resource):
service = ModelProviderService() service = ModelProviderService()
service.add_model_credential_to_model_list( service.add_model_credential_to_model_list(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
model_type=args["model_type"], model_type=args["model_type"],
model=args["model"], model=args["model"],
@ -407,7 +412,7 @@ class ModelProviderModelEnableApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, provider: str): def patch(self, provider: str):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@ -437,7 +442,7 @@ class ModelProviderModelDisableApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, provider: str): def patch(self, provider: str):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@ -465,7 +470,7 @@ class ModelProviderModelValidateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@ -514,8 +519,7 @@ class ModelProviderModelParameterRuleApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="args") parser.add_argument("model", type=str, required=True, nullable=False, location="args")
args = parser.parse_args() args = parser.parse_args()
_, tenant_id = current_account_with_tenant()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules( parameter_rules = model_provider_service.get_model_parameter_rules(
@ -531,8 +535,7 @@ class ModelProviderAvailableModelApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, model_type): def get(self, model_type):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)

View File

@ -1,7 +1,6 @@
import io import io
from flask import request, send_file from flask import request, send_file
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -11,7 +10,7 @@ from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_parameter_service import PluginParameterService from services.plugin.plugin_parameter_service import PluginParameterService
@ -26,7 +25,7 @@ class PluginDebuggingKeyApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(debug_required=True)
def get(self): def get(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
try: try:
return { return {
@ -44,7 +43,7 @@ class PluginListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=False, location="args", default=1) parser.add_argument("page", type=int, required=False, location="args", default=1)
parser.add_argument("page_size", type=int, required=False, location="args", default=256) parser.add_argument("page_size", type=int, required=False, location="args", default=256)
@ -81,7 +80,7 @@ class PluginListInstallationsFromIdsApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("plugin_ids", type=list, required=True, location="json") parser.add_argument("plugin_ids", type=list, required=True, location="json")
@ -120,7 +119,7 @@ class PluginUploadFromPkgApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
file = request.files["pkg"] file = request.files["pkg"]
@ -144,7 +143,7 @@ class PluginUploadFromGithubApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json") parser.add_argument("repo", type=str, required=True, location="json")
@ -167,7 +166,7 @@ class PluginUploadFromBundleApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
file = request.files["bundle"] file = request.files["bundle"]
@ -191,7 +190,7 @@ class PluginInstallFromPkgApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
@ -217,7 +216,7 @@ class PluginInstallFromGithubApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json") parser.add_argument("repo", type=str, required=True, location="json")
@ -247,7 +246,7 @@ class PluginInstallFromMarketplaceApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
@ -273,7 +272,7 @@ class PluginFetchMarketplacePkgApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def get(self): def get(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
@ -299,7 +298,7 @@ class PluginFetchManifestApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def get(self): def get(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
@ -324,7 +323,7 @@ class PluginFetchInstallTasksApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def get(self): def get(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args") parser.add_argument("page", type=int, required=True, location="args")
@ -346,7 +345,7 @@ class PluginFetchInstallTaskApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def get(self, task_id: str): def get(self, task_id: str):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
try: try:
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)}) return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
@ -361,7 +360,7 @@ class PluginDeleteInstallTaskApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self, task_id: str): def post(self, task_id: str):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
try: try:
return {"success": PluginService.delete_install_task(tenant_id, task_id)} return {"success": PluginService.delete_install_task(tenant_id, task_id)}
@ -376,7 +375,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
try: try:
return {"success": PluginService.delete_all_install_task_items(tenant_id)} return {"success": PluginService.delete_all_install_task_items(tenant_id)}
@ -391,7 +390,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self, task_id: str, identifier: str): def post(self, task_id: str, identifier: str):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
try: try:
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)} return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
@ -406,7 +405,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
@ -430,7 +429,7 @@ class PluginUpgradeFromGithubApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
@ -466,7 +465,7 @@ class PluginUninstallApi(Resource):
req.add_argument("plugin_installation_id", type=str, required=True, location="json") req.add_argument("plugin_installation_id", type=str, required=True, location="json")
args = req.parse_args() args = req.parse_args()
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
try: try:
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])} return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
@ -480,6 +479,7 @@ class PluginChangePermissionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user user = current_user
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
@ -492,7 +492,7 @@ class PluginChangePermissionApi(Resource):
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"]) install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"]) debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
tenant_id = user.current_tenant_id tenant_id = current_tenant_id
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)} return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
@ -503,7 +503,7 @@ class PluginFetchPermissionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
permission = PluginPermissionService.get_permission(tenant_id) permission = PluginPermissionService.get_permission(tenant_id)
if not permission: if not permission:
@ -529,10 +529,10 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
@account_initialization_required @account_initialization_required
def get(self): def get(self):
# check if the user is admin or owner # check if the user is admin or owner
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
tenant_id = current_user.current_tenant_id
user_id = current_user.id user_id = current_user.id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -567,7 +567,7 @@ class PluginChangePreferencesApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
@ -576,8 +576,6 @@ class PluginChangePreferencesApi(Resource):
req.add_argument("auto_upgrade", type=dict, required=True, location="json") req.add_argument("auto_upgrade", type=dict, required=True, location="json")
args = req.parse_args() args = req.parse_args()
tenant_id = user.current_tenant_id
permission = args["permission"] permission = args["permission"]
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone")) install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
@ -623,7 +621,7 @@ class PluginFetchPreferencesApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
permission = PluginPermissionService.get_permission(tenant_id) permission = PluginPermissionService.get_permission(tenant_id)
permission_dict = { permission_dict = {
@ -663,7 +661,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
@account_initialization_required @account_initialization_required
def post(self): def post(self):
# exclude one single plugin # exclude one single plugin
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
req = reqparse.RequestParser() req = reqparse.RequestParser()
req.add_argument("plugin_id", type=str, required=True, location="json") req.add_argument("plugin_id", type=str, required=True, location="json")

View File

@ -2,7 +2,6 @@ import io
from urllib.parse import urlparse from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file from flask import make_response, redirect, request, send_file
from flask_login import current_user
from flask_restx import ( from flask_restx import (
Resource, Resource,
reqparse, reqparse,
@ -24,7 +23,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen, alphanumeric, uuid_value from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.provider_ids import ToolProviderID from models.provider_ids import ToolProviderID
# from models.provider_ids import ToolProviderID # from models.provider_ids import ToolProviderID
@ -55,10 +54,9 @@ class ToolProviderListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
req = reqparse.RequestParser() req = reqparse.RequestParser()
req.add_argument( req.add_argument(
@ -80,9 +78,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
user = current_user _, tenant_id = current_account_with_tenant()
tenant_id = user.current_tenant_id
return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.list_builtin_tool_provider_tools( BuiltinToolManageService.list_builtin_tool_provider_tools(
@ -98,9 +94,7 @@ class ToolBuiltinProviderInfoApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
user = current_user _, tenant_id = current_account_with_tenant()
tenant_id = user.current_tenant_id
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@ -111,11 +105,10 @@ class ToolBuiltinProviderDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
tenant_id = user.current_tenant_id
req = reqparse.RequestParser() req = reqparse.RequestParser()
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json") req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = req.parse_args() args = req.parse_args()
@ -133,10 +126,9 @@ class ToolBuiltinProviderAddApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@ -163,13 +155,12 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
@ -195,7 +186,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credentials( BuiltinToolManageService.get_builtin_tool_provider_credentials(
@ -220,13 +211,12 @@ class ToolApiProviderAddApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@ -260,10 +250,9 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -284,10 +273,9 @@ class ToolApiProviderListToolsApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -310,13 +298,12 @@ class ToolApiProviderUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@ -352,13 +339,12 @@ class ToolApiProviderDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -379,10 +365,9 @@ class ToolApiProviderGetApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -403,8 +388,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider, credential_type): def get(self, provider, credential_type):
user = current_user _, tenant_id = current_account_with_tenant()
tenant_id = user.current_tenant_id
return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.list_builtin_provider_credentials_schema( BuiltinToolManageService.list_builtin_provider_credentials_schema(
@ -446,9 +430,9 @@ class ToolApiProviderPreviousTestApi(Resource):
parser.add_argument("schema", type=str, required=True, nullable=False, location="json") parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
return ApiToolManageService.test_api_tool_preview( return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id, current_tenant_id,
args["provider_name"] or "", args["provider_name"] or "",
args["tool_name"], args["tool_name"],
args["credentials"], args["credentials"],
@ -464,13 +448,12 @@ class ToolWorkflowProviderCreateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser() reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
@ -504,13 +487,12 @@ class ToolWorkflowProviderUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser() reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@ -547,13 +529,12 @@ class ToolWorkflowProviderDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser() reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@ -573,10 +554,9 @@ class ToolWorkflowProviderGetApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
@ -608,10 +588,9 @@ class ToolWorkflowProviderListToolApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
@ -633,10 +612,9 @@ class ToolBuiltinListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder( return jsonable_encoder(
[ [
@ -655,8 +633,7 @@ class ToolApiListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user _, tenant_id = current_account_with_tenant()
tenant_id = user.current_tenant_id
return jsonable_encoder( return jsonable_encoder(
[ [
@ -674,10 +651,9 @@ class ToolWorkflowListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder( return jsonable_encoder(
[ [
@ -711,19 +687,18 @@ class ToolPluginOAuthApi(Resource):
provider_name = tool_provider.provider_name provider_name = tool_provider.provider_name
# todo check permission # todo check permission
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
tenant_id = user.current_tenant_id
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider) oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
if oauth_client_params is None: if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider") raise Forbidden("no oauth available client config found for this tool provider")
oauth_handler = OAuthHandler() oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context( context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name user_id=user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
) )
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
authorization_url_response = oauth_handler.get_authorization_url( authorization_url_response = oauth_handler.get_authorization_url(
@ -802,11 +777,12 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json") parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
return BuiltinToolManageService.set_default_provider( return BuiltinToolManageService.set_default_provider(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
) )
@ -821,13 +797,13 @@ class ToolOAuthCustomClient(Resource):
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
return BuiltinToolManageService.save_custom_oauth_client_params( return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
client_params=args.get("client_params", {}), client_params=args.get("client_params", {}),
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True), enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
@ -837,20 +813,18 @@ class ToolOAuthCustomClient(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params( BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
tenant_id=current_user.current_tenant_id, provider=provider
)
) )
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider): def delete(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.delete_custom_oauth_client_params( BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
tenant_id=current_user.current_tenant_id, provider=provider
)
) )
@ -860,9 +834,10 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema( BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
tenant_id=current_user.current_tenant_id, provider_name=provider tenant_id=current_tenant_id, provider_name=provider
) )
) )
@ -873,7 +848,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credential_info( BuiltinToolManageService.get_builtin_tool_provider_credential_info(
@ -902,12 +877,12 @@ class ToolProviderMCPApi(Resource):
) )
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
args = parser.parse_args() args = parser.parse_args()
user = current_user user, tenant_id = current_account_with_tenant()
if not is_valid_url(args["server_url"]): if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.") raise ValueError("Server URL is not valid.")
return jsonable_encoder( return jsonable_encoder(
MCPToolManageService.create_mcp_provider( MCPToolManageService.create_mcp_provider(
tenant_id=user.current_tenant_id, tenant_id=tenant_id,
server_url=args["server_url"], server_url=args["server_url"],
name=args["name"], name=args["name"],
icon=args["icon"], icon=args["icon"],
@ -942,8 +917,9 @@ class ToolProviderMCPApi(Resource):
pass pass
else: else:
raise ValueError("Server URL is not valid.") raise ValueError("Server URL is not valid.")
_, current_tenant_id = current_account_with_tenant()
MCPToolManageService.update_mcp_provider( MCPToolManageService.update_mcp_provider(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider_id=args["provider_id"], provider_id=args["provider_id"],
server_url=args["server_url"], server_url=args["server_url"],
name=args["name"], name=args["name"],
@ -964,7 +940,8 @@ class ToolProviderMCPApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"]) _, current_tenant_id = current_account_with_tenant()
MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"} return {"result": "success"}
@ -979,7 +956,7 @@ class ToolMCPAuthApi(Resource):
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json") parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
provider_id = args["provider_id"] provider_id = args["provider_id"]
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id) provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider: if not provider:
raise ValueError("provider not found") raise ValueError("provider not found")
@ -1020,8 +997,8 @@ class ToolMCPDetailApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider_id): def get(self, provider_id):
user = current_user _, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id) provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@ -1031,8 +1008,7 @@ class ToolMCPListAllApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user _, tenant_id = current_account_with_tenant()
tenant_id = user.current_tenant_id
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id) tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
@ -1045,7 +1021,7 @@ class ToolMCPUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider_id): def get(self, provider_id):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.list_mcp_tool_from_remote_server( tools = MCPToolManageService.list_mcp_tool_from_remote_server(
tenant_id=tenant_id, tenant_id=tenant_id,
provider_id=provider_id, provider_id=provider_id,

View File

@ -12,8 +12,8 @@ from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs.login import current_user from libs.login import current_account_with_tenant
from models.account import Account, AccountStatus from models.account import AccountStatus
from models.dataset import RateLimitLog from models.dataset import RateLimitLog
from models.model import DifySetup from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus from services.feature_service import FeatureService, LicenseStatus
@ -25,16 +25,13 @@ P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
def _current_account() -> Account:
assert isinstance(current_user, Account)
return current_user
def account_initialization_required(view: Callable[P, R]): def account_initialization_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization # check account initialization
account = _current_account() current_user, _ = current_account_with_tenant()
account = current_user
if account.status == AccountStatus.UNINITIALIZED: if account.status == AccountStatus.UNINITIALIZED:
raise AccountNotInitializedError() raise AccountNotInitializedError()
@ -80,9 +77,8 @@ def only_edition_self_hosted(view: Callable[P, R]):
def cloud_edition_billing_enabled(view: Callable[P, R]): def cloud_edition_billing_enabled(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account() _, current_tenant_id = current_account_with_tenant()
assert account.current_tenant_id is not None features = FeatureService.get_features(current_tenant_id)
features = FeatureService.get_features(account.current_tenant_id)
if not features.billing.enabled: if not features.billing.enabled:
abort(403, "Billing feature is not enabled.") abort(403, "Billing feature is not enabled.")
return view(*args, **kwargs) return view(*args, **kwargs)
@ -94,10 +90,8 @@ def cloud_edition_billing_resource_check(resource: str):
def interceptor(view: Callable[P, R]): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account() _, current_tenant_id = current_account_with_tenant()
assert account.current_tenant_id is not None features = FeatureService.get_features(current_tenant_id)
tenant_id = account.current_tenant_id
features = FeatureService.get_features(tenant_id)
if features.billing.enabled: if features.billing.enabled:
members = features.members members = features.members
apps = features.apps apps = features.apps
@ -138,9 +132,8 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view: Callable[P, R]): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account() _, current_tenant_id = current_account_with_tenant()
assert account.current_tenant_id is not None features = FeatureService.get_features(current_tenant_id)
features = FeatureService.get_features(account.current_tenant_id)
if features.billing.enabled: if features.billing.enabled:
if resource == "add_segment": if resource == "add_segment":
if features.billing.subscription.plan == "sandbox": if features.billing.subscription.plan == "sandbox":
@ -163,13 +156,11 @@ def cloud_edition_billing_rate_limit_check(resource: str):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if resource == "knowledge": if resource == "knowledge":
account = _current_account() _, current_tenant_id = current_account_with_tenant()
assert account.current_tenant_id is not None knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_tenant_id)
tenant_id = account.current_tenant_id
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
if knowledge_rate_limit.enabled: if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000) current_time = int(time.time() * 1000)
key = f"rate_limit_{tenant_id}" key = f"rate_limit_{current_tenant_id}"
redis_client.zadd(key, {current_time: current_time}) redis_client.zadd(key, {current_time: current_time})
@ -180,7 +171,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
if request_count > knowledge_rate_limit.limit: if request_count > knowledge_rate_limit.limit:
# add ratelimit record # add ratelimit record
rate_limit_log = RateLimitLog( rate_limit_log = RateLimitLog(
tenant_id=tenant_id, tenant_id=current_tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan, subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge", operation="knowledge",
) )
@ -200,17 +191,15 @@ def cloud_utm_record(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
account = _current_account() _, current_tenant_id = current_account_with_tenant()
assert account.current_tenant_id is not None features = FeatureService.get_features(current_tenant_id)
tenant_id = account.current_tenant_id
features = FeatureService.get_features(tenant_id)
if features.billing.enabled: if features.billing.enabled:
utm_info = request.cookies.get("utm_info") utm_info = request.cookies.get("utm_info")
if utm_info: if utm_info:
utm_info_dict: dict = json.loads(utm_info) utm_info_dict: dict = json.loads(utm_info)
OperationService.record_utm(tenant_id, utm_info_dict) OperationService.record_utm(current_tenant_id, utm_info_dict)
return view(*args, **kwargs) return view(*args, **kwargs)
@ -289,9 +278,8 @@ def enable_change_email(view: Callable[P, R]):
def is_allow_transfer_owner(view: Callable[P, R]): def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account() _, current_tenant_id = current_account_with_tenant()
assert account.current_tenant_id is not None features = FeatureService.get_features(current_tenant_id)
features = FeatureService.get_features(account.current_tenant_id)
if features.is_allow_transfer_workspace: if features.is_allow_transfer_workspace:
return view(*args, **kwargs) return view(*args, **kwargs)
@ -301,12 +289,11 @@ def is_allow_transfer_owner(view: Callable[P, R]):
return decorated return decorated
def knowledge_pipeline_publish_enabled(view): def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account() _, current_tenant_id = current_account_with_tenant()
assert account.current_tenant_id is not None features = FeatureService.get_features(current_tenant_id)
features = FeatureService.get_features(account.current_tenant_id)
if features.knowledge_pipeline.publish_enabled: if features.knowledge_pipeline.publish_enabled:
return view(*args, **kwargs) return view(*args, **kwargs)
abort(403) abort(403)

View File

@ -1,5 +1,4 @@
from flask import request from flask import request
from flask_login import current_user
from flask_restx import marshal, reqparse from flask_restx import marshal, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -16,6 +15,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db from extensions.ext_database import db
from fields.segment_fields import child_chunk_fields, segment_fields from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import current_account_with_tenant
from models.dataset import Dataset from models.dataset import Dataset
from services.dataset_service import DatasetService, DocumentService, SegmentService from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
@ -66,6 +66,7 @@ class SegmentApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str): def post(self, tenant_id: str, dataset_id: str, document_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Create single segment.""" """Create single segment."""
# check dataset # check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@ -84,7 +85,7 @@ class SegmentApi(DatasetApiResource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
@ -117,6 +118,7 @@ class SegmentApi(DatasetApiResource):
} }
) )
def get(self, tenant_id: str, dataset_id: str, document_id: str): def get(self, tenant_id: str, dataset_id: str, document_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Get segments.""" """Get segments."""
# check dataset # check dataset
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
@ -133,7 +135,7 @@ class SegmentApi(DatasetApiResource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
@ -149,7 +151,7 @@ class SegmentApi(DatasetApiResource):
segments, total = SegmentService.get_segments( segments, total = SegmentService.get_segments(
document_id=document_id, document_id=document_id,
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
status_list=args["status"], status_list=args["status"],
keyword=args["keyword"], keyword=args["keyword"],
page=page, page=page,
@ -184,6 +186,7 @@ class DatasetSegmentApi(DatasetApiResource):
) )
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
@ -195,7 +198,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
SegmentService.delete_segment(segment, document, dataset) SegmentService.delete_segment(segment, document, dataset)
@ -217,6 +220,7 @@ class DatasetSegmentApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
@ -232,7 +236,7 @@ class DatasetSegmentApi(DatasetApiResource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
@ -244,7 +248,7 @@ class DatasetSegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
# check segment # check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
@ -266,6 +270,7 @@ class DatasetSegmentApi(DatasetApiResource):
} }
) )
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
@ -277,7 +282,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
@ -307,6 +312,7 @@ class ChildChunkApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Create child chunk.""" """Create child chunk."""
# check dataset # check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@ -319,7 +325,7 @@ class ChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
@ -328,7 +334,7 @@ class ChildChunkApi(DatasetApiResource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
@ -364,6 +370,7 @@ class ChildChunkApi(DatasetApiResource):
} }
) )
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Get child chunks.""" """Get child chunks."""
# check dataset # check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@ -376,7 +383,7 @@ class ChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
@ -423,6 +430,7 @@ class DatasetChildChunkApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str): def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Delete child chunk.""" """Delete child chunk."""
# check dataset # check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@ -435,7 +443,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
@ -444,9 +452,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check child chunk # check child chunk
child_chunk = SegmentService.get_child_chunk_by_id( child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
)
if not child_chunk: if not child_chunk:
raise NotFound("Child chunk not found.") raise NotFound("Child chunk not found.")
@ -483,6 +489,7 @@ class DatasetChildChunkApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str): def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Update child chunk.""" """Update child chunk."""
# check dataset # check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@ -495,7 +502,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# get segment # get segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
@ -504,9 +511,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# get child chunk # get child chunk
child_chunk = SegmentService.get_child_chunk_by_id( child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
)
if not child_chunk: if not child_chunk:
raise NotFound("Child chunk not found.") raise NotFound("Child chunk not found.")

View File

@ -1,10 +1,12 @@
import logging import logging
import queue import queue
import threading
import time import time
from abc import abstractmethod from abc import abstractmethod
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Any from typing import Any
from cachetools import TTLCache, cachedmethod
from redis.exceptions import RedisError from redis.exceptions import RedisError
from sqlalchemy.orm import DeclarativeMeta from sqlalchemy.orm import DeclarativeMeta
@ -45,6 +47,8 @@ class AppQueueManager:
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self._q = q self._q = q
self._stopped_cache: TTLCache[tuple, bool] = TTLCache(maxsize=1, ttl=1)
self._cache_lock = threading.Lock()
def listen(self): def listen(self):
""" """
@ -157,6 +161,7 @@ class AppQueueManager:
stopped_cache_key = cls._generate_stopped_cache_key(task_id) stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1) redis_client.setex(stopped_cache_key, 600, 1)
@cachedmethod(lambda self: self._stopped_cache, lock=lambda self: self._cache_lock)
def _is_stopped(self) -> bool: def _is_stopped(self) -> bool:
""" """
Check if task is stopped Check if task is stopped

View File

@ -472,6 +472,9 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache.delete() provider_model_credentials_cache.delete()
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session) self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
else:
# some historical data may have a provider record but not be set as valid
provider_record.is_valid = True
session.commit() session.commit()
except Exception: except Exception:

View File

@ -7,7 +7,7 @@ import uuid
from collections import deque from collections import deque
from collections.abc import Sequence from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from typing import Final from typing import Final, cast
from urllib.parse import urljoin from urllib.parse import urljoin
import httpx import httpx
@ -199,7 +199,7 @@ def convert_to_trace_id(uuid_v4: str | None) -> int:
raise ValueError("UUID cannot be None") raise ValueError("UUID cannot be None")
try: try:
uuid_obj = uuid.UUID(uuid_v4) uuid_obj = uuid.UUID(uuid_v4)
return uuid_obj.int return cast(int, uuid_obj.int)
except ValueError as e: except ValueError as e:
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e raise ValueError(f"Invalid UUID input: {uuid_v4}") from e

View File

@ -13,6 +13,7 @@ class TracingProviderEnum(StrEnum):
OPIK = "opik" OPIK = "opik"
WEAVE = "weave" WEAVE = "weave"
ALIYUN = "aliyun" ALIYUN = "aliyun"
TENCENT = "tencent"
class BaseTracingConfig(BaseModel): class BaseTracingConfig(BaseModel):
@ -195,5 +196,32 @@ class AliyunConfig(BaseTracingConfig):
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com") return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
class TencentConfig(BaseTracingConfig):
"""
Tencent APM tracing config
"""
token: str
endpoint: str
service_name: str
@field_validator("token")
@classmethod
def token_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("Token cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
@field_validator("service_name")
@classmethod
def service_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")
OPS_FILE_PATH = "ops_trace/" OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE" OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

View File

@ -90,6 +90,7 @@ class SuggestedQuestionTraceInfo(BaseTraceInfo):
class DatasetRetrievalTraceInfo(BaseTraceInfo): class DatasetRetrievalTraceInfo(BaseTraceInfo):
documents: Any = None documents: Any = None
error: str | None = None
class ToolTraceInfo(BaseTraceInfo): class ToolTraceInfo(BaseTraceInfo):

View File

@ -120,6 +120,17 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
"trace_instance": AliyunDataTrace, "trace_instance": AliyunDataTrace,
} }
case TracingProviderEnum.TENCENT:
from core.ops.entities.config_entity import TencentConfig
from core.ops.tencent_trace.tencent_trace import TencentDataTrace
return {
"config_class": TencentConfig,
"secret_keys": ["token"],
"other_keys": ["endpoint", "service_name"],
"trace_instance": TencentDataTrace,
}
case _: case _:
raise KeyError(f"Unsupported tracing provider: {provider}") raise KeyError(f"Unsupported tracing provider: {provider}")
@ -723,6 +734,7 @@ class TraceTask:
end_time=timer.get("end"), end_time=timer.get("end"),
metadata=metadata, metadata=metadata,
message_data=message_data.to_dict(), message_data=message_data.to_dict(),
error=kwargs.get("error"),
) )
return dataset_retrieval_trace_info return dataset_retrieval_trace_info
@ -889,6 +901,7 @@ class TraceQueueManager:
continue continue
file_id = uuid4().hex file_id = uuid4().hex
trace_info = task.execute() trace_info = task.execute()
task_data = TaskData( task_data = TaskData(
app_id=task.app_id, app_id=task.app_id,
trace_info_type=type(trace_info).__name__, trace_info_type=type(trace_info).__name__,

View File

View File

@ -0,0 +1,337 @@
"""
Tencent APM Trace Client - handles network operations, metrics, and API communication
"""
from __future__ import annotations
import importlib
import logging
import os
import socket
from typing import TYPE_CHECKING
from urllib.parse import urlparse
if TYPE_CHECKING:
from opentelemetry.metrics import Meter
from opentelemetry.metrics._internal.instrument import Histogram
from opentelemetry.sdk.metrics.export import MetricReader
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import SpanKind
from opentelemetry.util.types import AttributeValue
from configs import dify_config
from .entities.tencent_semconv import LLM_OPERATION_DURATION
from .entities.tencent_trace_entity import SpanData
logger = logging.getLogger(__name__)
class TencentTraceClient:
"""Tencent APM trace client using OpenTelemetry OTLP exporter"""
def __init__(
self,
service_name: str,
endpoint: str,
token: str,
max_queue_size: int = 1000,
schedule_delay_sec: int = 5,
max_export_batch_size: int = 50,
metrics_export_interval_sec: int = 10,
):
self.endpoint = endpoint
self.token = token
self.service_name = service_name
self.metrics_export_interval_sec = metrics_export_interval_sec
self.resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
}
)
# Prepare gRPC endpoint/metadata
grpc_endpoint, insecure, _, _ = self._resolve_grpc_target(endpoint)
headers = (("authorization", f"Bearer {token}"),)
self.exporter = OTLPSpanExporter(
endpoint=grpc_endpoint,
headers=headers,
insecure=insecure,
timeout=30,
)
self.tracer_provider = TracerProvider(resource=self.resource)
self.span_processor = BatchSpanProcessor(
span_exporter=self.exporter,
max_queue_size=max_queue_size,
schedule_delay_millis=schedule_delay_sec * 1000,
max_export_batch_size=max_export_batch_size,
)
self.tracer_provider.add_span_processor(self.span_processor)
self.tracer = self.tracer_provider.get_tracer("dify.tencent_apm")
# Store span contexts for parent-child relationships
self.span_contexts: dict[int, trace_api.SpanContext] = {}
self.meter: Meter | None = None
self.hist_llm_duration: Histogram | None = None
self.metric_reader: MetricReader | None = None
# Metrics exporter and instruments
try:
from opentelemetry import metrics
from opentelemetry.sdk.metrics import Histogram, MeterProvider
from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader
protocol = os.getenv("OTEL_EXPORTER_OTLP_PROTOCOL", "").strip().lower()
use_http_protobuf = protocol in {"http/protobuf", "http-protobuf"}
use_http_json = protocol in {"http/json", "http-json"}
# Set preferred temporality for histograms to DELTA
preferred_temporality: dict[type, AggregationTemporality] = {Histogram: AggregationTemporality.DELTA}
def _create_metric_exporter(exporter_cls, **kwargs):
"""Create metric exporter with preferred_temporality support"""
try:
return exporter_cls(**kwargs, preferred_temporality=preferred_temporality)
except Exception:
return exporter_cls(**kwargs)
metric_reader = None
if use_http_json:
exporter_cls = None
for mod_path in (
"opentelemetry.exporter.otlp.http.json.metric_exporter",
"opentelemetry.exporter.otlp.json.metric_exporter",
):
try:
mod = importlib.import_module(mod_path)
exporter_cls = getattr(mod, "OTLPMetricExporter", None)
if exporter_cls:
break
except Exception:
continue
if exporter_cls is not None:
metric_exporter = _create_metric_exporter(
exporter_cls,
endpoint=endpoint,
headers={"authorization": f"Bearer {token}"},
)
else:
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter as HttpMetricExporter,
)
metric_exporter = _create_metric_exporter(
HttpMetricExporter,
endpoint=endpoint,
headers={"authorization": f"Bearer {token}"},
)
metric_reader = PeriodicExportingMetricReader(
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
)
elif use_http_protobuf:
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter as HttpMetricExporter,
)
metric_exporter = _create_metric_exporter(
HttpMetricExporter,
endpoint=endpoint,
headers={"authorization": f"Bearer {token}"},
)
metric_reader = PeriodicExportingMetricReader(
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
)
else:
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import (
OTLPMetricExporter as GrpcMetricExporter,
)
m_grpc_endpoint, m_insecure, _, _ = self._resolve_grpc_target(endpoint)
metric_exporter = _create_metric_exporter(
GrpcMetricExporter,
endpoint=m_grpc_endpoint,
headers={"authorization": f"Bearer {token}"},
insecure=m_insecure,
)
metric_reader = PeriodicExportingMetricReader(
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
)
if metric_reader is not None:
provider = MeterProvider(resource=self.resource, metric_readers=[metric_reader])
metrics.set_meter_provider(provider)
self.meter = metrics.get_meter("dify-sdk", dify_config.project.version)
self.hist_llm_duration = self.meter.create_histogram(
name=LLM_OPERATION_DURATION,
unit="s",
description="LLM operation duration (seconds)",
)
self.metric_reader = metric_reader
else:
self.meter = None
self.hist_llm_duration = None
self.metric_reader = None
except Exception:
logger.exception("[Tencent APM] Metrics initialization failed; metrics disabled")
self.meter = None
self.hist_llm_duration = None
self.metric_reader = None
def add_span(self, span_data: SpanData) -> None:
"""Create and export span using OpenTelemetry Tracer API"""
try:
self._create_and_export_span(span_data)
logger.debug("[Tencent APM] Created span: %s", span_data.name)
except Exception:
logger.exception("[Tencent APM] Failed to create span: %s", span_data.name)
# Metrics recording API
def record_llm_duration(self, latency_seconds: float, attributes: dict[str, str] | None = None) -> None:
"""Record LLM operation duration histogram in seconds."""
try:
if not hasattr(self, "hist_llm_duration") or self.hist_llm_duration is None:
return
attrs: dict[str, str] = {}
if attributes:
for k, v in attributes.items():
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
self.hist_llm_duration.record(latency_seconds, attrs) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record LLM duration", exc_info=True)
def _create_and_export_span(self, span_data: SpanData) -> None:
"""Create span using OpenTelemetry Tracer API"""
try:
parent_context = None
if span_data.parent_span_id and span_data.parent_span_id in self.span_contexts:
parent_context = trace_api.set_span_in_context(
trace_api.NonRecordingSpan(self.span_contexts[span_data.parent_span_id])
)
span = self.tracer.start_span(
name=span_data.name,
context=parent_context,
kind=SpanKind.INTERNAL,
start_time=span_data.start_time,
)
self.span_contexts[span_data.span_id] = span.get_span_context()
if span_data.attributes:
attributes: dict[str, AttributeValue] = {}
for key, value in span_data.attributes.items():
if isinstance(value, (int, float, bool)):
attributes[key] = value
else:
attributes[key] = str(value)
span.set_attributes(attributes)
if span_data.events:
for event in span_data.events:
span.add_event(event.name, event.attributes, event.timestamp)
if span_data.status:
span.set_status(span_data.status)
# Manually end span; do not use context manager to avoid double-end warnings
span.end(end_time=span_data.end_time)
except Exception:
logger.exception("[Tencent APM] Error creating span: %s", span_data.name)
def api_check(self) -> bool:
"""Check API connectivity using socket connection test for gRPC endpoints"""
try:
# Resolve gRPC target consistently with exporters
_, _, host, port = self._resolve_grpc_target(self.endpoint)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5)
result = sock.connect_ex((host, port))
sock.close()
if result == 0:
logger.info("[Tencent APM] Endpoint %s:%s is accessible", host, port)
return True
else:
logger.warning("[Tencent APM] Endpoint %s:%s is not accessible", host, port)
if host in ["127.0.0.1", "localhost"]:
logger.info("[Tencent APM] Development environment detected, allowing config save")
return True
return False
except Exception:
logger.exception("[Tencent APM] API check failed")
if "127.0.0.1" in self.endpoint or "localhost" in self.endpoint:
return True
return False
def get_project_url(self) -> str:
"""Get project console URL"""
return "https://console.cloud.tencent.com/apm"
def shutdown(self) -> None:
"""Shutdown the client and export remaining spans"""
try:
if self.span_processor:
logger.info("[Tencent APM] Flushing remaining spans before shutdown")
_ = self.span_processor.force_flush()
self.span_processor.shutdown()
if self.tracer_provider:
self.tracer_provider.shutdown()
if self.metric_reader is not None:
try:
self.metric_reader.shutdown() # type: ignore[attr-defined]
except Exception:
pass
except Exception:
logger.exception("[Tencent APM] Error during client shutdown")
@staticmethod
def _resolve_grpc_target(endpoint: str, default_port: int = 4317) -> tuple[str, bool, str, int]:
"""Normalize endpoint to gRPC target and security flag.
Returns:
(grpc_endpoint, insecure, host, port)
"""
try:
if endpoint.startswith(("http://", "https://")):
parsed = urlparse(endpoint)
host = parsed.hostname or "localhost"
port = parsed.port or default_port
insecure = parsed.scheme == "http"
return f"{host}:{port}", insecure, host, port
host = endpoint
port = default_port
if ":" in endpoint:
parts = endpoint.rsplit(":", 1)
host = parts[0] or "localhost"
try:
port = int(parts[1])
except Exception:
port = default_port
insecure = ("localhost" in host) or ("127.0.0.1" in host)
return f"{host}:{port}", insecure, host, port
except Exception:
host, port = "localhost", default_port
return f"{host}:{port}", True, host, port

View File

@ -0,0 +1 @@
# Tencent trace entities module

View File

@ -0,0 +1,73 @@
from enum import Enum
# public
GEN_AI_SESSION_ID = "gen_ai.session.id"
GEN_AI_USER_ID = "gen_ai.user.id"
GEN_AI_USER_NAME = "gen_ai.user.name"
GEN_AI_SPAN_KIND = "gen_ai.span.kind"
GEN_AI_FRAMEWORK = "gen_ai.framework"
GEN_AI_IS_ENTRY = "gen_ai.is_entry" # mark to count the LLM-related traces
# Chain
INPUT_VALUE = "gen_ai.entity.input"
OUTPUT_VALUE = "gen_ai.entity.output"
# Retriever
RETRIEVAL_QUERY = "retrieval.query"
RETRIEVAL_DOCUMENT = "retrieval.document"
# GENERATION
GEN_AI_MODEL_NAME = "gen_ai.response.model"
GEN_AI_PROVIDER = "gen_ai.provider.name"
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template"
GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
GEN_AI_PROMPT = "gen_ai.prompt"
GEN_AI_COMPLETION = "gen_ai.completion"
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
# Tool
TOOL_NAME = "tool.name"
TOOL_DESCRIPTION = "tool.description"
TOOL_PARAMETERS = "tool.parameters"
# Instrumentation Library
INSTRUMENTATION_NAME = "dify-sdk"
INSTRUMENTATION_VERSION = "0.1.0"
INSTRUMENTATION_LANGUAGE = "python"
# Metrics
LLM_OPERATION_DURATION = "gen_ai.client.operation.duration"
class GenAISpanKind(Enum):
WORKFLOW = "WORKFLOW" # OpenLLMetry
RETRIEVER = "RETRIEVER" # RAG
GENERATION = "GENERATION" # Langfuse
TOOL = "TOOL" # OpenLLMetry
AGENT = "AGENT" # OpenLLMetry
TASK = "TASK" # OpenLLMetry

View File

@ -0,0 +1,21 @@
from collections.abc import Sequence
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from pydantic import BaseModel, Field
class SpanData(BaseModel):
model_config = {"arbitrary_types_allowed": True}
trace_id: int = Field(..., description="The unique identifier for the trace.")
parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.")
span_id: int = Field(..., description="The unique identifier for this span.")
name: str = Field(..., description="The name of the span.")
attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: int = Field(..., description="The start time of the span in nanoseconds.")
end_time: int = Field(..., description="The end time of the span in nanoseconds.")

View File

@ -0,0 +1,372 @@
"""
Tencent APM Span Builder - handles all span construction logic
"""
import json
import logging
from datetime import datetime
from opentelemetry.trace import Status, StatusCode
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
MessageTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.tencent_trace.entities.tencent_semconv import (
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,
GEN_AI_IS_ENTRY,
GEN_AI_MODEL_NAME,
GEN_AI_PROMPT,
GEN_AI_PROVIDER,
GEN_AI_RESPONSE_FINISH_REASON,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
GEN_AI_USAGE_INPUT_TOKENS,
GEN_AI_USAGE_OUTPUT_TOKENS,
GEN_AI_USAGE_TOTAL_TOKENS,
GEN_AI_USER_ID,
INPUT_VALUE,
OUTPUT_VALUE,
RETRIEVAL_DOCUMENT,
RETRIEVAL_QUERY,
TOOL_DESCRIPTION,
TOOL_NAME,
TOOL_PARAMETERS,
GenAISpanKind,
)
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
from core.ops.tencent_trace.utils import TencentTraceUtils
from core.rag.models.document import Document
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
logger = logging.getLogger(__name__)
class TencentSpanBuilder:
"""Builder class for constructing different types of spans"""
@staticmethod
def _get_time_nanoseconds(time_value: datetime | None) -> int:
"""Convert datetime to nanoseconds for span creation."""
return TencentTraceUtils.convert_datetime_to_nanoseconds(time_value)
@staticmethod
def build_workflow_spans(
trace_info: WorkflowTraceInfo, trace_id: int, user_id: str, links: list | None = None
) -> list[SpanData]:
"""Build workflow-related spans"""
spans = []
links = links or []
message_span_id = None
workflow_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "workflow")
if hasattr(trace_info, "metadata") and trace_info.metadata.get("conversation_id"):
message_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "message")
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
if message_span_id:
message_span = TencentSpanBuilder._build_message_span(
trace_info, trace_id, message_span_id, user_id, status, links
)
spans.append(message_span)
workflow_span = TencentSpanBuilder._build_workflow_span(
trace_info, trace_id, workflow_span_id, message_span_id, user_id, status, links
)
spans.append(workflow_span)
return spans
@staticmethod
def _build_message_span(
trace_info: WorkflowTraceInfo, trace_id: int, message_span_id: int, user_id: str, status: Status, links: list
) -> SpanData:
"""Build message span for chatflow"""
return SpanData(
trace_id=trace_id,
parent_span_id=None,
span_id=message_span_id,
name="message",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_IS_ENTRY: "true",
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""),
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
},
status=status,
links=links,
)
@staticmethod
def _build_workflow_span(
trace_info: WorkflowTraceInfo,
trace_id: int,
workflow_span_id: int,
message_span_id: int | None,
user_id: str,
status: Status,
links: list,
) -> SpanData:
"""Build workflow span"""
attributes = {
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
}
if message_span_id is None:
attributes[GEN_AI_IS_ENTRY] = "true"
return SpanData(
trace_id=trace_id,
parent_span_id=message_span_id,
span_id=workflow_span_id,
name="workflow",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes=attributes,
status=status,
links=links,
)
@staticmethod
def build_workflow_llm_span(
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
"""Build LLM span for workflow nodes."""
process_data = node_execution.process_data or {}
outputs = node_execution.outputs or {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
name="GENERATION",
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
GEN_AI_PROVIDER: process_data.get("model_provider", ""),
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
GEN_AI_COMPLETION: str(outputs.get("text", "")),
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
OUTPUT_VALUE: str(outputs.get("text", "")),
},
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@staticmethod
def build_message_span(
trace_info: MessageTraceInfo, trace_id: int, user_id: str, links: list | None = None
) -> SpanData:
"""Build message span."""
links = links or []
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
return SpanData(
trace_id=trace_id,
parent_span_id=None,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message"),
name="message",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_IS_ENTRY: "true",
INPUT_VALUE: str(trace_info.inputs or ""),
OUTPUT_VALUE: str(trace_info.outputs or ""),
},
status=status,
links=links,
)
@staticmethod
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
"""Build tool span."""
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
return SpanData(
trace_id=trace_id,
parent_span_id=parent_span_id,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "tool"),
name=trace_info.tool_name,
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: trace_info.tool_name,
TOOL_DESCRIPTION: "",
TOOL_PARAMETERS: json.dumps(trace_info.tool_parameters, ensure_ascii=False),
INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
OUTPUT_VALUE: str(trace_info.tool_outputs),
},
status=status,
)
@staticmethod
def build_retrieval_span(trace_info: DatasetRetrievalTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
"""Build dataset retrieval span."""
status = Status(StatusCode.OK)
if getattr(trace_info, "error", None):
status = Status(StatusCode.ERROR, trace_info.error) # type: ignore[arg-type]
documents_data = TencentSpanBuilder._extract_retrieval_documents(trace_info.documents)
return SpanData(
trace_id=trace_id,
parent_span_id=parent_span_id,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "retrieval"),
name="retrieval",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
GEN_AI_FRAMEWORK: "dify",
RETRIEVAL_QUERY: str(trace_info.inputs or ""),
RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False),
INPUT_VALUE: str(trace_info.inputs or ""),
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
},
status=status,
)
@staticmethod
def _get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
"""Get workflow node execution status."""
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
return Status(StatusCode.OK)
elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
return Status(StatusCode.ERROR, str(node_execution.error))
return Status(StatusCode.UNSET)
@staticmethod
def build_workflow_retrieval_span(
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
"""Build knowledge retrieval span for workflow nodes."""
input_value = ""
if node_execution.inputs:
input_value = str(node_execution.inputs.get("query", ""))
output_value = ""
if node_execution.outputs:
output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False)
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
GEN_AI_FRAMEWORK: "dify",
RETRIEVAL_QUERY: input_value,
RETRIEVAL_DOCUMENT: output_value,
INPUT_VALUE: input_value,
OUTPUT_VALUE: output_value,
},
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@staticmethod
def build_workflow_tool_span(
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
"""Build tool span for workflow nodes."""
tool_des = {}
if node_execution.metadata:
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: node_execution.title,
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@staticmethod
def build_workflow_task_span(
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
"""Build generic task span for workflow nodes."""
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@staticmethod
def _extract_retrieval_documents(documents: list[Document]):
"""Extract documents data for retrieval tracing."""
documents_data = []
for document in documents:
document_data = {
"content": document.page_content,
"metadata": {
"dataset_id": document.metadata.get("dataset_id"),
"doc_id": document.metadata.get("doc_id"),
"document_id": document.metadata.get("document_id"),
},
"score": document.metadata.get("score"),
}
documents_data.append(document_data)
return documents_data

View File

@ -0,0 +1,317 @@
"""
Tencent APM tracing implementation with separated concerns
"""
import logging
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import TencentConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.tencent_trace.client import TencentTraceClient
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
from core.ops.tencent_trace.span_builder import TencentSpanBuilder
from core.ops.tencent_trace.utils import TencentTraceUtils
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from core.workflow.nodes import NodeType
from extensions.ext_database import db
from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
class TencentDataTrace(BaseTraceInstance):
"""
Tencent APM trace implementation with single responsibility principle.
Acts as a coordinator that delegates specific tasks to specialized classes.
"""
def __init__(self, tencent_config: TencentConfig):
super().__init__(tencent_config)
self.trace_client = TencentTraceClient(
service_name=tencent_config.service_name,
endpoint=tencent_config.endpoint,
token=tencent_config.token,
metrics_export_interval_sec=5,
)
def trace(self, trace_info: BaseTraceInfo) -> None:
"""Main tracing entry point - coordinates different trace types."""
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
elif isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
elif isinstance(trace_info, ModerationTraceInfo):
pass
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
elif isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
elif isinstance(trace_info, GenerateNameTraceInfo):
pass
def api_check(self) -> bool:
return self.trace_client.api_check()
def get_project_url(self) -> str:
return self.trace_client.get_project_url()
def workflow_trace(self, trace_info: WorkflowTraceInfo) -> None:
"""Handle workflow tracing by coordinating data retrieval and span construction."""
try:
trace_id = TencentTraceUtils.convert_to_trace_id(trace_info.workflow_run_id)
links = []
if trace_info.trace_id:
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
user_id = self._get_user_id(trace_info)
workflow_spans = TencentSpanBuilder.build_workflow_spans(trace_info, trace_id, str(user_id), links)
for span in workflow_spans:
self.trace_client.add_span(span)
self._process_workflow_nodes(trace_info, trace_id)
except Exception:
logger.exception("[Tencent APM] Failed to process workflow trace")
def message_trace(self, trace_info: MessageTraceInfo) -> None:
"""Handle message tracing."""
try:
trace_id = TencentTraceUtils.convert_to_trace_id(trace_info.message_id)
user_id = self._get_user_id(trace_info)
links = []
if trace_info.trace_id:
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
self.trace_client.add_span(message_span)
except Exception:
logger.exception("[Tencent APM] Failed to process message trace")
def tool_trace(self, trace_info: ToolTraceInfo) -> None:
"""Handle tool tracing."""
try:
parent_span_id = None
trace_root_id = None
if trace_info.message_id:
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
trace_root_id = trace_info.message_id
if parent_span_id and trace_root_id:
trace_id = TencentTraceUtils.convert_to_trace_id(trace_root_id)
tool_span = TencentSpanBuilder.build_tool_span(trace_info, trace_id, parent_span_id)
self.trace_client.add_span(tool_span)
except Exception:
logger.exception("[Tencent APM] Failed to process tool trace")
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo) -> None:
"""Handle dataset retrieval tracing."""
try:
parent_span_id = None
trace_root_id = None
if trace_info.message_id:
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
trace_root_id = trace_info.message_id
if parent_span_id and trace_root_id:
trace_id = TencentTraceUtils.convert_to_trace_id(trace_root_id)
retrieval_span = TencentSpanBuilder.build_retrieval_span(trace_info, trace_id, parent_span_id)
self.trace_client.add_span(retrieval_span)
except Exception:
logger.exception("[Tencent APM] Failed to process dataset retrieval trace")
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo) -> None:
"""Handle suggested question tracing"""
try:
logger.info("[Tencent APM] Processing suggested question trace")
except Exception:
logger.exception("[Tencent APM] Failed to process suggested question trace")
def _process_workflow_nodes(self, trace_info: WorkflowTraceInfo, trace_id: int) -> None:
"""Process workflow node executions."""
try:
workflow_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "workflow")
node_executions = self._get_workflow_node_executions(trace_info)
for node_execution in node_executions:
try:
node_span = self._build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id)
if node_span:
self.trace_client.add_span(node_span)
if node_execution.node_type == NodeType.LLM:
self._record_llm_metrics(node_execution)
except Exception:
logger.exception("[Tencent APM] Failed to process node execution: %s", node_execution.id)
except Exception:
logger.exception("[Tencent APM] Failed to process workflow nodes")
def _build_workflow_node_span(
self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int
) -> SpanData | None:
"""Build span for different node types"""
try:
if node_execution.node_type == NodeType.LLM:
return TencentSpanBuilder.build_workflow_llm_span(
trace_id, workflow_span_id, trace_info, node_execution
)
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return TencentSpanBuilder.build_workflow_retrieval_span(
trace_id, workflow_span_id, trace_info, node_execution
)
elif node_execution.node_type == NodeType.TOOL:
return TencentSpanBuilder.build_workflow_tool_span(
trace_id, workflow_span_id, trace_info, node_execution
)
else:
# Handle all other node types as generic tasks
return TencentSpanBuilder.build_workflow_task_span(
trace_id, workflow_span_id, trace_info, node_execution
)
except Exception:
logger.debug(
"[Tencent APM] Error building span for node %s: %s",
node_execution.id,
node_execution.node_type,
exc_info=True,
)
return None
def _get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> list[WorkflowNodeExecution]:
"""Retrieve workflow node executions from database."""
try:
session_maker = sessionmaker(bind=db.engine)
with Session(db.engine, expire_on_commit=False) as session:
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator")
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account not found for app {app_id}")
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")
service_account.set_tenant_id(current_tenant.tenant_id)
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_maker,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
executions = repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id)
return list(executions)
except Exception:
logger.exception("[Tencent APM] Failed to get workflow node executions")
return []
def _get_user_id(self, trace_info: BaseTraceInfo) -> str:
"""Get user ID from trace info."""
try:
tenant_id = None
user_id = None
if isinstance(trace_info, (WorkflowTraceInfo, GenerateNameTraceInfo)):
tenant_id = trace_info.tenant_id
if hasattr(trace_info, "metadata") and trace_info.metadata:
user_id = trace_info.metadata.get("user_id")
if user_id and tenant_id:
stmt = (
select(Account.name)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.where(Account.id == user_id, TenantAccountJoin.tenant_id == tenant_id)
)
session_maker = sessionmaker(bind=db.engine)
with session_maker() as session:
account_name = session.scalar(stmt)
return account_name or str(user_id)
elif user_id:
return str(user_id)
return "anonymous"
except Exception:
logger.exception("[Tencent APM] Failed to get user ID")
return "unknown"
def _record_llm_metrics(self, node_execution: WorkflowNodeExecution) -> None:
"""Record LLM performance metrics"""
try:
if not hasattr(self.trace_client, "record_llm_duration"):
return
process_data = node_execution.process_data or {}
usage = process_data.get("usage", {})
latency_s = float(usage.get("latency", 0.0))
if latency_s > 0:
attributes = {
"provider": process_data.get("model_provider", ""),
"model": process_data.get("model_name", ""),
"span_kind": "GENERATION",
}
self.trace_client.record_llm_duration(latency_s, attributes)
except Exception:
logger.debug("[Tencent APM] Failed to record LLM metrics")
def __del__(self):
"""Ensure proper cleanup on garbage collection."""
try:
if hasattr(self, "trace_client"):
self.trace_client.shutdown()
except Exception:
pass

View File

@ -0,0 +1,65 @@
"""
Utility functions for Tencent APM tracing
"""
import hashlib
import random
import uuid
from datetime import datetime
from typing import cast
from opentelemetry.trace import Link, SpanContext, TraceFlags
class TencentTraceUtils:
"""Utility class for common tracing operations."""
INVALID_SPAN_ID = 0x0000000000000000
INVALID_TRACE_ID = 0x00000000000000000000000000000000
@staticmethod
def convert_to_trace_id(uuid_v4: str | None) -> int:
try:
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
return cast(int, uuid_obj.int)
@staticmethod
def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int:
try:
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
combined_key = f"{uuid_obj.hex}-{span_type}"
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
return int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
@staticmethod
def generate_span_id() -> int:
span_id = random.getrandbits(64)
while span_id == TencentTraceUtils.INVALID_SPAN_ID:
span_id = random.getrandbits(64)
return span_id
@staticmethod
def convert_datetime_to_nanoseconds(start_time: datetime | None) -> int:
if start_time is None:
start_time = datetime.now()
timestamp_in_seconds = start_time.timestamp()
return int(timestamp_in_seconds * 1e9)
@staticmethod
def create_link(trace_id_str: str) -> Link:
try:
trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else cast(int, uuid.UUID(trace_id_str).int)
except (ValueError, TypeError):
trace_id = cast(int, uuid.uuid4().int)
span_context = SpanContext(
trace_id=trace_id,
span_id=TencentTraceUtils.INVALID_SPAN_ID,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
return Link(span_context)

View File

@ -2,7 +2,7 @@ import inspect
import json import json
import logging import logging
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
from typing import Any, TypeVar from typing import Any, TypeVar, cast
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
@ -31,6 +31,17 @@ from core.plugin.impl.exc import (
) )
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
_plugin_daemon_timeout_config = cast(
float | httpx.Timeout | None,
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0),
)
plugin_daemon_request_timeout: httpx.Timeout | None
if _plugin_daemon_timeout_config is None:
plugin_daemon_request_timeout = None
elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout):
plugin_daemon_request_timeout = _plugin_daemon_timeout_config
else:
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
@ -58,6 +69,7 @@ class BasePluginClient:
"headers": headers, "headers": headers,
"params": params, "params": params,
"files": files, "files": files,
"timeout": plugin_daemon_request_timeout,
} }
if isinstance(prepared_data, dict): if isinstance(prepared_data, dict):
request_kwargs["data"] = prepared_data request_kwargs["data"] = prepared_data
@ -116,6 +128,7 @@ class BasePluginClient:
"headers": headers, "headers": headers,
"params": params, "params": params,
"files": files, "files": files,
"timeout": plugin_daemon_request_timeout,
} }
if isinstance(prepared_data, dict): if isinstance(prepared_data, dict):
stream_kwargs["data"] = prepared_data stream_kwargs["data"] = prepared_data

View File

@ -1,9 +1,24 @@
"""
Weaviate vector database implementation for Dify's RAG system.
This module provides integration with Weaviate vector database for storing and retrieving
document embeddings used in retrieval-augmented generation workflows.
"""
import datetime import datetime
import json import json
import logging
import uuid as _uuid
from typing import Any from typing import Any
from urllib.parse import urlparse
import weaviate # type: ignore import weaviate
import weaviate.classes.config as wc
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from weaviate.classes.data import DataObject
from weaviate.classes.init import Auth
from weaviate.classes.query import Filter, MetadataQuery
from weaviate.exceptions import UnexpectedStatusCodeError
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
@ -15,265 +30,394 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset
logger = logging.getLogger(__name__)
class WeaviateConfig(BaseModel): class WeaviateConfig(BaseModel):
"""
Configuration model for Weaviate connection settings.
Attributes:
endpoint: Weaviate server endpoint URL
api_key: Optional API key for authentication
batch_size: Number of objects to batch per insert operation
"""
endpoint: str endpoint: str
api_key: str | None = None api_key: str | None = None
batch_size: int = 100 batch_size: int = 100
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_config(cls, values: dict): def validate_config(cls, values: dict) -> dict:
"""Validates that required configuration values are present."""
if not values["endpoint"]: if not values["endpoint"]:
raise ValueError("config WEAVIATE_ENDPOINT is required") raise ValueError("config WEAVIATE_ENDPOINT is required")
return values return values
class WeaviateVector(BaseVector): class WeaviateVector(BaseVector):
"""
Weaviate vector database implementation for document storage and retrieval.
Handles creation, insertion, deletion, and querying of document embeddings
in a Weaviate collection.
"""
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list): def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
"""
Initializes the Weaviate vector store.
Args:
collection_name: Name of the Weaviate collection
config: Weaviate configuration settings
attributes: List of metadata attributes to store
"""
super().__init__(collection_name) super().__init__(collection_name)
self._client = self._init_client(config) self._client = self._init_client(config)
self._attributes = attributes self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.Client: def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
auth_config = weaviate.AuthApiKey(api_key=config.api_key or "") """
Initializes and returns a connected Weaviate client.
weaviate.connect.connection.has_grpc = False # ty: ignore [unresolved-attribute] Configures both HTTP and gRPC connections with proper authentication.
"""
p = urlparse(config.endpoint)
host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "")
http_secure = p.scheme == "https"
http_port = p.port or (443 if http_secure else 80)
try: grpc_host = host
client = weaviate.Client( grpc_secure = http_secure
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None grpc_port = 443 if grpc_secure else 50051
)
except Exception as exc:
raise ConnectionError("Vector database connection error") from exc
client.batch.configure( client = weaviate.connect_to_custom(
# `batch_size` takes an `int` value to enable auto-batching http_host=host,
# (`None` is used for manual batching) http_port=http_port,
batch_size=config.batch_size, http_secure=http_secure,
# dynamically update the `batch_size` based on import speed grpc_host=grpc_host,
dynamic=True, grpc_port=grpc_port,
# `timeout_retries` takes an `int` value to retry on time outs grpc_secure=grpc_secure,
timeout_retries=3, auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
) )
if not client.is_ready():
raise ConnectionError("Vector database is not ready")
return client return client
def get_type(self) -> str: def get_type(self) -> str:
"""Returns the vector database type identifier."""
return VectorType.WEAVIATE return VectorType.WEAVIATE
def get_collection_name(self, dataset: Dataset) -> str: def get_collection_name(self, dataset: Dataset) -> str:
"""
Retrieves or generates the collection name for a dataset.
Uses existing index structure if available, otherwise generates from dataset ID.
"""
if dataset.index_struct_dict: if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
if not class_prefix.endswith("_Node"): if not class_prefix.endswith("_Node"):
# original class_prefix
class_prefix += "_Node" class_prefix += "_Node"
return class_prefix return class_prefix
dataset_id = dataset.id dataset_id = dataset.id
return Dataset.gen_collection_name_by_id(dataset_id) return Dataset.gen_collection_name_by_id(dataset_id)
def to_index_struct(self): def to_index_struct(self) -> dict:
"""Returns the index structure dictionary for persistence."""
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
# create collection """
Creates a new collection and adds initial documents with embeddings.
"""
self._create_collection() self._create_collection()
# create vector
self.add_texts(texts, embeddings) self.add_texts(texts, embeddings)
def _create_collection(self): def _create_collection(self):
"""
Creates the Weaviate collection with required schema if it doesn't exist.
Uses Redis locking to prevent concurrent creation attempts.
"""
lock_name = f"vector_indexing_lock_{self._collection_name}" lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20): with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}" cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key): if redis_client.get(cache_key):
return return
schema = self._default_schema(self._collection_name)
if not self._client.schema.contains(schema): try:
# create collection if not self._client.collections.exists(self._collection_name):
self._client.schema.create_class(schema) self._client.collections.create(
redis_client.set(collection_exist_cache_key, 1, ex=3600) name=self._collection_name,
properties=[
wc.Property(
name=Field.TEXT_KEY.value,
data_type=wc.DataType.TEXT,
tokenization=wc.Tokenization.WORD,
),
wc.Property(name="document_id", data_type=wc.DataType.TEXT),
wc.Property(name="doc_id", data_type=wc.DataType.TEXT),
wc.Property(name="chunk_index", data_type=wc.DataType.INT),
],
vector_config=wc.Configure.Vectors.self_provided(),
)
self._ensure_properties()
redis_client.set(cache_key, 1, ex=3600)
except Exception as e:
logger.exception("Error creating collection %s", self._collection_name)
raise
def _ensure_properties(self) -> None:
"""
Ensures all required properties exist in the collection schema.
Adds missing properties if the collection exists but lacks them.
"""
if not self._client.collections.exists(self._collection_name):
return
col = self._client.collections.use(self._collection_name)
cfg = col.config.get()
existing = {p.name for p in (cfg.properties or [])}
to_add = []
if "document_id" not in existing:
to_add.append(wc.Property(name="document_id", data_type=wc.DataType.TEXT))
if "doc_id" not in existing:
to_add.append(wc.Property(name="doc_id", data_type=wc.DataType.TEXT))
if "chunk_index" not in existing:
to_add.append(wc.Property(name="chunk_index", data_type=wc.DataType.INT))
for prop in to_add:
try:
col.config.add_property(prop)
except Exception as e:
logger.warning("Could not add property %s: %s", prop.name, e)
def _get_uuids(self, documents: list[Document]) -> list[str]:
"""
Generates deterministic UUIDs for documents based on their content.
Uses UUID5 with URL namespace to ensure consistent IDs for identical content.
"""
URL_NAMESPACE = _uuid.UUID("6ba7b811-9dad-11d1-80b4-00c04fd430c8")
uuids = []
for doc in documents:
uuid_val = _uuid.uuid5(URL_NAMESPACE, doc.page_content)
uuids.append(str(uuid_val))
return uuids
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""
Adds documents with their embeddings to the collection.
Batches insertions for efficiency and returns the list of inserted object IDs.
"""
uuids = self._get_uuids(documents) uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents] texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents] metadatas = [d.metadata for d in documents]
ids = [] col = self._client.collections.use(self._collection_name)
objs: list[DataObject] = []
ids_out: list[str] = []
with self._client.batch as batch: for i, text in enumerate(texts):
for i, text in enumerate(texts): props: dict[str, Any] = {Field.TEXT_KEY.value: text}
data_properties = {Field.TEXT_KEY: text} meta = metadatas[i] or {}
if metadatas is not None: for k, v in meta.items():
# metadata maybe None props[k] = self._json_serializable(v)
for key, val in (metadatas[i] or {}).items():
data_properties[key] = self._json_serializable(val)
batch.add_data_object( candidate = uuids[i] if uuids else None
data_object=data_properties, uid = candidate if (candidate and self._is_uuid(candidate)) else str(_uuid.uuid4())
class_name=self._collection_name, ids_out.append(uid)
uuid=uuids[i],
vector=embeddings[i] if embeddings else None, vec_payload = None
if embeddings and i < len(embeddings) and embeddings[i]:
vec_payload = {"default": embeddings[i]}
objs.append(
DataObject(
uuid=uid,
properties=props, # type: ignore[arg-type] # mypy incorrectly infers DataObject signature
vector=vec_payload,
) )
ids.append(uuids[i]) )
return ids
def delete_by_metadata_field(self, key: str, value: str): batch_size = max(1, int(dify_config.WEAVIATE_BATCH_SIZE or 100))
# check whether the index already exists with col.batch.dynamic() as batch:
schema = self._default_schema(self._collection_name) for obj in objs:
if self._client.schema.contains(schema): batch.add_object(properties=obj.properties, uuid=obj.uuid, vector=obj.vector)
where_filter = {"operator": "Equal", "path": [key], "valueText": value}
self._client.batch.delete_objects(class_name=self._collection_name, where=where_filter, output="minimal") return ids_out
def _is_uuid(self, val: str) -> bool:
"""Validates whether a string is a valid UUID format."""
try:
_uuid.UUID(str(val))
return True
except Exception:
return False
def delete_by_metadata_field(self, key: str, value: str) -> None:
"""Deletes all objects matching a specific metadata field value."""
if not self._client.collections.exists(self._collection_name):
return
col = self._client.collections.use(self._collection_name)
col.data.delete_many(where=Filter.by_property(key).equal(value))
def delete(self): def delete(self):
# check whether the index already exists """Deletes the entire collection from Weaviate."""
schema = self._default_schema(self._collection_name) if self._client.collections.exists(self._collection_name):
if self._client.schema.contains(schema): self._client.collections.delete(self._collection_name)
self._client.schema.delete_class(self._collection_name)
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
collection_name = self._collection_name """Checks if a document with the given doc_id exists in the collection."""
schema = self._default_schema(self._collection_name) if not self._client.collections.exists(self._collection_name):
# check whether the index already exists
if not self._client.schema.contains(schema):
return False return False
result = (
self._client.query.get(collection_name) col = self._client.collections.use(self._collection_name)
.with_additional(["id"]) res = col.query.fetch_objects(
.with_where( filters=Filter.by_property("doc_id").equal(id),
{ limit=1,
"path": ["doc_id"], return_properties=["doc_id"],
"operator": "Equal",
"valueText": id,
}
)
.with_limit(1)
.do()
) )
if "errors" in result: return len(res.objects) > 0
raise ValueError(f"Error during query: {result['errors']}")
entries = result["data"]["Get"][collection_name] def delete_by_ids(self, ids: list[str]) -> None:
if len(entries) == 0: """
return False Deletes objects by their UUID identifiers.
return True Silently ignores 404 errors for non-existent IDs.
"""
if not self._client.collections.exists(self._collection_name):
return
def delete_by_ids(self, ids: list[str]): col = self._client.collections.use(self._collection_name)
# check whether the index already exists
schema = self._default_schema(self._collection_name) for uid in ids:
if self._client.schema.contains(schema): try:
for uuid in ids: col.data.delete_by_id(uid)
try: except UnexpectedStatusCodeError as e:
self._client.data_object.delete( if getattr(e, "status_code", None) != 404:
class_name=self._collection_name, raise
uuid=uuid,
)
except weaviate.UnexpectedStatusCodeException as e:
# tolerate not found error
if e.status_code != 404:
raise e
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Look up similar documents by embedding vector in Weaviate.""" """
collection_name = self._collection_name Performs vector similarity search using the provided query vector.
properties = self._attributes
properties.append(Field.TEXT_KEY)
query_obj = self._client.query.get(collection_name, properties)
vector = {"vector": query_vector} Filters by document IDs if provided and applies score threshold.
document_ids_filter = kwargs.get("document_ids_filter") Returns documents sorted by relevance score.
if document_ids_filter: """
operands = [] if not self._client.collections.exists(self._collection_name):
for document_id_filter in document_ids_filter: return []
operands.append({"path": ["document_id"], "operator": "Equal", "valueText": document_id_filter})
where_filter = {"operator": "Or", "operands": operands} col = self._client.collections.use(self._collection_name)
query_obj = query_obj.with_where(where_filter) props = list({*self._attributes, "document_id", Field.TEXT_KEY.value})
result = (
query_obj.with_near_vector(vector) where = None
.with_limit(kwargs.get("top_k", 4)) doc_ids = kwargs.get("document_ids_filter") or []
.with_additional(["vector", "distance"]) if doc_ids:
.do() ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
where = ors[0]
for f in ors[1:]:
where = where | f
top_k = int(kwargs.get("top_k", 4))
score_threshold = float(kwargs.get("score_threshold") or 0.0)
res = col.query.near_vector(
near_vector=query_vector,
limit=top_k,
return_properties=props,
return_metadata=MetadataQuery(distance=True),
include_vector=False,
filters=where,
target_vector="default",
) )
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs_and_scores = [] docs: list[Document] = []
for res in result["data"]["Get"][collection_name]: for obj in res.objects:
text = res.pop(Field.TEXT_KEY) properties = dict(obj.properties or {})
score = 1 - res["_additional"]["distance"] text = properties.pop(Field.TEXT_KEY.value, "")
docs_and_scores.append((Document(page_content=text, metadata=res), score)) distance = (obj.metadata.distance if obj.metadata else None) or 1.0
score = 1.0 - distance
docs = [] if score > score_threshold:
for doc, score in docs_and_scores: properties["score"] = score
score_threshold = float(kwargs.get("score_threshold") or 0.0) docs.append(Document(page_content=text, metadata=properties))
# check score threshold
if score >= score_threshold: docs.sort(key=lambda d: d.metadata.get("score", 0.0), reverse=True)
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
return docs return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs using BM25F.
Args:
query: Text to look up documents similar to.
Returns:
List of Documents most similar to the query.
""" """
collection_name = self._collection_name Performs BM25 full-text search on document content.
content: dict[str, Any] = {"concepts": [query]}
properties = self._attributes Filters by document IDs if provided and returns matching documents with vectors.
properties.append(Field.TEXT_KEY) """
if kwargs.get("search_distance"): if not self._client.collections.exists(self._collection_name):
content["certainty"] = kwargs.get("search_distance") return []
query_obj = self._client.query.get(collection_name, properties)
document_ids_filter = kwargs.get("document_ids_filter") col = self._client.collections.use(self._collection_name)
if document_ids_filter: props = list({*self._attributes, Field.TEXT_KEY.value})
operands = []
for document_id_filter in document_ids_filter: where = None
operands.append({"path": ["document_id"], "operator": "Equal", "valueText": document_id_filter}) doc_ids = kwargs.get("document_ids_filter") or []
where_filter = {"operator": "Or", "operands": operands} if doc_ids:
query_obj = query_obj.with_where(where_filter) ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
query_obj = query_obj.with_additional(["vector"]) where = ors[0]
properties = ["text"] for f in ors[1:]:
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do() where = where | f
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}") top_k = int(kwargs.get("top_k", 4))
docs = []
for res in result["data"]["Get"][collection_name]: res = col.query.bm25(
text = res.pop(Field.TEXT_KEY) query=query,
additional = res.pop("_additional") query_properties=[Field.TEXT_KEY.value],
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res)) limit=top_k,
return_properties=props,
include_vector=True,
filters=where,
)
docs: list[Document] = []
for obj in res.objects:
properties = dict(obj.properties or {})
text = properties.pop(Field.TEXT_KEY.value, "")
vec = obj.vector
if isinstance(vec, dict):
vec = vec.get("default") or next(iter(vec.values()), None)
docs.append(Document(page_content=text, vector=vec, metadata=properties))
return docs return docs
def _default_schema(self, index_name: str): def _json_serializable(self, value: Any) -> Any:
return { """Converts values to JSON-serializable format, handling datetime objects."""
"class": index_name,
"properties": [
{
"name": "text",
"dataType": ["text"],
}
],
}
def _json_serializable(self, value: Any):
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
return value.isoformat() return value.isoformat()
return value return value
class WeaviateVectorFactory(AbstractVectorFactory): class WeaviateVectorFactory(AbstractVectorFactory):
"""Factory class for creating WeaviateVector instances."""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector: def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
"""
Initializes a WeaviateVector instance for the given dataset.
Uses existing collection name from dataset index structure or generates a new one.
Updates dataset index structure if not already set.
"""
if dataset.index_struct_dict: if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix collection_name = class_prefix
@ -281,7 +425,6 @@ class WeaviateVectorFactory(AbstractVectorFactory):
dataset_id = dataset.id dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
return WeaviateVector( return WeaviateVector(
collection_name=collection_name, collection_name=collection_name,
config=WeaviateConfig( config=WeaviateConfig(

View File

@ -43,8 +43,7 @@ class CacheEmbedding(Embeddings):
else: else:
embedding_queue_indices.append(i) embedding_queue_indices.append(i)
# release database connection, because embedding may take a long time # NOTE: avoid closing the shared scoped session here; downstream code may still have pending work
db.session.close()
if embedding_queue_indices: if embedding_queue_indices:
embedding_queue_texts = [texts[i] for i in embedding_queue_indices] embedding_queue_texts = [texts[i] for i in embedding_queue_indices]

View File

@ -189,6 +189,11 @@ class ToolInvokeMessage(BaseModel):
data: Mapping[str, Any] = Field(..., description="Detailed log data") data: Mapping[str, Any] = Field(..., description="Detailed log data")
metadata: Mapping[str, Any] = Field(default_factory=dict, description="The metadata of the log") metadata: Mapping[str, Any] = Field(default_factory=dict, description="The metadata of the log")
@field_validator("metadata", mode="before")
@classmethod
def _normalize_metadata(cls, value: Mapping[str, Any] | None) -> Mapping[str, Any]:
return value or {}
class RetrieverResourceMessage(BaseModel): class RetrieverResourceMessage(BaseModel):
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources") retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context") context: str = Field(..., description="context")
@ -377,6 +382,11 @@ class ToolEntity(BaseModel):
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
return v or [] return v or []
@field_validator("output_schema", mode="before")
@classmethod
def _normalize_output_schema(cls, value: Mapping[str, object] | None) -> Mapping[str, object]:
return value or {}
class OAuthSchema(BaseModel): class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field( client_schema: list[ProviderConfig] = Field(

View File

@ -1,10 +1,11 @@
import logging import logging
import time as time_module import time as time_module
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any, cast
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import update from sqlalchemy import update
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
@ -267,7 +268,7 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]
# Build and execute the update statement # Build and execute the update statement
stmt = update(Provider).where(*where_conditions).values(**update_values) stmt = update(Provider).where(*where_conditions).values(**update_values)
result = session.execute(stmt) result = cast(CursorResult, session.execute(stmt))
rows_affected = result.rowcount rows_affected = result.rowcount
logger.debug( logger.debug(

View File

@ -64,7 +64,10 @@ def build_from_mapping(
config: FileUploadConfig | None = None, config: FileUploadConfig | None = None,
strict_type_validation: bool = False, strict_type_validation: bool = False,
) -> File: ) -> File:
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method")) transfer_method_value = mapping.get("transfer_method")
if not transfer_method_value:
raise ValueError("transfer_method is required in file mapping")
transfer_method = FileTransferMethod.value_of(transfer_method_value)
build_functions: dict[FileTransferMethod, Callable] = { build_functions: dict[FileTransferMethod, Callable] = {
FileTransferMethod.LOCAL_FILE: _build_from_local_file, FileTransferMethod.LOCAL_FILE: _build_from_local_file,
@ -104,6 +107,8 @@ def build_from_mappings(
) -> Sequence[File]: ) -> Sequence[File]:
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query. # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
# Implement batch processing to reduce database load when handling multiple files. # Implement batch processing to reduce database load when handling multiple files.
# Filter out None/empty mappings to avoid errors
valid_mappings = [m for m in mappings if m and m.get("transfer_method")]
files = [ files = [
build_from_mapping( build_from_mapping(
mapping=mapping, mapping=mapping,
@ -111,7 +116,7 @@ def build_from_mappings(
config=config, config=config,
strict_type_validation=strict_type_validation, strict_type_validation=strict_type_validation,
) )
for mapping in mappings for mapping in valid_mappings
] ]
if ( if (

View File

@ -13,6 +13,15 @@ from models.model import EndUser
#: A proxy for the current user. If no user is logged in, this will be an #: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user #: anonymous user
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user())) current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
def current_account_with_tenant():
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
return current_user, current_user.current_tenant_id
from typing import ParamSpec, TypeVar from typing import ParamSpec, TypeVar
P = ParamSpec("P") P = ParamSpec("P")

View File

@ -13,7 +13,7 @@ dependencies = [
"celery~=5.5.2", "celery~=5.5.2",
"chardet~=5.1.0", "chardet~=5.1.0",
"flask~=3.1.2", "flask~=3.1.2",
"flask-compress~=1.17", "flask-compress>=1.17,<1.18",
"flask-cors~=6.0.0", "flask-cors~=6.0.0",
"flask-login~=0.6.3", "flask-login~=0.6.3",
"flask-migrate~=4.0.7", "flask-migrate~=4.0.7",
@ -87,6 +87,7 @@ dependencies = [
"flask-restx~=1.3.0", "flask-restx~=1.3.0",
"packaging~=23.2", "packaging~=23.2",
"croniter>=6.0.0", "croniter>=6.0.0",
"weaviate-client==4.17.0",
] ]
# Before adding new dependency, consider place it in # Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group. # alphabet order (a-z) and suitable group.
@ -215,7 +216,7 @@ vdb = [
"tidb-vector==0.0.9", "tidb-vector==0.0.9",
"upstash-vector==0.6.0", "upstash-vector==0.6.0",
"volcengine-compat~=1.0.0", "volcengine-compat~=1.0.0",
"weaviate-client~=3.24.0", "weaviate-client>=4.0.0,<5.0.0",
"xinference-client~=1.2.2", "xinference-client~=1.2.2",
"mo-vector~=0.1.13", "mo-vector~=0.1.13",
"mysql-connector-python>=9.3.0", "mysql-connector-python>=9.3.0",

View File

@ -7,8 +7,10 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
from collections.abc import Sequence from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from typing import cast
from sqlalchemy import asc, delete, desc, select from sqlalchemy import asc, delete, desc, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from models.workflow import WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionModel
@ -181,7 +183,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
# Delete the batch # Delete the batch
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
result = session.execute(delete_stmt) result = cast(CursorResult, session.execute(delete_stmt))
session.commit() session.commit()
total_deleted += result.rowcount total_deleted += result.rowcount
@ -228,7 +230,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
# Delete the batch # Delete the batch
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
result = session.execute(delete_stmt) result = cast(CursorResult, session.execute(delete_stmt))
session.commit() session.commit()
total_deleted += result.rowcount total_deleted += result.rowcount
@ -285,6 +287,6 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
with self._session_maker() as session: with self._session_maker() as session:
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
result = session.execute(stmt) result = cast(CursorResult, session.execute(stmt))
session.commit() session.commit()
return result.rowcount return result.rowcount

View File

@ -22,8 +22,10 @@ Implementation Notes:
import logging import logging
from collections.abc import Sequence from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from typing import cast
from sqlalchemy import delete, select from sqlalchemy import delete, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
@ -150,7 +152,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
with self._session_maker() as session: with self._session_maker() as session:
stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)) stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
result = session.execute(stmt) result = cast(CursorResult, session.execute(stmt))
session.commit() session.commit()
deleted_count = result.rowcount deleted_count = result.rowcount
@ -186,7 +188,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
# Delete the batch # Delete the batch
delete_stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)) delete_stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
result = session.execute(delete_stmt) result = cast(CursorResult, session.execute(delete_stmt))
session.commit() session.commit()
batch_deleted = result.rowcount batch_deleted = result.rowcount

View File

@ -1,8 +1,11 @@
import datetime import datetime
import logging import logging
import time import time
from collections.abc import Sequence
import click import click
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
import app import app
from configs import dify_config from configs import dify_config
@ -35,50 +38,53 @@ def clean_workflow_runlogs_precise():
retention_days = dify_config.WORKFLOW_LOG_RETENTION_DAYS retention_days = dify_config.WORKFLOW_LOG_RETENTION_DAYS
cutoff_date = datetime.datetime.now() - datetime.timedelta(days=retention_days) cutoff_date = datetime.datetime.now() - datetime.timedelta(days=retention_days)
session_factory = sessionmaker(db.engine, expire_on_commit=False)
try: try:
total_workflow_runs = db.session.query(WorkflowRun).where(WorkflowRun.created_at < cutoff_date).count() with session_factory.begin() as session:
if total_workflow_runs == 0: total_workflow_runs = session.query(WorkflowRun).where(WorkflowRun.created_at < cutoff_date).count()
logger.info("No expired workflow run logs found") if total_workflow_runs == 0:
return logger.info("No expired workflow run logs found")
logger.info("Found %s expired workflow run logs to clean", total_workflow_runs) return
logger.info("Found %s expired workflow run logs to clean", total_workflow_runs)
total_deleted = 0 total_deleted = 0
failed_batches = 0 failed_batches = 0
batch_count = 0 batch_count = 0
while True: while True:
workflow_runs = ( with session_factory.begin() as session:
db.session.query(WorkflowRun.id).where(WorkflowRun.created_at < cutoff_date).limit(BATCH_SIZE).all() workflow_run_ids = session.scalars(
) select(WorkflowRun.id)
.where(WorkflowRun.created_at < cutoff_date)
.order_by(WorkflowRun.created_at, WorkflowRun.id)
.limit(BATCH_SIZE)
).all()
if not workflow_runs: if not workflow_run_ids:
break
workflow_run_ids = [run.id for run in workflow_runs]
batch_count += 1
success = _delete_batch_with_retry(workflow_run_ids, failed_batches)
if success:
total_deleted += len(workflow_run_ids)
failed_batches = 0
else:
failed_batches += 1
if failed_batches >= MAX_RETRIES:
logger.error("Failed to delete batch after %s retries, aborting cleanup for today", MAX_RETRIES)
break break
batch_count += 1
success = _delete_batch(session, workflow_run_ids, failed_batches)
if success:
total_deleted += len(workflow_run_ids)
failed_batches = 0
else: else:
# Calculate incremental delay times: 5, 10, 15 minutes failed_batches += 1
retry_delay_minutes = failed_batches * 5 if failed_batches >= MAX_RETRIES:
logger.warning("Batch deletion failed, retrying in %s minutes...", retry_delay_minutes) logger.error("Failed to delete batch after %s retries, aborting cleanup for today", MAX_RETRIES)
time.sleep(retry_delay_minutes * 60) break
continue else:
# Calculate incremental delay times: 5, 10, 15 minutes
retry_delay_minutes = failed_batches * 5
logger.warning("Batch deletion failed, retrying in %s minutes...", retry_delay_minutes)
time.sleep(retry_delay_minutes * 60)
continue
logger.info("Cleanup completed: %s expired workflow run logs deleted", total_deleted) logger.info("Cleanup completed: %s expired workflow run logs deleted", total_deleted)
except Exception: except Exception:
db.session.rollback()
logger.exception("Unexpected error in workflow log cleanup") logger.exception("Unexpected error in workflow log cleanup")
raise raise
@ -87,69 +93,56 @@ def clean_workflow_runlogs_precise():
click.echo(click.style(f"Cleaned workflow run logs from db success latency: {execution_time:.2f}s", fg="green")) click.echo(click.style(f"Cleaned workflow run logs from db success latency: {execution_time:.2f}s", fg="green"))
def _delete_batch_with_retry(workflow_run_ids: list[str], attempt_count: int) -> bool: def _delete_batch(session: Session, workflow_run_ids: Sequence[str], attempt_count: int) -> bool:
"""Delete a single batch with a retry mechanism and complete cascading deletion""" """Delete a single batch of workflow runs and all related data within a nested transaction."""
try: try:
with db.session.begin_nested(): with session.begin_nested():
message_data = ( message_data = (
db.session.query(Message.id, Message.conversation_id) session.query(Message.id, Message.conversation_id)
.where(Message.workflow_run_id.in_(workflow_run_ids)) .where(Message.workflow_run_id.in_(workflow_run_ids))
.all() .all()
) )
message_id_list = [msg.id for msg in message_data] message_id_list = [msg.id for msg in message_data]
conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id}) conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id})
if message_id_list: if message_id_list:
db.session.query(AppAnnotationHitHistory).where( message_related_models = [
AppAnnotationHitHistory.message_id.in_(message_id_list) AppAnnotationHitHistory,
).delete(synchronize_session=False) MessageAgentThought,
MessageChain,
MessageFile,
MessageAnnotation,
MessageFeedback,
]
for model in message_related_models:
session.query(model).where(model.message_id.in_(message_id_list)).delete(synchronize_session=False) # type: ignore
# error: "DeclarativeAttributeIntercept" has no attribute "message_id". But this type is only in lib
# and these 6 types all have the message_id field.
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_id_list)).delete( session.query(Message).where(Message.workflow_run_id.in_(workflow_run_ids)).delete(
synchronize_session=False synchronize_session=False
) )
db.session.query(MessageChain).where(MessageChain.message_id.in_(message_id_list)).delete( session.query(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)).delete(
synchronize_session=False
)
db.session.query(MessageFile).where(MessageFile.message_id.in_(message_id_list)).delete(
synchronize_session=False
)
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_id_list)).delete(
synchronize_session=False
)
db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_id_list)).delete(
synchronize_session=False
)
db.session.query(Message).where(Message.workflow_run_id.in_(workflow_run_ids)).delete(
synchronize_session=False
)
db.session.query(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)).delete(
synchronize_session=False synchronize_session=False
) )
db.session.query(WorkflowNodeExecutionModel).where( session.query(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids) WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids)
).delete(synchronize_session=False) ).delete(synchronize_session=False)
if conversation_id_list: if conversation_id_list:
db.session.query(ConversationVariable).where( session.query(ConversationVariable).where(
ConversationVariable.conversation_id.in_(conversation_id_list) ConversationVariable.conversation_id.in_(conversation_id_list)
).delete(synchronize_session=False) ).delete(synchronize_session=False)
db.session.query(Conversation).where(Conversation.id.in_(conversation_id_list)).delete( session.query(Conversation).where(Conversation.id.in_(conversation_id_list)).delete(
synchronize_session=False synchronize_session=False
) )
db.session.query(WorkflowRun).where(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False) session.query(WorkflowRun).where(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False)
db.session.commit() return True
return True
except Exception: except Exception:
db.session.rollback()
logger.exception("Batch deletion failed (attempt %s)", attempt_count + 1) logger.exception("Batch deletion failed (attempt %s)", attempt_count + 1)
return False return False

View File

@ -8,8 +8,7 @@ from werkzeug.exceptions import NotFound
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import current_user from libs.login import current_account_with_tenant
from models.account import Account
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
from services.feature_service import FeatureService from services.feature_service import FeatureService
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
@ -24,10 +23,10 @@ class AppAnnotationService:
@classmethod @classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info # get app info
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -63,12 +62,12 @@ class AppAnnotationService:
db.session.commit() db.session.commit()
# if annotation reply is enabled , add annotation to index # if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
assert current_user.current_tenant_id is not None assert current_tenant_id is not None
if annotation_setting: if annotation_setting:
add_annotation_to_index_task.delay( add_annotation_to_index_task.delay(
annotation.id, annotation.id,
args["question"], args["question"],
current_user.current_tenant_id, current_tenant_id,
app_id, app_id,
annotation_setting.collection_binding_id, annotation_setting.collection_binding_id,
) )
@ -86,13 +85,12 @@ class AppAnnotationService:
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
# send batch add segments task # send batch add segments task
redis_client.setnx(enable_app_annotation_job_key, "waiting") redis_client.setnx(enable_app_annotation_job_key, "waiting")
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
enable_annotation_reply_task.delay( enable_annotation_reply_task.delay(
str(job_id), str(job_id),
app_id, app_id,
current_user.id, current_user.id,
current_user.current_tenant_id, current_tenant_id,
args["score_threshold"], args["score_threshold"],
args["embedding_provider_name"], args["embedding_provider_name"],
args["embedding_model_name"], args["embedding_model_name"],
@ -101,8 +99,7 @@ class AppAnnotationService:
@classmethod @classmethod
def disable_app_annotation(cls, app_id: str): def disable_app_annotation(cls, app_id: str):
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(disable_app_annotation_key) cache_result = redis_client.get(disable_app_annotation_key)
if cache_result is not None: if cache_result is not None:
@ -113,17 +110,16 @@ class AppAnnotationService:
disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}" disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
# send batch add segments task # send batch add segments task
redis_client.setnx(disable_app_annotation_job_key, "waiting") redis_client.setnx(disable_app_annotation_job_key, "waiting")
disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id) disable_annotation_reply_task.delay(str(job_id), app_id, current_tenant_id)
return {"job_id": job_id, "job_status": "waiting"} return {"job_id": job_id, "job_status": "waiting"}
@classmethod @classmethod
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
# get app info # get app info
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -153,11 +149,10 @@ class AppAnnotationService:
@classmethod @classmethod
def export_annotation_list_by_app_id(cls, app_id: str): def export_annotation_list_by_app_id(cls, app_id: str):
# get app info # get app info
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -174,11 +169,10 @@ class AppAnnotationService:
@classmethod @classmethod
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info # get app info
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -196,7 +190,7 @@ class AppAnnotationService:
add_annotation_to_index_task.delay( add_annotation_to_index_task.delay(
annotation.id, annotation.id,
args["question"], args["question"],
current_user.current_tenant_id, current_tenant_id,
app_id, app_id,
annotation_setting.collection_binding_id, annotation_setting.collection_binding_id,
) )
@ -205,11 +199,10 @@ class AppAnnotationService:
@classmethod @classmethod
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
# get app info # get app info
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -234,7 +227,7 @@ class AppAnnotationService:
update_annotation_to_index_task.delay( update_annotation_to_index_task.delay(
annotation.id, annotation.id,
annotation.question, annotation.question,
current_user.current_tenant_id, current_tenant_id,
app_id, app_id,
app_annotation_setting.collection_binding_id, app_annotation_setting.collection_binding_id,
) )
@ -244,11 +237,10 @@ class AppAnnotationService:
@classmethod @classmethod
def delete_app_annotation(cls, app_id: str, annotation_id: str): def delete_app_annotation(cls, app_id: str, annotation_id: str):
# get app info # get app info
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -277,17 +269,16 @@ class AppAnnotationService:
if app_annotation_setting: if app_annotation_setting:
delete_annotation_index_task.delay( delete_annotation_index_task.delay(
annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id annotation.id, app_id, current_tenant_id, app_annotation_setting.collection_binding_id
) )
@classmethod @classmethod
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]): def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
# get app info # get app info
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -317,7 +308,7 @@ class AppAnnotationService:
for annotation, annotation_setting in annotations_to_delete: for annotation, annotation_setting in annotations_to_delete:
if annotation_setting: if annotation_setting:
delete_annotation_index_task.delay( delete_annotation_index_task.delay(
annotation.id, app_id, current_user.current_tenant_id, annotation_setting.collection_binding_id annotation.id, app_id, current_tenant_id, annotation_setting.collection_binding_id
) )
# Step 4: Bulk delete annotations in a single query # Step 4: Bulk delete annotations in a single query
@ -333,11 +324,10 @@ class AppAnnotationService:
@classmethod @classmethod
def batch_import_app_annotations(cls, app_id, file: FileStorage): def batch_import_app_annotations(cls, app_id, file: FileStorage):
# get app info # get app info
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -354,7 +344,7 @@ class AppAnnotationService:
if len(result) == 0: if len(result) == 0:
raise ValueError("The CSV file is empty.") raise ValueError("The CSV file is empty.")
# check annotation limit # check annotation limit
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled: if features.billing.enabled:
annotation_quota_limit = features.annotation_quota_limit annotation_quota_limit = features.annotation_quota_limit
if annotation_quota_limit.limit < len(result) + annotation_quota_limit.size: if annotation_quota_limit.limit < len(result) + annotation_quota_limit.size:
@ -364,21 +354,18 @@ class AppAnnotationService:
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
# send batch add segments task # send batch add segments task
redis_client.setnx(indexing_cache_key, "waiting") redis_client.setnx(indexing_cache_key, "waiting")
batch_import_annotations_task.delay( batch_import_annotations_task.delay(str(job_id), result, app_id, current_tenant_id, current_user.id)
str(job_id), result, app_id, current_user.current_tenant_id, current_user.id
)
except Exception as e: except Exception as e:
return {"error_msg": str(e)} return {"error_msg": str(e)}
return {"job_id": job_id, "job_status": "waiting"} return {"job_id": job_id, "job_status": "waiting"}
@classmethod @classmethod
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
# get app info # get app info
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -445,12 +432,11 @@ class AppAnnotationService:
@classmethod @classmethod
def get_app_annotation_setting_by_app_id(cls, app_id: str): def get_app_annotation_setting_by_app_id(cls, app_id: str):
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
# get app info # get app info
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -481,12 +467,11 @@ class AppAnnotationService:
@classmethod @classmethod
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
# get app info # get app info
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -531,11 +516,10 @@ class AppAnnotationService:
@classmethod @classmethod
def clear_all_annotations(cls, app_id: str): def clear_all_annotations(cls, app_id: str):
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -558,7 +542,7 @@ class AppAnnotationService:
# if annotation reply is enabled, delete annotation index # if annotation reply is enabled, delete annotation index
if app_annotation_setting: if app_annotation_setting:
delete_annotation_index_task.delay( delete_annotation_index_task.delay(
annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id annotation.id, app_id, current_tenant_id, app_annotation_setting.collection_binding_id
) )
db.session.delete(annotation) db.session.delete(annotation)

View File

@ -3,7 +3,6 @@ import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any
from flask_login import current_user
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
@ -18,6 +17,7 @@ from core.plugin.impl.oauth import OAuthHandler
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
from models.provider_ids import DatasourceProviderID from models.provider_ids import DatasourceProviderID
from services.plugin.plugin_service import PluginService from services.plugin.plugin_service import PluginService
@ -93,6 +93,8 @@ class DatasourceProviderService:
""" """
get credential by id get credential by id
""" """
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session: with Session(db.engine) as session:
if credential_id: if credential_id:
datasource_provider = ( datasource_provider = (
@ -157,6 +159,8 @@ class DatasourceProviderService:
""" """
get all datasource credentials by provider get all datasource credentials by provider
""" """
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session: with Session(db.engine) as session:
datasource_providers = ( datasource_providers = (
session.query(DatasourceProvider) session.query(DatasourceProvider)
@ -604,6 +608,8 @@ class DatasourceProviderService:
""" """
provider_name = provider_id.provider_name provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id plugin_id = provider_id.plugin_id
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session: with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}" lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
with redis_client.lock(lock, timeout=20): with redis_client.lock(lock, timeout=20):
@ -901,6 +907,8 @@ class DatasourceProviderService:
""" """
update datasource credentials. update datasource credentials.
""" """
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session: with Session(db.engine) as session:
datasource_provider = ( datasource_provider = (
session.query(DatasourceProvider) session.query(DatasourceProvider)

View File

@ -102,6 +102,15 @@ class OpsService:
except Exception: except Exception:
new_decrypt_tracing_config.update({"project_url": "https://arms.console.aliyun.com/"}) new_decrypt_tracing_config.update({"project_url": "https://arms.console.aliyun.com/"})
if tracing_provider == "tencent" and (
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
):
try:
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://console.cloud.tencent.com/apm"})
trace_config_data.tracing_config = new_decrypt_tracing_config trace_config_data.tracing_config = new_decrypt_tracing_config
return trace_config_data.to_dict() return trace_config_data.to_dict()
@ -144,7 +153,7 @@ class OpsService:
project_url = f"{tracing_config.get('host')}/project/{project_key}" project_url = f"{tracing_config.get('host')}/project/{project_key}"
except Exception: except Exception:
project_url = None project_url = None
elif tracing_provider in ("langsmith", "opik"): elif tracing_provider in ("langsmith", "opik", "tencent"):
try: try:
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider) project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
except Exception: except Exception:

View File

@ -86,12 +86,16 @@ class WorkflowAppService:
), ),
) )
if created_by_account: if created_by_account:
account = session.scalar(select(Account).where(Account.email == created_by_account))
if not account:
raise ValueError(f"Account not found: {created_by_account}")
stmt = stmt.join( stmt = stmt.join(
Account, Account,
and_( and_(
WorkflowAppLog.created_by == Account.id, WorkflowAppLog.created_by == Account.id,
WorkflowAppLog.created_by_role == CreatorUserRole.ACCOUNT, WorkflowAppLog.created_by_role == CreatorUserRole.ACCOUNT,
Account.email == created_by_account, Account.id == account.id,
), ),
) )

View File

@ -8,7 +8,6 @@ import click
import pandas as pd import pandas as pd
from celery import shared_task from celery import shared_task
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
@ -50,54 +49,48 @@ def batch_create_segment_to_index_task(
indexing_cache_key = f"segment_batch_import_{job_id}" indexing_cache_key = f"segment_batch_import_{job_id}"
try: try:
with Session(db.engine) as session: dataset = db.session.get(Dataset, dataset_id)
dataset = session.get(Dataset, dataset_id) if not dataset:
if not dataset: raise ValueError("Dataset not exist.")
raise ValueError("Dataset not exist.")
dataset_document = session.get(Document, document_id) dataset_document = db.session.get(Document, document_id)
if not dataset_document: if not dataset_document:
raise ValueError("Document not exist.") raise ValueError("Document not exist.")
if ( if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
not dataset_document.enabled raise ValueError("Document is not available.")
or dataset_document.archived
or dataset_document.indexing_status != "completed"
):
raise ValueError("Document is not available.")
upload_file = session.get(UploadFile, upload_file_id) upload_file = db.session.get(UploadFile, upload_file_id)
if not upload_file: if not upload_file:
raise ValueError("UploadFile not found.") raise ValueError("UploadFile not found.")
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix suffix = Path(upload_file.key).suffix
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore storage.download(upload_file.key, file_path)
storage.download(upload_file.key, file_path)
# Skip the first row df = pd.read_csv(file_path)
df = pd.read_csv(file_path) content = []
content = [] for _, row in df.iterrows():
for _, row in df.iterrows(): if dataset_document.doc_form == "qa_model":
if dataset_document.doc_form == "qa_model": data = {"content": row.iloc[0], "answer": row.iloc[1]}
data = {"content": row.iloc[0], "answer": row.iloc[1]} else:
else: data = {"content": row.iloc[0]}
data = {"content": row.iloc[0]} content.append(data)
content.append(data) if len(content) == 0:
if len(content) == 0: raise ValueError("The CSV file is empty.")
raise ValueError("The CSV file is empty.")
document_segments = []
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
document_segments = []
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
word_count_change = 0 word_count_change = 0
if embedding_model: if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens( tokens_list = embedding_model.get_text_embedding_num_tokens(
@ -105,6 +98,7 @@ def batch_create_segment_to_index_task(
) )
else: else:
tokens_list = [0] * len(content) tokens_list = [0] * len(content)
for segment, tokens in zip(content, tokens_list): for segment, tokens in zip(content, tokens_list):
content = segment["content"] content = segment["content"]
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
@ -135,11 +129,11 @@ def batch_create_segment_to_index_task(
word_count_change += segment_document.word_count word_count_change += segment_document.word_count
db.session.add(segment_document) db.session.add(segment_document)
document_segments.append(segment_document) document_segments.append(segment_document)
# update document word count
assert dataset_document.word_count is not None assert dataset_document.word_count is not None
dataset_document.word_count += word_count_change dataset_document.word_count += word_count_change
db.session.add(dataset_document) db.session.add(dataset_document)
# add index to db
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
db.session.commit() db.session.commit()
redis_client.setex(indexing_cache_key, 600, "completed") redis_client.setex(indexing_cache_key, 600, "completed")

View File

@ -75,10 +75,7 @@
</head> </head>
<body> <body>
<div class="container"> <div class="container">
<div class="header"> <div class="header"></div>
<!-- Optional: Add a logo or a header image here -->
<img src="https://assets.dify.ai/images/logo.png" alt="Dify Logo">
</div>
<div class="content"> <div class="content">
<p class="content1">Dear {{ to }},</p> <p class="content1">Dear {{ to }},</p>
<p class="content2">{{ inviter_name }} is pleased to invite you to join our workspace on {{application_title}}, a platform specifically designed for LLM application development. On {{application_title}}, you can explore, create, and collaborate to build and operate AI applications.</p> <p class="content2">{{ inviter_name }} is pleased to invite you to join our workspace on {{application_title}}, a platform specifically designed for LLM application development. On {{application_title}}, you can explore, create, and collaborate to build and operate AI applications.</p>

View File

@ -25,9 +25,7 @@ class TestAnnotationService:
patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task, patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task,
patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task, patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task,
patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task, patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task,
patch( patch("services.annotation_service.current_account_with_tenant") as mock_current_account_with_tenant,
"services.annotation_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
): ):
# Setup default mock returns # Setup default mock returns
mock_account_feature_service.get_features.return_value.billing.enabled = False mock_account_feature_service.get_features.return_value.billing.enabled = False
@ -38,6 +36,9 @@ class TestAnnotationService:
mock_disable_task.delay.return_value = None mock_disable_task.delay.return_value = None
mock_batch_import_task.delay.return_value = None mock_batch_import_task.delay.return_value = None
# Create mock user that will be returned by current_account_with_tenant
mock_user = create_autospec(Account, instance=True)
yield { yield {
"account_feature_service": mock_account_feature_service, "account_feature_service": mock_account_feature_service,
"feature_service": mock_feature_service, "feature_service": mock_feature_service,
@ -47,7 +48,8 @@ class TestAnnotationService:
"enable_task": mock_enable_task, "enable_task": mock_enable_task,
"disable_task": mock_disable_task, "disable_task": mock_disable_task,
"batch_import_task": mock_batch_import_task, "batch_import_task": mock_batch_import_task,
"current_user": mock_current_user, "current_account_with_tenant": mock_current_account_with_tenant,
"current_user": mock_user,
} }
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
@ -107,6 +109,11 @@ class TestAnnotationService:
""" """
mock_external_service_dependencies["current_user"].id = account_id mock_external_service_dependencies["current_user"].id = account_id
mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id
# Configure current_account_with_tenant to return (user, tenant_id)
mock_external_service_dependencies["current_account_with_tenant"].return_value = (
mock_external_service_dependencies["current_user"],
tenant_id,
)
def _create_test_conversation(self, app, account, fake): def _create_test_conversation(self, app, account, fake):
""" """

View File

@ -789,6 +789,31 @@ class TestWorkflowAppService:
assert result_account_filter["total"] == 3 assert result_account_filter["total"] == 3
assert all(log.created_by_role == CreatorUserRole.ACCOUNT for log in result_account_filter["data"]) assert all(log.created_by_role == CreatorUserRole.ACCOUNT for log in result_account_filter["data"])
# Test filtering by changed account email
original_email = account.email
new_email = "changed@example.com"
account.email = new_email
db_session_with_containers.commit()
assert account.email == new_email
# Results for new email, is expected to be the same as the original email
result_with_new_email = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, created_by_account=new_email, page=1, limit=20
)
assert result_with_new_email["total"] == 3
assert all(log.created_by_role == CreatorUserRole.ACCOUNT for log in result_with_new_email["data"])
# Old email unbound, is unexpected input, should raise ValueError
with pytest.raises(ValueError) as exc_info:
service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, created_by_account=original_email, page=1, limit=20
)
assert "Account not found" in str(exc_info.value)
account.email = original_email
db_session_with_containers.commit()
# Test filtering by non-existent session ID # Test filtering by non-existent session ID
result_no_session = service.get_paginate_workflow_app_logs( result_no_session = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, session=db_session_with_containers,
@ -799,15 +824,16 @@ class TestWorkflowAppService:
) )
assert result_no_session["total"] == 0 assert result_no_session["total"] == 0
# Test filtering by non-existent account email # Test filtering by non-existent account email, is unexpected input, should raise ValueError
result_no_account = service.get_paginate_workflow_app_logs( with pytest.raises(ValueError) as exc_info:
session=db_session_with_containers, service.get_paginate_workflow_app_logs(
app_model=app, session=db_session_with_containers,
created_by_account="nonexistent@example.com", app_model=app,
page=1, created_by_account="nonexistent@example.com",
limit=20, page=1,
) limit=20,
assert result_no_account["total"] == 0 )
assert "Account not found" in str(exc_info.value)
def test_get_paginate_workflow_app_logs_with_uuid_keyword_search( def test_get_paginate_workflow_app_logs_with_uuid_keyword_search(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers, mock_external_service_dependencies
@ -1057,15 +1083,15 @@ class TestWorkflowAppService:
assert len(result_no_session["data"]) == 0 assert len(result_no_session["data"]) == 0
# Test with account email that doesn't exist # Test with account email that doesn't exist
result_no_account = service.get_paginate_workflow_app_logs( with pytest.raises(ValueError) as exc_info:
session=db_session_with_containers, service.get_paginate_workflow_app_logs(
app_model=app, session=db_session_with_containers,
created_by_account="nonexistent@example.com", app_model=app,
page=1, created_by_account="nonexistent@example.com",
limit=20, page=1,
) limit=20,
assert result_no_account["total"] == 0 )
assert len(result_no_account["data"]) == 0 assert "Account not found" in str(exc_info.value)
def test_get_paginate_workflow_app_logs_with_complex_query_combinations( def test_get_paginate_workflow_app_logs_with_complex_query_combinations(
self, db_session_with_containers, mock_external_service_dependencies self, db_session_with_containers, mock_external_service_dependencies

View File

@ -0,0 +1,401 @@
"""
TestContainers-based integration tests for mail_owner_transfer_task.
This module provides comprehensive integration tests for the mail owner transfer tasks
using TestContainers to ensure real email service integration and proper functionality
testing with actual database and service dependencies.
"""
import logging
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from libs.email_i18n import EmailType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from tasks.mail_owner_transfer_task import (
send_new_owner_transfer_notify_email_task,
send_old_owner_transfer_notify_email_task,
send_owner_transfer_confirm_task,
)
logger = logging.getLogger(__name__)
class TestMailOwnerTransferTask:
"""Integration tests for mail owner transfer tasks using testcontainers."""
@pytest.fixture
def mock_mail_dependencies(self):
"""Mock setup for mail service dependencies."""
with (
patch("tasks.mail_owner_transfer_task.mail") as mock_mail,
patch("tasks.mail_owner_transfer_task.get_email_i18n_service") as mock_get_email_service,
):
# Setup mock mail service
mock_mail.is_inited.return_value = True
# Setup mock email service
mock_email_service = MagicMock()
mock_get_email_service.return_value = mock_email_service
yield {
"mail": mock_mail,
"email_service": mock_email_service,
"get_email_service": mock_get_email_service,
}
def _create_test_account_and_tenant(self, db_session_with_containers):
"""
Helper method to create test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
name=fake.company(),
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
return account, tenant
def test_send_owner_transfer_confirm_task_success(self, db_session_with_containers, mock_mail_dependencies):
"""
Test successful owner transfer confirmation email sending.
This test verifies:
- Proper email service initialization check
- Correct email service method calls with right parameters
- Email template context is properly constructed
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
test_language = "en-US"
test_email = account.email
test_code = "123456"
test_workspace = tenant.name
# Act: Execute the task
send_owner_transfer_confirm_task(
language=test_language,
to=test_email,
code=test_code,
workspace=test_workspace,
)
# Assert: Verify the expected outcomes
mock_mail_dependencies["mail"].is_inited.assert_called_once()
mock_mail_dependencies["get_email_service"].assert_called_once()
# Verify email service was called with correct parameters
mock_mail_dependencies["email_service"].send_email.assert_called_once()
call_args = mock_mail_dependencies["email_service"].send_email.call_args
assert call_args[1]["email_type"] == EmailType.OWNER_TRANSFER_CONFIRM
assert call_args[1]["language_code"] == test_language
assert call_args[1]["to"] == test_email
assert call_args[1]["template_context"]["to"] == test_email
assert call_args[1]["template_context"]["code"] == test_code
assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace
def test_send_owner_transfer_confirm_task_mail_not_initialized(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test owner transfer confirmation email when mail service is not initialized.
This test verifies:
- Early return when mail service is not initialized
- No email service calls are made
- No exceptions are raised
"""
# Arrange: Set mail service as not initialized
mock_mail_dependencies["mail"].is_inited.return_value = False
test_language = "en-US"
test_email = "test@example.com"
test_code = "123456"
test_workspace = "Test Workspace"
# Act: Execute the task
send_owner_transfer_confirm_task(
language=test_language,
to=test_email,
code=test_code,
workspace=test_workspace,
)
# Assert: Verify no email service calls were made
mock_mail_dependencies["get_email_service"].assert_not_called()
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_owner_transfer_confirm_task_exception_handling(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test exception handling in owner transfer confirmation email.
This test verifies:
- Exceptions are properly caught and logged
- No exceptions are propagated to caller
- Email service calls are attempted
- Error logging works correctly
"""
# Arrange: Setup email service to raise exception
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
test_language = "en-US"
test_email = "test@example.com"
test_code = "123456"
test_workspace = "Test Workspace"
# Act & Assert: Verify no exception is raised
try:
send_owner_transfer_confirm_task(
language=test_language,
to=test_email,
code=test_code,
workspace=test_workspace,
)
except Exception as e:
pytest.fail(f"Task should not raise exceptions, but raised: {e}")
# Verify email service was called despite the exception
mock_mail_dependencies["email_service"].send_email.assert_called_once()
def test_send_old_owner_transfer_notify_email_task_success(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test successful old owner transfer notification email sending.
This test verifies:
- Proper email service initialization check
- Correct email service method calls with right parameters
- Email template context includes new owner email
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
test_language = "en-US"
test_email = account.email
test_workspace = tenant.name
test_new_owner_email = "newowner@example.com"
# Act: Execute the task
send_old_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
new_owner_email=test_new_owner_email,
)
# Assert: Verify the expected outcomes
mock_mail_dependencies["mail"].is_inited.assert_called_once()
mock_mail_dependencies["get_email_service"].assert_called_once()
# Verify email service was called with correct parameters
mock_mail_dependencies["email_service"].send_email.assert_called_once()
call_args = mock_mail_dependencies["email_service"].send_email.call_args
assert call_args[1]["email_type"] == EmailType.OWNER_TRANSFER_OLD_NOTIFY
assert call_args[1]["language_code"] == test_language
assert call_args[1]["to"] == test_email
assert call_args[1]["template_context"]["to"] == test_email
assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace
assert call_args[1]["template_context"]["NewOwnerEmail"] == test_new_owner_email
def test_send_old_owner_transfer_notify_email_task_mail_not_initialized(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test old owner transfer notification email when mail service is not initialized.
This test verifies:
- Early return when mail service is not initialized
- No email service calls are made
- No exceptions are raised
"""
# Arrange: Set mail service as not initialized
mock_mail_dependencies["mail"].is_inited.return_value = False
test_language = "en-US"
test_email = "test@example.com"
test_workspace = "Test Workspace"
test_new_owner_email = "newowner@example.com"
# Act: Execute the task
send_old_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
new_owner_email=test_new_owner_email,
)
# Assert: Verify no email service calls were made
mock_mail_dependencies["get_email_service"].assert_not_called()
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_old_owner_transfer_notify_email_task_exception_handling(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test exception handling in old owner transfer notification email.
This test verifies:
- Exceptions are properly caught and logged
- No exceptions are propagated to caller
- Email service calls are attempted
- Error logging works correctly
"""
# Arrange: Setup email service to raise exception
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
test_language = "en-US"
test_email = "test@example.com"
test_workspace = "Test Workspace"
test_new_owner_email = "newowner@example.com"
# Act & Assert: Verify no exception is raised
try:
send_old_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
new_owner_email=test_new_owner_email,
)
except Exception as e:
pytest.fail(f"Task should not raise exceptions, but raised: {e}")
# Verify email service was called despite the exception
mock_mail_dependencies["email_service"].send_email.assert_called_once()
def test_send_new_owner_transfer_notify_email_task_success(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test successful new owner transfer notification email sending.
This test verifies:
- Proper email service initialization check
- Correct email service method calls with right parameters
- Email template context is properly constructed
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
test_language = "en-US"
test_email = account.email
test_workspace = tenant.name
# Act: Execute the task
send_new_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
)
# Assert: Verify the expected outcomes
mock_mail_dependencies["mail"].is_inited.assert_called_once()
mock_mail_dependencies["get_email_service"].assert_called_once()
# Verify email service was called with correct parameters
mock_mail_dependencies["email_service"].send_email.assert_called_once()
call_args = mock_mail_dependencies["email_service"].send_email.call_args
assert call_args[1]["email_type"] == EmailType.OWNER_TRANSFER_NEW_NOTIFY
assert call_args[1]["language_code"] == test_language
assert call_args[1]["to"] == test_email
assert call_args[1]["template_context"]["to"] == test_email
assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace
def test_send_new_owner_transfer_notify_email_task_mail_not_initialized(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test new owner transfer notification email when mail service is not initialized.
This test verifies:
- Early return when mail service is not initialized
- No email service calls are made
- No exceptions are raised
"""
# Arrange: Set mail service as not initialized
mock_mail_dependencies["mail"].is_inited.return_value = False
test_language = "en-US"
test_email = "test@example.com"
test_workspace = "Test Workspace"
# Act: Execute the task
send_new_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
)
# Assert: Verify no email service calls were made
mock_mail_dependencies["get_email_service"].assert_not_called()
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_new_owner_transfer_notify_email_task_exception_handling(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test exception handling in new owner transfer notification email.
This test verifies:
- Exceptions are properly caught and logged
- No exceptions are propagated to caller
- Email service calls are attempted
- Error logging works correctly
"""
# Arrange: Setup email service to raise exception
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
test_language = "en-US"
test_email = "test@example.com"
test_workspace = "Test Workspace"
# Act & Assert: Verify no exception is raised
try:
send_new_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
)
except Exception as e:
pytest.fail(f"Task should not raise exceptions, but raised: {e}")
# Verify email service was called despite the exception
mock_mail_dependencies["email_service"].send_email.assert_called_once()

View File

@ -60,7 +60,7 @@ class TestAccountInitialization:
return "success" return "success"
# Act # Act
with patch("controllers.console.wraps._current_account", return_value=mock_user): with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_user, "tenant123")):
result = protected_view() result = protected_view()
# Assert # Assert
@ -77,7 +77,7 @@ class TestAccountInitialization:
return "success" return "success"
# Act & Assert # Act & Assert
with patch("controllers.console.wraps._current_account", return_value=mock_user): with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_user, "tenant123")):
with pytest.raises(AccountNotInitializedError): with pytest.raises(AccountNotInitializedError):
protected_view() protected_view()
@ -163,7 +163,9 @@ class TestBillingResourceLimits:
return "member_added" return "member_added"
# Act # Act
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): with patch(
"controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123")
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = add_member() result = add_member()
@ -185,7 +187,10 @@ class TestBillingResourceLimits:
# Act & Assert # Act & Assert
with app.test_request_context(): with app.test_request_context():
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): with patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MockUser("test_user"), "tenant123"),
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
add_member() add_member()
@ -207,7 +212,10 @@ class TestBillingResourceLimits:
# Test 1: Should reject when source is datasets # Test 1: Should reject when source is datasets
with app.test_request_context("/?source=datasets"): with app.test_request_context("/?source=datasets"):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): with patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MockUser("test_user"), "tenant123"),
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
upload_document() upload_document()
@ -215,7 +223,10 @@ class TestBillingResourceLimits:
# Test 2: Should allow when source is not datasets # Test 2: Should allow when source is not datasets
with app.test_request_context("/?source=other"): with app.test_request_context("/?source=other"):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): with patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MockUser("test_user"), "tenant123"),
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = upload_document() result = upload_document()
assert result == "document_uploaded" assert result == "document_uploaded"
@ -239,7 +250,9 @@ class TestRateLimiting:
return "knowledge_success" return "knowledge_success"
# Act # Act
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): with patch(
"controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123")
):
with patch( with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
): ):
@ -271,7 +284,10 @@ class TestRateLimiting:
# Act & Assert # Act & Assert
with app.test_request_context(): with app.test_request_context():
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): with patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MockUser("test_user"), "tenant123"),
):
with patch( with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
): ):

View File

@ -110,19 +110,6 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes
mock_redis.set.assert_called_once() mock_redis.set.assert_called_once()
def test_config_validation(self):
"""Test configuration validation."""
# Test missing required fields
with pytest.raises(ValueError):
AlibabaCloudMySQLVectorConfig(
host="", # Empty host should raise error
port=3306,
user="test",
password="test",
database="test",
max_connection=5,
)
@patch( @patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
) )
@ -718,5 +705,29 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
@pytest.mark.parametrize(
"invalid_config_override",
[
{"host": ""}, # Test empty host
{"port": 0}, # Test invalid port
{"max_connection": 0}, # Test invalid max_connection
],
)
def test_config_validation_parametrized(invalid_config_override):
"""Test configuration validation for various invalid inputs using parametrize."""
config = {
"host": "localhost",
"port": 3306,
"user": "test",
"password": "test",
"database": "test",
"max_connection": 5,
}
config.update(invalid_config_override)
with pytest.raises(ValueError):
AlibabaCloudMySQLVectorConfig(**config)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -0,0 +1,29 @@
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
def _make_identity() -> ToolIdentity:
return ToolIdentity(
author="author",
name="tool",
label=I18nObject(en_US="Label"),
provider="builtin",
)
def test_log_message_metadata_none_defaults_to_empty_dict():
log_message = ToolInvokeMessage.LogMessage(
id="log-1",
label="Log entry",
status=ToolInvokeMessage.LogMessage.LogStatus.START,
data={},
metadata=None,
)
assert log_message.metadata == {}
def test_tool_entity_output_schema_none_defaults_to_empty_dict():
entity = ToolEntity(identity=_make_identity(), output_schema=None)
assert entity.output_schema == {}

1981
api/uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -329,7 +329,7 @@ services:
# The Weaviate vector store. # The Weaviate vector store.
weaviate: weaviate:
image: semitechnologies/weaviate:1.19.0 image: semitechnologies/weaviate:1.27.0
profiles: profiles:
- "" - ""
- weaviate - weaviate

View File

@ -181,7 +181,7 @@ services:
# The Weaviate vector store. # The Weaviate vector store.
weaviate: weaviate:
image: semitechnologies/weaviate:1.19.0 image: semitechnologies/weaviate:1.27.0
profiles: profiles:
- "" - ""
- weaviate - weaviate
@ -206,6 +206,7 @@ services:
AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai} AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai}
ports: ports:
- "${EXPOSE_WEAVIATE_PORT:-8080}:8080" - "${EXPOSE_WEAVIATE_PORT:-8080}:8080"
- "${EXPOSE_WEAVIATE_GRPC_PORT:-50051}:50051"
networks: networks:
# create a network between sandbox, api and ssrf_proxy, and can not access outside. # create a network between sandbox, api and ssrf_proxy, and can not access outside.

View File

@ -0,0 +1,9 @@
services:
api:
volumes:
- ../api/core/rag/datasource/vdb/weaviate/weaviate_vector.py:/app/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py:ro
command: >
sh -c "
pip install --no-cache-dir 'weaviate>=4.0.0' &&
/bin/bash /entrypoint.sh
"

View File

@ -941,7 +941,7 @@ services:
# The Weaviate vector store. # The Weaviate vector store.
weaviate: weaviate:
image: semitechnologies/weaviate:1.19.0 image: semitechnologies/weaviate:1.27.0
profiles: profiles:
- "" - ""
- weaviate - weaviate

View File

@ -5,7 +5,7 @@ import { useTranslation } from 'react-i18next'
import { useBoolean } from 'ahooks' import { useBoolean } from 'ahooks'
import TracingIcon from './tracing-icon' import TracingIcon from './tracing-icon'
import ProviderPanel from './provider-panel' import ProviderPanel from './provider-panel'
import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, WeaveConfig } from './type' import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
import { TracingProvider } from './type' import { TracingProvider } from './type'
import ProviderConfigModal from './provider-config-modal' import ProviderConfigModal from './provider-config-modal'
import Indicator from '@/app/components/header/indicator' import Indicator from '@/app/components/header/indicator'
@ -30,7 +30,8 @@ export type PopupProps = {
opikConfig: OpikConfig | null opikConfig: OpikConfig | null
weaveConfig: WeaveConfig | null weaveConfig: WeaveConfig | null
aliyunConfig: AliyunConfig | null aliyunConfig: AliyunConfig | null
onConfigUpdated: (provider: TracingProvider, payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig) => void tencentConfig: TencentConfig | null
onConfigUpdated: (provider: TracingProvider, payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig) => void
onConfigRemoved: (provider: TracingProvider) => void onConfigRemoved: (provider: TracingProvider) => void
} }
@ -48,6 +49,7 @@ const ConfigPopup: FC<PopupProps> = ({
opikConfig, opikConfig,
weaveConfig, weaveConfig,
aliyunConfig, aliyunConfig,
tencentConfig,
onConfigUpdated, onConfigUpdated,
onConfigRemoved, onConfigRemoved,
}) => { }) => {
@ -81,8 +83,8 @@ const ConfigPopup: FC<PopupProps> = ({
hideConfigModal() hideConfigModal()
}, [currentProvider, hideConfigModal, onConfigRemoved]) }, [currentProvider, hideConfigModal, onConfigRemoved])
const providerAllConfigured = arizeConfig && phoenixConfig && langSmithConfig && langFuseConfig && opikConfig && weaveConfig && aliyunConfig const providerAllConfigured = arizeConfig && phoenixConfig && langSmithConfig && langFuseConfig && opikConfig && weaveConfig && aliyunConfig && tencentConfig
const providerAllNotConfigured = !arizeConfig && !phoenixConfig && !langSmithConfig && !langFuseConfig && !opikConfig && !weaveConfig && !aliyunConfig const providerAllNotConfigured = !arizeConfig && !phoenixConfig && !langSmithConfig && !langFuseConfig && !opikConfig && !weaveConfig && !aliyunConfig && !tencentConfig
const switchContent = ( const switchContent = (
<Switch <Switch
@ -182,6 +184,19 @@ const ConfigPopup: FC<PopupProps> = ({
key="aliyun-provider-panel" key="aliyun-provider-panel"
/> />
) )
const tencentPanel = (
<ProviderPanel
type={TracingProvider.tencent}
readOnly={readOnly}
config={tencentConfig}
hasConfigured={!!tencentConfig}
onConfig={handleOnConfig(TracingProvider.tencent)}
isChosen={chosenProvider === TracingProvider.tencent}
onChoose={handleOnChoose(TracingProvider.tencent)}
key="tencent-provider-panel"
/>
)
const configuredProviderPanel = () => { const configuredProviderPanel = () => {
const configuredPanels: JSX.Element[] = [] const configuredPanels: JSX.Element[] = []
@ -206,6 +221,9 @@ const ConfigPopup: FC<PopupProps> = ({
if (aliyunConfig) if (aliyunConfig)
configuredPanels.push(aliyunPanel) configuredPanels.push(aliyunPanel)
if (tencentConfig)
configuredPanels.push(tencentPanel)
return configuredPanels return configuredPanels
} }
@ -233,6 +251,9 @@ const ConfigPopup: FC<PopupProps> = ({
if (!aliyunConfig) if (!aliyunConfig)
notConfiguredPanels.push(aliyunPanel) notConfiguredPanels.push(aliyunPanel)
if (!tencentConfig)
notConfiguredPanels.push(tencentPanel)
return notConfiguredPanels return notConfiguredPanels
} }
@ -249,6 +270,8 @@ const ConfigPopup: FC<PopupProps> = ({
return opikConfig return opikConfig
if (currentProvider === TracingProvider.aliyun) if (currentProvider === TracingProvider.aliyun)
return aliyunConfig return aliyunConfig
if (currentProvider === TracingProvider.tencent)
return tencentConfig
return weaveConfig return weaveConfig
} }
@ -297,6 +320,7 @@ const ConfigPopup: FC<PopupProps> = ({
{arizePanel} {arizePanel}
{phoenixPanel} {phoenixPanel}
{aliyunPanel} {aliyunPanel}
{tencentPanel}
</div> </div>
</> </>
) )

View File

@ -8,4 +8,5 @@ export const docURL = {
[TracingProvider.opik]: 'https://www.comet.com/docs/opik/tracing/integrations/dify#setup-instructions', [TracingProvider.opik]: 'https://www.comet.com/docs/opik/tracing/integrations/dify#setup-instructions',
[TracingProvider.weave]: 'https://weave-docs.wandb.ai/', [TracingProvider.weave]: 'https://weave-docs.wandb.ai/',
[TracingProvider.aliyun]: 'https://help.aliyun.com/zh/arms/tracing-analysis/untitled-document-1750672984680', [TracingProvider.aliyun]: 'https://help.aliyun.com/zh/arms/tracing-analysis/untitled-document-1750672984680',
[TracingProvider.tencent]: 'https://cloud.tencent.com/document/product/248/116531',
} }

View File

@ -8,12 +8,12 @@ import {
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { usePathname } from 'next/navigation' import { usePathname } from 'next/navigation'
import { useBoolean } from 'ahooks' import { useBoolean } from 'ahooks'
import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, WeaveConfig } from './type' import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
import { TracingProvider } from './type' import { TracingProvider } from './type'
import TracingIcon from './tracing-icon' import TracingIcon from './tracing-icon'
import ConfigButton from './config-button' import ConfigButton from './config-button'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import { AliyunIcon, ArizeIcon, LangfuseIcon, LangsmithIcon, OpikIcon, PhoenixIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing' import { AliyunIcon, ArizeIcon, LangfuseIcon, LangsmithIcon, OpikIcon, PhoenixIcon, TencentIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing'
import Indicator from '@/app/components/header/indicator' import Indicator from '@/app/components/header/indicator'
import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps' import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps'
import type { TracingStatus } from '@/models/app' import type { TracingStatus } from '@/models/app'
@ -71,6 +71,7 @@ const Panel: FC = () => {
[TracingProvider.opik]: OpikIcon, [TracingProvider.opik]: OpikIcon,
[TracingProvider.weave]: WeaveIcon, [TracingProvider.weave]: WeaveIcon,
[TracingProvider.aliyun]: AliyunIcon, [TracingProvider.aliyun]: AliyunIcon,
[TracingProvider.tencent]: TencentIcon,
} }
const InUseProviderIcon = inUseTracingProvider ? providerIconMap[inUseTracingProvider] : undefined const InUseProviderIcon = inUseTracingProvider ? providerIconMap[inUseTracingProvider] : undefined
@ -81,7 +82,8 @@ const Panel: FC = () => {
const [opikConfig, setOpikConfig] = useState<OpikConfig | null>(null) const [opikConfig, setOpikConfig] = useState<OpikConfig | null>(null)
const [weaveConfig, setWeaveConfig] = useState<WeaveConfig | null>(null) const [weaveConfig, setWeaveConfig] = useState<WeaveConfig | null>(null)
const [aliyunConfig, setAliyunConfig] = useState<AliyunConfig | null>(null) const [aliyunConfig, setAliyunConfig] = useState<AliyunConfig | null>(null)
const hasConfiguredTracing = !!(langSmithConfig || langFuseConfig || opikConfig || weaveConfig || arizeConfig || phoenixConfig || aliyunConfig) const [tencentConfig, setTencentConfig] = useState<TencentConfig | null>(null)
const hasConfiguredTracing = !!(langSmithConfig || langFuseConfig || opikConfig || weaveConfig || arizeConfig || phoenixConfig || aliyunConfig || tencentConfig)
const fetchTracingConfig = async () => { const fetchTracingConfig = async () => {
const getArizeConfig = async () => { const getArizeConfig = async () => {
@ -119,6 +121,11 @@ const Panel: FC = () => {
if (!aliyunHasNotConfig) if (!aliyunHasNotConfig)
setAliyunConfig(aliyunConfig as AliyunConfig) setAliyunConfig(aliyunConfig as AliyunConfig)
} }
const getTencentConfig = async () => {
const { tracing_config: tencentConfig, has_not_configured: tencentHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.tencent })
if (!tencentHasNotConfig)
setTencentConfig(tencentConfig as TencentConfig)
}
Promise.all([ Promise.all([
getArizeConfig(), getArizeConfig(),
getPhoenixConfig(), getPhoenixConfig(),
@ -127,6 +134,7 @@ const Panel: FC = () => {
getOpikConfig(), getOpikConfig(),
getWeaveConfig(), getWeaveConfig(),
getAliyunConfig(), getAliyunConfig(),
getTencentConfig(),
]) ])
} }
@ -147,6 +155,8 @@ const Panel: FC = () => {
setWeaveConfig(tracing_config as WeaveConfig) setWeaveConfig(tracing_config as WeaveConfig)
else if (provider === TracingProvider.aliyun) else if (provider === TracingProvider.aliyun)
setAliyunConfig(tracing_config as AliyunConfig) setAliyunConfig(tracing_config as AliyunConfig)
else if (provider === TracingProvider.tencent)
setTencentConfig(tracing_config as TencentConfig)
} }
const handleTracingConfigRemoved = (provider: TracingProvider) => { const handleTracingConfigRemoved = (provider: TracingProvider) => {
@ -164,6 +174,8 @@ const Panel: FC = () => {
setWeaveConfig(null) setWeaveConfig(null)
else if (provider === TracingProvider.aliyun) else if (provider === TracingProvider.aliyun)
setAliyunConfig(null) setAliyunConfig(null)
else if (provider === TracingProvider.tencent)
setTencentConfig(null)
if (provider === inUseTracingProvider) { if (provider === inUseTracingProvider) {
handleTracingStatusChange({ handleTracingStatusChange({
enabled: false, enabled: false,
@ -209,6 +221,7 @@ const Panel: FC = () => {
opikConfig={opikConfig} opikConfig={opikConfig}
weaveConfig={weaveConfig} weaveConfig={weaveConfig}
aliyunConfig={aliyunConfig} aliyunConfig={aliyunConfig}
tencentConfig={tencentConfig}
onConfigUpdated={handleTracingConfigUpdated} onConfigUpdated={handleTracingConfigUpdated}
onConfigRemoved={handleTracingConfigRemoved} onConfigRemoved={handleTracingConfigRemoved}
> >
@ -245,6 +258,7 @@ const Panel: FC = () => {
opikConfig={opikConfig} opikConfig={opikConfig}
weaveConfig={weaveConfig} weaveConfig={weaveConfig}
aliyunConfig={aliyunConfig} aliyunConfig={aliyunConfig}
tencentConfig={tencentConfig}
onConfigUpdated={handleTracingConfigUpdated} onConfigUpdated={handleTracingConfigUpdated}
onConfigRemoved={handleTracingConfigRemoved} onConfigRemoved={handleTracingConfigRemoved}
> >

View File

@ -4,7 +4,7 @@ import React, { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { useBoolean } from 'ahooks' import { useBoolean } from 'ahooks'
import Field from './field' import Field from './field'
import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, WeaveConfig } from './type' import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
import { TracingProvider } from './type' import { TracingProvider } from './type'
import { docURL } from './config' import { docURL } from './config'
import { import {
@ -22,10 +22,10 @@ import Divider from '@/app/components/base/divider'
type Props = { type Props = {
appId: string appId: string
type: TracingProvider type: TracingProvider
payload?: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | null payload?: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig | null
onRemoved: () => void onRemoved: () => void
onCancel: () => void onCancel: () => void
onSaved: (payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig) => void onSaved: (payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig) => void
onChosen: (provider: TracingProvider) => void onChosen: (provider: TracingProvider) => void
} }
@ -77,6 +77,12 @@ const aliyunConfigTemplate = {
endpoint: '', endpoint: '',
} }
const tencentConfigTemplate = {
token: '',
endpoint: '',
service_name: '',
}
const ProviderConfigModal: FC<Props> = ({ const ProviderConfigModal: FC<Props> = ({
appId, appId,
type, type,
@ -90,7 +96,7 @@ const ProviderConfigModal: FC<Props> = ({
const isEdit = !!payload const isEdit = !!payload
const isAdd = !isEdit const isAdd = !isEdit
const [isSaving, setIsSaving] = useState(false) const [isSaving, setIsSaving] = useState(false)
const [config, setConfig] = useState<ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig>((() => { const [config, setConfig] = useState<ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig>((() => {
if (isEdit) if (isEdit)
return payload return payload
@ -112,6 +118,9 @@ const ProviderConfigModal: FC<Props> = ({
else if (type === TracingProvider.aliyun) else if (type === TracingProvider.aliyun)
return aliyunConfigTemplate return aliyunConfigTemplate
else if (type === TracingProvider.tencent)
return tencentConfigTemplate
return weaveConfigTemplate return weaveConfigTemplate
})()) })())
const [isShowRemoveConfirm, { const [isShowRemoveConfirm, {
@ -202,6 +211,16 @@ const ProviderConfigModal: FC<Props> = ({
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Endpoint' }) errorMessage = t('common.errorMsg.fieldRequired', { field: 'Endpoint' })
} }
if (type === TracingProvider.tencent) {
const postData = config as TencentConfig
if (!errorMessage && !postData.token)
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Token' })
if (!errorMessage && !postData.endpoint)
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Endpoint' })
if (!errorMessage && !postData.service_name)
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Service Name' })
}
return errorMessage return errorMessage
}, [config, t, type]) }, [config, t, type])
const handleSave = useCallback(async () => { const handleSave = useCallback(async () => {
@ -338,6 +357,34 @@ const ProviderConfigModal: FC<Props> = ({
/> />
</> </>
)} )}
{type === TracingProvider.tencent && (
<>
<Field
label='Token'
labelClassName='!text-sm'
isRequired
value={(config as TencentConfig).token}
onChange={handleConfigChange('token')}
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'Token' })!}
/>
<Field
label='Endpoint'
labelClassName='!text-sm'
isRequired
value={(config as TencentConfig).endpoint}
onChange={handleConfigChange('endpoint')}
placeholder='https://your-region.cls.tencentcs.com'
/>
<Field
label='Service Name'
labelClassName='!text-sm'
isRequired
value={(config as TencentConfig).service_name}
onChange={handleConfigChange('service_name')}
placeholder='dify_app'
/>
</>
)}
{type === TracingProvider.weave && ( {type === TracingProvider.weave && (
<> <>
<Field <Field

View File

@ -7,7 +7,7 @@ import {
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { TracingProvider } from './type' import { TracingProvider } from './type'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import { AliyunIconBig, ArizeIconBig, LangfuseIconBig, LangsmithIconBig, OpikIconBig, PhoenixIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing' import { AliyunIconBig, ArizeIconBig, LangfuseIconBig, LangsmithIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing'
import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general' import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general'
const I18N_PREFIX = 'app.tracing' const I18N_PREFIX = 'app.tracing'
@ -31,6 +31,7 @@ const getIcon = (type: TracingProvider) => {
[TracingProvider.opik]: OpikIconBig, [TracingProvider.opik]: OpikIconBig,
[TracingProvider.weave]: WeaveIconBig, [TracingProvider.weave]: WeaveIconBig,
[TracingProvider.aliyun]: AliyunIconBig, [TracingProvider.aliyun]: AliyunIconBig,
[TracingProvider.tencent]: TencentIconBig,
})[type] })[type]
} }

View File

@ -6,6 +6,7 @@ export enum TracingProvider {
opik = 'opik', opik = 'opik',
weave = 'weave', weave = 'weave',
aliyun = 'aliyun', aliyun = 'aliyun',
tencent = 'tencent',
} }
export type ArizeConfig = { export type ArizeConfig = {
@ -53,3 +54,9 @@ export type AliyunConfig = {
license_key: string license_key: string
endpoint: string endpoint: string
} }
export type TencentConfig = {
token: string
endpoint: string
service_name: string
}

View File

@ -53,9 +53,6 @@ const ChatWrapper = () => {
initUserVariables, initUserVariables,
} = useChatWithHistoryContext() } = useChatWithHistoryContext()
// Semantic variable for better code readability
const isHistoryConversation = !!currentConversationId
const appConfig = useMemo(() => { const appConfig = useMemo(() => {
const config = appParams || {} const config = appParams || {}
@ -66,9 +63,9 @@ const ChatWrapper = () => {
fileUploadConfig: (config as any).system_parameters, fileUploadConfig: (config as any).system_parameters,
}, },
supportFeedback: true, supportFeedback: true,
opening_statement: isHistoryConversation ? currentConversationItem?.introduction : (config as any).opening_statement, opening_statement: currentConversationItem?.introduction || (config as any).opening_statement,
} as ChatConfig } as ChatConfig
}, [appParams, currentConversationItem?.introduction, isHistoryConversation]) }, [appParams, currentConversationItem?.introduction])
const { const {
chatList, chatList,
setTargetMessageId, setTargetMessageId,
@ -79,7 +76,7 @@ const ChatWrapper = () => {
} = useChat( } = useChat(
appConfig, appConfig,
{ {
inputs: (isHistoryConversation ? currentConversationInputs : newConversationInputs) as any, inputs: (currentConversationId ? currentConversationInputs : newConversationInputs) as any,
inputsForm: inputsForms, inputsForm: inputsForms,
}, },
appPrevChatTree, appPrevChatTree,
@ -87,7 +84,7 @@ const ChatWrapper = () => {
clearChatList, clearChatList,
setClearChatList, setClearChatList,
) )
const inputsFormValue = isHistoryConversation ? currentConversationInputs : newConversationInputsRef?.current const inputsFormValue = currentConversationId ? currentConversationInputs : newConversationInputsRef?.current
const inputDisabled = useMemo(() => { const inputDisabled = useMemo(() => {
if (allInputsHidden) if (allInputsHidden)
return false return false
@ -136,7 +133,7 @@ const ChatWrapper = () => {
const data: any = { const data: any = {
query: message, query: message,
files, files,
inputs: formatBooleanInputs(inputsForms, isHistoryConversation ? currentConversationInputs : newConversationInputs), inputs: formatBooleanInputs(inputsForms, currentConversationId ? currentConversationInputs : newConversationInputs),
conversation_id: currentConversationId, conversation_id: currentConversationId,
parent_message_id: (isRegenerate ? parentAnswer?.id : getLastAnswer(chatList)?.id) || null, parent_message_id: (isRegenerate ? parentAnswer?.id : getLastAnswer(chatList)?.id) || null,
} }
@ -146,11 +143,11 @@ const ChatWrapper = () => {
data, data,
{ {
onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, isInstalledApp, appId), onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, isInstalledApp, appId),
onConversationComplete: isHistoryConversation ? undefined : handleNewConversationCompleted, onConversationComplete: currentConversationId ? undefined : handleNewConversationCompleted,
isPublicAPI: !isInstalledApp, isPublicAPI: !isInstalledApp,
}, },
) )
}, [chatList, handleNewConversationCompleted, handleSend, isHistoryConversation, currentConversationInputs, newConversationInputs, isInstalledApp, appId]) }, [chatList, handleNewConversationCompleted, handleSend, currentConversationId, currentConversationInputs, newConversationInputs, isInstalledApp, appId])
const doRegenerate = useCallback((chatItem: ChatItemInTree, editedQuestion?: { message: string, files?: FileEntity[] }) => { const doRegenerate = useCallback((chatItem: ChatItemInTree, editedQuestion?: { message: string, files?: FileEntity[] }) => {
const question = editedQuestion ? chatItem : chatList.find(item => item.id === chatItem.parentMessageId)! const question = editedQuestion ? chatItem : chatList.find(item => item.id === chatItem.parentMessageId)!
@ -163,30 +160,38 @@ const ChatWrapper = () => {
}, [chatList, doSend]) }, [chatList, doSend])
const messageList = useMemo(() => { const messageList = useMemo(() => {
// Always filter out opening statement from message list as it's handled separately in welcome component if (currentConversationId || chatList.length > 1)
return chatList
// Without messages we are in the welcome screen, so hide the opening statement from chatlist
return chatList.filter(item => !item.isOpeningStatement) return chatList.filter(item => !item.isOpeningStatement)
}, [chatList]) }, [chatList])
const [collapsed, setCollapsed] = useState(isHistoryConversation) const [collapsed, setCollapsed] = useState(!!currentConversationId)
const chatNode = useMemo(() => { const chatNode = useMemo(() => {
if (allInputsHidden || !inputsForms.length) if (allInputsHidden || !inputsForms.length)
return null return null
if (isMobile) { if (isMobile) {
if (!isHistoryConversation) if (!currentConversationId)
return <InputsForm collapsed={collapsed} setCollapsed={setCollapsed} /> return <InputsForm collapsed={collapsed} setCollapsed={setCollapsed} />
return null return null
} }
else { else {
return <InputsForm collapsed={collapsed} setCollapsed={setCollapsed} /> return <InputsForm collapsed={collapsed} setCollapsed={setCollapsed} />
} }
}, [inputsForms.length, isMobile, isHistoryConversation, collapsed, allInputsHidden]) },
[
inputsForms.length,
isMobile,
currentConversationId,
collapsed, allInputsHidden,
])
const welcome = useMemo(() => { const welcome = useMemo(() => {
const welcomeMessage = chatList.find(item => item.isOpeningStatement) const welcomeMessage = chatList.find(item => item.isOpeningStatement)
if (respondingState) if (respondingState)
return null return null
if (isHistoryConversation) if (currentConversationId)
return null return null
if (!welcomeMessage) if (!welcomeMessage)
return null return null
@ -227,7 +232,18 @@ const ChatWrapper = () => {
</div> </div>
</div> </div>
) )
}, [appData?.site.icon, appData?.site.icon_background, appData?.site.icon_type, appData?.site.icon_url, chatList, collapsed, isHistoryConversation, inputsForms.length, respondingState, allInputsHidden]) },
[
appData?.site.icon,
appData?.site.icon_background,
appData?.site.icon_type,
appData?.site.icon_url,
chatList, collapsed,
currentConversationId,
inputsForms.length,
respondingState,
allInputsHidden,
])
const answerIcon = (appData?.site && appData.site.use_icon_as_answer_icon) const answerIcon = (appData?.site && appData.site.use_icon_as_answer_icon)
? <AnswerIcon ? <AnswerIcon
@ -251,7 +267,7 @@ const ChatWrapper = () => {
chatFooterClassName='pb-4' chatFooterClassName='pb-4'
chatFooterInnerClassName={`mx-auto w-full max-w-[768px] ${isMobile ? 'px-2' : 'px-4'}`} chatFooterInnerClassName={`mx-auto w-full max-w-[768px] ${isMobile ? 'px-2' : 'px-4'}`}
onSend={doSend} onSend={doSend}
inputs={isHistoryConversation ? currentConversationInputs as any : newConversationInputs} inputs={currentConversationId ? currentConversationInputs as any : newConversationInputs}
inputsForm={inputsForms} inputsForm={inputsForms}
onRegenerate={doRegenerate} onRegenerate={doRegenerate}
onStopResponding={handleStop} onStopResponding={handleStop}

View File

@ -111,6 +111,8 @@ const Answer: FC<AnswerProps> = ({
} }
}, [switchSibling, item.prevSibling, item.nextSibling]) }, [switchSibling, item.prevSibling, item.nextSibling])
const contentIsEmpty = content.trim() === ''
return ( return (
<div className='mb-2 flex last:mb-0'> <div className='mb-2 flex last:mb-0'>
<div className='relative h-10 w-10 shrink-0'> <div className='relative h-10 w-10 shrink-0'>
@ -153,14 +155,14 @@ const Answer: FC<AnswerProps> = ({
) )
} }
{ {
responding && !content && !hasAgentThoughts && ( responding && contentIsEmpty && !hasAgentThoughts && (
<div className='flex h-5 w-6 items-center justify-center'> <div className='flex h-5 w-6 items-center justify-center'>
<LoadingAnim type='text' /> <LoadingAnim type='text' />
</div> </div>
) )
} }
{ {
content && !hasAgentThoughts && ( !contentIsEmpty && !hasAgentThoughts && (
<BasicContent item={item} /> <BasicContent item={item} />
) )
} }

View File

@ -83,6 +83,15 @@ const ChatInputArea = ({
const historyRef = useRef(['']) const historyRef = useRef([''])
const [currentIndex, setCurrentIndex] = useState(-1) const [currentIndex, setCurrentIndex] = useState(-1)
const isComposingRef = useRef(false) const isComposingRef = useRef(false)
const handleQueryChange = useCallback(
(value: string) => {
setQuery(value)
setTimeout(handleTextareaResize, 0)
},
[handleTextareaResize],
)
const handleSend = () => { const handleSend = () => {
if (isResponding) { if (isResponding) {
notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') }) notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') })
@ -101,7 +110,7 @@ const ChatInputArea = ({
} }
if (checkInputsForm(inputs, inputsForm)) { if (checkInputsForm(inputs, inputsForm)) {
onSend(query, files) onSend(query, files)
setQuery('') handleQueryChange('')
setFiles([]) setFiles([])
} }
} }
@ -131,19 +140,19 @@ const ChatInputArea = ({
// When the cmd + up key is pressed, output the previous element // When the cmd + up key is pressed, output the previous element
if (currentIndex > 0) { if (currentIndex > 0) {
setCurrentIndex(currentIndex - 1) setCurrentIndex(currentIndex - 1)
setQuery(historyRef.current[currentIndex - 1]) handleQueryChange(historyRef.current[currentIndex - 1])
} }
} }
else if (e.key === 'ArrowDown' && !e.shiftKey && !e.nativeEvent.isComposing && e.metaKey) { else if (e.key === 'ArrowDown' && !e.shiftKey && !e.nativeEvent.isComposing && e.metaKey) {
// When the cmd + down key is pressed, output the next element // When the cmd + down key is pressed, output the next element
if (currentIndex < historyRef.current.length - 1) { if (currentIndex < historyRef.current.length - 1) {
setCurrentIndex(currentIndex + 1) setCurrentIndex(currentIndex + 1)
setQuery(historyRef.current[currentIndex + 1]) handleQueryChange(historyRef.current[currentIndex + 1])
} }
else if (currentIndex === historyRef.current.length - 1) { else if (currentIndex === historyRef.current.length - 1) {
// If it is the last element, clear the input box // If it is the last element, clear the input box
setCurrentIndex(historyRef.current.length) setCurrentIndex(historyRef.current.length)
setQuery('') handleQueryChange('')
} }
} }
} }
@ -171,7 +180,7 @@ const ChatInputArea = ({
<> <>
<div <div
className={cn( className={cn(
'relative z-10 rounded-xl border border-components-chat-input-border bg-components-panel-bg-blur pb-[9px] shadow-md', 'relative z-10 overflow-hidden rounded-xl border border-components-chat-input-border bg-components-panel-bg-blur pb-[9px] shadow-md',
isDragActive && 'border border-dashed border-components-option-card-option-selected-border', isDragActive && 'border border-dashed border-components-option-card-option-selected-border',
disabled && 'pointer-events-none border-components-panel-border opacity-50 shadow-none', disabled && 'pointer-events-none border-components-panel-border opacity-50 shadow-none',
)} )}
@ -197,12 +206,8 @@ const ChatInputArea = ({
placeholder={t('common.chat.inputPlaceholder', { botName }) || ''} placeholder={t('common.chat.inputPlaceholder', { botName }) || ''}
autoFocus autoFocus
minRows={1} minRows={1}
onResize={handleTextareaResize}
value={query} value={query}
onChange={(e) => { onChange={e => handleQueryChange(e.target.value)}
setQuery(e.target.value)
setTimeout(handleTextareaResize, 0)
}}
onKeyDown={handleKeyDown} onKeyDown={handleKeyDown}
onCompositionStart={handleCompositionStart} onCompositionStart={handleCompositionStart}
onCompositionEnd={handleCompositionEnd} onCompositionEnd={handleCompositionEnd}
@ -221,7 +226,7 @@ const ChatInputArea = ({
showVoiceInput && ( showVoiceInput && (
<VoiceInput <VoiceInput
onCancel={() => setShowVoiceInput(false)} onCancel={() => setShowVoiceInput(false)}
onConverted={text => setQuery(text)} onConverted={text => handleQueryChange(text)}
/> />
) )
} }

View File

@ -1,3 +1,4 @@
import type { FC, Ref } from 'react'
import { memo } from 'react' import { memo } from 'react'
import { import {
RiMicLine, RiMicLine,
@ -18,20 +19,17 @@ type OperationProps = {
speechToTextConfig?: EnableType speechToTextConfig?: EnableType
onShowVoiceInput?: () => void onShowVoiceInput?: () => void
onSend: () => void onSend: () => void
theme?: Theme | null theme?: Theme | null,
ref?: Ref<HTMLDivElement>;
} }
const Operation = ( const Operation: FC<OperationProps> = ({
{ ref,
ref, fileConfig,
fileConfig, speechToTextConfig,
speechToTextConfig, onShowVoiceInput,
onShowVoiceInput, onSend,
onSend, theme,
theme, }) => {
}: OperationProps & {
ref: React.RefObject<HTMLDivElement>;
},
) => {
return ( return (
<div <div
className={cn( className={cn(

View File

@ -62,9 +62,9 @@ const ChatWrapper = () => {
fileUploadConfig: (config as any).system_parameters, fileUploadConfig: (config as any).system_parameters,
}, },
supportFeedback: true, supportFeedback: true,
opening_statement: currentConversationId ? currentConversationItem?.introduction : (config as any).opening_statement, opening_statement: currentConversationItem?.introduction || (config as any).opening_statement,
} as ChatConfig } as ChatConfig
}, [appParams, currentConversationItem?.introduction, currentConversationId]) }, [appParams, currentConversationItem?.introduction])
const { const {
chatList, chatList,
setTargetMessageId, setTargetMessageId,
@ -158,8 +158,9 @@ const ChatWrapper = () => {
}, [chatList, doSend]) }, [chatList, doSend])
const messageList = useMemo(() => { const messageList = useMemo(() => {
if (currentConversationId) if (currentConversationId || chatList.length > 1)
return chatList return chatList
// Without messages we are in the welcome screen, so hide the opening statement from chatlist
return chatList.filter(item => !item.isOpeningStatement) return chatList.filter(item => !item.isOpeningStatement)
}, [chatList, currentConversationId]) }, [chatList, currentConversationId])
@ -240,7 +241,7 @@ const ChatWrapper = () => {
config={appConfig} config={appConfig}
chatList={messageList} chatList={messageList}
isResponding={respondingState} isResponding={respondingState}
chatContainerInnerClassName={cn('mx-auto w-full max-w-full pt-4 tablet:px-4', isMobile && 'px-4')} chatContainerInnerClassName={cn('mx-auto w-full max-w-full px-4', messageList.length && 'pt-4')}
chatFooterClassName={cn('pb-4', !isMobile && 'rounded-b-2xl')} chatFooterClassName={cn('pb-4', !isMobile && 'rounded-b-2xl')}
chatFooterInnerClassName={cn('mx-auto w-full max-w-full px-4', isMobile && 'px-2')} chatFooterInnerClassName={cn('mx-auto w-full max-w-full px-4', isMobile && 'px-2')}
onSend={doSend} onSend={doSend}

View File

@ -36,6 +36,7 @@ const Header: FC<IHeaderProps> = ({
appData, appData,
currentConversationId, currentConversationId,
inputsForms, inputsForms,
allInputsHidden,
} = useEmbeddedChatbotContext() } = useEmbeddedChatbotContext()
const isClient = typeof window !== 'undefined' const isClient = typeof window !== 'undefined'
@ -124,7 +125,7 @@ const Header: FC<IHeaderProps> = ({
</ActionButton> </ActionButton>
</Tooltip> </Tooltip>
)} )}
{currentConversationId && inputsForms.length > 0 && ( {currentConversationId && inputsForms.length > 0 && !allInputsHidden && (
<ViewFormDropdown /> <ViewFormDropdown />
)} )}
</div> </div>
@ -135,7 +136,7 @@ const Header: FC<IHeaderProps> = ({
return ( return (
<div <div
className={cn('flex h-14 shrink-0 items-center justify-between rounded-t-2xl px-3')} className={cn('flex h-14 shrink-0 items-center justify-between rounded-t-2xl px-3')}
style={Object.assign({}, CssTransform(theme?.backgroundHeaderColorStyle ?? ''), CssTransform(theme?.headerBorderBottomStyle ?? ''))} style={CssTransform(theme?.headerBorderBottomStyle ?? '')}
> >
<div className="flex grow items-center space-x-3"> <div className="flex grow items-center space-x-3">
{customerIcon} {customerIcon}
@ -171,7 +172,7 @@ const Header: FC<IHeaderProps> = ({
</ActionButton> </ActionButton>
</Tooltip> </Tooltip>
)} )}
{currentConversationId && inputsForms.length > 0 && ( {currentConversationId && inputsForms.length > 0 && !allInputsHidden && (
<ViewFormDropdown iconColor={theme?.colorPathOnHeader} /> <ViewFormDropdown iconColor={theme?.colorPathOnHeader} />
)} )}
</div> </div>

View File

@ -49,8 +49,8 @@ const Chatbot = () => {
<div className='relative'> <div className='relative'>
<div <div
className={cn( className={cn(
'flex flex-col rounded-2xl border border-components-panel-border-subtle', 'flex flex-col rounded-2xl',
isMobile ? 'h-[calc(100vh_-_60px)] border-[0.5px] border-components-panel-border shadow-xs' : 'h-[100vh] bg-chatbot-bg', isMobile ? 'h-[calc(100vh_-_60px)] shadow-xs' : 'h-[100vh] bg-chatbot-bg',
)} )}
style={isMobile ? Object.assign({}, CssTransform(themeBuilder?.theme?.backgroundHeaderColorStyle ?? '')) : {}} style={isMobile ? Object.assign({}, CssTransform(themeBuilder?.theme?.backgroundHeaderColorStyle ?? '')) : {}}
> >
@ -62,7 +62,7 @@ const Chatbot = () => {
theme={themeBuilder?.theme} theme={themeBuilder?.theme}
onCreateNewChat={handleNewConversation} onCreateNewChat={handleNewConversation}
/> />
<div className={cn('flex grow flex-col overflow-y-auto', isMobile && '!h-[calc(100vh_-_3rem)] rounded-2xl bg-chatbot-bg')}> <div className={cn('flex grow flex-col overflow-y-auto', isMobile && 'm-[0.5px] !h-[calc(100vh_-_3rem)] rounded-2xl bg-chatbot-bg')}>
{appChatListDataLoading && ( {appChatListDataLoading && (
<Loading type='app' /> <Loading type='app' />
)} )}

View File

@ -1,13 +0,0 @@
export { default as Chunk } from './Chunk'
export { default as Collapse } from './Collapse'
export { default as Divider } from './Divider'
export { default as File } from './File'
export { default as GeneralType } from './GeneralType'
export { default as LayoutRight2LineMod } from './LayoutRight2LineMod'
export { default as OptionCardEffectBlueLight } from './OptionCardEffectBlueLight'
export { default as OptionCardEffectBlue } from './OptionCardEffectBlue'
export { default as OptionCardEffectOrange } from './OptionCardEffectOrange'
export { default as OptionCardEffectPurple } from './OptionCardEffectPurple'
export { default as ParentChildType } from './ParentChildType'
export { default as SelectionMod } from './SelectionMod'
export { default as Watercrawl } from './Watercrawl'

View File

@ -1,5 +0,0 @@
.wrapper {
display: inline-flex;
background: url(~@/app/components/base/icons/assets/image/llm/baichuan-text-cn.png) center center no-repeat;
background-size: contain;
}

View File

@ -1,20 +0,0 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import cn from '@/utils/classnames'
import s from './BaichuanTextCn.module.css'
const Icon = (
{
ref,
className,
...restProps
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
ref?: React.RefObject<HTMLSpanElement>;
},
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
Icon.displayName = 'BaichuanTextCn'
export default Icon

View File

@ -1,5 +0,0 @@
.wrapper {
display: inline-flex;
background: url(~@/app/components/base/icons/assets/image/llm/minimax.png) center center no-repeat;
background-size: contain;
}

View File

@ -1,20 +0,0 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import cn from '@/utils/classnames'
import s from './Minimax.module.css'
const Icon = (
{
ref,
className,
...restProps
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
ref?: React.RefObject<HTMLSpanElement>;
},
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
Icon.displayName = 'Minimax'
export default Icon

View File

@ -1,5 +0,0 @@
.wrapper {
display: inline-flex;
background: url(~@/app/components/base/icons/assets/image/llm/minimax-text.png) center center no-repeat;
background-size: contain;
}

View File

@ -1,20 +0,0 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import cn from '@/utils/classnames'
import s from './MinimaxText.module.css'
const Icon = (
{
ref,
className,
...restProps
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
ref?: React.RefObject<HTMLSpanElement>;
},
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
Icon.displayName = 'MinimaxText'
export default Icon

View File

@ -1,5 +0,0 @@
.wrapper {
display: inline-flex;
background: url(~@/app/components/base/icons/assets/image/llm/tongyi.png) center center no-repeat;
background-size: contain;
}

View File

@ -1,20 +0,0 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import cn from '@/utils/classnames'
import s from './Tongyi.module.css'
const Icon = (
{
ref,
className,
...restProps
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
ref?: React.RefObject<HTMLSpanElement>;
},
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
Icon.displayName = 'Tongyi'
export default Icon

View File

@ -1,5 +0,0 @@
.wrapper {
display: inline-flex;
background: url(~@/app/components/base/icons/assets/image/llm/tongyi-text.png) center center no-repeat;
background-size: contain;
}

View File

@ -1,20 +0,0 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import cn from '@/utils/classnames'
import s from './TongyiText.module.css'
const Icon = (
{
ref,
className,
...restProps
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
ref?: React.RefObject<HTMLSpanElement>;
},
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
Icon.displayName = 'TongyiText'
export default Icon

View File

@ -1,5 +0,0 @@
.wrapper {
display: inline-flex;
background: url(~@/app/components/base/icons/assets/image/llm/tongyi-text-cn.png) center center no-repeat;
background-size: contain;
}

View File

@ -1,20 +0,0 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import cn from '@/utils/classnames'
import s from './TongyiTextCn.module.css'
const Icon = (
{
ref,
className,
...restProps
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
ref?: React.RefObject<HTMLSpanElement>;
},
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
Icon.displayName = 'TongyiTextCn'
export default Icon

View File

@ -1,5 +0,0 @@
.wrapper {
display: inline-flex;
background: url(~@/app/components/base/icons/assets/image/llm/wxyy.png) center center no-repeat;
background-size: contain;
}

View File

@ -1,20 +0,0 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import cn from '@/utils/classnames'
import s from './Wxyy.module.css'
const Icon = (
{
ref,
className,
...restProps
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
ref?: React.RefObject<HTMLSpanElement>;
},
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
Icon.displayName = 'Wxyy'
export default Icon

View File

@ -1,5 +0,0 @@
.wrapper {
display: inline-flex;
background: url(~@/app/components/base/icons/assets/image/llm/wxyy-text.png) center center no-repeat;
background-size: contain;
}

View File

@ -1,20 +0,0 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import cn from '@/utils/classnames'
import s from './WxyyText.module.css'
const Icon = (
{
ref,
className,
...restProps
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
ref?: React.RefObject<HTMLSpanElement>;
},
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
Icon.displayName = 'WxyyText'
export default Icon

View File

@ -1,5 +0,0 @@
.wrapper {
display: inline-flex;
background: url(~@/app/components/base/icons/assets/image/llm/wxyy-text-cn.png) center center no-repeat;
background-size: contain;
}

View File

@ -1,20 +0,0 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import cn from '@/utils/classnames'
import s from './WxyyTextCn.module.css'
const Icon = (
{
ref,
className,
...restProps
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
ref?: React.RefObject<HTMLSpanElement>;
},
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
Icon.displayName = 'WxyyTextCn'
export default Icon

Some files were not shown because too many files have changed in this diff Show More