mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
Merge remote-tracking branch 'origin/main' into feat/trigger
This commit is contained in:
1
.github/workflows/expose_service_ports.sh
vendored
1
.github/workflows/expose_service_ports.sh
vendored
@ -1,6 +1,7 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
|
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
|
||||||
|
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
|
||||||
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
|
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
|
||||||
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
|
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
|
||||||
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
|
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
|
||||||
|
|||||||
@ -200,6 +200,11 @@ class PluginConfig(BaseSettings):
|
|||||||
default="plugin-api-key",
|
default="plugin-api-key",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
|
||||||
|
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
|
||||||
|
default=300.0,
|
||||||
|
)
|
||||||
|
|
||||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||||
|
|
||||||
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
|
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
|
||||||
|
|||||||
@ -7,8 +7,7 @@ from werkzeug.exceptions import Forbidden
|
|||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from models.model import ApiToken, App
|
from models.model import ApiToken, App
|
||||||
|
|
||||||
@ -57,9 +56,9 @@ class BaseApiKeyListResource(Resource):
|
|||||||
def get(self, resource_id):
|
def get(self, resource_id):
|
||||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||||
keys = db.session.scalars(
|
keys = db.session.scalars(
|
||||||
select(ApiToken).where(
|
select(ApiToken).where(
|
||||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||||
@ -71,9 +70,8 @@ class BaseApiKeyListResource(Resource):
|
|||||||
def post(self, resource_id):
|
def post(self, resource_id):
|
||||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
assert isinstance(current_user, Account)
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
@ -93,7 +91,7 @@ class BaseApiKeyListResource(Resource):
|
|||||||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||||
api_token = ApiToken()
|
api_token = ApiToken()
|
||||||
setattr(api_token, self.resource_id_field, resource_id)
|
setattr(api_token, self.resource_id_field, resource_id)
|
||||||
api_token.tenant_id = current_user.current_tenant_id
|
api_token.tenant_id = current_tenant_id
|
||||||
api_token.token = key
|
api_token.token = key
|
||||||
api_token.type = self.resource_type
|
api_token.type = self.resource_type
|
||||||
db.session.add(api_token)
|
db.session.add(api_token)
|
||||||
@ -112,9 +110,8 @@ class BaseApiKeyResource(Resource):
|
|||||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
api_key_id = str(api_key_id)
|
api_key_id = str(api_key_id)
|
||||||
assert isinstance(current_user, Account)
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
@ -158,11 +155,6 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
|||||||
"""Create a new API key for an app"""
|
"""Create a new API key for an app"""
|
||||||
return super().post(resource_id)
|
return super().post(resource_id)
|
||||||
|
|
||||||
def after_request(self, resp):
|
|
||||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
|
||||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
|
||||||
return resp
|
|
||||||
|
|
||||||
resource_type = "app"
|
resource_type = "app"
|
||||||
resource_model = App
|
resource_model = App
|
||||||
resource_id_field = "app_id"
|
resource_id_field = "app_id"
|
||||||
@ -179,11 +171,6 @@ class AppApiKeyResource(BaseApiKeyResource):
|
|||||||
"""Delete an API key for an app"""
|
"""Delete an API key for an app"""
|
||||||
return super().delete(resource_id, api_key_id)
|
return super().delete(resource_id, api_key_id)
|
||||||
|
|
||||||
def after_request(self, resp):
|
|
||||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
|
||||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
|
||||||
return resp
|
|
||||||
|
|
||||||
resource_type = "app"
|
resource_type = "app"
|
||||||
resource_model = App
|
resource_model = App
|
||||||
resource_id_field = "app_id"
|
resource_id_field = "app_id"
|
||||||
@ -208,11 +195,6 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
|||||||
"""Create a new API key for a dataset"""
|
"""Create a new API key for a dataset"""
|
||||||
return super().post(resource_id)
|
return super().post(resource_id)
|
||||||
|
|
||||||
def after_request(self, resp):
|
|
||||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
|
||||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
|
||||||
return resp
|
|
||||||
|
|
||||||
resource_type = "dataset"
|
resource_type = "dataset"
|
||||||
resource_model = Dataset
|
resource_model = Dataset
|
||||||
resource_id_field = "dataset_id"
|
resource_id_field = "dataset_id"
|
||||||
@ -229,11 +211,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
|||||||
"""Delete an API key for a dataset"""
|
"""Delete an API key for a dataset"""
|
||||||
return super().delete(resource_id, api_key_id)
|
return super().delete(resource_id, api_key_id)
|
||||||
|
|
||||||
def after_request(self, resp):
|
|
||||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
|
||||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
|
||||||
return resp
|
|
||||||
|
|
||||||
resource_type = "dataset"
|
resource_type = "dataset"
|
||||||
resource_model = Dataset
|
resource_model = Dataset
|
||||||
resource_id_field = "dataset_id"
|
resource_id_field = "dataset_id"
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
@ -17,7 +16,7 @@ from fields.annotation_fields import (
|
|||||||
annotation_fields,
|
annotation_fields,
|
||||||
annotation_hit_history_fields,
|
annotation_hit_history_fields,
|
||||||
)
|
)
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.annotation_service import AppAnnotationService
|
from services.annotation_service import AppAnnotationService
|
||||||
|
|
||||||
|
|
||||||
@ -43,7 +42,9 @@ class AnnotationReplyActionApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
@ -70,7 +71,9 @@ class AppAnnotationSettingDetailApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
@ -99,7 +102,9 @@ class AppAnnotationSettingUpdateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, app_id, annotation_setting_id):
|
def post(self, app_id, annotation_setting_id):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
@ -125,7 +130,9 @@ class AnnotationReplyActionStatusApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
def get(self, app_id, job_id, action):
|
def get(self, app_id, job_id, action):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
job_id = str(job_id)
|
job_id = str(job_id)
|
||||||
@ -160,7 +167,9 @@ class AnnotationApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
page = request.args.get("page", default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
@ -199,7 +208,9 @@ class AnnotationApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
@marshal_with(annotation_fields)
|
@marshal_with(annotation_fields)
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
@ -214,7 +225,9 @@ class AnnotationApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, app_id):
|
def delete(self, app_id):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
@ -250,7 +263,9 @@ class AnnotationExportApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
@ -273,7 +288,9 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
@marshal_with(annotation_fields)
|
@marshal_with(annotation_fields)
|
||||||
def post(self, app_id, annotation_id):
|
def post(self, app_id, annotation_id):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
@ -289,7 +306,9 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, app_id, annotation_id):
|
def delete(self, app_id, annotation_id):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
@ -311,7 +330,9 @@ class AnnotationBatchImportApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
@ -342,7 +363,9 @@ class AnnotationBatchImportStatusApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
def get(self, app_id, job_id):
|
def get(self, app_id, job_id):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
job_id = str(job_id)
|
job_id = str(job_id)
|
||||||
@ -377,7 +400,9 @@ class AnnotationHitHistoryListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_id, annotation_id):
|
def get(self, app_id, annotation_id):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
page = request.args.get("page", default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
|
|||||||
@ -1,6 +1,3 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
@ -13,8 +10,7 @@ from controllers.console.wraps import (
|
|||||||
)
|
)
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
|
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account
|
|
||||||
from models.model import App
|
from models.model import App
|
||||||
from services.app_dsl_service import AppDslService, ImportStatus
|
from services.app_dsl_service import AppDslService, ImportStatus
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
@ -32,7 +28,8 @@ class AppImportApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("apps")
|
@cloud_edition_billing_resource_check("apps")
|
||||||
def post(self):
|
def post(self):
|
||||||
# Check user role first
|
# Check user role first
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -51,7 +48,7 @@ class AppImportApi(Resource):
|
|||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
import_service = AppDslService(session)
|
import_service = AppDslService(session)
|
||||||
# Import app
|
# Import app
|
||||||
account = cast(Account, current_user)
|
account = current_user
|
||||||
result = import_service.import_app(
|
result = import_service.import_app(
|
||||||
account=account,
|
account=account,
|
||||||
import_mode=args["mode"],
|
import_mode=args["mode"],
|
||||||
@ -85,14 +82,15 @@ class AppImportConfirmApi(Resource):
|
|||||||
@marshal_with(app_import_fields)
|
@marshal_with(app_import_fields)
|
||||||
def post(self, import_id):
|
def post(self, import_id):
|
||||||
# Check user role first
|
# Check user role first
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
# Create service with session
|
# Create service with session
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
import_service = AppDslService(session)
|
import_service = AppDslService(session)
|
||||||
# Confirm import
|
# Confirm import
|
||||||
account = cast(Account, current_user)
|
account = current_user
|
||||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
@ -110,7 +108,8 @@ class AppImportCheckDependenciesApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(app_import_check_dependencies_fields)
|
@marshal_with(app_import_check_dependencies_fields)
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields, reqparse
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
@ -18,7 +17,7 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP
|
|||||||
from core.llm_generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import App
|
from models import App
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
|
|
||||||
@ -49,11 +48,11 @@ class RuleGenerateApi(Resource):
|
|||||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
account = current_user
|
|
||||||
try:
|
try:
|
||||||
rules = LLMGenerator.generate_rule_config(
|
rules = LLMGenerator.generate_rule_config(
|
||||||
tenant_id=account.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
no_variable=args["no_variable"],
|
no_variable=args["no_variable"],
|
||||||
@ -100,11 +99,11 @@ class RuleCodeGenerateApi(Resource):
|
|||||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||||
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
account = current_user
|
|
||||||
try:
|
try:
|
||||||
code_result = LLMGenerator.generate_code(
|
code_result = LLMGenerator.generate_code(
|
||||||
tenant_id=account.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
code_language=args["code_language"],
|
code_language=args["code_language"],
|
||||||
@ -145,11 +144,11 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
account = current_user
|
|
||||||
try:
|
try:
|
||||||
structured_output = LLMGenerator.generate_structured_output(
|
structured_output = LLMGenerator.generate_structured_output(
|
||||||
tenant_id=account.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
)
|
)
|
||||||
@ -199,6 +198,7 @@ class InstructionGenerateApi(Resource):
|
|||||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||||
code_provider: type[CodeNodeProvider] | None = next(
|
code_provider: type[CodeNodeProvider] | None = next(
|
||||||
(p for p in providers if p.is_accept_language(args["language"])), None
|
(p for p in providers if p.is_accept_language(args["language"])), None
|
||||||
@ -221,21 +221,21 @@ class InstructionGenerateApi(Resource):
|
|||||||
match node_type:
|
match node_type:
|
||||||
case "llm":
|
case "llm":
|
||||||
return LLMGenerator.generate_rule_config(
|
return LLMGenerator.generate_rule_config(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
no_variable=True,
|
no_variable=True,
|
||||||
)
|
)
|
||||||
case "agent":
|
case "agent":
|
||||||
return LLMGenerator.generate_rule_config(
|
return LLMGenerator.generate_rule_config(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
no_variable=True,
|
no_variable=True,
|
||||||
)
|
)
|
||||||
case "code":
|
case "code":
|
||||||
return LLMGenerator.generate_code(
|
return LLMGenerator.generate_code(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
code_language=args["language"],
|
code_language=args["language"],
|
||||||
@ -244,7 +244,7 @@ class InstructionGenerateApi(Resource):
|
|||||||
return {"error": f"invalid node type: {node_type}"}
|
return {"error": f"invalid node type: {node_type}"}
|
||||||
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
||||||
return LLMGenerator.instruction_modify_legacy(
|
return LLMGenerator.instruction_modify_legacy(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
flow_id=args["flow_id"],
|
flow_id=args["flow_id"],
|
||||||
current=args["current"],
|
current=args["current"],
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
@ -253,7 +253,7 @@ class InstructionGenerateApi(Resource):
|
|||||||
)
|
)
|
||||||
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
||||||
return LLMGenerator.instruction_modify_workflow(
|
return LLMGenerator.instruction_modify_workflow(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
flow_id=args["flow_id"],
|
flow_id=args["flow_id"],
|
||||||
node_id=args["node_id"],
|
node_id=args["node_id"],
|
||||||
current=args["current"],
|
current=args["current"],
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import json
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource, fields
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
@ -15,8 +14,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
|
|||||||
from events.app_event import app_model_config_was_updated
|
from events.app_event import app_model_config_was_updated
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from models.model import AppMode, AppModelConfig
|
from models.model import AppMode, AppModelConfig
|
||||||
from services.app_model_config_service import AppModelConfigService
|
from services.app_model_config_service import AppModelConfigService
|
||||||
|
|
||||||
@ -54,16 +52,14 @@ class ModelConfigResource(Resource):
|
|||||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
"""Modify app model config"""
|
"""Modify app model config"""
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
|
|
||||||
# validate config
|
# validate config
|
||||||
model_configuration = AppModelConfigService.validate_configuration(
|
model_configuration = AppModelConfigService.validate_configuration(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
config=cast(dict, request.json),
|
config=cast(dict, request.json),
|
||||||
app_mode=AppMode.value_of(app_model.mode),
|
app_mode=AppMode.value_of(app_model.mode),
|
||||||
)
|
)
|
||||||
@ -95,12 +91,12 @@ class ModelConfigResource(Resource):
|
|||||||
# get tool
|
# get tool
|
||||||
try:
|
try:
|
||||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
agent_tool=agent_tool_entity,
|
agent_tool=agent_tool_entity,
|
||||||
)
|
)
|
||||||
manager = ToolParameterConfigurationManager(
|
manager = ToolParameterConfigurationManager(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
tool_runtime=tool_runtime,
|
tool_runtime=tool_runtime,
|
||||||
provider_name=agent_tool_entity.provider_id,
|
provider_name=agent_tool_entity.provider_id,
|
||||||
provider_type=agent_tool_entity.provider_type,
|
provider_type=agent_tool_entity.provider_type,
|
||||||
@ -134,7 +130,7 @@ class ModelConfigResource(Resource):
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
agent_tool=agent_tool_entity,
|
agent_tool=agent_tool_entity,
|
||||||
)
|
)
|
||||||
@ -142,7 +138,7 @@ class ModelConfigResource(Resource):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
manager = ToolParameterConfigurationManager(
|
manager = ToolParameterConfigurationManager(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
tool_runtime=tool_runtime,
|
tool_runtime=tool_runtime,
|
||||||
provider_name=agent_tool_entity.provider_id,
|
provider_name=agent_tool_entity.provider_id,
|
||||||
provider_type=agent_tool_entity.provider_type,
|
provider_type=agent_tool_entity.provider_type,
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
@ -9,7 +8,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.app_fields import app_site_fields
|
from fields.app_fields import app_site_fields
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account, Site
|
from models import Account, Site
|
||||||
|
|
||||||
|
|
||||||
@ -76,9 +75,10 @@ class AppSite(Resource):
|
|||||||
@marshal_with(app_site_fields)
|
@marshal_with(app_site_fields)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
args = parse_app_site_args()
|
args = parse_app_site_args()
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
# The role of the current user in the ta table must be editor, admin, or owner
|
# The role of the current user in the ta table must be editor, admin, or owner
|
||||||
if not current_user.is_editor:
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||||
@ -131,6 +131,8 @@ class AppSiteAccessTokenReset(Resource):
|
|||||||
@marshal_with(app_site_fields)
|
@marshal_with(app_site_fields)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||||
|
|
||||||
from ..wraps import account_initialization_required, setup_required
|
from ..wraps import account_initialization_required, setup_required
|
||||||
@ -16,7 +15,8 @@ class ApiKeyAuthDataSource(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
|
||||||
if data_source_api_key_bindings:
|
if data_source_api_key_bindings:
|
||||||
return {
|
return {
|
||||||
"sources": [
|
"sources": [
|
||||||
@ -41,6 +41,8 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
# The role of the current user in the table must be admin or owner
|
# The role of the current user in the table must be admin or owner
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -50,7 +52,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||||
try:
|
try:
|
||||||
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
|
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ApiKeyAuthFailedError(str(e))
|
raise ApiKeyAuthFailedError(str(e))
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
@ -63,9 +65,11 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, binding_id):
|
def delete(self, binding_id):
|
||||||
# The role of the current user in the table must be admin or owner
|
# The role of the current user in the table must be admin or owner
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
|
||||||
|
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from collections.abc import Generator
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -20,7 +19,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import DataSourceOauthBinding, Document
|
from models import DataSourceOauthBinding, Document
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from services.datasource_provider_service import DatasourceProviderService
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
@ -37,10 +36,12 @@ class DataSourceApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(integrate_list_fields)
|
@marshal_with(integrate_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# get workspace data source integrates
|
# get workspace data source integrates
|
||||||
data_source_integrates = db.session.scalars(
|
data_source_integrates = db.session.scalars(
|
||||||
select(DataSourceOauthBinding).where(
|
select(DataSourceOauthBinding).where(
|
||||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
DataSourceOauthBinding.tenant_id == current_tenant_id,
|
||||||
DataSourceOauthBinding.disabled == False,
|
DataSourceOauthBinding.disabled == False,
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
@ -120,13 +121,15 @@ class DataSourceNotionListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(integrate_notion_info_list_fields)
|
@marshal_with(integrate_notion_info_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||||
if not credential_id:
|
if not credential_id:
|
||||||
raise ValueError("Credential id is required.")
|
raise ValueError("Credential id is required.")
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
credential = datasource_provider_service.get_datasource_credentials(
|
credential = datasource_provider_service.get_datasource_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
credential_id=credential_id,
|
credential_id=credential_id,
|
||||||
provider="notion_datasource",
|
provider="notion_datasource",
|
||||||
plugin_id="langgenius/notion_datasource",
|
plugin_id="langgenius/notion_datasource",
|
||||||
@ -146,7 +149,7 @@ class DataSourceNotionListApi(Resource):
|
|||||||
documents = session.scalars(
|
documents = session.scalars(
|
||||||
select(Document).filter_by(
|
select(Document).filter_by(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
data_source_type="notion_import",
|
data_source_type="notion_import",
|
||||||
enabled=True,
|
enabled=True,
|
||||||
)
|
)
|
||||||
@ -161,7 +164,7 @@ class DataSourceNotionListApi(Resource):
|
|||||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||||
provider_id="langgenius/notion_datasource/notion_datasource",
|
provider_id="langgenius/notion_datasource/notion_datasource",
|
||||||
datasource_name="notion_datasource",
|
datasource_name="notion_datasource",
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||||
)
|
)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
@ -210,12 +213,14 @@ class DataSourceNotionApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, workspace_id, page_id, page_type):
|
def get(self, workspace_id, page_id, page_type):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||||
if not credential_id:
|
if not credential_id:
|
||||||
raise ValueError("Credential id is required.")
|
raise ValueError("Credential id is required.")
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
credential = datasource_provider_service.get_datasource_credentials(
|
credential = datasource_provider_service.get_datasource_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
credential_id=credential_id,
|
credential_id=credential_id,
|
||||||
provider="notion_datasource",
|
provider="notion_datasource",
|
||||||
plugin_id="langgenius/notion_datasource",
|
plugin_id="langgenius/notion_datasource",
|
||||||
@ -229,7 +234,7 @@ class DataSourceNotionApi(Resource):
|
|||||||
notion_obj_id=page_id,
|
notion_obj_id=page_id,
|
||||||
notion_page_type=page_type,
|
notion_page_type=page_type,
|
||||||
notion_access_token=credential.get("integration_secret"),
|
notion_access_token=credential.get("integration_secret"),
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
text_docs = extractor.extract()
|
text_docs = extractor.extract()
|
||||||
@ -239,6 +244,8 @@ class DataSourceNotionApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
||||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||||
@ -263,7 +270,7 @@ class DataSourceNotionApi(Resource):
|
|||||||
"notion_workspace_id": workspace_id,
|
"notion_workspace_id": workspace_id,
|
||||||
"notion_obj_id": page["page_id"],
|
"notion_obj_id": page["page_id"],
|
||||||
"notion_page_type": page["type"],
|
"notion_page_type": page["type"],
|
||||||
"tenant_id": current_user.current_tenant_id,
|
"tenant_id": current_tenant_id,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
document_model=args["doc_form"],
|
document_model=args["doc_form"],
|
||||||
@ -271,7 +278,7 @@ class DataSourceNotionApi(Resource):
|
|||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
response = indexing_runner.indexing_estimate(
|
response = indexing_runner.indexing_estimate(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
extract_settings,
|
extract_settings,
|
||||||
args["process_rule"],
|
args["process_rule"],
|
||||||
args["doc_form"],
|
args["doc_form"],
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal, reqparse
|
from flask_restx import Resource, marshal, reqparse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
@ -27,7 +26,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.dataset import ChildChunk, DocumentSegment
|
from models.dataset import ChildChunk, DocumentSegment
|
||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||||
@ -43,6 +42,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id, document_id):
|
def get(self, dataset_id, document_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -79,7 +80,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||||||
select(DocumentSegment)
|
select(DocumentSegment)
|
||||||
.where(
|
.where(
|
||||||
DocumentSegment.document_id == str(document_id),
|
DocumentSegment.document_id == str(document_id),
|
||||||
DocumentSegment.tenant_id == current_user.current_tenant_id,
|
DocumentSegment.tenant_id == current_tenant_id,
|
||||||
)
|
)
|
||||||
.order_by(DocumentSegment.position.asc())
|
.order_by(DocumentSegment.position.asc())
|
||||||
)
|
)
|
||||||
@ -115,6 +116,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def delete(self, dataset_id, document_id):
|
def delete(self, dataset_id, document_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -148,6 +151,8 @@ class DatasetDocumentSegmentApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def patch(self, dataset_id, document_id, action):
|
def patch(self, dataset_id, document_id, action):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
@ -171,7 +176,7 @@ class DatasetDocumentSegmentApi(Resource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
@ -204,6 +209,8 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self, dataset_id, document_id):
|
def post(self, dataset_id, document_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -221,7 +228,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
@ -255,6 +262,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def patch(self, dataset_id, document_id, segment_id):
|
def patch(self, dataset_id, document_id, segment_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -272,7 +281,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
@ -287,7 +296,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -317,6 +326,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def delete(self, dataset_id, document_id, segment_id):
|
def delete(self, dataset_id, document_id, segment_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -333,7 +344,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -361,6 +372,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self, dataset_id, document_id):
|
def post(self, dataset_id, document_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -396,7 +409,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||||||
upload_file_id,
|
upload_file_id,
|
||||||
dataset_id,
|
dataset_id,
|
||||||
document_id,
|
document_id,
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
current_user.id,
|
current_user.id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -427,6 +440,8 @@ class ChildChunkAddApi(Resource):
|
|||||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self, dataset_id, document_id, segment_id):
|
def post(self, dataset_id, document_id, segment_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -441,7 +456,7 @@ class ChildChunkAddApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -453,7 +468,7 @@ class ChildChunkAddApi(Resource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
@ -483,6 +498,8 @@ class ChildChunkAddApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id, document_id, segment_id):
|
def get(self, dataset_id, document_id, segment_id):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -499,7 +516,7 @@ class ChildChunkAddApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -530,6 +547,8 @@ class ChildChunkAddApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def patch(self, dataset_id, document_id, segment_id):
|
def patch(self, dataset_id, document_id, segment_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -546,7 +565,7 @@ class ChildChunkAddApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -580,6 +599,8 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -596,7 +617,7 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -607,7 +628,7 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
db.session.query(ChildChunk)
|
db.session.query(ChildChunk)
|
||||||
.where(
|
.where(
|
||||||
ChildChunk.id == str(child_chunk_id),
|
ChildChunk.id == str(child_chunk_id),
|
||||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
ChildChunk.tenant_id == current_tenant_id,
|
||||||
ChildChunk.segment_id == segment.id,
|
ChildChunk.segment_id == segment.id,
|
||||||
ChildChunk.document_id == document_id,
|
ChildChunk.document_id == document_id,
|
||||||
)
|
)
|
||||||
@ -634,6 +655,8 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -650,7 +673,7 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -661,7 +684,7 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
db.session.query(ChildChunk)
|
db.session.query(ChildChunk)
|
||||||
.where(
|
.where(
|
||||||
ChildChunk.id == str(child_chunk_id),
|
ChildChunk.id == str(child_chunk_id),
|
||||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
ChildChunk.tenant_id == current_tenant_id,
|
||||||
ChildChunk.segment_id == segment.id,
|
ChildChunk.segment_id == segment.id,
|
||||||
ChildChunk.document_id == document_id,
|
ChildChunk.document_id == document_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from flask import make_response, redirect, request
|
from flask import make_response, redirect, request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
@ -13,7 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.impl.oauth import OAuthHandler
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from libs.helper import StrLen
|
from libs.helper import StrLen
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.provider_ids import DatasourceProviderID
|
from models.provider_ids import DatasourceProviderID
|
||||||
from services.datasource_provider_service import DatasourceProviderService
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
from services.plugin.oauth_service import OAuthProxyService
|
from services.plugin.oauth_service import OAuthProxyService
|
||||||
@ -25,9 +24,10 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider_id: str):
|
def get(self, provider_id: str):
|
||||||
user = current_user
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
if not current_user.is_editor:
|
tenant_id = current_tenant_id
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
credential_id = request.args.get("credential_id")
|
credential_id = request.args.get("credential_id")
|
||||||
@ -52,7 +52,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
|||||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||||
authorization_url_response = oauth_handler.get_authorization_url(
|
authorization_url_response = oauth_handler.get_authorization_url(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=user.id,
|
user_id=current_user.id,
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
provider=provider_name,
|
provider=provider_name,
|
||||||
redirect_uri=redirect_uri,
|
redirect_uri=redirect_uri,
|
||||||
@ -131,7 +131,9 @@ class DatasourceAuth(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
if not current_user.is_editor:
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -145,7 +147,7 @@ class DatasourceAuth(Resource):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
datasource_provider_service.add_datasource_api_key_provider(
|
datasource_provider_service.add_datasource_api_key_provider(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider_id=datasource_provider_id,
|
provider_id=datasource_provider_id,
|
||||||
credentials=args["credentials"],
|
credentials=args["credentials"],
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
@ -160,8 +162,10 @@ class DatasourceAuth(Resource):
|
|||||||
def get(self, provider_id: str):
|
def get(self, provider_id: str):
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasources = datasource_provider_service.list_datasource_credentials(
|
datasources = datasource_provider_service.list_datasource_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=datasource_provider_id.provider_name,
|
provider=datasource_provider_id.provider_name,
|
||||||
plugin_id=datasource_provider_id.plugin_id,
|
plugin_id=datasource_provider_id.plugin_id,
|
||||||
)
|
)
|
||||||
@ -174,17 +178,19 @@ class DatasourceAuthDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
plugin_id = datasource_provider_id.plugin_id
|
plugin_id = datasource_provider_id.plugin_id
|
||||||
provider_name = datasource_provider_id.provider_name
|
provider_name = datasource_provider_id.provider_name
|
||||||
if not current_user.is_editor:
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.remove_datasource_credentials(
|
datasource_provider_service.remove_datasource_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
auth_id=args["credential_id"],
|
auth_id=args["credential_id"],
|
||||||
provider=provider_name,
|
provider=provider_name,
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
@ -198,17 +204,19 @@ class DatasourceAuthUpdateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if not current_user.is_editor:
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.update_datasource_credentials(
|
datasource_provider_service.update_datasource_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
auth_id=args["credential_id"],
|
auth_id=args["credential_id"],
|
||||||
provider=datasource_provider_id.provider_name,
|
provider=datasource_provider_id.provider_name,
|
||||||
plugin_id=datasource_provider_id.plugin_id,
|
plugin_id=datasource_provider_id.plugin_id,
|
||||||
@ -224,10 +232,10 @@ class DatasourceAuthListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasources = datasource_provider_service.get_all_datasource_credentials(
|
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
|
||||||
tenant_id=current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
return {"result": jsonable_encoder(datasources)}, 200
|
return {"result": jsonable_encoder(datasources)}, 200
|
||||||
|
|
||||||
|
|
||||||
@ -237,10 +245,10 @@ class DatasourceHardCodeAuthListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
|
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
|
||||||
tenant_id=current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
return {"result": jsonable_encoder(datasources)}, 200
|
return {"result": jsonable_encoder(datasources)}, 200
|
||||||
|
|
||||||
|
|
||||||
@ -250,7 +258,9 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
if not current_user.is_editor:
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||||
@ -259,7 +269,7 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.setup_oauth_custom_client_params(
|
datasource_provider_service.setup_oauth_custom_client_params(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_provider_id=datasource_provider_id,
|
datasource_provider_id=datasource_provider_id,
|
||||||
client_params=args.get("client_params", {}),
|
client_params=args.get("client_params", {}),
|
||||||
enabled=args.get("enable_oauth_custom_client", False),
|
enabled=args.get("enable_oauth_custom_client", False),
|
||||||
@ -270,10 +280,12 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, provider_id: str):
|
def delete(self, provider_id: str):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.remove_oauth_custom_client_params(
|
datasource_provider_service.remove_oauth_custom_client_params(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_provider_id=datasource_provider_id,
|
datasource_provider_id=datasource_provider_id,
|
||||||
)
|
)
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
@ -285,7 +297,9 @@ class DatasourceAuthDefaultApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
if not current_user.is_editor:
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||||
@ -293,7 +307,7 @@ class DatasourceAuthDefaultApi(Resource):
|
|||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.set_default_datasource_provider(
|
datasource_provider_service.set_default_datasource_provider(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_provider_id=datasource_provider_id,
|
datasource_provider_id=datasource_provider_id,
|
||||||
credential_id=args["id"],
|
credential_id=args["id"],
|
||||||
)
|
)
|
||||||
@ -306,7 +320,9 @@ class DatasourceUpdateProviderNameApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
if not current_user.is_editor:
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||||
@ -315,7 +331,7 @@ class DatasourceUpdateProviderNameApi(Resource):
|
|||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.update_datasource_provider_name(
|
datasource_provider_service.update_datasource_provider_name(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_provider_id=datasource_provider_id,
|
datasource_provider_id=datasource_provider_id,
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
credential_id=args["credential_id"],
|
credential_id=args["credential_id"],
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal, reqparse
|
from flask_restx import Resource, marshal, reqparse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
@ -13,7 +12,7 @@ from controllers.console.wraps import (
|
|||||||
)
|
)
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.dataset_fields import dataset_detail_fields
|
from fields.dataset_fields import dataset_detail_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.dataset import DatasetPermissionEnum
|
from models.dataset import DatasetPermissionEnum
|
||||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||||
@ -38,7 +37,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
@ -58,12 +57,12 @@ class CreateRagPipelineDatasetApi(Resource):
|
|||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||||
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||||
)
|
)
|
||||||
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||||
DatasetPermissionService.update_partial_member_list(
|
DatasetPermissionService.update_partial_member_list(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
import_info["dataset_id"],
|
import_info["dataset_id"],
|
||||||
rag_pipeline_dataset_create_entity.partial_member_list,
|
rag_pipeline_dataset_create_entity.partial_member_list,
|
||||||
)
|
)
|
||||||
@ -81,10 +80,12 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
|||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self):
|
def post(self):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
|
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
|
||||||
name="",
|
name="",
|
||||||
description="",
|
description="",
|
||||||
|
|||||||
@ -12,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.installed_app_fields import installed_app_list_fields
|
from fields.installed_app_fields import installed_app_list_fields
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account, App, InstalledApp, RecommendedApp
|
from models import App, InstalledApp, RecommendedApp
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
@ -29,9 +29,7 @@ class InstalledAppsListApi(Resource):
|
|||||||
@marshal_with(installed_app_list_fields)
|
@marshal_with(installed_app_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
app_id = request.args.get("app_id", default=None, type=str)
|
app_id = request.args.get("app_id", default=None, type=str)
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
current_tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
if app_id:
|
if app_id:
|
||||||
installed_apps = db.session.scalars(
|
installed_apps = db.session.scalars(
|
||||||
@ -121,9 +119,8 @@ class InstalledAppsListApi(Resource):
|
|||||||
if recommended_app is None:
|
if recommended_app is None:
|
||||||
raise NotFound("App not found")
|
raise NotFound("App not found")
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
current_tenant_id = current_user.current_tenant_id
|
|
||||||
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
||||||
|
|
||||||
if app is None:
|
if app is None:
|
||||||
@ -163,9 +160,8 @@ class InstalledAppApi(InstalledAppResource):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def delete(self, installed_app):
|
def delete(self, installed_app):
|
||||||
if not isinstance(current_user, Account):
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("current_user must be an Account instance")
|
if installed_app.app_owner_tenant_id == current_tenant_id:
|
||||||
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
|
|
||||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||||
|
|
||||||
db.session.delete(installed_app)
|
db.session.delete(installed_app)
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from constants import HIDDEN_VALUE
|
|||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from fields.api_based_extension_fields import api_based_extension_fields
|
from fields.api_based_extension_fields import api_based_extension_fields
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, current_user, login_required
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.api_based_extension import APIBasedExtension
|
from models.api_based_extension import APIBasedExtension
|
||||||
from services.api_based_extension_service import APIBasedExtensionService
|
from services.api_based_extension_service import APIBasedExtensionService
|
||||||
@ -47,9 +47,7 @@ class APIBasedExtensionAPI(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_based_extension_fields)
|
@marshal_with(api_based_extension_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
assert isinstance(current_user, Account)
|
_, tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||||
|
|
||||||
@api.doc("create_api_based_extension")
|
@api.doc("create_api_based_extension")
|
||||||
@ -77,9 +75,10 @@ class APIBasedExtensionAPI(Resource):
|
|||||||
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||||
parser.add_argument("api_key", type=str, required=True, location="json")
|
parser.add_argument("api_key", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
extension_data = APIBasedExtension(
|
extension_data = APIBasedExtension(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
api_endpoint=args["api_endpoint"],
|
api_endpoint=args["api_endpoint"],
|
||||||
api_key=args["api_key"],
|
api_key=args["api_key"],
|
||||||
@ -102,7 +101,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
assert current_user.current_tenant_id is not None
|
assert current_user.current_tenant_id is not None
|
||||||
api_based_extension_id = str(id)
|
api_based_extension_id = str(id)
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
@ -128,9 +127,9 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
assert current_user.current_tenant_id is not None
|
assert current_user.current_tenant_id is not None
|
||||||
api_based_extension_id = str(id)
|
api_based_extension_id = str(id)
|
||||||
tenant_id = current_user.current_tenant_id
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
parser.add_argument("name", type=str, required=True, location="json")
|
||||||
@ -157,9 +156,9 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
assert current_user.current_tenant_id is not None
|
assert current_user.current_tenant_id is not None
|
||||||
api_based_extension_id = str(id)
|
api_based_extension_id = str(id)
|
||||||
tenant_id = current_user.current_tenant_id
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
APIBasedExtensionService.delete(extension_data_from_db)
|
APIBasedExtensionService.delete(extension_data_from_db)
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource, fields
|
||||||
|
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
from . import api, console_ns
|
from . import api, console_ns
|
||||||
@ -23,9 +22,9 @@ class FeatureApi(Resource):
|
|||||||
@cloud_utm_record
|
@cloud_utm_record
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Get feature configuration for current tenant"""
|
"""Get feature configuration for current tenant"""
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
|
return FeatureService.get_features(current_tenant_id).model_dump()
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/system-features")
|
@console_ns.route("/system-features")
|
||||||
|
|||||||
@ -108,4 +108,4 @@ class FileSupportTypeApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
|
return {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}
|
||||||
|
|||||||
@ -14,8 +14,7 @@ from core.file import helpers as file_helpers
|
|||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||||
from libs.login import current_user
|
from libs.login import current_account_with_tenant
|
||||||
from models.account import Account
|
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
from . import console_ns
|
from . import console_ns
|
||||||
@ -64,8 +63,7 @@ class RemoteFileUploadApi(Resource):
|
|||||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert isinstance(current_user, Account)
|
user, _ = current_account_with_tenant()
|
||||||
user = current_user
|
|
||||||
upload_file = FileService(db.engine).upload_file(
|
upload_file = FileService(db.engine).upload_file(
|
||||||
filename=file_info.filename,
|
filename=file_info.filename,
|
||||||
content=content,
|
content=content,
|
||||||
|
|||||||
@ -5,18 +5,10 @@ from controllers.console import api, console_ns
|
|||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from services.plugin.endpoint_service import EndpointService
|
from services.plugin.endpoint_service import EndpointService
|
||||||
|
|
||||||
|
|
||||||
def _current_account_with_tenant() -> tuple[Account, str]:
|
|
||||||
assert isinstance(current_user, Account)
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
assert tenant_id is not None
|
|
||||||
return current_user, tenant_id
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/endpoints/create")
|
@console_ns.route("/workspaces/current/endpoints/create")
|
||||||
class EndpointCreateApi(Resource):
|
class EndpointCreateApi(Resource):
|
||||||
@api.doc("create_endpoint")
|
@api.doc("create_endpoint")
|
||||||
@ -41,7 +33,7 @@ class EndpointCreateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user, tenant_id = _current_account_with_tenant()
|
user, tenant_id = current_account_with_tenant()
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
@ -87,7 +79,7 @@ class EndpointListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user, tenant_id = _current_account_with_tenant()
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("page", type=int, required=True, location="args")
|
parser.add_argument("page", type=int, required=True, location="args")
|
||||||
@ -130,7 +122,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user, tenant_id = _current_account_with_tenant()
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("page", type=int, required=True, location="args")
|
parser.add_argument("page", type=int, required=True, location="args")
|
||||||
@ -172,7 +164,7 @@ class EndpointDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user, tenant_id = _current_account_with_tenant()
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("endpoint_id", type=str, required=True)
|
parser.add_argument("endpoint_id", type=str, required=True)
|
||||||
@ -212,7 +204,7 @@ class EndpointUpdateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user, tenant_id = _current_account_with_tenant()
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("endpoint_id", type=str, required=True)
|
parser.add_argument("endpoint_id", type=str, required=True)
|
||||||
@ -255,7 +247,7 @@ class EndpointEnableApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user, tenant_id = _current_account_with_tenant()
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("endpoint_id", type=str, required=True)
|
parser.add_argument("endpoint_id", type=str, required=True)
|
||||||
@ -288,7 +280,7 @@ class EndpointDisableApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user, tenant_id = _current_account_with_tenant()
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("endpoint_id", type=str, required=True)
|
parser.add_argument("endpoint_id", type=str, required=True)
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from controllers.console.wraps import (
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.member_fields import account_with_role_list_fields
|
from fields.member_fields import account_with_role_list_fields
|
||||||
from libs.helper import extract_remote_ip
|
from libs.helper import extract_remote_ip
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account, TenantAccountRole
|
from models.account import Account, TenantAccountRole
|
||||||
from services.account_service import AccountService, RegisterService, TenantService
|
from services.account_service import AccountService, RegisterService, TenantService
|
||||||
from services.errors.account import AccountAlreadyInTenantError
|
from services.errors.account import AccountAlreadyInTenantError
|
||||||
@ -41,8 +41,7 @@ class MemberListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_with_role_list_fields)
|
@marshal_with(account_with_role_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
if not current_user.current_tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||||
@ -69,9 +68,7 @@ class MemberInviteEmailApi(Resource):
|
|||||||
interface_language = args["language"]
|
interface_language = args["language"]
|
||||||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
inviter = current_user
|
inviter = current_user
|
||||||
if not inviter.current_tenant:
|
if not inviter.current_tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
@ -120,8 +117,7 @@ class MemberCancelInviteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, member_id):
|
def delete(self, member_id):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
if not current_user.current_tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
member = db.session.query(Account).where(Account.id == str(member_id)).first()
|
member = db.session.query(Account).where(Account.id == str(member_id)).first()
|
||||||
@ -160,9 +156,7 @@ class MemberUpdateRoleApi(Resource):
|
|||||||
|
|
||||||
if not TenantAccountRole.is_valid_role(new_role):
|
if not TenantAccountRole.is_valid_role(new_role):
|
||||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
if not current_user.current_tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
member = db.session.get(Account, str(member_id))
|
member = db.session.get(Account, str(member_id))
|
||||||
@ -189,8 +183,7 @@ class DatasetOperatorMemberListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_with_role_list_fields)
|
@marshal_with(account_with_role_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
if not current_user.current_tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||||
@ -212,10 +205,8 @@ class SendOwnerTransferEmailApi(Resource):
|
|||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
if AccountService.is_email_send_ip_limit(ip_address):
|
if AccountService.is_email_send_ip_limit(ip_address):
|
||||||
raise EmailSendIpLimitError()
|
raise EmailSendIpLimitError()
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
# check if the current user is the owner of the workspace
|
# check if the current user is the owner of the workspace
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
if not current_user.current_tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||||
@ -250,8 +241,7 @@ class OwnerTransferCheckApi(Resource):
|
|||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# check if the current user is the owner of the workspace
|
# check if the current user is the owner of the workspace
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
if not current_user.current_tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||||
@ -296,8 +286,7 @@ class OwnerTransfer(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# check if the current user is the owner of the workspace
|
# check if the current user is the owner of the workspace
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
if not current_user.current_tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import io
|
import io
|
||||||
|
|
||||||
from flask import send_file
|
from flask import send_file
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
@ -11,8 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from libs.helper import StrLen, uuid_value
|
from libs.helper import StrLen, uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
from services.model_provider_service import ModelProviderService
|
from services.model_provider_service import ModelProviderService
|
||||||
|
|
||||||
@ -23,11 +21,8 @@ class ModelProviderListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
tenant_id = current_tenant_id
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -52,11 +47,8 @@ class ModelProviderCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
tenant_id = current_tenant_id
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
# if credential_id is not provided, return current used credential
|
# if credential_id is not provided, return current used credential
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||||
@ -73,8 +65,7 @@ class ModelProviderCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
@ -85,11 +76,9 @@ class ModelProviderCredentialApi(Resource):
|
|||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
try:
|
try:
|
||||||
model_provider_service.create_provider_credential(
|
model_provider_service.create_provider_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
credentials=args["credentials"],
|
credentials=args["credentials"],
|
||||||
credential_name=args["name"],
|
credential_name=args["name"],
|
||||||
@ -103,8 +92,7 @@ class ModelProviderCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def put(self, provider: str):
|
def put(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
@ -116,11 +104,9 @@ class ModelProviderCredentialApi(Resource):
|
|||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
try:
|
try:
|
||||||
model_provider_service.update_provider_credential(
|
model_provider_service.update_provider_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
credentials=args["credentials"],
|
credentials=args["credentials"],
|
||||||
credential_id=args["credential_id"],
|
credential_id=args["credential_id"],
|
||||||
@ -135,19 +121,16 @@ class ModelProviderCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, provider: str):
|
def delete(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
model_provider_service.remove_provider_credential(
|
model_provider_service.remove_provider_credential(
|
||||||
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
@ -159,19 +142,16 @@ class ModelProviderCredentialSwitchApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
service = ModelProviderService()
|
service = ModelProviderService()
|
||||||
service.switch_active_provider_credential(
|
service.switch_active_provider_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
credential_id=args["credential_id"],
|
credential_id=args["credential_id"],
|
||||||
)
|
)
|
||||||
@ -184,15 +164,12 @@ class ModelProviderValidateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
tenant_id = current_tenant_id
|
||||||
raise ValueError("No current tenant")
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
@ -240,14 +217,11 @@ class PreferredProviderTypeUpdateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
tenant_id = current_tenant_id
|
||||||
raise ValueError("No current tenant")
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -276,14 +250,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
|||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
if provider != "anthropic":
|
if provider != "anthropic":
|
||||||
raise ValueError(f"provider name {provider} is invalid")
|
raise ValueError(f"provider name {provider} is invalid")
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
BillingService.is_tenant_owner_or_admin(current_user)
|
BillingService.is_tenant_owner_or_admin(current_user)
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
data = BillingService.get_model_provider_payment_link(
|
data = BillingService.get_model_provider_payment_link(
|
||||||
provider_name=provider,
|
provider_name=provider,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
account_id=current_user.id,
|
account_id=current_user.id,
|
||||||
prefilled_email=current_user.email,
|
prefilled_email=current_user.email,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
@ -10,7 +9,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from libs.helper import StrLen, uuid_value
|
from libs.helper import StrLen, uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||||
from services.model_provider_service import ModelProviderService
|
from services.model_provider_service import ModelProviderService
|
||||||
|
|
||||||
@ -23,6 +22,8 @@ class DefaultModelApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"model_type",
|
"model_type",
|
||||||
@ -34,8 +35,6 @@ class DefaultModelApi(Resource):
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
default_model_entity = model_provider_service.get_default_model_of_model_type(
|
default_model_entity = model_provider_service.get_default_model_of_model_type(
|
||||||
tenant_id=tenant_id, model_type=args["model_type"]
|
tenant_id=tenant_id, model_type=args["model_type"]
|
||||||
@ -47,15 +46,14 @@ class DefaultModelApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
|
current_user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
|
parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
model_settings = args["model_settings"]
|
model_settings = args["model_settings"]
|
||||||
for model_setting in model_settings:
|
for model_setting in model_settings:
|
||||||
@ -92,7 +90,7 @@ class ModelProviderModelApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
|
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
|
||||||
@ -104,11 +102,11 @@ class ModelProviderModelApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
# To save the model's load balance configs
|
# To save the model's load balance configs
|
||||||
|
current_user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -129,7 +127,7 @@ class ModelProviderModelApi(Resource):
|
|||||||
raise ValueError("credential_id is required when configuring a custom-model")
|
raise ValueError("credential_id is required when configuring a custom-model")
|
||||||
service = ModelProviderService()
|
service = ModelProviderService()
|
||||||
service.switch_active_custom_model_credential(
|
service.switch_active_custom_model_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model_type=args["model_type"],
|
model_type=args["model_type"],
|
||||||
model=args["model"],
|
model=args["model"],
|
||||||
@ -164,11 +162,11 @@ class ModelProviderModelApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, provider: str):
|
def delete(self, provider: str):
|
||||||
|
current_user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -195,7 +193,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
|
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
|
||||||
@ -257,6 +255,8 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
|
current_user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
@ -274,7 +274,6 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -301,6 +300,8 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def put(self, provider: str):
|
def put(self, provider: str):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
@ -323,7 +324,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
model_provider_service.update_model_credential(
|
model_provider_service.update_model_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model_type=args["model_type"],
|
model_type=args["model_type"],
|
||||||
model=args["model"],
|
model=args["model"],
|
||||||
@ -340,6 +341,8 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, provider: str):
|
def delete(self, provider: str):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -357,7 +360,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
model_provider_service.remove_model_credential(
|
model_provider_service.remove_model_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model_type=args["model_type"],
|
model_type=args["model_type"],
|
||||||
model=args["model"],
|
model=args["model"],
|
||||||
@ -373,6 +376,8 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -390,7 +395,7 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
|||||||
|
|
||||||
service = ModelProviderService()
|
service = ModelProviderService()
|
||||||
service.add_model_credential_to_model_list(
|
service.add_model_credential_to_model_list(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model_type=args["model_type"],
|
model_type=args["model_type"],
|
||||||
model=args["model"],
|
model=args["model"],
|
||||||
@ -407,7 +412,7 @@ class ModelProviderModelEnableApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def patch(self, provider: str):
|
def patch(self, provider: str):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
@ -437,7 +442,7 @@ class ModelProviderModelDisableApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def patch(self, provider: str):
|
def patch(self, provider: str):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
@ -465,7 +470,7 @@ class ModelProviderModelValidateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
@ -514,8 +519,7 @@ class ModelProviderModelParameterRuleApi(Resource):
|
|||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
|
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, tenant_id = current_account_with_tenant()
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
parameter_rules = model_provider_service.get_model_parameter_rules(
|
parameter_rules = model_provider_service.get_model_parameter_rules(
|
||||||
@ -531,8 +535,7 @@ class ModelProviderAvailableModelApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, model_type):
|
def get(self, model_type):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import io
|
import io
|
||||||
|
|
||||||
from flask import request, send_file
|
from flask import request, send_file
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
@ -11,7 +10,7 @@ from controllers.console.workspace import plugin_permission_required
|
|||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
||||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||||
from services.plugin.plugin_parameter_service import PluginParameterService
|
from services.plugin.plugin_parameter_service import PluginParameterService
|
||||||
@ -26,7 +25,7 @@ class PluginDebuggingKeyApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(debug_required=True)
|
@plugin_permission_required(debug_required=True)
|
||||||
def get(self):
|
def get(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return {
|
return {
|
||||||
@ -44,7 +43,7 @@ class PluginListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("page", type=int, required=False, location="args", default=1)
|
parser.add_argument("page", type=int, required=False, location="args", default=1)
|
||||||
parser.add_argument("page_size", type=int, required=False, location="args", default=256)
|
parser.add_argument("page_size", type=int, required=False, location="args", default=256)
|
||||||
@ -81,7 +80,7 @@ class PluginListInstallationsFromIdsApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("plugin_ids", type=list, required=True, location="json")
|
parser.add_argument("plugin_ids", type=list, required=True, location="json")
|
||||||
@ -120,7 +119,7 @@ class PluginUploadFromPkgApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
file = request.files["pkg"]
|
file = request.files["pkg"]
|
||||||
|
|
||||||
@ -144,7 +143,7 @@ class PluginUploadFromGithubApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("repo", type=str, required=True, location="json")
|
parser.add_argument("repo", type=str, required=True, location="json")
|
||||||
@ -167,7 +166,7 @@ class PluginUploadFromBundleApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
file = request.files["bundle"]
|
file = request.files["bundle"]
|
||||||
|
|
||||||
@ -191,7 +190,7 @@ class PluginInstallFromPkgApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
||||||
@ -217,7 +216,7 @@ class PluginInstallFromGithubApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("repo", type=str, required=True, location="json")
|
parser.add_argument("repo", type=str, required=True, location="json")
|
||||||
@ -247,7 +246,7 @@ class PluginInstallFromMarketplaceApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
||||||
@ -273,7 +272,7 @@ class PluginFetchMarketplacePkgApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def get(self):
|
def get(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||||
@ -299,7 +298,7 @@ class PluginFetchManifestApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def get(self):
|
def get(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||||
@ -324,7 +323,7 @@ class PluginFetchInstallTasksApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def get(self):
|
def get(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("page", type=int, required=True, location="args")
|
parser.add_argument("page", type=int, required=True, location="args")
|
||||||
@ -346,7 +345,7 @@ class PluginFetchInstallTaskApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def get(self, task_id: str):
|
def get(self, task_id: str):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
|
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
|
||||||
@ -361,7 +360,7 @@ class PluginDeleteInstallTaskApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self, task_id: str):
|
def post(self, task_id: str):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
|
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
|
||||||
@ -376,7 +375,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
|
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
|
||||||
@ -391,7 +390,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self, task_id: str, identifier: str):
|
def post(self, task_id: str, identifier: str):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
|
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
|
||||||
@ -406,7 +405,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||||
@ -430,7 +429,7 @@ class PluginUpgradeFromGithubApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||||
@ -466,7 +465,7 @@ class PluginUninstallApi(Resource):
|
|||||||
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
|
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
|
||||||
args = req.parse_args()
|
args = req.parse_args()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
|
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
|
||||||
@ -480,6 +479,7 @@ class PluginChangePermissionApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
user = current_user
|
user = current_user
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
@ -492,7 +492,7 @@ class PluginChangePermissionApi(Resource):
|
|||||||
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
|
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
|
||||||
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
|
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
|
||||||
|
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
|
|
||||||
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
|
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
|
||||||
|
|
||||||
@ -503,7 +503,7 @@ class PluginFetchPermissionApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
permission = PluginPermissionService.get_permission(tenant_id)
|
permission = PluginPermissionService.get_permission(tenant_id)
|
||||||
if not permission:
|
if not permission:
|
||||||
@ -529,10 +529,10 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
# check if the user is admin or owner
|
# check if the user is admin or owner
|
||||||
|
current_user, tenant_id = current_account_with_tenant()
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
user_id = current_user.id
|
user_id = current_user.id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -567,7 +567,7 @@ class PluginChangePreferencesApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
@ -576,8 +576,6 @@ class PluginChangePreferencesApi(Resource):
|
|||||||
req.add_argument("auto_upgrade", type=dict, required=True, location="json")
|
req.add_argument("auto_upgrade", type=dict, required=True, location="json")
|
||||||
args = req.parse_args()
|
args = req.parse_args()
|
||||||
|
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
permission = args["permission"]
|
permission = args["permission"]
|
||||||
|
|
||||||
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
|
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
|
||||||
@ -623,7 +621,7 @@ class PluginFetchPreferencesApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
permission = PluginPermissionService.get_permission(tenant_id)
|
permission = PluginPermissionService.get_permission(tenant_id)
|
||||||
permission_dict = {
|
permission_dict = {
|
||||||
@ -663,7 +661,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
# exclude one single plugin
|
# exclude one single plugin
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
req = reqparse.RequestParser()
|
req = reqparse.RequestParser()
|
||||||
req.add_argument("plugin_id", type=str, required=True, location="json")
|
req.add_argument("plugin_id", type=str, required=True, location="json")
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import io
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from flask import make_response, redirect, request, send_file
|
from flask import make_response, redirect, request, send_file
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import (
|
from flask_restx import (
|
||||||
Resource,
|
Resource,
|
||||||
reqparse,
|
reqparse,
|
||||||
@ -24,7 +23,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||||||
from core.plugin.entities.plugin_daemon import CredentialType
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
from core.plugin.impl.oauth import OAuthHandler
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.provider_ids import ToolProviderID
|
from models.provider_ids import ToolProviderID
|
||||||
|
|
||||||
# from models.provider_ids import ToolProviderID
|
# from models.provider_ids import ToolProviderID
|
||||||
@ -55,10 +54,9 @@ class ToolProviderListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
req = reqparse.RequestParser()
|
req = reqparse.RequestParser()
|
||||||
req.add_argument(
|
req.add_argument(
|
||||||
@ -80,9 +78,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
user = current_user
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
BuiltinToolManageService.list_builtin_tool_provider_tools(
|
BuiltinToolManageService.list_builtin_tool_provider_tools(
|
||||||
@ -98,9 +94,7 @@ class ToolBuiltinProviderInfoApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
user = current_user
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
||||||
|
|
||||||
@ -111,11 +105,10 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider):
|
def post(self, provider):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
req = reqparse.RequestParser()
|
req = reqparse.RequestParser()
|
||||||
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
args = req.parse_args()
|
args = req.parse_args()
|
||||||
@ -133,10 +126,9 @@ class ToolBuiltinProviderAddApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider):
|
def post(self, provider):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
@ -163,13 +155,12 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider):
|
def post(self, provider):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
@ -195,7 +186,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||||
@ -220,13 +211,12 @@ class ToolApiProviderAddApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
@ -260,10 +250,9 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
@ -284,10 +273,9 @@ class ToolApiProviderListToolsApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
@ -310,13 +298,12 @@ class ToolApiProviderUpdateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
@ -352,13 +339,12 @@ class ToolApiProviderDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
@ -379,10 +365,9 @@ class ToolApiProviderGetApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
@ -403,8 +388,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider, credential_type):
|
def get(self, provider, credential_type):
|
||||||
user = current_user
|
_, tenant_id = current_account_with_tenant()
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
BuiltinToolManageService.list_builtin_provider_credentials_schema(
|
BuiltinToolManageService.list_builtin_provider_credentials_schema(
|
||||||
@ -446,9 +430,9 @@ class ToolApiProviderPreviousTestApi(Resource):
|
|||||||
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
return ApiToolManageService.test_api_tool_preview(
|
return ApiToolManageService.test_api_tool_preview(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
args["provider_name"] or "",
|
args["provider_name"] or "",
|
||||||
args["tool_name"],
|
args["tool_name"],
|
||||||
args["credentials"],
|
args["credentials"],
|
||||||
@ -464,13 +448,12 @@ class ToolWorkflowProviderCreateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
reqparser = reqparse.RequestParser()
|
reqparser = reqparse.RequestParser()
|
||||||
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
|
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
@ -504,13 +487,12 @@ class ToolWorkflowProviderUpdateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
reqparser = reqparse.RequestParser()
|
reqparser = reqparse.RequestParser()
|
||||||
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
@ -547,13 +529,12 @@ class ToolWorkflowProviderDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
reqparser = reqparse.RequestParser()
|
reqparser = reqparse.RequestParser()
|
||||||
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
@ -573,10 +554,9 @@ class ToolWorkflowProviderGetApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
|
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||||
@ -608,10 +588,9 @@ class ToolWorkflowProviderListToolApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
|
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
|
||||||
@ -633,10 +612,9 @@ class ToolBuiltinListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
[
|
[
|
||||||
@ -655,8 +633,7 @@ class ToolApiListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
_, tenant_id = current_account_with_tenant()
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
[
|
[
|
||||||
@ -674,10 +651,9 @@ class ToolWorkflowListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
[
|
[
|
||||||
@ -711,19 +687,18 @@ class ToolPluginOAuthApi(Resource):
|
|||||||
provider_name = tool_provider.provider_name
|
provider_name = tool_provider.provider_name
|
||||||
|
|
||||||
# todo check permission
|
# todo check permission
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
|
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
|
||||||
if oauth_client_params is None:
|
if oauth_client_params is None:
|
||||||
raise Forbidden("no oauth available client config found for this tool provider")
|
raise Forbidden("no oauth available client config found for this tool provider")
|
||||||
|
|
||||||
oauth_handler = OAuthHandler()
|
oauth_handler = OAuthHandler()
|
||||||
context_id = OAuthProxyService.create_proxy_context(
|
context_id = OAuthProxyService.create_proxy_context(
|
||||||
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
|
user_id=user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
|
||||||
)
|
)
|
||||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
||||||
authorization_url_response = oauth_handler.get_authorization_url(
|
authorization_url_response = oauth_handler.get_authorization_url(
|
||||||
@ -802,11 +777,12 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider):
|
def post(self, provider):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return BuiltinToolManageService.set_default_provider(
|
return BuiltinToolManageService.set_default_provider(
|
||||||
tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
|
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -821,13 +797,13 @@ class ToolOAuthCustomClient(Resource):
|
|||||||
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
return BuiltinToolManageService.save_custom_oauth_client_params(
|
return BuiltinToolManageService.save_custom_oauth_client_params(
|
||||||
tenant_id=user.current_tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
client_params=args.get("client_params", {}),
|
client_params=args.get("client_params", {}),
|
||||||
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
|
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
|
||||||
@ -837,20 +813,18 @@ class ToolOAuthCustomClient(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
BuiltinToolManageService.get_custom_oauth_client_params(
|
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||||
tenant_id=current_user.current_tenant_id, provider=provider
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, provider):
|
def delete(self, provider):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
BuiltinToolManageService.delete_custom_oauth_client_params(
|
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||||
tenant_id=current_user.current_tenant_id, provider=provider
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -860,9 +834,10 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
|
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
|
||||||
tenant_id=current_user.current_tenant_id, provider_name=provider
|
tenant_id=current_tenant_id, provider_name=provider
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -873,7 +848,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
|
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
|
||||||
@ -902,12 +877,12 @@ class ToolProviderMCPApi(Resource):
|
|||||||
)
|
)
|
||||||
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
if not is_valid_url(args["server_url"]):
|
if not is_valid_url(args["server_url"]):
|
||||||
raise ValueError("Server URL is not valid.")
|
raise ValueError("Server URL is not valid.")
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
MCPToolManageService.create_mcp_provider(
|
MCPToolManageService.create_mcp_provider(
|
||||||
tenant_id=user.current_tenant_id,
|
tenant_id=tenant_id,
|
||||||
server_url=args["server_url"],
|
server_url=args["server_url"],
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
icon=args["icon"],
|
icon=args["icon"],
|
||||||
@ -942,8 +917,9 @@ class ToolProviderMCPApi(Resource):
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise ValueError("Server URL is not valid.")
|
raise ValueError("Server URL is not valid.")
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
MCPToolManageService.update_mcp_provider(
|
MCPToolManageService.update_mcp_provider(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider_id=args["provider_id"],
|
provider_id=args["provider_id"],
|
||||||
server_url=args["server_url"],
|
server_url=args["server_url"],
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
@ -964,7 +940,8 @@ class ToolProviderMCPApi(Resource):
|
|||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
@ -979,7 +956,7 @@ class ToolMCPAuthApi(Resource):
|
|||||||
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
|
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
provider_id = args["provider_id"]
|
provider_id = args["provider_id"]
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||||
if not provider:
|
if not provider:
|
||||||
raise ValueError("provider not found")
|
raise ValueError("provider not found")
|
||||||
@ -1020,8 +997,8 @@ class ToolMCPDetailApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider_id):
|
def get(self, provider_id):
|
||||||
user = current_user
|
_, tenant_id = current_account_with_tenant()
|
||||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
|
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||||
|
|
||||||
|
|
||||||
@ -1031,8 +1008,7 @@ class ToolMCPListAllApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
_, tenant_id = current_account_with_tenant()
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
|
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
|
||||||
|
|
||||||
@ -1045,7 +1021,7 @@ class ToolMCPUpdateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider_id):
|
def get(self, provider_id):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
|
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
|||||||
@ -12,8 +12,8 @@ from configs import dify_config
|
|||||||
from controllers.console.workspace.error import AccountNotInitializedError
|
from controllers.console.workspace.error import AccountNotInitializedError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from libs.login import current_user
|
from libs.login import current_account_with_tenant
|
||||||
from models.account import Account, AccountStatus
|
from models.account import AccountStatus
|
||||||
from models.dataset import RateLimitLog
|
from models.dataset import RateLimitLog
|
||||||
from models.model import DifySetup
|
from models.model import DifySetup
|
||||||
from services.feature_service import FeatureService, LicenseStatus
|
from services.feature_service import FeatureService, LicenseStatus
|
||||||
@ -25,16 +25,13 @@ P = ParamSpec("P")
|
|||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
def _current_account() -> Account:
|
|
||||||
assert isinstance(current_user, Account)
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
|
|
||||||
def account_initialization_required(view: Callable[P, R]):
|
def account_initialization_required(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
# check account initialization
|
# check account initialization
|
||||||
account = _current_account()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
account = current_user
|
||||||
|
|
||||||
if account.status == AccountStatus.UNINITIALIZED:
|
if account.status == AccountStatus.UNINITIALIZED:
|
||||||
raise AccountNotInitializedError()
|
raise AccountNotInitializedError()
|
||||||
@ -80,9 +77,8 @@ def only_edition_self_hosted(view: Callable[P, R]):
|
|||||||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
account = _current_account()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert account.current_tenant_id is not None
|
features = FeatureService.get_features(current_tenant_id)
|
||||||
features = FeatureService.get_features(account.current_tenant_id)
|
|
||||||
if not features.billing.enabled:
|
if not features.billing.enabled:
|
||||||
abort(403, "Billing feature is not enabled.")
|
abort(403, "Billing feature is not enabled.")
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
@ -94,10 +90,8 @@ def cloud_edition_billing_resource_check(resource: str):
|
|||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
account = _current_account()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert account.current_tenant_id is not None
|
features = FeatureService.get_features(current_tenant_id)
|
||||||
tenant_id = account.current_tenant_id
|
|
||||||
features = FeatureService.get_features(tenant_id)
|
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
members = features.members
|
members = features.members
|
||||||
apps = features.apps
|
apps = features.apps
|
||||||
@ -138,9 +132,8 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
|||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
account = _current_account()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert account.current_tenant_id is not None
|
features = FeatureService.get_features(current_tenant_id)
|
||||||
features = FeatureService.get_features(account.current_tenant_id)
|
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
if resource == "add_segment":
|
if resource == "add_segment":
|
||||||
if features.billing.subscription.plan == "sandbox":
|
if features.billing.subscription.plan == "sandbox":
|
||||||
@ -163,13 +156,11 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
if resource == "knowledge":
|
if resource == "knowledge":
|
||||||
account = _current_account()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert account.current_tenant_id is not None
|
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_tenant_id)
|
||||||
tenant_id = account.current_tenant_id
|
|
||||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
|
|
||||||
if knowledge_rate_limit.enabled:
|
if knowledge_rate_limit.enabled:
|
||||||
current_time = int(time.time() * 1000)
|
current_time = int(time.time() * 1000)
|
||||||
key = f"rate_limit_{tenant_id}"
|
key = f"rate_limit_{current_tenant_id}"
|
||||||
|
|
||||||
redis_client.zadd(key, {current_time: current_time})
|
redis_client.zadd(key, {current_time: current_time})
|
||||||
|
|
||||||
@ -180,7 +171,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||||||
if request_count > knowledge_rate_limit.limit:
|
if request_count > knowledge_rate_limit.limit:
|
||||||
# add ratelimit record
|
# add ratelimit record
|
||||||
rate_limit_log = RateLimitLog(
|
rate_limit_log = RateLimitLog(
|
||||||
tenant_id=tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||||
operation="knowledge",
|
operation="knowledge",
|
||||||
)
|
)
|
||||||
@ -200,17 +191,15 @@ def cloud_utm_record(view: Callable[P, R]):
|
|||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
account = _current_account()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert account.current_tenant_id is not None
|
features = FeatureService.get_features(current_tenant_id)
|
||||||
tenant_id = account.current_tenant_id
|
|
||||||
features = FeatureService.get_features(tenant_id)
|
|
||||||
|
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
utm_info = request.cookies.get("utm_info")
|
utm_info = request.cookies.get("utm_info")
|
||||||
|
|
||||||
if utm_info:
|
if utm_info:
|
||||||
utm_info_dict: dict = json.loads(utm_info)
|
utm_info_dict: dict = json.loads(utm_info)
|
||||||
OperationService.record_utm(tenant_id, utm_info_dict)
|
OperationService.record_utm(current_tenant_id, utm_info_dict)
|
||||||
|
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
@ -289,9 +278,8 @@ def enable_change_email(view: Callable[P, R]):
|
|||||||
def is_allow_transfer_owner(view: Callable[P, R]):
|
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
account = _current_account()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert account.current_tenant_id is not None
|
features = FeatureService.get_features(current_tenant_id)
|
||||||
features = FeatureService.get_features(account.current_tenant_id)
|
|
||||||
if features.is_allow_transfer_workspace:
|
if features.is_allow_transfer_workspace:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
@ -301,12 +289,11 @@ def is_allow_transfer_owner(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def knowledge_pipeline_publish_enabled(view):
|
def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
account = _current_account()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert account.current_tenant_id is not None
|
features = FeatureService.get_features(current_tenant_id)
|
||||||
features = FeatureService.get_features(account.current_tenant_id)
|
|
||||||
if features.knowledge_pipeline.publish_enabled:
|
if features.knowledge_pipeline.publish_enabled:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
abort(403)
|
abort(403)
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import marshal, reqparse
|
from flask_restx import marshal, reqparse
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
@ -16,6 +15,7 @@ from core.model_manager import ModelManager
|
|||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||||
|
from libs.login import current_account_with_tenant
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
|
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
|
||||||
@ -66,6 +66,7 @@ class SegmentApi(DatasetApiResource):
|
|||||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||||
def post(self, tenant_id: str, dataset_id: str, document_id: str):
|
def post(self, tenant_id: str, dataset_id: str, document_id: str):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
"""Create single segment."""
|
"""Create single segment."""
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
@ -84,7 +85,7 @@ class SegmentApi(DatasetApiResource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
@ -117,6 +118,7 @@ class SegmentApi(DatasetApiResource):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
def get(self, tenant_id: str, dataset_id: str, document_id: str):
|
def get(self, tenant_id: str, dataset_id: str, document_id: str):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
"""Get segments."""
|
"""Get segments."""
|
||||||
# check dataset
|
# check dataset
|
||||||
page = request.args.get("page", default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
@ -133,7 +135,7 @@ class SegmentApi(DatasetApiResource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
@ -149,7 +151,7 @@ class SegmentApi(DatasetApiResource):
|
|||||||
|
|
||||||
segments, total = SegmentService.get_segments(
|
segments, total = SegmentService.get_segments(
|
||||||
document_id=document_id,
|
document_id=document_id,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
status_list=args["status"],
|
status_list=args["status"],
|
||||||
keyword=args["keyword"],
|
keyword=args["keyword"],
|
||||||
page=page,
|
page=page,
|
||||||
@ -184,6 +186,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||||||
)
|
)
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||||
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
if not dataset:
|
if not dataset:
|
||||||
@ -195,7 +198,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||||||
if not document:
|
if not document:
|
||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
# check segment
|
# check segment
|
||||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
SegmentService.delete_segment(segment, document, dataset)
|
SegmentService.delete_segment(segment, document, dataset)
|
||||||
@ -217,6 +220,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||||
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
if not dataset:
|
if not dataset:
|
||||||
@ -232,7 +236,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
@ -244,7 +248,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
# check segment
|
# check segment
|
||||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
|
|
||||||
@ -266,6 +270,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
if not dataset:
|
if not dataset:
|
||||||
@ -277,7 +282,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||||||
if not document:
|
if not document:
|
||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
# check segment
|
# check segment
|
||||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
|
|
||||||
@ -307,6 +312,7 @@ class ChildChunkApi(DatasetApiResource):
|
|||||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||||
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
"""Create child chunk."""
|
"""Create child chunk."""
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
@ -319,7 +325,7 @@ class ChildChunkApi(DatasetApiResource):
|
|||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
|
|
||||||
# check segment
|
# check segment
|
||||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
|
|
||||||
@ -328,7 +334,7 @@ class ChildChunkApi(DatasetApiResource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
@ -364,6 +370,7 @@ class ChildChunkApi(DatasetApiResource):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
"""Get child chunks."""
|
"""Get child chunks."""
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
@ -376,7 +383,7 @@ class ChildChunkApi(DatasetApiResource):
|
|||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
|
|
||||||
# check segment
|
# check segment
|
||||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
|
|
||||||
@ -423,6 +430,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||||
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
|
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
"""Delete child chunk."""
|
"""Delete child chunk."""
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
@ -435,7 +443,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
|
|
||||||
# check segment
|
# check segment
|
||||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
|
|
||||||
@ -444,9 +452,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
|
|
||||||
# check child chunk
|
# check child chunk
|
||||||
child_chunk = SegmentService.get_child_chunk_by_id(
|
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
|
||||||
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
if not child_chunk:
|
if not child_chunk:
|
||||||
raise NotFound("Child chunk not found.")
|
raise NotFound("Child chunk not found.")
|
||||||
|
|
||||||
@ -483,6 +489,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||||
def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
|
def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
"""Update child chunk."""
|
"""Update child chunk."""
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
@ -495,7 +502,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
|
|
||||||
# get segment
|
# get segment
|
||||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
|
|
||||||
@ -504,9 +511,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
|
|
||||||
# get child chunk
|
# get child chunk
|
||||||
child_chunk = SegmentService.get_child_chunk_by_id(
|
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
|
||||||
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
if not child_chunk:
|
if not child_chunk:
|
||||||
raise NotFound("Child chunk not found.")
|
raise NotFound("Child chunk not found.")
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from cachetools import TTLCache, cachedmethod
|
||||||
from redis.exceptions import RedisError
|
from redis.exceptions import RedisError
|
||||||
from sqlalchemy.orm import DeclarativeMeta
|
from sqlalchemy.orm import DeclarativeMeta
|
||||||
|
|
||||||
@ -45,6 +47,8 @@ class AppQueueManager:
|
|||||||
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||||
|
|
||||||
self._q = q
|
self._q = q
|
||||||
|
self._stopped_cache: TTLCache[tuple, bool] = TTLCache(maxsize=1, ttl=1)
|
||||||
|
self._cache_lock = threading.Lock()
|
||||||
|
|
||||||
def listen(self):
|
def listen(self):
|
||||||
"""
|
"""
|
||||||
@ -157,6 +161,7 @@ class AppQueueManager:
|
|||||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||||
redis_client.setex(stopped_cache_key, 600, 1)
|
redis_client.setex(stopped_cache_key, 600, 1)
|
||||||
|
|
||||||
|
@cachedmethod(lambda self: self._stopped_cache, lock=lambda self: self._cache_lock)
|
||||||
def _is_stopped(self) -> bool:
|
def _is_stopped(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if task is stopped
|
Check if task is stopped
|
||||||
|
|||||||
@ -472,6 +472,9 @@ class ProviderConfiguration(BaseModel):
|
|||||||
provider_model_credentials_cache.delete()
|
provider_model_credentials_cache.delete()
|
||||||
|
|
||||||
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
|
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
|
||||||
|
else:
|
||||||
|
# some historical data may have a provider record but not be set as valid
|
||||||
|
provider_record.is_valid = True
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import uuid
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Final
|
from typing import Final, cast
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@ -199,7 +199,7 @@ def convert_to_trace_id(uuid_v4: str | None) -> int:
|
|||||||
raise ValueError("UUID cannot be None")
|
raise ValueError("UUID cannot be None")
|
||||||
try:
|
try:
|
||||||
uuid_obj = uuid.UUID(uuid_v4)
|
uuid_obj = uuid.UUID(uuid_v4)
|
||||||
return uuid_obj.int
|
return cast(int, uuid_obj.int)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e
|
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,7 @@ class TracingProviderEnum(StrEnum):
|
|||||||
OPIK = "opik"
|
OPIK = "opik"
|
||||||
WEAVE = "weave"
|
WEAVE = "weave"
|
||||||
ALIYUN = "aliyun"
|
ALIYUN = "aliyun"
|
||||||
|
TENCENT = "tencent"
|
||||||
|
|
||||||
|
|
||||||
class BaseTracingConfig(BaseModel):
|
class BaseTracingConfig(BaseModel):
|
||||||
@ -195,5 +196,32 @@ class AliyunConfig(BaseTracingConfig):
|
|||||||
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
|
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||||
|
|
||||||
|
|
||||||
|
class TencentConfig(BaseTracingConfig):
|
||||||
|
"""
|
||||||
|
Tencent APM tracing config
|
||||||
|
"""
|
||||||
|
|
||||||
|
token: str
|
||||||
|
endpoint: str
|
||||||
|
service_name: str
|
||||||
|
|
||||||
|
@field_validator("token")
|
||||||
|
@classmethod
|
||||||
|
def token_validator(cls, v, info: ValidationInfo):
|
||||||
|
if not v or v.strip() == "":
|
||||||
|
raise ValueError("Token cannot be empty")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("endpoint")
|
||||||
|
@classmethod
|
||||||
|
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||||
|
return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
|
||||||
|
|
||||||
|
@field_validator("service_name")
|
||||||
|
@classmethod
|
||||||
|
def service_name_validator(cls, v, info: ValidationInfo):
|
||||||
|
return cls.validate_project_field(v, "dify_app")
|
||||||
|
|
||||||
|
|
||||||
OPS_FILE_PATH = "ops_trace/"
|
OPS_FILE_PATH = "ops_trace/"
|
||||||
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
||||||
|
|||||||
@ -90,6 +90,7 @@ class SuggestedQuestionTraceInfo(BaseTraceInfo):
|
|||||||
|
|
||||||
class DatasetRetrievalTraceInfo(BaseTraceInfo):
|
class DatasetRetrievalTraceInfo(BaseTraceInfo):
|
||||||
documents: Any = None
|
documents: Any = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ToolTraceInfo(BaseTraceInfo):
|
class ToolTraceInfo(BaseTraceInfo):
|
||||||
|
|||||||
@ -120,6 +120,17 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
|||||||
"trace_instance": AliyunDataTrace,
|
"trace_instance": AliyunDataTrace,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case TracingProviderEnum.TENCENT:
|
||||||
|
from core.ops.entities.config_entity import TencentConfig
|
||||||
|
from core.ops.tencent_trace.tencent_trace import TencentDataTrace
|
||||||
|
|
||||||
|
return {
|
||||||
|
"config_class": TencentConfig,
|
||||||
|
"secret_keys": ["token"],
|
||||||
|
"other_keys": ["endpoint", "service_name"],
|
||||||
|
"trace_instance": TencentDataTrace,
|
||||||
|
}
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||||
|
|
||||||
@ -723,6 +734,7 @@ class TraceTask:
|
|||||||
end_time=timer.get("end"),
|
end_time=timer.get("end"),
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
message_data=message_data.to_dict(),
|
message_data=message_data.to_dict(),
|
||||||
|
error=kwargs.get("error"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataset_retrieval_trace_info
|
return dataset_retrieval_trace_info
|
||||||
@ -889,6 +901,7 @@ class TraceQueueManager:
|
|||||||
continue
|
continue
|
||||||
file_id = uuid4().hex
|
file_id = uuid4().hex
|
||||||
trace_info = task.execute()
|
trace_info = task.execute()
|
||||||
|
|
||||||
task_data = TaskData(
|
task_data = TaskData(
|
||||||
app_id=task.app_id,
|
app_id=task.app_id,
|
||||||
trace_info_type=type(trace_info).__name__,
|
trace_info_type=type(trace_info).__name__,
|
||||||
|
|||||||
0
api/core/ops/tencent_trace/__init__.py
Normal file
0
api/core/ops/tencent_trace/__init__.py
Normal file
337
api/core/ops/tencent_trace/client.py
Normal file
337
api/core/ops/tencent_trace/client.py
Normal file
@ -0,0 +1,337 @@
|
|||||||
|
"""
|
||||||
|
Tencent APM Trace Client - handles network operations, metrics, and API communication
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from opentelemetry.metrics import Meter
|
||||||
|
from opentelemetry.metrics._internal.instrument import Histogram
|
||||||
|
from opentelemetry.sdk.metrics.export import MetricReader
|
||||||
|
|
||||||
|
from opentelemetry import trace as trace_api
|
||||||
|
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||||
|
from opentelemetry.sdk.resources import Resource
|
||||||
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||||
|
from opentelemetry.semconv.resource import ResourceAttributes
|
||||||
|
from opentelemetry.trace import SpanKind
|
||||||
|
from opentelemetry.util.types import AttributeValue
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
from .entities.tencent_semconv import LLM_OPERATION_DURATION
|
||||||
|
from .entities.tencent_trace_entity import SpanData
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TencentTraceClient:
|
||||||
|
"""Tencent APM trace client using OpenTelemetry OTLP exporter"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
service_name: str,
|
||||||
|
endpoint: str,
|
||||||
|
token: str,
|
||||||
|
max_queue_size: int = 1000,
|
||||||
|
schedule_delay_sec: int = 5,
|
||||||
|
max_export_batch_size: int = 50,
|
||||||
|
metrics_export_interval_sec: int = 10,
|
||||||
|
):
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.token = token
|
||||||
|
self.service_name = service_name
|
||||||
|
self.metrics_export_interval_sec = metrics_export_interval_sec
|
||||||
|
|
||||||
|
self.resource = Resource(
|
||||||
|
attributes={
|
||||||
|
ResourceAttributes.SERVICE_NAME: service_name,
|
||||||
|
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||||
|
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||||
|
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Prepare gRPC endpoint/metadata
|
||||||
|
grpc_endpoint, insecure, _, _ = self._resolve_grpc_target(endpoint)
|
||||||
|
|
||||||
|
headers = (("authorization", f"Bearer {token}"),)
|
||||||
|
|
||||||
|
self.exporter = OTLPSpanExporter(
|
||||||
|
endpoint=grpc_endpoint,
|
||||||
|
headers=headers,
|
||||||
|
insecure=insecure,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tracer_provider = TracerProvider(resource=self.resource)
|
||||||
|
self.span_processor = BatchSpanProcessor(
|
||||||
|
span_exporter=self.exporter,
|
||||||
|
max_queue_size=max_queue_size,
|
||||||
|
schedule_delay_millis=schedule_delay_sec * 1000,
|
||||||
|
max_export_batch_size=max_export_batch_size,
|
||||||
|
)
|
||||||
|
self.tracer_provider.add_span_processor(self.span_processor)
|
||||||
|
|
||||||
|
self.tracer = self.tracer_provider.get_tracer("dify.tencent_apm")
|
||||||
|
|
||||||
|
# Store span contexts for parent-child relationships
|
||||||
|
self.span_contexts: dict[int, trace_api.SpanContext] = {}
|
||||||
|
|
||||||
|
self.meter: Meter | None = None
|
||||||
|
self.hist_llm_duration: Histogram | None = None
|
||||||
|
self.metric_reader: MetricReader | None = None
|
||||||
|
|
||||||
|
# Metrics exporter and instruments
|
||||||
|
try:
|
||||||
|
from opentelemetry import metrics
|
||||||
|
from opentelemetry.sdk.metrics import Histogram, MeterProvider
|
||||||
|
from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader
|
||||||
|
|
||||||
|
protocol = os.getenv("OTEL_EXPORTER_OTLP_PROTOCOL", "").strip().lower()
|
||||||
|
use_http_protobuf = protocol in {"http/protobuf", "http-protobuf"}
|
||||||
|
use_http_json = protocol in {"http/json", "http-json"}
|
||||||
|
|
||||||
|
# Set preferred temporality for histograms to DELTA
|
||||||
|
preferred_temporality: dict[type, AggregationTemporality] = {Histogram: AggregationTemporality.DELTA}
|
||||||
|
|
||||||
|
def _create_metric_exporter(exporter_cls, **kwargs):
|
||||||
|
"""Create metric exporter with preferred_temporality support"""
|
||||||
|
try:
|
||||||
|
return exporter_cls(**kwargs, preferred_temporality=preferred_temporality)
|
||||||
|
except Exception:
|
||||||
|
return exporter_cls(**kwargs)
|
||||||
|
|
||||||
|
metric_reader = None
|
||||||
|
if use_http_json:
|
||||||
|
exporter_cls = None
|
||||||
|
for mod_path in (
|
||||||
|
"opentelemetry.exporter.otlp.http.json.metric_exporter",
|
||||||
|
"opentelemetry.exporter.otlp.json.metric_exporter",
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
mod = importlib.import_module(mod_path)
|
||||||
|
exporter_cls = getattr(mod, "OTLPMetricExporter", None)
|
||||||
|
if exporter_cls:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if exporter_cls is not None:
|
||||||
|
metric_exporter = _create_metric_exporter(
|
||||||
|
exporter_cls,
|
||||||
|
endpoint=endpoint,
|
||||||
|
headers={"authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
|
||||||
|
OTLPMetricExporter as HttpMetricExporter,
|
||||||
|
)
|
||||||
|
|
||||||
|
metric_exporter = _create_metric_exporter(
|
||||||
|
HttpMetricExporter,
|
||||||
|
endpoint=endpoint,
|
||||||
|
headers={"authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
metric_reader = PeriodicExportingMetricReader(
|
||||||
|
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
elif use_http_protobuf:
|
||||||
|
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
|
||||||
|
OTLPMetricExporter as HttpMetricExporter,
|
||||||
|
)
|
||||||
|
|
||||||
|
metric_exporter = _create_metric_exporter(
|
||||||
|
HttpMetricExporter,
|
||||||
|
endpoint=endpoint,
|
||||||
|
headers={"authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
metric_reader = PeriodicExportingMetricReader(
|
||||||
|
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import (
|
||||||
|
OTLPMetricExporter as GrpcMetricExporter,
|
||||||
|
)
|
||||||
|
|
||||||
|
m_grpc_endpoint, m_insecure, _, _ = self._resolve_grpc_target(endpoint)
|
||||||
|
|
||||||
|
metric_exporter = _create_metric_exporter(
|
||||||
|
GrpcMetricExporter,
|
||||||
|
endpoint=m_grpc_endpoint,
|
||||||
|
headers={"authorization": f"Bearer {token}"},
|
||||||
|
insecure=m_insecure,
|
||||||
|
)
|
||||||
|
metric_reader = PeriodicExportingMetricReader(
|
||||||
|
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
if metric_reader is not None:
|
||||||
|
provider = MeterProvider(resource=self.resource, metric_readers=[metric_reader])
|
||||||
|
metrics.set_meter_provider(provider)
|
||||||
|
self.meter = metrics.get_meter("dify-sdk", dify_config.project.version)
|
||||||
|
self.hist_llm_duration = self.meter.create_histogram(
|
||||||
|
name=LLM_OPERATION_DURATION,
|
||||||
|
unit="s",
|
||||||
|
description="LLM operation duration (seconds)",
|
||||||
|
)
|
||||||
|
self.metric_reader = metric_reader
|
||||||
|
else:
|
||||||
|
self.meter = None
|
||||||
|
self.hist_llm_duration = None
|
||||||
|
self.metric_reader = None
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Tencent APM] Metrics initialization failed; metrics disabled")
|
||||||
|
self.meter = None
|
||||||
|
self.hist_llm_duration = None
|
||||||
|
self.metric_reader = None
|
||||||
|
|
||||||
|
def add_span(self, span_data: SpanData) -> None:
|
||||||
|
"""Create and export span using OpenTelemetry Tracer API"""
|
||||||
|
try:
|
||||||
|
self._create_and_export_span(span_data)
|
||||||
|
logger.debug("[Tencent APM] Created span: %s", span_data.name)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Tencent APM] Failed to create span: %s", span_data.name)
|
||||||
|
|
||||||
|
# Metrics recording API
|
||||||
|
def record_llm_duration(self, latency_seconds: float, attributes: dict[str, str] | None = None) -> None:
|
||||||
|
"""Record LLM operation duration histogram in seconds."""
|
||||||
|
try:
|
||||||
|
if not hasattr(self, "hist_llm_duration") or self.hist_llm_duration is None:
|
||||||
|
return
|
||||||
|
attrs: dict[str, str] = {}
|
||||||
|
if attributes:
|
||||||
|
for k, v in attributes.items():
|
||||||
|
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
|
||||||
|
self.hist_llm_duration.record(latency_seconds, attrs) # type: ignore[attr-defined]
|
||||||
|
except Exception:
|
||||||
|
logger.debug("[Tencent APM] Failed to record LLM duration", exc_info=True)
|
||||||
|
|
||||||
|
def _create_and_export_span(self, span_data: SpanData) -> None:
|
||||||
|
"""Create span using OpenTelemetry Tracer API"""
|
||||||
|
try:
|
||||||
|
parent_context = None
|
||||||
|
if span_data.parent_span_id and span_data.parent_span_id in self.span_contexts:
|
||||||
|
parent_context = trace_api.set_span_in_context(
|
||||||
|
trace_api.NonRecordingSpan(self.span_contexts[span_data.parent_span_id])
|
||||||
|
)
|
||||||
|
|
||||||
|
span = self.tracer.start_span(
|
||||||
|
name=span_data.name,
|
||||||
|
context=parent_context,
|
||||||
|
kind=SpanKind.INTERNAL,
|
||||||
|
start_time=span_data.start_time,
|
||||||
|
)
|
||||||
|
self.span_contexts[span_data.span_id] = span.get_span_context()
|
||||||
|
|
||||||
|
if span_data.attributes:
|
||||||
|
attributes: dict[str, AttributeValue] = {}
|
||||||
|
for key, value in span_data.attributes.items():
|
||||||
|
if isinstance(value, (int, float, bool)):
|
||||||
|
attributes[key] = value
|
||||||
|
else:
|
||||||
|
attributes[key] = str(value)
|
||||||
|
span.set_attributes(attributes)
|
||||||
|
|
||||||
|
if span_data.events:
|
||||||
|
for event in span_data.events:
|
||||||
|
span.add_event(event.name, event.attributes, event.timestamp)
|
||||||
|
|
||||||
|
if span_data.status:
|
||||||
|
span.set_status(span_data.status)
|
||||||
|
|
||||||
|
# Manually end span; do not use context manager to avoid double-end warnings
|
||||||
|
span.end(end_time=span_data.end_time)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Tencent APM] Error creating span: %s", span_data.name)
|
||||||
|
|
||||||
|
def api_check(self) -> bool:
|
||||||
|
"""Check API connectivity using socket connection test for gRPC endpoints"""
|
||||||
|
try:
|
||||||
|
# Resolve gRPC target consistently with exporters
|
||||||
|
_, _, host, port = self._resolve_grpc_target(self.endpoint)
|
||||||
|
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
sock.settimeout(5)
|
||||||
|
result = sock.connect_ex((host, port))
|
||||||
|
sock.close()
|
||||||
|
|
||||||
|
if result == 0:
|
||||||
|
logger.info("[Tencent APM] Endpoint %s:%s is accessible", host, port)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning("[Tencent APM] Endpoint %s:%s is not accessible", host, port)
|
||||||
|
if host in ["127.0.0.1", "localhost"]:
|
||||||
|
logger.info("[Tencent APM] Development environment detected, allowing config save")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Tencent APM] API check failed")
|
||||||
|
if "127.0.0.1" in self.endpoint or "localhost" in self.endpoint:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_project_url(self) -> str:
|
||||||
|
"""Get project console URL"""
|
||||||
|
return "https://console.cloud.tencent.com/apm"
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
"""Shutdown the client and export remaining spans"""
|
||||||
|
try:
|
||||||
|
if self.span_processor:
|
||||||
|
logger.info("[Tencent APM] Flushing remaining spans before shutdown")
|
||||||
|
_ = self.span_processor.force_flush()
|
||||||
|
self.span_processor.shutdown()
|
||||||
|
|
||||||
|
if self.tracer_provider:
|
||||||
|
self.tracer_provider.shutdown()
|
||||||
|
if self.metric_reader is not None:
|
||||||
|
try:
|
||||||
|
self.metric_reader.shutdown() # type: ignore[attr-defined]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Tencent APM] Error during client shutdown")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_grpc_target(endpoint: str, default_port: int = 4317) -> tuple[str, bool, str, int]:
|
||||||
|
"""Normalize endpoint to gRPC target and security flag.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(grpc_endpoint, insecure, host, port)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if endpoint.startswith(("http://", "https://")):
|
||||||
|
parsed = urlparse(endpoint)
|
||||||
|
host = parsed.hostname or "localhost"
|
||||||
|
port = parsed.port or default_port
|
||||||
|
insecure = parsed.scheme == "http"
|
||||||
|
return f"{host}:{port}", insecure, host, port
|
||||||
|
|
||||||
|
host = endpoint
|
||||||
|
port = default_port
|
||||||
|
if ":" in endpoint:
|
||||||
|
parts = endpoint.rsplit(":", 1)
|
||||||
|
host = parts[0] or "localhost"
|
||||||
|
try:
|
||||||
|
port = int(parts[1])
|
||||||
|
except Exception:
|
||||||
|
port = default_port
|
||||||
|
|
||||||
|
insecure = ("localhost" in host) or ("127.0.0.1" in host)
|
||||||
|
return f"{host}:{port}", insecure, host, port
|
||||||
|
except Exception:
|
||||||
|
host, port = "localhost", default_port
|
||||||
|
return f"{host}:{port}", True, host, port
|
||||||
1
api/core/ops/tencent_trace/entities/__init__.py
Normal file
1
api/core/ops/tencent_trace/entities/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Tencent trace entities module
|
||||||
73
api/core/ops/tencent_trace/entities/tencent_semconv.py
Normal file
73
api/core/ops/tencent_trace/entities/tencent_semconv.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
# public
|
||||||
|
GEN_AI_SESSION_ID = "gen_ai.session.id"
|
||||||
|
|
||||||
|
GEN_AI_USER_ID = "gen_ai.user.id"
|
||||||
|
|
||||||
|
GEN_AI_USER_NAME = "gen_ai.user.name"
|
||||||
|
|
||||||
|
GEN_AI_SPAN_KIND = "gen_ai.span.kind"
|
||||||
|
|
||||||
|
GEN_AI_FRAMEWORK = "gen_ai.framework"
|
||||||
|
|
||||||
|
GEN_AI_IS_ENTRY = "gen_ai.is_entry" # mark to count the LLM-related traces
|
||||||
|
|
||||||
|
# Chain
|
||||||
|
INPUT_VALUE = "gen_ai.entity.input"
|
||||||
|
|
||||||
|
OUTPUT_VALUE = "gen_ai.entity.output"
|
||||||
|
|
||||||
|
|
||||||
|
# Retriever
|
||||||
|
RETRIEVAL_QUERY = "retrieval.query"
|
||||||
|
|
||||||
|
RETRIEVAL_DOCUMENT = "retrieval.document"
|
||||||
|
|
||||||
|
|
||||||
|
# GENERATION
|
||||||
|
GEN_AI_MODEL_NAME = "gen_ai.response.model"
|
||||||
|
|
||||||
|
GEN_AI_PROVIDER = "gen_ai.provider.name"
|
||||||
|
|
||||||
|
|
||||||
|
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
|
||||||
|
|
||||||
|
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
|
||||||
|
|
||||||
|
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
|
||||||
|
|
||||||
|
GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template"
|
||||||
|
|
||||||
|
GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
|
||||||
|
|
||||||
|
GEN_AI_PROMPT = "gen_ai.prompt"
|
||||||
|
|
||||||
|
GEN_AI_COMPLETION = "gen_ai.completion"
|
||||||
|
|
||||||
|
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
|
||||||
|
|
||||||
|
# Tool
|
||||||
|
TOOL_NAME = "tool.name"
|
||||||
|
|
||||||
|
TOOL_DESCRIPTION = "tool.description"
|
||||||
|
|
||||||
|
TOOL_PARAMETERS = "tool.parameters"
|
||||||
|
|
||||||
|
# Instrumentation Library
|
||||||
|
INSTRUMENTATION_NAME = "dify-sdk"
|
||||||
|
INSTRUMENTATION_VERSION = "0.1.0"
|
||||||
|
INSTRUMENTATION_LANGUAGE = "python"
|
||||||
|
|
||||||
|
|
||||||
|
# Metrics
|
||||||
|
LLM_OPERATION_DURATION = "gen_ai.client.operation.duration"
|
||||||
|
|
||||||
|
|
||||||
|
class GenAISpanKind(Enum):
|
||||||
|
WORKFLOW = "WORKFLOW" # OpenLLMetry
|
||||||
|
RETRIEVER = "RETRIEVER" # RAG
|
||||||
|
GENERATION = "GENERATION" # Langfuse
|
||||||
|
TOOL = "TOOL" # OpenLLMetry
|
||||||
|
AGENT = "AGENT" # OpenLLMetry
|
||||||
|
TASK = "TASK" # OpenLLMetry
|
||||||
21
api/core/ops/tencent_trace/entities/tencent_trace_entity.py
Normal file
21
api/core/ops/tencent_trace/entities/tencent_trace_entity.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from opentelemetry import trace as trace_api
|
||||||
|
from opentelemetry.sdk.trace import Event
|
||||||
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SpanData(BaseModel):
|
||||||
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
||||||
|
trace_id: int = Field(..., description="The unique identifier for the trace.")
|
||||||
|
parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.")
|
||||||
|
span_id: int = Field(..., description="The unique identifier for this span.")
|
||||||
|
name: str = Field(..., description="The name of the span.")
|
||||||
|
attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
|
||||||
|
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
|
||||||
|
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
|
||||||
|
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
|
||||||
|
start_time: int = Field(..., description="The start time of the span in nanoseconds.")
|
||||||
|
end_time: int = Field(..., description="The end time of the span in nanoseconds.")
|
||||||
372
api/core/ops/tencent_trace/span_builder.py
Normal file
372
api/core/ops/tencent_trace/span_builder.py
Normal file
@ -0,0 +1,372 @@
|
|||||||
|
"""
|
||||||
|
Tencent APM Span Builder - handles all span construction logic
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
|
from core.ops.entities.trace_entity import (
|
||||||
|
DatasetRetrievalTraceInfo,
|
||||||
|
MessageTraceInfo,
|
||||||
|
ToolTraceInfo,
|
||||||
|
WorkflowTraceInfo,
|
||||||
|
)
|
||||||
|
from core.ops.tencent_trace.entities.tencent_semconv import (
|
||||||
|
GEN_AI_COMPLETION,
|
||||||
|
GEN_AI_FRAMEWORK,
|
||||||
|
GEN_AI_IS_ENTRY,
|
||||||
|
GEN_AI_MODEL_NAME,
|
||||||
|
GEN_AI_PROMPT,
|
||||||
|
GEN_AI_PROVIDER,
|
||||||
|
GEN_AI_RESPONSE_FINISH_REASON,
|
||||||
|
GEN_AI_SESSION_ID,
|
||||||
|
GEN_AI_SPAN_KIND,
|
||||||
|
GEN_AI_USAGE_INPUT_TOKENS,
|
||||||
|
GEN_AI_USAGE_OUTPUT_TOKENS,
|
||||||
|
GEN_AI_USAGE_TOTAL_TOKENS,
|
||||||
|
GEN_AI_USER_ID,
|
||||||
|
INPUT_VALUE,
|
||||||
|
OUTPUT_VALUE,
|
||||||
|
RETRIEVAL_DOCUMENT,
|
||||||
|
RETRIEVAL_QUERY,
|
||||||
|
TOOL_DESCRIPTION,
|
||||||
|
TOOL_NAME,
|
||||||
|
TOOL_PARAMETERS,
|
||||||
|
GenAISpanKind,
|
||||||
|
)
|
||||||
|
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
|
||||||
|
from core.ops.tencent_trace.utils import TencentTraceUtils
|
||||||
|
from core.rag.models.document import Document
|
||||||
|
from core.workflow.entities.workflow_node_execution import (
|
||||||
|
WorkflowNodeExecution,
|
||||||
|
WorkflowNodeExecutionMetadataKey,
|
||||||
|
WorkflowNodeExecutionStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TencentSpanBuilder:
|
||||||
|
"""Builder class for constructing different types of spans"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_time_nanoseconds(time_value: datetime | None) -> int:
|
||||||
|
"""Convert datetime to nanoseconds for span creation."""
|
||||||
|
return TencentTraceUtils.convert_datetime_to_nanoseconds(time_value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_workflow_spans(
|
||||||
|
trace_info: WorkflowTraceInfo, trace_id: int, user_id: str, links: list | None = None
|
||||||
|
) -> list[SpanData]:
|
||||||
|
"""Build workflow-related spans"""
|
||||||
|
spans = []
|
||||||
|
links = links or []
|
||||||
|
|
||||||
|
message_span_id = None
|
||||||
|
workflow_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "workflow")
|
||||||
|
|
||||||
|
if hasattr(trace_info, "metadata") and trace_info.metadata.get("conversation_id"):
|
||||||
|
message_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "message")
|
||||||
|
|
||||||
|
status = Status(StatusCode.OK)
|
||||||
|
if trace_info.error:
|
||||||
|
status = Status(StatusCode.ERROR, trace_info.error)
|
||||||
|
|
||||||
|
if message_span_id:
|
||||||
|
message_span = TencentSpanBuilder._build_message_span(
|
||||||
|
trace_info, trace_id, message_span_id, user_id, status, links
|
||||||
|
)
|
||||||
|
spans.append(message_span)
|
||||||
|
|
||||||
|
workflow_span = TencentSpanBuilder._build_workflow_span(
|
||||||
|
trace_info, trace_id, workflow_span_id, message_span_id, user_id, status, links
|
||||||
|
)
|
||||||
|
spans.append(workflow_span)
|
||||||
|
|
||||||
|
return spans
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_message_span(
|
||||||
|
trace_info: WorkflowTraceInfo, trace_id: int, message_span_id: int, user_id: str, status: Status, links: list
|
||||||
|
) -> SpanData:
|
||||||
|
"""Build message span for chatflow"""
|
||||||
|
return SpanData(
|
||||||
|
trace_id=trace_id,
|
||||||
|
parent_span_id=None,
|
||||||
|
span_id=message_span_id,
|
||||||
|
name="message",
|
||||||
|
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||||
|
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||||
|
attributes={
|
||||||
|
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||||
|
GEN_AI_USER_ID: str(user_id),
|
||||||
|
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
|
||||||
|
GEN_AI_FRAMEWORK: "dify",
|
||||||
|
GEN_AI_IS_ENTRY: "true",
|
||||||
|
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""),
|
||||||
|
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||||
|
},
|
||||||
|
status=status,
|
||||||
|
links=links,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_workflow_span(
|
||||||
|
trace_info: WorkflowTraceInfo,
|
||||||
|
trace_id: int,
|
||||||
|
workflow_span_id: int,
|
||||||
|
message_span_id: int | None,
|
||||||
|
user_id: str,
|
||||||
|
status: Status,
|
||||||
|
links: list,
|
||||||
|
) -> SpanData:
|
||||||
|
"""Build workflow span"""
|
||||||
|
attributes = {
|
||||||
|
GEN_AI_USER_ID: str(user_id),
|
||||||
|
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
|
||||||
|
GEN_AI_FRAMEWORK: "dify",
|
||||||
|
INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
|
||||||
|
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||||
|
}
|
||||||
|
|
||||||
|
if message_span_id is None:
|
||||||
|
attributes[GEN_AI_IS_ENTRY] = "true"
|
||||||
|
|
||||||
|
return SpanData(
|
||||||
|
trace_id=trace_id,
|
||||||
|
parent_span_id=message_span_id,
|
||||||
|
span_id=workflow_span_id,
|
||||||
|
name="workflow",
|
||||||
|
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||||
|
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||||
|
attributes=attributes,
|
||||||
|
status=status,
|
||||||
|
links=links,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_workflow_llm_span(
|
||||||
|
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||||
|
) -> SpanData:
|
||||||
|
"""Build LLM span for workflow nodes."""
|
||||||
|
process_data = node_execution.process_data or {}
|
||||||
|
outputs = node_execution.outputs or {}
|
||||||
|
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||||
|
|
||||||
|
return SpanData(
|
||||||
|
trace_id=trace_id,
|
||||||
|
parent_span_id=workflow_span_id,
|
||||||
|
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
|
||||||
|
name="GENERATION",
|
||||||
|
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
|
||||||
|
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
|
||||||
|
attributes={
|
||||||
|
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||||
|
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
|
||||||
|
GEN_AI_FRAMEWORK: "dify",
|
||||||
|
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
|
||||||
|
GEN_AI_PROVIDER: process_data.get("model_provider", ""),
|
||||||
|
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
|
||||||
|
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
|
||||||
|
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
|
||||||
|
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||||
|
GEN_AI_COMPLETION: str(outputs.get("text", "")),
|
||||||
|
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
|
||||||
|
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||||
|
OUTPUT_VALUE: str(outputs.get("text", "")),
|
||||||
|
},
|
||||||
|
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_message_span(
|
||||||
|
trace_info: MessageTraceInfo, trace_id: int, user_id: str, links: list | None = None
|
||||||
|
) -> SpanData:
|
||||||
|
"""Build message span."""
|
||||||
|
links = links or []
|
||||||
|
status = Status(StatusCode.OK)
|
||||||
|
if trace_info.error:
|
||||||
|
status = Status(StatusCode.ERROR, trace_info.error)
|
||||||
|
|
||||||
|
return SpanData(
|
||||||
|
trace_id=trace_id,
|
||||||
|
parent_span_id=None,
|
||||||
|
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message"),
|
||||||
|
name="message",
|
||||||
|
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||||
|
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||||
|
attributes={
|
||||||
|
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||||
|
GEN_AI_USER_ID: str(user_id),
|
||||||
|
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
|
||||||
|
GEN_AI_FRAMEWORK: "dify",
|
||||||
|
GEN_AI_IS_ENTRY: "true",
|
||||||
|
INPUT_VALUE: str(trace_info.inputs or ""),
|
||||||
|
OUTPUT_VALUE: str(trace_info.outputs or ""),
|
||||||
|
},
|
||||||
|
status=status,
|
||||||
|
links=links,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
|
||||||
|
"""Build tool span."""
|
||||||
|
status = Status(StatusCode.OK)
|
||||||
|
if trace_info.error:
|
||||||
|
status = Status(StatusCode.ERROR, trace_info.error)
|
||||||
|
|
||||||
|
return SpanData(
|
||||||
|
trace_id=trace_id,
|
||||||
|
parent_span_id=parent_span_id,
|
||||||
|
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "tool"),
|
||||||
|
name=trace_info.tool_name,
|
||||||
|
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||||
|
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||||
|
attributes={
|
||||||
|
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
|
||||||
|
GEN_AI_FRAMEWORK: "dify",
|
||||||
|
TOOL_NAME: trace_info.tool_name,
|
||||||
|
TOOL_DESCRIPTION: "",
|
||||||
|
TOOL_PARAMETERS: json.dumps(trace_info.tool_parameters, ensure_ascii=False),
|
||||||
|
INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
|
||||||
|
OUTPUT_VALUE: str(trace_info.tool_outputs),
|
||||||
|
},
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_retrieval_span(trace_info: DatasetRetrievalTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
|
||||||
|
"""Build dataset retrieval span."""
|
||||||
|
status = Status(StatusCode.OK)
|
||||||
|
if getattr(trace_info, "error", None):
|
||||||
|
status = Status(StatusCode.ERROR, trace_info.error) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
documents_data = TencentSpanBuilder._extract_retrieval_documents(trace_info.documents)
|
||||||
|
|
||||||
|
return SpanData(
|
||||||
|
trace_id=trace_id,
|
||||||
|
parent_span_id=parent_span_id,
|
||||||
|
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "retrieval"),
|
||||||
|
name="retrieval",
|
||||||
|
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||||
|
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||||
|
attributes={
|
||||||
|
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
|
||||||
|
GEN_AI_FRAMEWORK: "dify",
|
||||||
|
RETRIEVAL_QUERY: str(trace_info.inputs or ""),
|
||||||
|
RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False),
|
||||||
|
INPUT_VALUE: str(trace_info.inputs or ""),
|
||||||
|
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
|
||||||
|
},
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
|
||||||
|
"""Get workflow node execution status."""
|
||||||
|
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||||
|
return Status(StatusCode.OK)
|
||||||
|
elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
|
||||||
|
return Status(StatusCode.ERROR, str(node_execution.error))
|
||||||
|
return Status(StatusCode.UNSET)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_workflow_retrieval_span(
|
||||||
|
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||||
|
) -> SpanData:
|
||||||
|
"""Build knowledge retrieval span for workflow nodes."""
|
||||||
|
input_value = ""
|
||||||
|
if node_execution.inputs:
|
||||||
|
input_value = str(node_execution.inputs.get("query", ""))
|
||||||
|
output_value = ""
|
||||||
|
if node_execution.outputs:
|
||||||
|
output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False)
|
||||||
|
|
||||||
|
return SpanData(
|
||||||
|
trace_id=trace_id,
|
||||||
|
parent_span_id=workflow_span_id,
|
||||||
|
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
|
||||||
|
name=node_execution.title,
|
||||||
|
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
|
||||||
|
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
|
||||||
|
attributes={
|
||||||
|
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||||
|
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
|
||||||
|
GEN_AI_FRAMEWORK: "dify",
|
||||||
|
RETRIEVAL_QUERY: input_value,
|
||||||
|
RETRIEVAL_DOCUMENT: output_value,
|
||||||
|
INPUT_VALUE: input_value,
|
||||||
|
OUTPUT_VALUE: output_value,
|
||||||
|
},
|
||||||
|
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_workflow_tool_span(
|
||||||
|
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||||
|
) -> SpanData:
|
||||||
|
"""Build tool span for workflow nodes."""
|
||||||
|
tool_des = {}
|
||||||
|
if node_execution.metadata:
|
||||||
|
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
|
||||||
|
|
||||||
|
return SpanData(
|
||||||
|
trace_id=trace_id,
|
||||||
|
parent_span_id=workflow_span_id,
|
||||||
|
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
|
||||||
|
name=node_execution.title,
|
||||||
|
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
|
||||||
|
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
|
||||||
|
attributes={
|
||||||
|
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
|
||||||
|
GEN_AI_FRAMEWORK: "dify",
|
||||||
|
TOOL_NAME: node_execution.title,
|
||||||
|
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
|
||||||
|
TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
|
||||||
|
INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
|
||||||
|
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||||
|
},
|
||||||
|
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_workflow_task_span(
|
||||||
|
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||||
|
) -> SpanData:
|
||||||
|
"""Build generic task span for workflow nodes."""
|
||||||
|
return SpanData(
|
||||||
|
trace_id=trace_id,
|
||||||
|
parent_span_id=workflow_span_id,
|
||||||
|
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
|
||||||
|
name=node_execution.title,
|
||||||
|
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
|
||||||
|
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
|
||||||
|
attributes={
|
||||||
|
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||||
|
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
|
||||||
|
GEN_AI_FRAMEWORK: "dify",
|
||||||
|
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
|
||||||
|
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||||
|
},
|
||||||
|
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_retrieval_documents(documents: list[Document]):
|
||||||
|
"""Extract documents data for retrieval tracing."""
|
||||||
|
documents_data = []
|
||||||
|
for document in documents:
|
||||||
|
document_data = {
|
||||||
|
"content": document.page_content,
|
||||||
|
"metadata": {
|
||||||
|
"dataset_id": document.metadata.get("dataset_id"),
|
||||||
|
"doc_id": document.metadata.get("doc_id"),
|
||||||
|
"document_id": document.metadata.get("document_id"),
|
||||||
|
},
|
||||||
|
"score": document.metadata.get("score"),
|
||||||
|
}
|
||||||
|
documents_data.append(document_data)
|
||||||
|
return documents_data
|
||||||
317
api/core/ops/tencent_trace/tencent_trace.py
Normal file
317
api/core/ops/tencent_trace/tencent_trace.py
Normal file
@ -0,0 +1,317 @@
|
|||||||
|
"""
|
||||||
|
Tencent APM tracing implementation with separated concerns
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
from core.ops.base_trace_instance import BaseTraceInstance
|
||||||
|
from core.ops.entities.config_entity import TencentConfig
|
||||||
|
from core.ops.entities.trace_entity import (
|
||||||
|
BaseTraceInfo,
|
||||||
|
DatasetRetrievalTraceInfo,
|
||||||
|
GenerateNameTraceInfo,
|
||||||
|
MessageTraceInfo,
|
||||||
|
ModerationTraceInfo,
|
||||||
|
SuggestedQuestionTraceInfo,
|
||||||
|
ToolTraceInfo,
|
||||||
|
WorkflowTraceInfo,
|
||||||
|
)
|
||||||
|
from core.ops.tencent_trace.client import TencentTraceClient
|
||||||
|
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
|
||||||
|
from core.ops.tencent_trace.span_builder import TencentSpanBuilder
|
||||||
|
from core.ops.tencent_trace.utils import TencentTraceUtils
|
||||||
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.entities.workflow_node_execution import (
|
||||||
|
WorkflowNodeExecution,
|
||||||
|
)
|
||||||
|
from core.workflow.nodes import NodeType
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TencentDataTrace(BaseTraceInstance):
|
||||||
|
"""
|
||||||
|
Tencent APM trace implementation with single responsibility principle.
|
||||||
|
Acts as a coordinator that delegates specific tasks to specialized classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tencent_config: TencentConfig):
|
||||||
|
super().__init__(tencent_config)
|
||||||
|
self.trace_client = TencentTraceClient(
|
||||||
|
service_name=tencent_config.service_name,
|
||||||
|
endpoint=tencent_config.endpoint,
|
||||||
|
token=tencent_config.token,
|
||||||
|
metrics_export_interval_sec=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
def trace(self, trace_info: BaseTraceInfo) -> None:
|
||||||
|
"""Main tracing entry point - coordinates different trace types."""
|
||||||
|
if isinstance(trace_info, WorkflowTraceInfo):
|
||||||
|
self.workflow_trace(trace_info)
|
||||||
|
elif isinstance(trace_info, MessageTraceInfo):
|
||||||
|
self.message_trace(trace_info)
|
||||||
|
elif isinstance(trace_info, ModerationTraceInfo):
|
||||||
|
pass
|
||||||
|
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||||
|
self.suggested_question_trace(trace_info)
|
||||||
|
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||||
|
self.dataset_retrieval_trace(trace_info)
|
||||||
|
elif isinstance(trace_info, ToolTraceInfo):
|
||||||
|
self.tool_trace(trace_info)
|
||||||
|
elif isinstance(trace_info, GenerateNameTraceInfo):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def api_check(self) -> bool:
|
||||||
|
return self.trace_client.api_check()
|
||||||
|
|
||||||
|
def get_project_url(self) -> str:
|
||||||
|
return self.trace_client.get_project_url()
|
||||||
|
|
||||||
|
def workflow_trace(self, trace_info: WorkflowTraceInfo) -> None:
|
||||||
|
"""Handle workflow tracing by coordinating data retrieval and span construction."""
|
||||||
|
try:
|
||||||
|
trace_id = TencentTraceUtils.convert_to_trace_id(trace_info.workflow_run_id)
|
||||||
|
|
||||||
|
links = []
|
||||||
|
if trace_info.trace_id:
|
||||||
|
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
|
||||||
|
|
||||||
|
user_id = self._get_user_id(trace_info)
|
||||||
|
|
||||||
|
workflow_spans = TencentSpanBuilder.build_workflow_spans(trace_info, trace_id, str(user_id), links)
|
||||||
|
|
||||||
|
for span in workflow_spans:
|
||||||
|
self.trace_client.add_span(span)
|
||||||
|
|
||||||
|
self._process_workflow_nodes(trace_info, trace_id)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Tencent APM] Failed to process workflow trace")
|
||||||
|
|
||||||
|
def message_trace(self, trace_info: MessageTraceInfo) -> None:
|
||||||
|
"""Handle message tracing."""
|
||||||
|
try:
|
||||||
|
trace_id = TencentTraceUtils.convert_to_trace_id(trace_info.message_id)
|
||||||
|
user_id = self._get_user_id(trace_info)
|
||||||
|
|
||||||
|
links = []
|
||||||
|
if trace_info.trace_id:
|
||||||
|
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
|
||||||
|
|
||||||
|
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
|
||||||
|
|
||||||
|
self.trace_client.add_span(message_span)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Tencent APM] Failed to process message trace")
|
||||||
|
|
||||||
|
def tool_trace(self, trace_info: ToolTraceInfo) -> None:
|
||||||
|
"""Handle tool tracing."""
|
||||||
|
try:
|
||||||
|
parent_span_id = None
|
||||||
|
trace_root_id = None
|
||||||
|
|
||||||
|
if trace_info.message_id:
|
||||||
|
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
|
||||||
|
trace_root_id = trace_info.message_id
|
||||||
|
|
||||||
|
if parent_span_id and trace_root_id:
|
||||||
|
trace_id = TencentTraceUtils.convert_to_trace_id(trace_root_id)
|
||||||
|
|
||||||
|
tool_span = TencentSpanBuilder.build_tool_span(trace_info, trace_id, parent_span_id)
|
||||||
|
|
||||||
|
self.trace_client.add_span(tool_span)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Tencent APM] Failed to process tool trace")
|
||||||
|
|
||||||
|
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo) -> None:
|
||||||
|
"""Handle dataset retrieval tracing."""
|
||||||
|
try:
|
||||||
|
parent_span_id = None
|
||||||
|
trace_root_id = None
|
||||||
|
|
||||||
|
if trace_info.message_id:
|
||||||
|
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
|
||||||
|
trace_root_id = trace_info.message_id
|
||||||
|
|
||||||
|
if parent_span_id and trace_root_id:
|
||||||
|
trace_id = TencentTraceUtils.convert_to_trace_id(trace_root_id)
|
||||||
|
|
||||||
|
retrieval_span = TencentSpanBuilder.build_retrieval_span(trace_info, trace_id, parent_span_id)
|
||||||
|
|
||||||
|
self.trace_client.add_span(retrieval_span)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Tencent APM] Failed to process dataset retrieval trace")
|
||||||
|
|
||||||
|
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo) -> None:
|
||||||
|
"""Handle suggested question tracing"""
|
||||||
|
try:
|
||||||
|
logger.info("[Tencent APM] Processing suggested question trace")
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Tencent APM] Failed to process suggested question trace")
|
||||||
|
|
||||||
|
def _process_workflow_nodes(self, trace_info: WorkflowTraceInfo, trace_id: int) -> None:
|
||||||
|
"""Process workflow node executions."""
|
||||||
|
try:
|
||||||
|
workflow_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "workflow")
|
||||||
|
|
||||||
|
node_executions = self._get_workflow_node_executions(trace_info)
|
||||||
|
|
||||||
|
for node_execution in node_executions:
|
||||||
|
try:
|
||||||
|
node_span = self._build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id)
|
||||||
|
if node_span:
|
||||||
|
self.trace_client.add_span(node_span)
|
||||||
|
|
||||||
|
if node_execution.node_type == NodeType.LLM:
|
||||||
|
self._record_llm_metrics(node_execution)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Tencent APM] Failed to process node execution: %s", node_execution.id)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Tencent APM] Failed to process workflow nodes")
|
||||||
|
|
||||||
|
def _build_workflow_node_span(
|
||||||
|
self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int
|
||||||
|
) -> SpanData | None:
|
||||||
|
"""Build span for different node types"""
|
||||||
|
try:
|
||||||
|
if node_execution.node_type == NodeType.LLM:
|
||||||
|
return TencentSpanBuilder.build_workflow_llm_span(
|
||||||
|
trace_id, workflow_span_id, trace_info, node_execution
|
||||||
|
)
|
||||||
|
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||||
|
return TencentSpanBuilder.build_workflow_retrieval_span(
|
||||||
|
trace_id, workflow_span_id, trace_info, node_execution
|
||||||
|
)
|
||||||
|
elif node_execution.node_type == NodeType.TOOL:
|
||||||
|
return TencentSpanBuilder.build_workflow_tool_span(
|
||||||
|
trace_id, workflow_span_id, trace_info, node_execution
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Handle all other node types as generic tasks
|
||||||
|
return TencentSpanBuilder.build_workflow_task_span(
|
||||||
|
trace_id, workflow_span_id, trace_info, node_execution
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"[Tencent APM] Error building span for node %s: %s",
|
||||||
|
node_execution.id,
|
||||||
|
node_execution.node_type,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> list[WorkflowNodeExecution]:
|
||||||
|
"""Retrieve workflow node executions from database."""
|
||||||
|
try:
|
||||||
|
session_maker = sessionmaker(bind=db.engine)
|
||||||
|
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
app_id = trace_info.metadata.get("app_id")
|
||||||
|
if not app_id:
|
||||||
|
raise ValueError("No app_id found in trace_info metadata")
|
||||||
|
|
||||||
|
app_stmt = select(App).where(App.id == app_id)
|
||||||
|
app = session.scalar(app_stmt)
|
||||||
|
if not app:
|
||||||
|
raise ValueError(f"App with id {app_id} not found")
|
||||||
|
|
||||||
|
if not app.created_by:
|
||||||
|
raise ValueError(f"App with id {app_id} has no creator")
|
||||||
|
|
||||||
|
account_stmt = select(Account).where(Account.id == app.created_by)
|
||||||
|
service_account = session.scalar(account_stmt)
|
||||||
|
if not service_account:
|
||||||
|
raise ValueError(f"Creator account not found for app {app_id}")
|
||||||
|
|
||||||
|
current_tenant = (
|
||||||
|
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
|
||||||
|
)
|
||||||
|
if not current_tenant:
|
||||||
|
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
||||||
|
|
||||||
|
service_account.set_tenant_id(current_tenant.tenant_id)
|
||||||
|
|
||||||
|
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
|
session_factory=session_maker,
|
||||||
|
user=service_account,
|
||||||
|
app_id=trace_info.metadata.get("app_id"),
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||||
|
)
|
||||||
|
|
||||||
|
executions = repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id)
|
||||||
|
return list(executions)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Tencent APM] Failed to get workflow node executions")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _get_user_id(self, trace_info: BaseTraceInfo) -> str:
|
||||||
|
"""Get user ID from trace info."""
|
||||||
|
try:
|
||||||
|
tenant_id = None
|
||||||
|
user_id = None
|
||||||
|
|
||||||
|
if isinstance(trace_info, (WorkflowTraceInfo, GenerateNameTraceInfo)):
|
||||||
|
tenant_id = trace_info.tenant_id
|
||||||
|
|
||||||
|
if hasattr(trace_info, "metadata") and trace_info.metadata:
|
||||||
|
user_id = trace_info.metadata.get("user_id")
|
||||||
|
|
||||||
|
if user_id and tenant_id:
|
||||||
|
stmt = (
|
||||||
|
select(Account.name)
|
||||||
|
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||||
|
.where(Account.id == user_id, TenantAccountJoin.tenant_id == tenant_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
session_maker = sessionmaker(bind=db.engine)
|
||||||
|
with session_maker() as session:
|
||||||
|
account_name = session.scalar(stmt)
|
||||||
|
return account_name or str(user_id)
|
||||||
|
elif user_id:
|
||||||
|
return str(user_id)
|
||||||
|
|
||||||
|
return "anonymous"
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Tencent APM] Failed to get user ID")
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
def _record_llm_metrics(self, node_execution: WorkflowNodeExecution) -> None:
|
||||||
|
"""Record LLM performance metrics"""
|
||||||
|
try:
|
||||||
|
if not hasattr(self.trace_client, "record_llm_duration"):
|
||||||
|
return
|
||||||
|
|
||||||
|
process_data = node_execution.process_data or {}
|
||||||
|
usage = process_data.get("usage", {})
|
||||||
|
latency_s = float(usage.get("latency", 0.0))
|
||||||
|
|
||||||
|
if latency_s > 0:
|
||||||
|
attributes = {
|
||||||
|
"provider": process_data.get("model_provider", ""),
|
||||||
|
"model": process_data.get("model_name", ""),
|
||||||
|
"span_kind": "GENERATION",
|
||||||
|
}
|
||||||
|
self.trace_client.record_llm_duration(latency_s, attributes)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.debug("[Tencent APM] Failed to record LLM metrics")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Ensure proper cleanup on garbage collection."""
|
||||||
|
try:
|
||||||
|
if hasattr(self, "trace_client"):
|
||||||
|
self.trace_client.shutdown()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
65
api/core/ops/tencent_trace/utils.py
Normal file
65
api/core/ops/tencent_trace/utils.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
"""
|
||||||
|
Utility functions for Tencent APM tracing
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import random
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||||
|
|
||||||
|
|
||||||
|
class TencentTraceUtils:
|
||||||
|
"""Utility class for common tracing operations."""
|
||||||
|
|
||||||
|
INVALID_SPAN_ID = 0x0000000000000000
|
||||||
|
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_to_trace_id(uuid_v4: str | None) -> int:
|
||||||
|
try:
|
||||||
|
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid UUID input: {e}")
|
||||||
|
return cast(int, uuid_obj.int)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int:
|
||||||
|
try:
|
||||||
|
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid UUID input: {e}")
|
||||||
|
combined_key = f"{uuid_obj.hex}-{span_type}"
|
||||||
|
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
|
||||||
|
return int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_span_id() -> int:
|
||||||
|
span_id = random.getrandbits(64)
|
||||||
|
while span_id == TencentTraceUtils.INVALID_SPAN_ID:
|
||||||
|
span_id = random.getrandbits(64)
|
||||||
|
return span_id
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_datetime_to_nanoseconds(start_time: datetime | None) -> int:
|
||||||
|
if start_time is None:
|
||||||
|
start_time = datetime.now()
|
||||||
|
timestamp_in_seconds = start_time.timestamp()
|
||||||
|
return int(timestamp_in_seconds * 1e9)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_link(trace_id_str: str) -> Link:
|
||||||
|
try:
|
||||||
|
trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else cast(int, uuid.UUID(trace_id_str).int)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
trace_id = cast(int, uuid.uuid4().int)
|
||||||
|
|
||||||
|
span_context = SpanContext(
|
||||||
|
trace_id=trace_id,
|
||||||
|
span_id=TencentTraceUtils.INVALID_SPAN_ID,
|
||||||
|
is_remote=False,
|
||||||
|
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||||
|
)
|
||||||
|
return Link(span_context)
|
||||||
@ -2,7 +2,7 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable, Generator
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -31,6 +31,17 @@ from core.plugin.impl.exc import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
|
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
|
||||||
|
_plugin_daemon_timeout_config = cast(
|
||||||
|
float | httpx.Timeout | None,
|
||||||
|
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0),
|
||||||
|
)
|
||||||
|
plugin_daemon_request_timeout: httpx.Timeout | None
|
||||||
|
if _plugin_daemon_timeout_config is None:
|
||||||
|
plugin_daemon_request_timeout = None
|
||||||
|
elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout):
|
||||||
|
plugin_daemon_request_timeout = _plugin_daemon_timeout_config
|
||||||
|
else:
|
||||||
|
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
|
||||||
|
|
||||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||||
|
|
||||||
@ -58,6 +69,7 @@ class BasePluginClient:
|
|||||||
"headers": headers,
|
"headers": headers,
|
||||||
"params": params,
|
"params": params,
|
||||||
"files": files,
|
"files": files,
|
||||||
|
"timeout": plugin_daemon_request_timeout,
|
||||||
}
|
}
|
||||||
if isinstance(prepared_data, dict):
|
if isinstance(prepared_data, dict):
|
||||||
request_kwargs["data"] = prepared_data
|
request_kwargs["data"] = prepared_data
|
||||||
@ -116,6 +128,7 @@ class BasePluginClient:
|
|||||||
"headers": headers,
|
"headers": headers,
|
||||||
"params": params,
|
"params": params,
|
||||||
"files": files,
|
"files": files,
|
||||||
|
"timeout": plugin_daemon_request_timeout,
|
||||||
}
|
}
|
||||||
if isinstance(prepared_data, dict):
|
if isinstance(prepared_data, dict):
|
||||||
stream_kwargs["data"] = prepared_data
|
stream_kwargs["data"] = prepared_data
|
||||||
|
|||||||
@ -1,9 +1,24 @@
|
|||||||
|
"""
|
||||||
|
Weaviate vector database implementation for Dify's RAG system.
|
||||||
|
|
||||||
|
This module provides integration with Weaviate vector database for storing and retrieving
|
||||||
|
document embeddings used in retrieval-augmented generation workflows.
|
||||||
|
"""
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid as _uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import weaviate # type: ignore
|
import weaviate
|
||||||
|
import weaviate.classes.config as wc
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
from weaviate.classes.data import DataObject
|
||||||
|
from weaviate.classes.init import Auth
|
||||||
|
from weaviate.classes.query import Filter, MetadataQuery
|
||||||
|
from weaviate.exceptions import UnexpectedStatusCodeError
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
@ -15,265 +30,394 @@ from core.rag.models.document import Document
|
|||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WeaviateConfig(BaseModel):
|
class WeaviateConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration model for Weaviate connection settings.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
endpoint: Weaviate server endpoint URL
|
||||||
|
api_key: Optional API key for authentication
|
||||||
|
batch_size: Number of objects to batch per insert operation
|
||||||
|
"""
|
||||||
|
|
||||||
endpoint: str
|
endpoint: str
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
batch_size: int = 100
|
batch_size: int = 100
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_config(cls, values: dict):
|
def validate_config(cls, values: dict) -> dict:
|
||||||
|
"""Validates that required configuration values are present."""
|
||||||
if not values["endpoint"]:
|
if not values["endpoint"]:
|
||||||
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
class WeaviateVector(BaseVector):
|
class WeaviateVector(BaseVector):
|
||||||
|
"""
|
||||||
|
Weaviate vector database implementation for document storage and retrieval.
|
||||||
|
|
||||||
|
Handles creation, insertion, deletion, and querying of document embeddings
|
||||||
|
in a Weaviate collection.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
|
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
|
||||||
|
"""
|
||||||
|
Initializes the Weaviate vector store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: Name of the Weaviate collection
|
||||||
|
config: Weaviate configuration settings
|
||||||
|
attributes: List of metadata attributes to store
|
||||||
|
"""
|
||||||
super().__init__(collection_name)
|
super().__init__(collection_name)
|
||||||
self._client = self._init_client(config)
|
self._client = self._init_client(config)
|
||||||
self._attributes = attributes
|
self._attributes = attributes
|
||||||
|
|
||||||
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
|
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
|
||||||
auth_config = weaviate.AuthApiKey(api_key=config.api_key or "")
|
"""
|
||||||
|
Initializes and returns a connected Weaviate client.
|
||||||
|
|
||||||
weaviate.connect.connection.has_grpc = False # ty: ignore [unresolved-attribute]
|
Configures both HTTP and gRPC connections with proper authentication.
|
||||||
|
"""
|
||||||
|
p = urlparse(config.endpoint)
|
||||||
|
host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "")
|
||||||
|
http_secure = p.scheme == "https"
|
||||||
|
http_port = p.port or (443 if http_secure else 80)
|
||||||
|
|
||||||
try:
|
grpc_host = host
|
||||||
client = weaviate.Client(
|
grpc_secure = http_secure
|
||||||
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
|
grpc_port = 443 if grpc_secure else 50051
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
raise ConnectionError("Vector database connection error") from exc
|
|
||||||
|
|
||||||
client.batch.configure(
|
client = weaviate.connect_to_custom(
|
||||||
# `batch_size` takes an `int` value to enable auto-batching
|
http_host=host,
|
||||||
# (`None` is used for manual batching)
|
http_port=http_port,
|
||||||
batch_size=config.batch_size,
|
http_secure=http_secure,
|
||||||
# dynamically update the `batch_size` based on import speed
|
grpc_host=grpc_host,
|
||||||
dynamic=True,
|
grpc_port=grpc_port,
|
||||||
# `timeout_retries` takes an `int` value to retry on time outs
|
grpc_secure=grpc_secure,
|
||||||
timeout_retries=3,
|
auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not client.is_ready():
|
||||||
|
raise ConnectionError("Vector database is not ready")
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
|
"""Returns the vector database type identifier."""
|
||||||
return VectorType.WEAVIATE
|
return VectorType.WEAVIATE
|
||||||
|
|
||||||
def get_collection_name(self, dataset: Dataset) -> str:
|
def get_collection_name(self, dataset: Dataset) -> str:
|
||||||
|
"""
|
||||||
|
Retrieves or generates the collection name for a dataset.
|
||||||
|
|
||||||
|
Uses existing index structure if available, otherwise generates from dataset ID.
|
||||||
|
"""
|
||||||
if dataset.index_struct_dict:
|
if dataset.index_struct_dict:
|
||||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||||
if not class_prefix.endswith("_Node"):
|
if not class_prefix.endswith("_Node"):
|
||||||
# original class_prefix
|
|
||||||
class_prefix += "_Node"
|
class_prefix += "_Node"
|
||||||
|
|
||||||
return class_prefix
|
return class_prefix
|
||||||
|
|
||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
return Dataset.gen_collection_name_by_id(dataset_id)
|
return Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
|
|
||||||
def to_index_struct(self):
|
def to_index_struct(self) -> dict:
|
||||||
|
"""Returns the index structure dictionary for persistence."""
|
||||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
# create collection
|
"""
|
||||||
|
Creates a new collection and adds initial documents with embeddings.
|
||||||
|
"""
|
||||||
self._create_collection()
|
self._create_collection()
|
||||||
# create vector
|
|
||||||
self.add_texts(texts, embeddings)
|
self.add_texts(texts, embeddings)
|
||||||
|
|
||||||
def _create_collection(self):
|
def _create_collection(self):
|
||||||
|
"""
|
||||||
|
Creates the Weaviate collection with required schema if it doesn't exist.
|
||||||
|
|
||||||
|
Uses Redis locking to prevent concurrent creation attempts.
|
||||||
|
"""
|
||||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(cache_key):
|
||||||
return
|
return
|
||||||
schema = self._default_schema(self._collection_name)
|
|
||||||
if not self._client.schema.contains(schema):
|
try:
|
||||||
# create collection
|
if not self._client.collections.exists(self._collection_name):
|
||||||
self._client.schema.create_class(schema)
|
self._client.collections.create(
|
||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
name=self._collection_name,
|
||||||
|
properties=[
|
||||||
|
wc.Property(
|
||||||
|
name=Field.TEXT_KEY.value,
|
||||||
|
data_type=wc.DataType.TEXT,
|
||||||
|
tokenization=wc.Tokenization.WORD,
|
||||||
|
),
|
||||||
|
wc.Property(name="document_id", data_type=wc.DataType.TEXT),
|
||||||
|
wc.Property(name="doc_id", data_type=wc.DataType.TEXT),
|
||||||
|
wc.Property(name="chunk_index", data_type=wc.DataType.INT),
|
||||||
|
],
|
||||||
|
vector_config=wc.Configure.Vectors.self_provided(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._ensure_properties()
|
||||||
|
redis_client.set(cache_key, 1, ex=3600)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error creating collection %s", self._collection_name)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _ensure_properties(self) -> None:
|
||||||
|
"""
|
||||||
|
Ensures all required properties exist in the collection schema.
|
||||||
|
|
||||||
|
Adds missing properties if the collection exists but lacks them.
|
||||||
|
"""
|
||||||
|
if not self._client.collections.exists(self._collection_name):
|
||||||
|
return
|
||||||
|
|
||||||
|
col = self._client.collections.use(self._collection_name)
|
||||||
|
cfg = col.config.get()
|
||||||
|
existing = {p.name for p in (cfg.properties or [])}
|
||||||
|
|
||||||
|
to_add = []
|
||||||
|
if "document_id" not in existing:
|
||||||
|
to_add.append(wc.Property(name="document_id", data_type=wc.DataType.TEXT))
|
||||||
|
if "doc_id" not in existing:
|
||||||
|
to_add.append(wc.Property(name="doc_id", data_type=wc.DataType.TEXT))
|
||||||
|
if "chunk_index" not in existing:
|
||||||
|
to_add.append(wc.Property(name="chunk_index", data_type=wc.DataType.INT))
|
||||||
|
|
||||||
|
for prop in to_add:
|
||||||
|
try:
|
||||||
|
col.config.add_property(prop)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Could not add property %s: %s", prop.name, e)
|
||||||
|
|
||||||
|
def _get_uuids(self, documents: list[Document]) -> list[str]:
|
||||||
|
"""
|
||||||
|
Generates deterministic UUIDs for documents based on their content.
|
||||||
|
|
||||||
|
Uses UUID5 with URL namespace to ensure consistent IDs for identical content.
|
||||||
|
"""
|
||||||
|
URL_NAMESPACE = _uuid.UUID("6ba7b811-9dad-11d1-80b4-00c04fd430c8")
|
||||||
|
|
||||||
|
uuids = []
|
||||||
|
for doc in documents:
|
||||||
|
uuid_val = _uuid.uuid5(URL_NAMESPACE, doc.page_content)
|
||||||
|
uuids.append(str(uuid_val))
|
||||||
|
|
||||||
|
return uuids
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
"""
|
||||||
|
Adds documents with their embeddings to the collection.
|
||||||
|
|
||||||
|
Batches insertions for efficiency and returns the list of inserted object IDs.
|
||||||
|
"""
|
||||||
uuids = self._get_uuids(documents)
|
uuids = self._get_uuids(documents)
|
||||||
texts = [d.page_content for d in documents]
|
texts = [d.page_content for d in documents]
|
||||||
metadatas = [d.metadata for d in documents]
|
metadatas = [d.metadata for d in documents]
|
||||||
|
|
||||||
ids = []
|
col = self._client.collections.use(self._collection_name)
|
||||||
|
objs: list[DataObject] = []
|
||||||
|
ids_out: list[str] = []
|
||||||
|
|
||||||
with self._client.batch as batch:
|
for i, text in enumerate(texts):
|
||||||
for i, text in enumerate(texts):
|
props: dict[str, Any] = {Field.TEXT_KEY.value: text}
|
||||||
data_properties = {Field.TEXT_KEY: text}
|
meta = metadatas[i] or {}
|
||||||
if metadatas is not None:
|
for k, v in meta.items():
|
||||||
# metadata maybe None
|
props[k] = self._json_serializable(v)
|
||||||
for key, val in (metadatas[i] or {}).items():
|
|
||||||
data_properties[key] = self._json_serializable(val)
|
|
||||||
|
|
||||||
batch.add_data_object(
|
candidate = uuids[i] if uuids else None
|
||||||
data_object=data_properties,
|
uid = candidate if (candidate and self._is_uuid(candidate)) else str(_uuid.uuid4())
|
||||||
class_name=self._collection_name,
|
ids_out.append(uid)
|
||||||
uuid=uuids[i],
|
|
||||||
vector=embeddings[i] if embeddings else None,
|
vec_payload = None
|
||||||
|
if embeddings and i < len(embeddings) and embeddings[i]:
|
||||||
|
vec_payload = {"default": embeddings[i]}
|
||||||
|
|
||||||
|
objs.append(
|
||||||
|
DataObject(
|
||||||
|
uuid=uid,
|
||||||
|
properties=props, # type: ignore[arg-type] # mypy incorrectly infers DataObject signature
|
||||||
|
vector=vec_payload,
|
||||||
)
|
)
|
||||||
ids.append(uuids[i])
|
)
|
||||||
return ids
|
|
||||||
|
|
||||||
def delete_by_metadata_field(self, key: str, value: str):
|
batch_size = max(1, int(dify_config.WEAVIATE_BATCH_SIZE or 100))
|
||||||
# check whether the index already exists
|
with col.batch.dynamic() as batch:
|
||||||
schema = self._default_schema(self._collection_name)
|
for obj in objs:
|
||||||
if self._client.schema.contains(schema):
|
batch.add_object(properties=obj.properties, uuid=obj.uuid, vector=obj.vector)
|
||||||
where_filter = {"operator": "Equal", "path": [key], "valueText": value}
|
|
||||||
|
|
||||||
self._client.batch.delete_objects(class_name=self._collection_name, where=where_filter, output="minimal")
|
return ids_out
|
||||||
|
|
||||||
|
def _is_uuid(self, val: str) -> bool:
|
||||||
|
"""Validates whether a string is a valid UUID format."""
|
||||||
|
try:
|
||||||
|
_uuid.UUID(str(val))
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||||
|
"""Deletes all objects matching a specific metadata field value."""
|
||||||
|
if not self._client.collections.exists(self._collection_name):
|
||||||
|
return
|
||||||
|
|
||||||
|
col = self._client.collections.use(self._collection_name)
|
||||||
|
col.data.delete_many(where=Filter.by_property(key).equal(value))
|
||||||
|
|
||||||
def delete(self):
|
def delete(self):
|
||||||
# check whether the index already exists
|
"""Deletes the entire collection from Weaviate."""
|
||||||
schema = self._default_schema(self._collection_name)
|
if self._client.collections.exists(self._collection_name):
|
||||||
if self._client.schema.contains(schema):
|
self._client.collections.delete(self._collection_name)
|
||||||
self._client.schema.delete_class(self._collection_name)
|
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
collection_name = self._collection_name
|
"""Checks if a document with the given doc_id exists in the collection."""
|
||||||
schema = self._default_schema(self._collection_name)
|
if not self._client.collections.exists(self._collection_name):
|
||||||
|
|
||||||
# check whether the index already exists
|
|
||||||
if not self._client.schema.contains(schema):
|
|
||||||
return False
|
return False
|
||||||
result = (
|
|
||||||
self._client.query.get(collection_name)
|
col = self._client.collections.use(self._collection_name)
|
||||||
.with_additional(["id"])
|
res = col.query.fetch_objects(
|
||||||
.with_where(
|
filters=Filter.by_property("doc_id").equal(id),
|
||||||
{
|
limit=1,
|
||||||
"path": ["doc_id"],
|
return_properties=["doc_id"],
|
||||||
"operator": "Equal",
|
|
||||||
"valueText": id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
.with_limit(1)
|
|
||||||
.do()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if "errors" in result:
|
return len(res.objects) > 0
|
||||||
raise ValueError(f"Error during query: {result['errors']}")
|
|
||||||
|
|
||||||
entries = result["data"]["Get"][collection_name]
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
if len(entries) == 0:
|
"""
|
||||||
return False
|
Deletes objects by their UUID identifiers.
|
||||||
|
|
||||||
return True
|
Silently ignores 404 errors for non-existent IDs.
|
||||||
|
"""
|
||||||
|
if not self._client.collections.exists(self._collection_name):
|
||||||
|
return
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]):
|
col = self._client.collections.use(self._collection_name)
|
||||||
# check whether the index already exists
|
|
||||||
schema = self._default_schema(self._collection_name)
|
for uid in ids:
|
||||||
if self._client.schema.contains(schema):
|
try:
|
||||||
for uuid in ids:
|
col.data.delete_by_id(uid)
|
||||||
try:
|
except UnexpectedStatusCodeError as e:
|
||||||
self._client.data_object.delete(
|
if getattr(e, "status_code", None) != 404:
|
||||||
class_name=self._collection_name,
|
raise
|
||||||
uuid=uuid,
|
|
||||||
)
|
|
||||||
except weaviate.UnexpectedStatusCodeException as e:
|
|
||||||
# tolerate not found error
|
|
||||||
if e.status_code != 404:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
"""Look up similar documents by embedding vector in Weaviate."""
|
"""
|
||||||
collection_name = self._collection_name
|
Performs vector similarity search using the provided query vector.
|
||||||
properties = self._attributes
|
|
||||||
properties.append(Field.TEXT_KEY)
|
|
||||||
query_obj = self._client.query.get(collection_name, properties)
|
|
||||||
|
|
||||||
vector = {"vector": query_vector}
|
Filters by document IDs if provided and applies score threshold.
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
Returns documents sorted by relevance score.
|
||||||
if document_ids_filter:
|
"""
|
||||||
operands = []
|
if not self._client.collections.exists(self._collection_name):
|
||||||
for document_id_filter in document_ids_filter:
|
return []
|
||||||
operands.append({"path": ["document_id"], "operator": "Equal", "valueText": document_id_filter})
|
|
||||||
where_filter = {"operator": "Or", "operands": operands}
|
col = self._client.collections.use(self._collection_name)
|
||||||
query_obj = query_obj.with_where(where_filter)
|
props = list({*self._attributes, "document_id", Field.TEXT_KEY.value})
|
||||||
result = (
|
|
||||||
query_obj.with_near_vector(vector)
|
where = None
|
||||||
.with_limit(kwargs.get("top_k", 4))
|
doc_ids = kwargs.get("document_ids_filter") or []
|
||||||
.with_additional(["vector", "distance"])
|
if doc_ids:
|
||||||
.do()
|
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
|
||||||
|
where = ors[0]
|
||||||
|
for f in ors[1:]:
|
||||||
|
where = where | f
|
||||||
|
|
||||||
|
top_k = int(kwargs.get("top_k", 4))
|
||||||
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
|
|
||||||
|
res = col.query.near_vector(
|
||||||
|
near_vector=query_vector,
|
||||||
|
limit=top_k,
|
||||||
|
return_properties=props,
|
||||||
|
return_metadata=MetadataQuery(distance=True),
|
||||||
|
include_vector=False,
|
||||||
|
filters=where,
|
||||||
|
target_vector="default",
|
||||||
)
|
)
|
||||||
if "errors" in result:
|
|
||||||
raise ValueError(f"Error during query: {result['errors']}")
|
|
||||||
|
|
||||||
docs_and_scores = []
|
docs: list[Document] = []
|
||||||
for res in result["data"]["Get"][collection_name]:
|
for obj in res.objects:
|
||||||
text = res.pop(Field.TEXT_KEY)
|
properties = dict(obj.properties or {})
|
||||||
score = 1 - res["_additional"]["distance"]
|
text = properties.pop(Field.TEXT_KEY.value, "")
|
||||||
docs_and_scores.append((Document(page_content=text, metadata=res), score))
|
distance = (obj.metadata.distance if obj.metadata else None) or 1.0
|
||||||
|
score = 1.0 - distance
|
||||||
|
|
||||||
docs = []
|
if score > score_threshold:
|
||||||
for doc, score in docs_and_scores:
|
properties["score"] = score
|
||||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
docs.append(Document(page_content=text, metadata=properties))
|
||||||
# check score threshold
|
|
||||||
if score >= score_threshold:
|
docs.sort(key=lambda d: d.metadata.get("score", 0.0), reverse=True)
|
||||||
if doc.metadata is not None:
|
|
||||||
doc.metadata["score"] = score
|
|
||||||
docs.append(doc)
|
|
||||||
# Sort the documents by score in descending order
|
|
||||||
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
"""Return docs using BM25F.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Text to look up documents similar to.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Documents most similar to the query.
|
|
||||||
"""
|
"""
|
||||||
collection_name = self._collection_name
|
Performs BM25 full-text search on document content.
|
||||||
content: dict[str, Any] = {"concepts": [query]}
|
|
||||||
properties = self._attributes
|
Filters by document IDs if provided and returns matching documents with vectors.
|
||||||
properties.append(Field.TEXT_KEY)
|
"""
|
||||||
if kwargs.get("search_distance"):
|
if not self._client.collections.exists(self._collection_name):
|
||||||
content["certainty"] = kwargs.get("search_distance")
|
return []
|
||||||
query_obj = self._client.query.get(collection_name, properties)
|
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
col = self._client.collections.use(self._collection_name)
|
||||||
if document_ids_filter:
|
props = list({*self._attributes, Field.TEXT_KEY.value})
|
||||||
operands = []
|
|
||||||
for document_id_filter in document_ids_filter:
|
where = None
|
||||||
operands.append({"path": ["document_id"], "operator": "Equal", "valueText": document_id_filter})
|
doc_ids = kwargs.get("document_ids_filter") or []
|
||||||
where_filter = {"operator": "Or", "operands": operands}
|
if doc_ids:
|
||||||
query_obj = query_obj.with_where(where_filter)
|
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
|
||||||
query_obj = query_obj.with_additional(["vector"])
|
where = ors[0]
|
||||||
properties = ["text"]
|
for f in ors[1:]:
|
||||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()
|
where = where | f
|
||||||
if "errors" in result:
|
|
||||||
raise ValueError(f"Error during query: {result['errors']}")
|
top_k = int(kwargs.get("top_k", 4))
|
||||||
docs = []
|
|
||||||
for res in result["data"]["Get"][collection_name]:
|
res = col.query.bm25(
|
||||||
text = res.pop(Field.TEXT_KEY)
|
query=query,
|
||||||
additional = res.pop("_additional")
|
query_properties=[Field.TEXT_KEY.value],
|
||||||
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
|
limit=top_k,
|
||||||
|
return_properties=props,
|
||||||
|
include_vector=True,
|
||||||
|
filters=where,
|
||||||
|
)
|
||||||
|
|
||||||
|
docs: list[Document] = []
|
||||||
|
for obj in res.objects:
|
||||||
|
properties = dict(obj.properties or {})
|
||||||
|
text = properties.pop(Field.TEXT_KEY.value, "")
|
||||||
|
|
||||||
|
vec = obj.vector
|
||||||
|
if isinstance(vec, dict):
|
||||||
|
vec = vec.get("default") or next(iter(vec.values()), None)
|
||||||
|
|
||||||
|
docs.append(Document(page_content=text, vector=vec, metadata=properties))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def _default_schema(self, index_name: str):
|
def _json_serializable(self, value: Any) -> Any:
|
||||||
return {
|
"""Converts values to JSON-serializable format, handling datetime objects."""
|
||||||
"class": index_name,
|
|
||||||
"properties": [
|
|
||||||
{
|
|
||||||
"name": "text",
|
|
||||||
"dataType": ["text"],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
def _json_serializable(self, value: Any):
|
|
||||||
if isinstance(value, datetime.datetime):
|
if isinstance(value, datetime.datetime):
|
||||||
return value.isoformat()
|
return value.isoformat()
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
class WeaviateVectorFactory(AbstractVectorFactory):
|
class WeaviateVectorFactory(AbstractVectorFactory):
|
||||||
|
"""Factory class for creating WeaviateVector instances."""
|
||||||
|
|
||||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
|
||||||
|
"""
|
||||||
|
Initializes a WeaviateVector instance for the given dataset.
|
||||||
|
|
||||||
|
Uses existing collection name from dataset index structure or generates a new one.
|
||||||
|
Updates dataset index structure if not already set.
|
||||||
|
"""
|
||||||
if dataset.index_struct_dict:
|
if dataset.index_struct_dict:
|
||||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||||
collection_name = class_prefix
|
collection_name = class_prefix
|
||||||
@ -281,7 +425,6 @@ class WeaviateVectorFactory(AbstractVectorFactory):
|
|||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||||
|
|
||||||
return WeaviateVector(
|
return WeaviateVector(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
config=WeaviateConfig(
|
config=WeaviateConfig(
|
||||||
|
|||||||
@ -43,8 +43,7 @@ class CacheEmbedding(Embeddings):
|
|||||||
else:
|
else:
|
||||||
embedding_queue_indices.append(i)
|
embedding_queue_indices.append(i)
|
||||||
|
|
||||||
# release database connection, because embedding may take a long time
|
# NOTE: avoid closing the shared scoped session here; downstream code may still have pending work
|
||||||
db.session.close()
|
|
||||||
|
|
||||||
if embedding_queue_indices:
|
if embedding_queue_indices:
|
||||||
embedding_queue_texts = [texts[i] for i in embedding_queue_indices]
|
embedding_queue_texts = [texts[i] for i in embedding_queue_indices]
|
||||||
|
|||||||
@ -189,6 +189,11 @@ class ToolInvokeMessage(BaseModel):
|
|||||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||||
metadata: Mapping[str, Any] = Field(default_factory=dict, description="The metadata of the log")
|
metadata: Mapping[str, Any] = Field(default_factory=dict, description="The metadata of the log")
|
||||||
|
|
||||||
|
@field_validator("metadata", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_metadata(cls, value: Mapping[str, Any] | None) -> Mapping[str, Any]:
|
||||||
|
return value or {}
|
||||||
|
|
||||||
class RetrieverResourceMessage(BaseModel):
|
class RetrieverResourceMessage(BaseModel):
|
||||||
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||||
context: str = Field(..., description="context")
|
context: str = Field(..., description="context")
|
||||||
@ -377,6 +382,11 @@ class ToolEntity(BaseModel):
|
|||||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
|
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
|
||||||
return v or []
|
return v or []
|
||||||
|
|
||||||
|
@field_validator("output_schema", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_output_schema(cls, value: Mapping[str, object] | None) -> Mapping[str, object]:
|
||||||
|
return value or {}
|
||||||
|
|
||||||
|
|
||||||
class OAuthSchema(BaseModel):
|
class OAuthSchema(BaseModel):
|
||||||
client_schema: list[ProviderConfig] = Field(
|
client_schema: list[ProviderConfig] = Field(
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
import time as time_module
|
import time as time_module
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
from sqlalchemy.engine import CursorResult
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -267,7 +268,7 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]
|
|||||||
|
|
||||||
# Build and execute the update statement
|
# Build and execute the update statement
|
||||||
stmt = update(Provider).where(*where_conditions).values(**update_values)
|
stmt = update(Provider).where(*where_conditions).values(**update_values)
|
||||||
result = session.execute(stmt)
|
result = cast(CursorResult, session.execute(stmt))
|
||||||
rows_affected = result.rowcount
|
rows_affected = result.rowcount
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|||||||
@ -64,7 +64,10 @@ def build_from_mapping(
|
|||||||
config: FileUploadConfig | None = None,
|
config: FileUploadConfig | None = None,
|
||||||
strict_type_validation: bool = False,
|
strict_type_validation: bool = False,
|
||||||
) -> File:
|
) -> File:
|
||||||
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
|
transfer_method_value = mapping.get("transfer_method")
|
||||||
|
if not transfer_method_value:
|
||||||
|
raise ValueError("transfer_method is required in file mapping")
|
||||||
|
transfer_method = FileTransferMethod.value_of(transfer_method_value)
|
||||||
|
|
||||||
build_functions: dict[FileTransferMethod, Callable] = {
|
build_functions: dict[FileTransferMethod, Callable] = {
|
||||||
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
|
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
|
||||||
@ -104,6 +107,8 @@ def build_from_mappings(
|
|||||||
) -> Sequence[File]:
|
) -> Sequence[File]:
|
||||||
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
|
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
|
||||||
# Implement batch processing to reduce database load when handling multiple files.
|
# Implement batch processing to reduce database load when handling multiple files.
|
||||||
|
# Filter out None/empty mappings to avoid errors
|
||||||
|
valid_mappings = [m for m in mappings if m and m.get("transfer_method")]
|
||||||
files = [
|
files = [
|
||||||
build_from_mapping(
|
build_from_mapping(
|
||||||
mapping=mapping,
|
mapping=mapping,
|
||||||
@ -111,7 +116,7 @@ def build_from_mappings(
|
|||||||
config=config,
|
config=config,
|
||||||
strict_type_validation=strict_type_validation,
|
strict_type_validation=strict_type_validation,
|
||||||
)
|
)
|
||||||
for mapping in mappings
|
for mapping in valid_mappings
|
||||||
]
|
]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|||||||
@ -13,6 +13,15 @@ from models.model import EndUser
|
|||||||
#: A proxy for the current user. If no user is logged in, this will be an
|
#: A proxy for the current user. If no user is logged in, this will be an
|
||||||
#: anonymous user
|
#: anonymous user
|
||||||
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
|
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
|
||||||
|
|
||||||
|
|
||||||
|
def current_account_with_tenant():
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
|
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
|
||||||
|
return current_user, current_user.current_tenant_id
|
||||||
|
|
||||||
|
|
||||||
from typing import ParamSpec, TypeVar
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
|
|||||||
@ -13,7 +13,7 @@ dependencies = [
|
|||||||
"celery~=5.5.2",
|
"celery~=5.5.2",
|
||||||
"chardet~=5.1.0",
|
"chardet~=5.1.0",
|
||||||
"flask~=3.1.2",
|
"flask~=3.1.2",
|
||||||
"flask-compress~=1.17",
|
"flask-compress>=1.17,<1.18",
|
||||||
"flask-cors~=6.0.0",
|
"flask-cors~=6.0.0",
|
||||||
"flask-login~=0.6.3",
|
"flask-login~=0.6.3",
|
||||||
"flask-migrate~=4.0.7",
|
"flask-migrate~=4.0.7",
|
||||||
@ -87,6 +87,7 @@ dependencies = [
|
|||||||
"flask-restx~=1.3.0",
|
"flask-restx~=1.3.0",
|
||||||
"packaging~=23.2",
|
"packaging~=23.2",
|
||||||
"croniter>=6.0.0",
|
"croniter>=6.0.0",
|
||||||
|
"weaviate-client==4.17.0",
|
||||||
]
|
]
|
||||||
# Before adding new dependency, consider place it in
|
# Before adding new dependency, consider place it in
|
||||||
# alphabet order (a-z) and suitable group.
|
# alphabet order (a-z) and suitable group.
|
||||||
@ -215,7 +216,7 @@ vdb = [
|
|||||||
"tidb-vector==0.0.9",
|
"tidb-vector==0.0.9",
|
||||||
"upstash-vector==0.6.0",
|
"upstash-vector==0.6.0",
|
||||||
"volcengine-compat~=1.0.0",
|
"volcengine-compat~=1.0.0",
|
||||||
"weaviate-client~=3.24.0",
|
"weaviate-client>=4.0.0,<5.0.0",
|
||||||
"xinference-client~=1.2.2",
|
"xinference-client~=1.2.2",
|
||||||
"mo-vector~=0.1.13",
|
"mo-vector~=0.1.13",
|
||||||
"mysql-connector-python>=9.3.0",
|
"mysql-connector-python>=9.3.0",
|
||||||
|
|||||||
@ -7,8 +7,10 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
|||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from sqlalchemy import asc, delete, desc, select
|
from sqlalchemy import asc, delete, desc, select
|
||||||
|
from sqlalchemy.engine import CursorResult
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from models.workflow import WorkflowNodeExecutionModel
|
from models.workflow import WorkflowNodeExecutionModel
|
||||||
@ -181,7 +183,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
|||||||
|
|
||||||
# Delete the batch
|
# Delete the batch
|
||||||
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
||||||
result = session.execute(delete_stmt)
|
result = cast(CursorResult, session.execute(delete_stmt))
|
||||||
session.commit()
|
session.commit()
|
||||||
total_deleted += result.rowcount
|
total_deleted += result.rowcount
|
||||||
|
|
||||||
@ -228,7 +230,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
|||||||
|
|
||||||
# Delete the batch
|
# Delete the batch
|
||||||
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
||||||
result = session.execute(delete_stmt)
|
result = cast(CursorResult, session.execute(delete_stmt))
|
||||||
session.commit()
|
session.commit()
|
||||||
total_deleted += result.rowcount
|
total_deleted += result.rowcount
|
||||||
|
|
||||||
@ -285,6 +287,6 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
|||||||
|
|
||||||
with self._session_maker() as session:
|
with self._session_maker() as session:
|
||||||
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
||||||
result = session.execute(stmt)
|
result = cast(CursorResult, session.execute(stmt))
|
||||||
session.commit()
|
session.commit()
|
||||||
return result.rowcount
|
return result.rowcount
|
||||||
|
|||||||
@ -22,8 +22,10 @@ Implementation Notes:
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from sqlalchemy import delete, select
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.engine import CursorResult
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
@ -150,7 +152,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||||||
|
|
||||||
with self._session_maker() as session:
|
with self._session_maker() as session:
|
||||||
stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
|
stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
|
||||||
result = session.execute(stmt)
|
result = cast(CursorResult, session.execute(stmt))
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
deleted_count = result.rowcount
|
deleted_count = result.rowcount
|
||||||
@ -186,7 +188,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||||||
|
|
||||||
# Delete the batch
|
# Delete the batch
|
||||||
delete_stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
|
delete_stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
|
||||||
result = session.execute(delete_stmt)
|
result = cast(CursorResult, session.execute(delete_stmt))
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
batch_deleted = result.rowcount
|
batch_deleted = result.rowcount
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
import app
|
import app
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -35,50 +38,53 @@ def clean_workflow_runlogs_precise():
|
|||||||
|
|
||||||
retention_days = dify_config.WORKFLOW_LOG_RETENTION_DAYS
|
retention_days = dify_config.WORKFLOW_LOG_RETENTION_DAYS
|
||||||
cutoff_date = datetime.datetime.now() - datetime.timedelta(days=retention_days)
|
cutoff_date = datetime.datetime.now() - datetime.timedelta(days=retention_days)
|
||||||
|
session_factory = sessionmaker(db.engine, expire_on_commit=False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
total_workflow_runs = db.session.query(WorkflowRun).where(WorkflowRun.created_at < cutoff_date).count()
|
with session_factory.begin() as session:
|
||||||
if total_workflow_runs == 0:
|
total_workflow_runs = session.query(WorkflowRun).where(WorkflowRun.created_at < cutoff_date).count()
|
||||||
logger.info("No expired workflow run logs found")
|
if total_workflow_runs == 0:
|
||||||
return
|
logger.info("No expired workflow run logs found")
|
||||||
logger.info("Found %s expired workflow run logs to clean", total_workflow_runs)
|
return
|
||||||
|
logger.info("Found %s expired workflow run logs to clean", total_workflow_runs)
|
||||||
|
|
||||||
total_deleted = 0
|
total_deleted = 0
|
||||||
failed_batches = 0
|
failed_batches = 0
|
||||||
batch_count = 0
|
batch_count = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
workflow_runs = (
|
with session_factory.begin() as session:
|
||||||
db.session.query(WorkflowRun.id).where(WorkflowRun.created_at < cutoff_date).limit(BATCH_SIZE).all()
|
workflow_run_ids = session.scalars(
|
||||||
)
|
select(WorkflowRun.id)
|
||||||
|
.where(WorkflowRun.created_at < cutoff_date)
|
||||||
|
.order_by(WorkflowRun.created_at, WorkflowRun.id)
|
||||||
|
.limit(BATCH_SIZE)
|
||||||
|
).all()
|
||||||
|
|
||||||
if not workflow_runs:
|
if not workflow_run_ids:
|
||||||
break
|
|
||||||
|
|
||||||
workflow_run_ids = [run.id for run in workflow_runs]
|
|
||||||
batch_count += 1
|
|
||||||
|
|
||||||
success = _delete_batch_with_retry(workflow_run_ids, failed_batches)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
total_deleted += len(workflow_run_ids)
|
|
||||||
failed_batches = 0
|
|
||||||
else:
|
|
||||||
failed_batches += 1
|
|
||||||
if failed_batches >= MAX_RETRIES:
|
|
||||||
logger.error("Failed to delete batch after %s retries, aborting cleanup for today", MAX_RETRIES)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
batch_count += 1
|
||||||
|
|
||||||
|
success = _delete_batch(session, workflow_run_ids, failed_batches)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
total_deleted += len(workflow_run_ids)
|
||||||
|
failed_batches = 0
|
||||||
else:
|
else:
|
||||||
# Calculate incremental delay times: 5, 10, 15 minutes
|
failed_batches += 1
|
||||||
retry_delay_minutes = failed_batches * 5
|
if failed_batches >= MAX_RETRIES:
|
||||||
logger.warning("Batch deletion failed, retrying in %s minutes...", retry_delay_minutes)
|
logger.error("Failed to delete batch after %s retries, aborting cleanup for today", MAX_RETRIES)
|
||||||
time.sleep(retry_delay_minutes * 60)
|
break
|
||||||
continue
|
else:
|
||||||
|
# Calculate incremental delay times: 5, 10, 15 minutes
|
||||||
|
retry_delay_minutes = failed_batches * 5
|
||||||
|
logger.warning("Batch deletion failed, retrying in %s minutes...", retry_delay_minutes)
|
||||||
|
time.sleep(retry_delay_minutes * 60)
|
||||||
|
continue
|
||||||
|
|
||||||
logger.info("Cleanup completed: %s expired workflow run logs deleted", total_deleted)
|
logger.info("Cleanup completed: %s expired workflow run logs deleted", total_deleted)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
db.session.rollback()
|
|
||||||
logger.exception("Unexpected error in workflow log cleanup")
|
logger.exception("Unexpected error in workflow log cleanup")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -87,69 +93,56 @@ def clean_workflow_runlogs_precise():
|
|||||||
click.echo(click.style(f"Cleaned workflow run logs from db success latency: {execution_time:.2f}s", fg="green"))
|
click.echo(click.style(f"Cleaned workflow run logs from db success latency: {execution_time:.2f}s", fg="green"))
|
||||||
|
|
||||||
|
|
||||||
def _delete_batch_with_retry(workflow_run_ids: list[str], attempt_count: int) -> bool:
|
def _delete_batch(session: Session, workflow_run_ids: Sequence[str], attempt_count: int) -> bool:
|
||||||
"""Delete a single batch with a retry mechanism and complete cascading deletion"""
|
"""Delete a single batch of workflow runs and all related data within a nested transaction."""
|
||||||
try:
|
try:
|
||||||
with db.session.begin_nested():
|
with session.begin_nested():
|
||||||
message_data = (
|
message_data = (
|
||||||
db.session.query(Message.id, Message.conversation_id)
|
session.query(Message.id, Message.conversation_id)
|
||||||
.where(Message.workflow_run_id.in_(workflow_run_ids))
|
.where(Message.workflow_run_id.in_(workflow_run_ids))
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
message_id_list = [msg.id for msg in message_data]
|
message_id_list = [msg.id for msg in message_data]
|
||||||
conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id})
|
conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id})
|
||||||
if message_id_list:
|
if message_id_list:
|
||||||
db.session.query(AppAnnotationHitHistory).where(
|
message_related_models = [
|
||||||
AppAnnotationHitHistory.message_id.in_(message_id_list)
|
AppAnnotationHitHistory,
|
||||||
).delete(synchronize_session=False)
|
MessageAgentThought,
|
||||||
|
MessageChain,
|
||||||
|
MessageFile,
|
||||||
|
MessageAnnotation,
|
||||||
|
MessageFeedback,
|
||||||
|
]
|
||||||
|
for model in message_related_models:
|
||||||
|
session.query(model).where(model.message_id.in_(message_id_list)).delete(synchronize_session=False) # type: ignore
|
||||||
|
# error: "DeclarativeAttributeIntercept" has no attribute "message_id". But this type is only in lib
|
||||||
|
# and these 6 types all have the message_id field.
|
||||||
|
|
||||||
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_id_list)).delete(
|
session.query(Message).where(Message.workflow_run_id.in_(workflow_run_ids)).delete(
|
||||||
synchronize_session=False
|
synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.query(MessageChain).where(MessageChain.message_id.in_(message_id_list)).delete(
|
session.query(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)).delete(
|
||||||
synchronize_session=False
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.query(MessageFile).where(MessageFile.message_id.in_(message_id_list)).delete(
|
|
||||||
synchronize_session=False
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_id_list)).delete(
|
|
||||||
synchronize_session=False
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_id_list)).delete(
|
|
||||||
synchronize_session=False
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.query(Message).where(Message.workflow_run_id.in_(workflow_run_ids)).delete(
|
|
||||||
synchronize_session=False
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.query(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)).delete(
|
|
||||||
synchronize_session=False
|
synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.query(WorkflowNodeExecutionModel).where(
|
session.query(WorkflowNodeExecutionModel).where(
|
||||||
WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids)
|
WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids)
|
||||||
).delete(synchronize_session=False)
|
).delete(synchronize_session=False)
|
||||||
|
|
||||||
if conversation_id_list:
|
if conversation_id_list:
|
||||||
db.session.query(ConversationVariable).where(
|
session.query(ConversationVariable).where(
|
||||||
ConversationVariable.conversation_id.in_(conversation_id_list)
|
ConversationVariable.conversation_id.in_(conversation_id_list)
|
||||||
).delete(synchronize_session=False)
|
).delete(synchronize_session=False)
|
||||||
|
|
||||||
db.session.query(Conversation).where(Conversation.id.in_(conversation_id_list)).delete(
|
session.query(Conversation).where(Conversation.id.in_(conversation_id_list)).delete(
|
||||||
synchronize_session=False
|
synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.query(WorkflowRun).where(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False)
|
session.query(WorkflowRun).where(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False)
|
||||||
|
|
||||||
db.session.commit()
|
return True
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
db.session.rollback()
|
|
||||||
logger.exception("Batch deletion failed (attempt %s)", attempt_count + 1)
|
logger.exception("Batch deletion failed (attempt %s)", attempt_count + 1)
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -8,8 +8,7 @@ from werkzeug.exceptions import NotFound
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import current_user
|
from libs.login import current_account_with_tenant
|
||||||
from models.account import Account
|
|
||||||
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
|
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
|
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
|
||||||
@ -24,10 +23,10 @@ class AppAnnotationService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
|
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||||
# get app info
|
# get app info
|
||||||
assert isinstance(current_user, Account)
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
app = (
|
app = (
|
||||||
db.session.query(App)
|
db.session.query(App)
|
||||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -63,12 +62,12 @@ class AppAnnotationService:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
# if annotation reply is enabled , add annotation to index
|
# if annotation reply is enabled , add annotation to index
|
||||||
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||||
assert current_user.current_tenant_id is not None
|
assert current_tenant_id is not None
|
||||||
if annotation_setting:
|
if annotation_setting:
|
||||||
add_annotation_to_index_task.delay(
|
add_annotation_to_index_task.delay(
|
||||||
annotation.id,
|
annotation.id,
|
||||||
args["question"],
|
args["question"],
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
app_id,
|
app_id,
|
||||||
annotation_setting.collection_binding_id,
|
annotation_setting.collection_binding_id,
|
||||||
)
|
)
|
||||||
@ -86,13 +85,12 @@ class AppAnnotationService:
|
|||||||
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
|
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
|
||||||
# send batch add segments task
|
# send batch add segments task
|
||||||
redis_client.setnx(enable_app_annotation_job_key, "waiting")
|
redis_client.setnx(enable_app_annotation_job_key, "waiting")
|
||||||
assert isinstance(current_user, Account)
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
enable_annotation_reply_task.delay(
|
enable_annotation_reply_task.delay(
|
||||||
str(job_id),
|
str(job_id),
|
||||||
app_id,
|
app_id,
|
||||||
current_user.id,
|
current_user.id,
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
args["score_threshold"],
|
args["score_threshold"],
|
||||||
args["embedding_provider_name"],
|
args["embedding_provider_name"],
|
||||||
args["embedding_model_name"],
|
args["embedding_model_name"],
|
||||||
@ -101,8 +99,7 @@ class AppAnnotationService:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def disable_app_annotation(cls, app_id: str):
|
def disable_app_annotation(cls, app_id: str):
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
|
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
|
||||||
cache_result = redis_client.get(disable_app_annotation_key)
|
cache_result = redis_client.get(disable_app_annotation_key)
|
||||||
if cache_result is not None:
|
if cache_result is not None:
|
||||||
@ -113,17 +110,16 @@ class AppAnnotationService:
|
|||||||
disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
|
disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
|
||||||
# send batch add segments task
|
# send batch add segments task
|
||||||
redis_client.setnx(disable_app_annotation_job_key, "waiting")
|
redis_client.setnx(disable_app_annotation_job_key, "waiting")
|
||||||
disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id)
|
disable_annotation_reply_task.delay(str(job_id), app_id, current_tenant_id)
|
||||||
return {"job_id": job_id, "job_status": "waiting"}
|
return {"job_id": job_id, "job_status": "waiting"}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
|
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
|
||||||
# get app info
|
# get app info
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
app = (
|
app = (
|
||||||
db.session.query(App)
|
db.session.query(App)
|
||||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -153,11 +149,10 @@ class AppAnnotationService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def export_annotation_list_by_app_id(cls, app_id: str):
|
def export_annotation_list_by_app_id(cls, app_id: str):
|
||||||
# get app info
|
# get app info
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
app = (
|
app = (
|
||||||
db.session.query(App)
|
db.session.query(App)
|
||||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -174,11 +169,10 @@ class AppAnnotationService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
|
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||||
# get app info
|
# get app info
|
||||||
assert isinstance(current_user, Account)
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
app = (
|
app = (
|
||||||
db.session.query(App)
|
db.session.query(App)
|
||||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -196,7 +190,7 @@ class AppAnnotationService:
|
|||||||
add_annotation_to_index_task.delay(
|
add_annotation_to_index_task.delay(
|
||||||
annotation.id,
|
annotation.id,
|
||||||
args["question"],
|
args["question"],
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
app_id,
|
app_id,
|
||||||
annotation_setting.collection_binding_id,
|
annotation_setting.collection_binding_id,
|
||||||
)
|
)
|
||||||
@ -205,11 +199,10 @@ class AppAnnotationService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
|
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
|
||||||
# get app info
|
# get app info
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
app = (
|
app = (
|
||||||
db.session.query(App)
|
db.session.query(App)
|
||||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -234,7 +227,7 @@ class AppAnnotationService:
|
|||||||
update_annotation_to_index_task.delay(
|
update_annotation_to_index_task.delay(
|
||||||
annotation.id,
|
annotation.id,
|
||||||
annotation.question,
|
annotation.question,
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
app_id,
|
app_id,
|
||||||
app_annotation_setting.collection_binding_id,
|
app_annotation_setting.collection_binding_id,
|
||||||
)
|
)
|
||||||
@ -244,11 +237,10 @@ class AppAnnotationService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def delete_app_annotation(cls, app_id: str, annotation_id: str):
|
def delete_app_annotation(cls, app_id: str, annotation_id: str):
|
||||||
# get app info
|
# get app info
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
app = (
|
app = (
|
||||||
db.session.query(App)
|
db.session.query(App)
|
||||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -277,17 +269,16 @@ class AppAnnotationService:
|
|||||||
|
|
||||||
if app_annotation_setting:
|
if app_annotation_setting:
|
||||||
delete_annotation_index_task.delay(
|
delete_annotation_index_task.delay(
|
||||||
annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id
|
annotation.id, app_id, current_tenant_id, app_annotation_setting.collection_binding_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
|
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
|
||||||
# get app info
|
# get app info
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
app = (
|
app = (
|
||||||
db.session.query(App)
|
db.session.query(App)
|
||||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -317,7 +308,7 @@ class AppAnnotationService:
|
|||||||
for annotation, annotation_setting in annotations_to_delete:
|
for annotation, annotation_setting in annotations_to_delete:
|
||||||
if annotation_setting:
|
if annotation_setting:
|
||||||
delete_annotation_index_task.delay(
|
delete_annotation_index_task.delay(
|
||||||
annotation.id, app_id, current_user.current_tenant_id, annotation_setting.collection_binding_id
|
annotation.id, app_id, current_tenant_id, annotation_setting.collection_binding_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 4: Bulk delete annotations in a single query
|
# Step 4: Bulk delete annotations in a single query
|
||||||
@ -333,11 +324,10 @@ class AppAnnotationService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def batch_import_app_annotations(cls, app_id, file: FileStorage):
|
def batch_import_app_annotations(cls, app_id, file: FileStorage):
|
||||||
# get app info
|
# get app info
|
||||||
assert isinstance(current_user, Account)
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
app = (
|
app = (
|
||||||
db.session.query(App)
|
db.session.query(App)
|
||||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -354,7 +344,7 @@ class AppAnnotationService:
|
|||||||
if len(result) == 0:
|
if len(result) == 0:
|
||||||
raise ValueError("The CSV file is empty.")
|
raise ValueError("The CSV file is empty.")
|
||||||
# check annotation limit
|
# check annotation limit
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_tenant_id)
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
annotation_quota_limit = features.annotation_quota_limit
|
annotation_quota_limit = features.annotation_quota_limit
|
||||||
if annotation_quota_limit.limit < len(result) + annotation_quota_limit.size:
|
if annotation_quota_limit.limit < len(result) + annotation_quota_limit.size:
|
||||||
@ -364,21 +354,18 @@ class AppAnnotationService:
|
|||||||
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
|
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
|
||||||
# send batch add segments task
|
# send batch add segments task
|
||||||
redis_client.setnx(indexing_cache_key, "waiting")
|
redis_client.setnx(indexing_cache_key, "waiting")
|
||||||
batch_import_annotations_task.delay(
|
batch_import_annotations_task.delay(str(job_id), result, app_id, current_tenant_id, current_user.id)
|
||||||
str(job_id), result, app_id, current_user.current_tenant_id, current_user.id
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"error_msg": str(e)}
|
return {"error_msg": str(e)}
|
||||||
return {"job_id": job_id, "job_status": "waiting"}
|
return {"job_id": job_id, "job_status": "waiting"}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
|
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
# get app info
|
# get app info
|
||||||
app = (
|
app = (
|
||||||
db.session.query(App)
|
db.session.query(App)
|
||||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -445,12 +432,11 @@ class AppAnnotationService:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_app_annotation_setting_by_app_id(cls, app_id: str):
|
def get_app_annotation_setting_by_app_id(cls, app_id: str):
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
# get app info
|
# get app info
|
||||||
app = (
|
app = (
|
||||||
db.session.query(App)
|
db.session.query(App)
|
||||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -481,12 +467,11 @@ class AppAnnotationService:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
|
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
|
||||||
assert isinstance(current_user, Account)
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
# get app info
|
# get app info
|
||||||
app = (
|
app = (
|
||||||
db.session.query(App)
|
db.session.query(App)
|
||||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -531,11 +516,10 @@ class AppAnnotationService:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def clear_all_annotations(cls, app_id: str):
|
def clear_all_annotations(cls, app_id: str):
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
app = (
|
app = (
|
||||||
db.session.query(App)
|
db.session.query(App)
|
||||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -558,7 +542,7 @@ class AppAnnotationService:
|
|||||||
# if annotation reply is enabled, delete annotation index
|
# if annotation reply is enabled, delete annotation index
|
||||||
if app_annotation_setting:
|
if app_annotation_setting:
|
||||||
delete_annotation_index_task.delay(
|
delete_annotation_index_task.delay(
|
||||||
annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id
|
annotation.id, app_id, current_tenant_id, app_annotation_setting.collection_binding_id
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.delete(annotation)
|
db.session.delete(annotation)
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import time
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -18,6 +17,7 @@ from core.plugin.impl.oauth import OAuthHandler
|
|||||||
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
from libs.login import current_account_with_tenant
|
||||||
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
|
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
|
||||||
from models.provider_ids import DatasourceProviderID
|
from models.provider_ids import DatasourceProviderID
|
||||||
from services.plugin.plugin_service import PluginService
|
from services.plugin.plugin_service import PluginService
|
||||||
@ -93,6 +93,8 @@ class DatasourceProviderService:
|
|||||||
"""
|
"""
|
||||||
get credential by id
|
get credential by id
|
||||||
"""
|
"""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
if credential_id:
|
if credential_id:
|
||||||
datasource_provider = (
|
datasource_provider = (
|
||||||
@ -157,6 +159,8 @@ class DatasourceProviderService:
|
|||||||
"""
|
"""
|
||||||
get all datasource credentials by provider
|
get all datasource credentials by provider
|
||||||
"""
|
"""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
datasource_providers = (
|
datasource_providers = (
|
||||||
session.query(DatasourceProvider)
|
session.query(DatasourceProvider)
|
||||||
@ -604,6 +608,8 @@ class DatasourceProviderService:
|
|||||||
"""
|
"""
|
||||||
provider_name = provider_id.provider_name
|
provider_name = provider_id.provider_name
|
||||||
plugin_id = provider_id.plugin_id
|
plugin_id = provider_id.plugin_id
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
|
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
|
||||||
with redis_client.lock(lock, timeout=20):
|
with redis_client.lock(lock, timeout=20):
|
||||||
@ -901,6 +907,8 @@ class DatasourceProviderService:
|
|||||||
"""
|
"""
|
||||||
update datasource credentials.
|
update datasource credentials.
|
||||||
"""
|
"""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
datasource_provider = (
|
datasource_provider = (
|
||||||
session.query(DatasourceProvider)
|
session.query(DatasourceProvider)
|
||||||
|
|||||||
@ -102,6 +102,15 @@ class OpsService:
|
|||||||
except Exception:
|
except Exception:
|
||||||
new_decrypt_tracing_config.update({"project_url": "https://arms.console.aliyun.com/"})
|
new_decrypt_tracing_config.update({"project_url": "https://arms.console.aliyun.com/"})
|
||||||
|
|
||||||
|
if tracing_provider == "tencent" and (
|
||||||
|
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
|
||||||
|
new_decrypt_tracing_config.update({"project_url": project_url})
|
||||||
|
except Exception:
|
||||||
|
new_decrypt_tracing_config.update({"project_url": "https://console.cloud.tencent.com/apm"})
|
||||||
|
|
||||||
trace_config_data.tracing_config = new_decrypt_tracing_config
|
trace_config_data.tracing_config = new_decrypt_tracing_config
|
||||||
return trace_config_data.to_dict()
|
return trace_config_data.to_dict()
|
||||||
|
|
||||||
@ -144,7 +153,7 @@ class OpsService:
|
|||||||
project_url = f"{tracing_config.get('host')}/project/{project_key}"
|
project_url = f"{tracing_config.get('host')}/project/{project_key}"
|
||||||
except Exception:
|
except Exception:
|
||||||
project_url = None
|
project_url = None
|
||||||
elif tracing_provider in ("langsmith", "opik"):
|
elif tracing_provider in ("langsmith", "opik", "tencent"):
|
||||||
try:
|
try:
|
||||||
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
|
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@ -86,12 +86,16 @@ class WorkflowAppService:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
if created_by_account:
|
if created_by_account:
|
||||||
|
account = session.scalar(select(Account).where(Account.email == created_by_account))
|
||||||
|
if not account:
|
||||||
|
raise ValueError(f"Account not found: {created_by_account}")
|
||||||
|
|
||||||
stmt = stmt.join(
|
stmt = stmt.join(
|
||||||
Account,
|
Account,
|
||||||
and_(
|
and_(
|
||||||
WorkflowAppLog.created_by == Account.id,
|
WorkflowAppLog.created_by == Account.id,
|
||||||
WorkflowAppLog.created_by_role == CreatorUserRole.ACCOUNT,
|
WorkflowAppLog.created_by_role == CreatorUserRole.ACCOUNT,
|
||||||
Account.email == created_by_account,
|
Account.id == account.id,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,6 @@ import click
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
@ -50,54 +49,48 @@ def batch_create_segment_to_index_task(
|
|||||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with Session(db.engine) as session:
|
dataset = db.session.get(Dataset, dataset_id)
|
||||||
dataset = session.get(Dataset, dataset_id)
|
if not dataset:
|
||||||
if not dataset:
|
raise ValueError("Dataset not exist.")
|
||||||
raise ValueError("Dataset not exist.")
|
|
||||||
|
|
||||||
dataset_document = session.get(Document, document_id)
|
dataset_document = db.session.get(Document, document_id)
|
||||||
if not dataset_document:
|
if not dataset_document:
|
||||||
raise ValueError("Document not exist.")
|
raise ValueError("Document not exist.")
|
||||||
|
|
||||||
if (
|
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||||
not dataset_document.enabled
|
raise ValueError("Document is not available.")
|
||||||
or dataset_document.archived
|
|
||||||
or dataset_document.indexing_status != "completed"
|
|
||||||
):
|
|
||||||
raise ValueError("Document is not available.")
|
|
||||||
|
|
||||||
upload_file = session.get(UploadFile, upload_file_id)
|
upload_file = db.session.get(UploadFile, upload_file_id)
|
||||||
if not upload_file:
|
if not upload_file:
|
||||||
raise ValueError("UploadFile not found.")
|
raise ValueError("UploadFile not found.")
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
suffix = Path(upload_file.key).suffix
|
suffix = Path(upload_file.key).suffix
|
||||||
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
|
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
storage.download(upload_file.key, file_path)
|
||||||
storage.download(upload_file.key, file_path)
|
|
||||||
|
|
||||||
# Skip the first row
|
df = pd.read_csv(file_path)
|
||||||
df = pd.read_csv(file_path)
|
content = []
|
||||||
content = []
|
for _, row in df.iterrows():
|
||||||
for _, row in df.iterrows():
|
if dataset_document.doc_form == "qa_model":
|
||||||
if dataset_document.doc_form == "qa_model":
|
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
else:
|
||||||
else:
|
data = {"content": row.iloc[0]}
|
||||||
data = {"content": row.iloc[0]}
|
content.append(data)
|
||||||
content.append(data)
|
if len(content) == 0:
|
||||||
if len(content) == 0:
|
raise ValueError("The CSV file is empty.")
|
||||||
raise ValueError("The CSV file is empty.")
|
|
||||||
|
document_segments = []
|
||||||
|
embedding_model = None
|
||||||
|
if dataset.indexing_technique == "high_quality":
|
||||||
|
model_manager = ModelManager()
|
||||||
|
embedding_model = model_manager.get_model_instance(
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
provider=dataset.embedding_model_provider,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=dataset.embedding_model,
|
||||||
|
)
|
||||||
|
|
||||||
document_segments = []
|
|
||||||
embedding_model = None
|
|
||||||
if dataset.indexing_technique == "high_quality":
|
|
||||||
model_manager = ModelManager()
|
|
||||||
embedding_model = model_manager.get_model_instance(
|
|
||||||
tenant_id=dataset.tenant_id,
|
|
||||||
provider=dataset.embedding_model_provider,
|
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
|
||||||
model=dataset.embedding_model,
|
|
||||||
)
|
|
||||||
word_count_change = 0
|
word_count_change = 0
|
||||||
if embedding_model:
|
if embedding_model:
|
||||||
tokens_list = embedding_model.get_text_embedding_num_tokens(
|
tokens_list = embedding_model.get_text_embedding_num_tokens(
|
||||||
@ -105,6 +98,7 @@ def batch_create_segment_to_index_task(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
tokens_list = [0] * len(content)
|
tokens_list = [0] * len(content)
|
||||||
|
|
||||||
for segment, tokens in zip(content, tokens_list):
|
for segment, tokens in zip(content, tokens_list):
|
||||||
content = segment["content"]
|
content = segment["content"]
|
||||||
doc_id = str(uuid.uuid4())
|
doc_id = str(uuid.uuid4())
|
||||||
@ -135,11 +129,11 @@ def batch_create_segment_to_index_task(
|
|||||||
word_count_change += segment_document.word_count
|
word_count_change += segment_document.word_count
|
||||||
db.session.add(segment_document)
|
db.session.add(segment_document)
|
||||||
document_segments.append(segment_document)
|
document_segments.append(segment_document)
|
||||||
# update document word count
|
|
||||||
assert dataset_document.word_count is not None
|
assert dataset_document.word_count is not None
|
||||||
dataset_document.word_count += word_count_change
|
dataset_document.word_count += word_count_change
|
||||||
db.session.add(dataset_document)
|
db.session.add(dataset_document)
|
||||||
# add index to db
|
|
||||||
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
|
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||||
|
|||||||
@ -75,10 +75,7 @@
|
|||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div class="container">
|
<div class="container">
|
||||||
<div class="header">
|
<div class="header"></div>
|
||||||
<!-- Optional: Add a logo or a header image here -->
|
|
||||||
<img src="https://assets.dify.ai/images/logo.png" alt="Dify Logo">
|
|
||||||
</div>
|
|
||||||
<div class="content">
|
<div class="content">
|
||||||
<p class="content1">Dear {{ to }},</p>
|
<p class="content1">Dear {{ to }},</p>
|
||||||
<p class="content2">{{ inviter_name }} is pleased to invite you to join our workspace on {{application_title}}, a platform specifically designed for LLM application development. On {{application_title}}, you can explore, create, and collaborate to build and operate AI applications.</p>
|
<p class="content2">{{ inviter_name }} is pleased to invite you to join our workspace on {{application_title}}, a platform specifically designed for LLM application development. On {{application_title}}, you can explore, create, and collaborate to build and operate AI applications.</p>
|
||||||
|
|||||||
@ -25,9 +25,7 @@ class TestAnnotationService:
|
|||||||
patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task,
|
patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task,
|
||||||
patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task,
|
patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task,
|
||||||
patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task,
|
patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task,
|
||||||
patch(
|
patch("services.annotation_service.current_account_with_tenant") as mock_current_account_with_tenant,
|
||||||
"services.annotation_service.current_user", create_autospec(Account, instance=True)
|
|
||||||
) as mock_current_user,
|
|
||||||
):
|
):
|
||||||
# Setup default mock returns
|
# Setup default mock returns
|
||||||
mock_account_feature_service.get_features.return_value.billing.enabled = False
|
mock_account_feature_service.get_features.return_value.billing.enabled = False
|
||||||
@ -38,6 +36,9 @@ class TestAnnotationService:
|
|||||||
mock_disable_task.delay.return_value = None
|
mock_disable_task.delay.return_value = None
|
||||||
mock_batch_import_task.delay.return_value = None
|
mock_batch_import_task.delay.return_value = None
|
||||||
|
|
||||||
|
# Create mock user that will be returned by current_account_with_tenant
|
||||||
|
mock_user = create_autospec(Account, instance=True)
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"account_feature_service": mock_account_feature_service,
|
"account_feature_service": mock_account_feature_service,
|
||||||
"feature_service": mock_feature_service,
|
"feature_service": mock_feature_service,
|
||||||
@ -47,7 +48,8 @@ class TestAnnotationService:
|
|||||||
"enable_task": mock_enable_task,
|
"enable_task": mock_enable_task,
|
||||||
"disable_task": mock_disable_task,
|
"disable_task": mock_disable_task,
|
||||||
"batch_import_task": mock_batch_import_task,
|
"batch_import_task": mock_batch_import_task,
|
||||||
"current_user": mock_current_user,
|
"current_account_with_tenant": mock_current_account_with_tenant,
|
||||||
|
"current_user": mock_user,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
@ -107,6 +109,11 @@ class TestAnnotationService:
|
|||||||
"""
|
"""
|
||||||
mock_external_service_dependencies["current_user"].id = account_id
|
mock_external_service_dependencies["current_user"].id = account_id
|
||||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id
|
mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id
|
||||||
|
# Configure current_account_with_tenant to return (user, tenant_id)
|
||||||
|
mock_external_service_dependencies["current_account_with_tenant"].return_value = (
|
||||||
|
mock_external_service_dependencies["current_user"],
|
||||||
|
tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
def _create_test_conversation(self, app, account, fake):
|
def _create_test_conversation(self, app, account, fake):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -789,6 +789,31 @@ class TestWorkflowAppService:
|
|||||||
assert result_account_filter["total"] == 3
|
assert result_account_filter["total"] == 3
|
||||||
assert all(log.created_by_role == CreatorUserRole.ACCOUNT for log in result_account_filter["data"])
|
assert all(log.created_by_role == CreatorUserRole.ACCOUNT for log in result_account_filter["data"])
|
||||||
|
|
||||||
|
# Test filtering by changed account email
|
||||||
|
original_email = account.email
|
||||||
|
new_email = "changed@example.com"
|
||||||
|
account.email = new_email
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
assert account.email == new_email
|
||||||
|
|
||||||
|
# Results for new email, is expected to be the same as the original email
|
||||||
|
result_with_new_email = service.get_paginate_workflow_app_logs(
|
||||||
|
session=db_session_with_containers, app_model=app, created_by_account=new_email, page=1, limit=20
|
||||||
|
)
|
||||||
|
assert result_with_new_email["total"] == 3
|
||||||
|
assert all(log.created_by_role == CreatorUserRole.ACCOUNT for log in result_with_new_email["data"])
|
||||||
|
|
||||||
|
# Old email unbound, is unexpected input, should raise ValueError
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
service.get_paginate_workflow_app_logs(
|
||||||
|
session=db_session_with_containers, app_model=app, created_by_account=original_email, page=1, limit=20
|
||||||
|
)
|
||||||
|
assert "Account not found" in str(exc_info.value)
|
||||||
|
|
||||||
|
account.email = original_email
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
# Test filtering by non-existent session ID
|
# Test filtering by non-existent session ID
|
||||||
result_no_session = service.get_paginate_workflow_app_logs(
|
result_no_session = service.get_paginate_workflow_app_logs(
|
||||||
session=db_session_with_containers,
|
session=db_session_with_containers,
|
||||||
@ -799,15 +824,16 @@ class TestWorkflowAppService:
|
|||||||
)
|
)
|
||||||
assert result_no_session["total"] == 0
|
assert result_no_session["total"] == 0
|
||||||
|
|
||||||
# Test filtering by non-existent account email
|
# Test filtering by non-existent account email, is unexpected input, should raise ValueError
|
||||||
result_no_account = service.get_paginate_workflow_app_logs(
|
with pytest.raises(ValueError) as exc_info:
|
||||||
session=db_session_with_containers,
|
service.get_paginate_workflow_app_logs(
|
||||||
app_model=app,
|
session=db_session_with_containers,
|
||||||
created_by_account="nonexistent@example.com",
|
app_model=app,
|
||||||
page=1,
|
created_by_account="nonexistent@example.com",
|
||||||
limit=20,
|
page=1,
|
||||||
)
|
limit=20,
|
||||||
assert result_no_account["total"] == 0
|
)
|
||||||
|
assert "Account not found" in str(exc_info.value)
|
||||||
|
|
||||||
def test_get_paginate_workflow_app_logs_with_uuid_keyword_search(
|
def test_get_paginate_workflow_app_logs_with_uuid_keyword_search(
|
||||||
self, db_session_with_containers, mock_external_service_dependencies
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
@ -1057,15 +1083,15 @@ class TestWorkflowAppService:
|
|||||||
assert len(result_no_session["data"]) == 0
|
assert len(result_no_session["data"]) == 0
|
||||||
|
|
||||||
# Test with account email that doesn't exist
|
# Test with account email that doesn't exist
|
||||||
result_no_account = service.get_paginate_workflow_app_logs(
|
with pytest.raises(ValueError) as exc_info:
|
||||||
session=db_session_with_containers,
|
service.get_paginate_workflow_app_logs(
|
||||||
app_model=app,
|
session=db_session_with_containers,
|
||||||
created_by_account="nonexistent@example.com",
|
app_model=app,
|
||||||
page=1,
|
created_by_account="nonexistent@example.com",
|
||||||
limit=20,
|
page=1,
|
||||||
)
|
limit=20,
|
||||||
assert result_no_account["total"] == 0
|
)
|
||||||
assert len(result_no_account["data"]) == 0
|
assert "Account not found" in str(exc_info.value)
|
||||||
|
|
||||||
def test_get_paginate_workflow_app_logs_with_complex_query_combinations(
|
def test_get_paginate_workflow_app_logs_with_complex_query_combinations(
|
||||||
self, db_session_with_containers, mock_external_service_dependencies
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
|||||||
@ -0,0 +1,401 @@
|
|||||||
|
"""
|
||||||
|
TestContainers-based integration tests for mail_owner_transfer_task.
|
||||||
|
|
||||||
|
This module provides comprehensive integration tests for the mail owner transfer tasks
|
||||||
|
using TestContainers to ensure real email service integration and proper functionality
|
||||||
|
testing with actual database and service dependencies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from faker import Faker
|
||||||
|
|
||||||
|
from libs.email_i18n import EmailType
|
||||||
|
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||||
|
from tasks.mail_owner_transfer_task import (
|
||||||
|
send_new_owner_transfer_notify_email_task,
|
||||||
|
send_old_owner_transfer_notify_email_task,
|
||||||
|
send_owner_transfer_confirm_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMailOwnerTransferTask:
|
||||||
|
"""Integration tests for mail owner transfer tasks using testcontainers."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_mail_dependencies(self):
|
||||||
|
"""Mock setup for mail service dependencies."""
|
||||||
|
with (
|
||||||
|
patch("tasks.mail_owner_transfer_task.mail") as mock_mail,
|
||||||
|
patch("tasks.mail_owner_transfer_task.get_email_i18n_service") as mock_get_email_service,
|
||||||
|
):
|
||||||
|
# Setup mock mail service
|
||||||
|
mock_mail.is_inited.return_value = True
|
||||||
|
|
||||||
|
# Setup mock email service
|
||||||
|
mock_email_service = MagicMock()
|
||||||
|
mock_get_email_service.return_value = mock_email_service
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"mail": mock_mail,
|
||||||
|
"email_service": mock_email_service,
|
||||||
|
"get_email_service": mock_get_email_service,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_test_account_and_tenant(self, db_session_with_containers):
|
||||||
|
"""
|
||||||
|
Helper method to create test account and tenant for testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session_with_containers: Database session from testcontainers infrastructure
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (account, tenant) - Created account and tenant instances
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Create account
|
||||||
|
account = Account(
|
||||||
|
email=fake.email(),
|
||||||
|
name=fake.name(),
|
||||||
|
interface_language="en-US",
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(account)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
# Create tenant
|
||||||
|
tenant = Tenant(
|
||||||
|
name=fake.company(),
|
||||||
|
status="normal",
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(tenant)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
# Create tenant-account join
|
||||||
|
join = TenantAccountJoin(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
account_id=account.id,
|
||||||
|
role=TenantAccountRole.OWNER.value,
|
||||||
|
current=True,
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(join)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
return account, tenant
|
||||||
|
|
||||||
|
def test_send_owner_transfer_confirm_task_success(self, db_session_with_containers, mock_mail_dependencies):
|
||||||
|
"""
|
||||||
|
Test successful owner transfer confirmation email sending.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper email service initialization check
|
||||||
|
- Correct email service method calls with right parameters
|
||||||
|
- Email template context is properly constructed
|
||||||
|
"""
|
||||||
|
# Arrange: Create test data
|
||||||
|
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||||
|
|
||||||
|
test_language = "en-US"
|
||||||
|
test_email = account.email
|
||||||
|
test_code = "123456"
|
||||||
|
test_workspace = tenant.name
|
||||||
|
|
||||||
|
# Act: Execute the task
|
||||||
|
send_owner_transfer_confirm_task(
|
||||||
|
language=test_language,
|
||||||
|
to=test_email,
|
||||||
|
code=test_code,
|
||||||
|
workspace=test_workspace,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert: Verify the expected outcomes
|
||||||
|
mock_mail_dependencies["mail"].is_inited.assert_called_once()
|
||||||
|
mock_mail_dependencies["get_email_service"].assert_called_once()
|
||||||
|
|
||||||
|
# Verify email service was called with correct parameters
|
||||||
|
mock_mail_dependencies["email_service"].send_email.assert_called_once()
|
||||||
|
call_args = mock_mail_dependencies["email_service"].send_email.call_args
|
||||||
|
|
||||||
|
assert call_args[1]["email_type"] == EmailType.OWNER_TRANSFER_CONFIRM
|
||||||
|
assert call_args[1]["language_code"] == test_language
|
||||||
|
assert call_args[1]["to"] == test_email
|
||||||
|
assert call_args[1]["template_context"]["to"] == test_email
|
||||||
|
assert call_args[1]["template_context"]["code"] == test_code
|
||||||
|
assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace
|
||||||
|
|
||||||
|
def test_send_owner_transfer_confirm_task_mail_not_initialized(
|
||||||
|
self, db_session_with_containers, mock_mail_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test owner transfer confirmation email when mail service is not initialized.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Early return when mail service is not initialized
|
||||||
|
- No email service calls are made
|
||||||
|
- No exceptions are raised
|
||||||
|
"""
|
||||||
|
# Arrange: Set mail service as not initialized
|
||||||
|
mock_mail_dependencies["mail"].is_inited.return_value = False
|
||||||
|
|
||||||
|
test_language = "en-US"
|
||||||
|
test_email = "test@example.com"
|
||||||
|
test_code = "123456"
|
||||||
|
test_workspace = "Test Workspace"
|
||||||
|
|
||||||
|
# Act: Execute the task
|
||||||
|
send_owner_transfer_confirm_task(
|
||||||
|
language=test_language,
|
||||||
|
to=test_email,
|
||||||
|
code=test_code,
|
||||||
|
workspace=test_workspace,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert: Verify no email service calls were made
|
||||||
|
mock_mail_dependencies["get_email_service"].assert_not_called()
|
||||||
|
mock_mail_dependencies["email_service"].send_email.assert_not_called()
|
||||||
|
|
||||||
|
def test_send_owner_transfer_confirm_task_exception_handling(
|
||||||
|
self, db_session_with_containers, mock_mail_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test exception handling in owner transfer confirmation email.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Exceptions are properly caught and logged
|
||||||
|
- No exceptions are propagated to caller
|
||||||
|
- Email service calls are attempted
|
||||||
|
- Error logging works correctly
|
||||||
|
"""
|
||||||
|
# Arrange: Setup email service to raise exception
|
||||||
|
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
|
||||||
|
|
||||||
|
test_language = "en-US"
|
||||||
|
test_email = "test@example.com"
|
||||||
|
test_code = "123456"
|
||||||
|
test_workspace = "Test Workspace"
|
||||||
|
|
||||||
|
# Act & Assert: Verify no exception is raised
|
||||||
|
try:
|
||||||
|
send_owner_transfer_confirm_task(
|
||||||
|
language=test_language,
|
||||||
|
to=test_email,
|
||||||
|
code=test_code,
|
||||||
|
workspace=test_workspace,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Task should not raise exceptions, but raised: {e}")
|
||||||
|
|
||||||
|
# Verify email service was called despite the exception
|
||||||
|
mock_mail_dependencies["email_service"].send_email.assert_called_once()
|
||||||
|
|
||||||
|
def test_send_old_owner_transfer_notify_email_task_success(
|
||||||
|
self, db_session_with_containers, mock_mail_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test successful old owner transfer notification email sending.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper email service initialization check
|
||||||
|
- Correct email service method calls with right parameters
|
||||||
|
- Email template context includes new owner email
|
||||||
|
"""
|
||||||
|
# Arrange: Create test data
|
||||||
|
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||||
|
|
||||||
|
test_language = "en-US"
|
||||||
|
test_email = account.email
|
||||||
|
test_workspace = tenant.name
|
||||||
|
test_new_owner_email = "newowner@example.com"
|
||||||
|
|
||||||
|
# Act: Execute the task
|
||||||
|
send_old_owner_transfer_notify_email_task(
|
||||||
|
language=test_language,
|
||||||
|
to=test_email,
|
||||||
|
workspace=test_workspace,
|
||||||
|
new_owner_email=test_new_owner_email,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert: Verify the expected outcomes
|
||||||
|
mock_mail_dependencies["mail"].is_inited.assert_called_once()
|
||||||
|
mock_mail_dependencies["get_email_service"].assert_called_once()
|
||||||
|
|
||||||
|
# Verify email service was called with correct parameters
|
||||||
|
mock_mail_dependencies["email_service"].send_email.assert_called_once()
|
||||||
|
call_args = mock_mail_dependencies["email_service"].send_email.call_args
|
||||||
|
|
||||||
|
assert call_args[1]["email_type"] == EmailType.OWNER_TRANSFER_OLD_NOTIFY
|
||||||
|
assert call_args[1]["language_code"] == test_language
|
||||||
|
assert call_args[1]["to"] == test_email
|
||||||
|
assert call_args[1]["template_context"]["to"] == test_email
|
||||||
|
assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace
|
||||||
|
assert call_args[1]["template_context"]["NewOwnerEmail"] == test_new_owner_email
|
||||||
|
|
||||||
|
def test_send_old_owner_transfer_notify_email_task_mail_not_initialized(
|
||||||
|
self, db_session_with_containers, mock_mail_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test old owner transfer notification email when mail service is not initialized.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Early return when mail service is not initialized
|
||||||
|
- No email service calls are made
|
||||||
|
- No exceptions are raised
|
||||||
|
"""
|
||||||
|
# Arrange: Set mail service as not initialized
|
||||||
|
mock_mail_dependencies["mail"].is_inited.return_value = False
|
||||||
|
|
||||||
|
test_language = "en-US"
|
||||||
|
test_email = "test@example.com"
|
||||||
|
test_workspace = "Test Workspace"
|
||||||
|
test_new_owner_email = "newowner@example.com"
|
||||||
|
|
||||||
|
# Act: Execute the task
|
||||||
|
send_old_owner_transfer_notify_email_task(
|
||||||
|
language=test_language,
|
||||||
|
to=test_email,
|
||||||
|
workspace=test_workspace,
|
||||||
|
new_owner_email=test_new_owner_email,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert: Verify no email service calls were made
|
||||||
|
mock_mail_dependencies["get_email_service"].assert_not_called()
|
||||||
|
mock_mail_dependencies["email_service"].send_email.assert_not_called()
|
||||||
|
|
||||||
|
def test_send_old_owner_transfer_notify_email_task_exception_handling(
|
||||||
|
self, db_session_with_containers, mock_mail_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test exception handling in old owner transfer notification email.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Exceptions are properly caught and logged
|
||||||
|
- No exceptions are propagated to caller
|
||||||
|
- Email service calls are attempted
|
||||||
|
- Error logging works correctly
|
||||||
|
"""
|
||||||
|
# Arrange: Setup email service to raise exception
|
||||||
|
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
|
||||||
|
|
||||||
|
test_language = "en-US"
|
||||||
|
test_email = "test@example.com"
|
||||||
|
test_workspace = "Test Workspace"
|
||||||
|
test_new_owner_email = "newowner@example.com"
|
||||||
|
|
||||||
|
# Act & Assert: Verify no exception is raised
|
||||||
|
try:
|
||||||
|
send_old_owner_transfer_notify_email_task(
|
||||||
|
language=test_language,
|
||||||
|
to=test_email,
|
||||||
|
workspace=test_workspace,
|
||||||
|
new_owner_email=test_new_owner_email,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Task should not raise exceptions, but raised: {e}")
|
||||||
|
|
||||||
|
# Verify email service was called despite the exception
|
||||||
|
mock_mail_dependencies["email_service"].send_email.assert_called_once()
|
||||||
|
|
||||||
|
def test_send_new_owner_transfer_notify_email_task_success(
|
||||||
|
self, db_session_with_containers, mock_mail_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test successful new owner transfer notification email sending.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper email service initialization check
|
||||||
|
- Correct email service method calls with right parameters
|
||||||
|
- Email template context is properly constructed
|
||||||
|
"""
|
||||||
|
# Arrange: Create test data
|
||||||
|
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||||
|
|
||||||
|
test_language = "en-US"
|
||||||
|
test_email = account.email
|
||||||
|
test_workspace = tenant.name
|
||||||
|
|
||||||
|
# Act: Execute the task
|
||||||
|
send_new_owner_transfer_notify_email_task(
|
||||||
|
language=test_language,
|
||||||
|
to=test_email,
|
||||||
|
workspace=test_workspace,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert: Verify the expected outcomes
|
||||||
|
mock_mail_dependencies["mail"].is_inited.assert_called_once()
|
||||||
|
mock_mail_dependencies["get_email_service"].assert_called_once()
|
||||||
|
|
||||||
|
# Verify email service was called with correct parameters
|
||||||
|
mock_mail_dependencies["email_service"].send_email.assert_called_once()
|
||||||
|
call_args = mock_mail_dependencies["email_service"].send_email.call_args
|
||||||
|
|
||||||
|
assert call_args[1]["email_type"] == EmailType.OWNER_TRANSFER_NEW_NOTIFY
|
||||||
|
assert call_args[1]["language_code"] == test_language
|
||||||
|
assert call_args[1]["to"] == test_email
|
||||||
|
assert call_args[1]["template_context"]["to"] == test_email
|
||||||
|
assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace
|
||||||
|
|
||||||
|
def test_send_new_owner_transfer_notify_email_task_mail_not_initialized(
|
||||||
|
self, db_session_with_containers, mock_mail_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test new owner transfer notification email when mail service is not initialized.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Early return when mail service is not initialized
|
||||||
|
- No email service calls are made
|
||||||
|
- No exceptions are raised
|
||||||
|
"""
|
||||||
|
# Arrange: Set mail service as not initialized
|
||||||
|
mock_mail_dependencies["mail"].is_inited.return_value = False
|
||||||
|
|
||||||
|
test_language = "en-US"
|
||||||
|
test_email = "test@example.com"
|
||||||
|
test_workspace = "Test Workspace"
|
||||||
|
|
||||||
|
# Act: Execute the task
|
||||||
|
send_new_owner_transfer_notify_email_task(
|
||||||
|
language=test_language,
|
||||||
|
to=test_email,
|
||||||
|
workspace=test_workspace,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert: Verify no email service calls were made
|
||||||
|
mock_mail_dependencies["get_email_service"].assert_not_called()
|
||||||
|
mock_mail_dependencies["email_service"].send_email.assert_not_called()
|
||||||
|
|
||||||
|
def test_send_new_owner_transfer_notify_email_task_exception_handling(
|
||||||
|
self, db_session_with_containers, mock_mail_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test exception handling in new owner transfer notification email.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Exceptions are properly caught and logged
|
||||||
|
- No exceptions are propagated to caller
|
||||||
|
- Email service calls are attempted
|
||||||
|
- Error logging works correctly
|
||||||
|
"""
|
||||||
|
# Arrange: Setup email service to raise exception
|
||||||
|
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
|
||||||
|
|
||||||
|
test_language = "en-US"
|
||||||
|
test_email = "test@example.com"
|
||||||
|
test_workspace = "Test Workspace"
|
||||||
|
|
||||||
|
# Act & Assert: Verify no exception is raised
|
||||||
|
try:
|
||||||
|
send_new_owner_transfer_notify_email_task(
|
||||||
|
language=test_language,
|
||||||
|
to=test_email,
|
||||||
|
workspace=test_workspace,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Task should not raise exceptions, but raised: {e}")
|
||||||
|
|
||||||
|
# Verify email service was called despite the exception
|
||||||
|
mock_mail_dependencies["email_service"].send_email.assert_called_once()
|
||||||
@ -60,7 +60,7 @@ class TestAccountInitialization:
|
|||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with patch("controllers.console.wraps._current_account", return_value=mock_user):
|
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_user, "tenant123")):
|
||||||
result = protected_view()
|
result = protected_view()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
@ -77,7 +77,7 @@ class TestAccountInitialization:
|
|||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with patch("controllers.console.wraps._current_account", return_value=mock_user):
|
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_user, "tenant123")):
|
||||||
with pytest.raises(AccountNotInitializedError):
|
with pytest.raises(AccountNotInitializedError):
|
||||||
protected_view()
|
protected_view()
|
||||||
|
|
||||||
@ -163,7 +163,9 @@ class TestBillingResourceLimits:
|
|||||||
return "member_added"
|
return "member_added"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
with patch(
|
||||||
|
"controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123")
|
||||||
|
):
|
||||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||||
result = add_member()
|
result = add_member()
|
||||||
|
|
||||||
@ -185,7 +187,10 @@ class TestBillingResourceLimits:
|
|||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with app.test_request_context():
|
with app.test_request_context():
|
||||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
with patch(
|
||||||
|
"controllers.console.wraps.current_account_with_tenant",
|
||||||
|
return_value=(MockUser("test_user"), "tenant123"),
|
||||||
|
):
|
||||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||||
with pytest.raises(Exception) as exc_info:
|
with pytest.raises(Exception) as exc_info:
|
||||||
add_member()
|
add_member()
|
||||||
@ -207,7 +212,10 @@ class TestBillingResourceLimits:
|
|||||||
|
|
||||||
# Test 1: Should reject when source is datasets
|
# Test 1: Should reject when source is datasets
|
||||||
with app.test_request_context("/?source=datasets"):
|
with app.test_request_context("/?source=datasets"):
|
||||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
with patch(
|
||||||
|
"controllers.console.wraps.current_account_with_tenant",
|
||||||
|
return_value=(MockUser("test_user"), "tenant123"),
|
||||||
|
):
|
||||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||||
with pytest.raises(Exception) as exc_info:
|
with pytest.raises(Exception) as exc_info:
|
||||||
upload_document()
|
upload_document()
|
||||||
@ -215,7 +223,10 @@ class TestBillingResourceLimits:
|
|||||||
|
|
||||||
# Test 2: Should allow when source is not datasets
|
# Test 2: Should allow when source is not datasets
|
||||||
with app.test_request_context("/?source=other"):
|
with app.test_request_context("/?source=other"):
|
||||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
with patch(
|
||||||
|
"controllers.console.wraps.current_account_with_tenant",
|
||||||
|
return_value=(MockUser("test_user"), "tenant123"),
|
||||||
|
):
|
||||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||||
result = upload_document()
|
result = upload_document()
|
||||||
assert result == "document_uploaded"
|
assert result == "document_uploaded"
|
||||||
@ -239,7 +250,9 @@ class TestRateLimiting:
|
|||||||
return "knowledge_success"
|
return "knowledge_success"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
with patch(
|
||||||
|
"controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123")
|
||||||
|
):
|
||||||
with patch(
|
with patch(
|
||||||
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
||||||
):
|
):
|
||||||
@ -271,7 +284,10 @@ class TestRateLimiting:
|
|||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with app.test_request_context():
|
with app.test_request_context():
|
||||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
with patch(
|
||||||
|
"controllers.console.wraps.current_account_with_tenant",
|
||||||
|
return_value=(MockUser("test_user"), "tenant123"),
|
||||||
|
):
|
||||||
with patch(
|
with patch(
|
||||||
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
||||||
):
|
):
|
||||||
|
|||||||
@ -110,19 +110,6 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
|||||||
assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes
|
assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes
|
||||||
mock_redis.set.assert_called_once()
|
mock_redis.set.assert_called_once()
|
||||||
|
|
||||||
def test_config_validation(self):
|
|
||||||
"""Test configuration validation."""
|
|
||||||
# Test missing required fields
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
AlibabaCloudMySQLVectorConfig(
|
|
||||||
host="", # Empty host should raise error
|
|
||||||
port=3306,
|
|
||||||
user="test",
|
|
||||||
password="test",
|
|
||||||
database="test",
|
|
||||||
max_connection=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||||
)
|
)
|
||||||
@ -718,5 +705,29 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
|||||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"invalid_config_override",
|
||||||
|
[
|
||||||
|
{"host": ""}, # Test empty host
|
||||||
|
{"port": 0}, # Test invalid port
|
||||||
|
{"max_connection": 0}, # Test invalid max_connection
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_config_validation_parametrized(invalid_config_override):
|
||||||
|
"""Test configuration validation for various invalid inputs using parametrize."""
|
||||||
|
config = {
|
||||||
|
"host": "localhost",
|
||||||
|
"port": 3306,
|
||||||
|
"user": "test",
|
||||||
|
"password": "test",
|
||||||
|
"database": "test",
|
||||||
|
"max_connection": 5,
|
||||||
|
}
|
||||||
|
config.update(invalid_config_override)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
AlibabaCloudMySQLVectorConfig(**config)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
29
api/tests/unit_tests/core/tools/test_tool_entities.py
Normal file
29
api/tests/unit_tests/core/tools/test_tool_entities.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
|
||||||
|
|
||||||
|
|
||||||
|
def _make_identity() -> ToolIdentity:
|
||||||
|
return ToolIdentity(
|
||||||
|
author="author",
|
||||||
|
name="tool",
|
||||||
|
label=I18nObject(en_US="Label"),
|
||||||
|
provider="builtin",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_message_metadata_none_defaults_to_empty_dict():
|
||||||
|
log_message = ToolInvokeMessage.LogMessage(
|
||||||
|
id="log-1",
|
||||||
|
label="Log entry",
|
||||||
|
status=ToolInvokeMessage.LogMessage.LogStatus.START,
|
||||||
|
data={},
|
||||||
|
metadata=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert log_message.metadata == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_entity_output_schema_none_defaults_to_empty_dict():
|
||||||
|
entity = ToolEntity(identity=_make_identity(), output_schema=None)
|
||||||
|
|
||||||
|
assert entity.output_schema == {}
|
||||||
1981
api/uv.lock
generated
1981
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@ -329,7 +329,7 @@ services:
|
|||||||
|
|
||||||
# The Weaviate vector store.
|
# The Weaviate vector store.
|
||||||
weaviate:
|
weaviate:
|
||||||
image: semitechnologies/weaviate:1.19.0
|
image: semitechnologies/weaviate:1.27.0
|
||||||
profiles:
|
profiles:
|
||||||
- ""
|
- ""
|
||||||
- weaviate
|
- weaviate
|
||||||
|
|||||||
@ -181,7 +181,7 @@ services:
|
|||||||
|
|
||||||
# The Weaviate vector store.
|
# The Weaviate vector store.
|
||||||
weaviate:
|
weaviate:
|
||||||
image: semitechnologies/weaviate:1.19.0
|
image: semitechnologies/weaviate:1.27.0
|
||||||
profiles:
|
profiles:
|
||||||
- ""
|
- ""
|
||||||
- weaviate
|
- weaviate
|
||||||
@ -206,6 +206,7 @@ services:
|
|||||||
AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai}
|
AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai}
|
||||||
ports:
|
ports:
|
||||||
- "${EXPOSE_WEAVIATE_PORT:-8080}:8080"
|
- "${EXPOSE_WEAVIATE_PORT:-8080}:8080"
|
||||||
|
- "${EXPOSE_WEAVIATE_GRPC_PORT:-50051}:50051"
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
# create a network between sandbox, api and ssrf_proxy, and can not access outside.
|
# create a network between sandbox, api and ssrf_proxy, and can not access outside.
|
||||||
|
|||||||
9
docker/docker-compose.override.yml
Normal file
9
docker/docker-compose.override.yml
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
services:
|
||||||
|
api:
|
||||||
|
volumes:
|
||||||
|
- ../api/core/rag/datasource/vdb/weaviate/weaviate_vector.py:/app/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py:ro
|
||||||
|
command: >
|
||||||
|
sh -c "
|
||||||
|
pip install --no-cache-dir 'weaviate>=4.0.0' &&
|
||||||
|
/bin/bash /entrypoint.sh
|
||||||
|
"
|
||||||
@ -941,7 +941,7 @@ services:
|
|||||||
|
|
||||||
# The Weaviate vector store.
|
# The Weaviate vector store.
|
||||||
weaviate:
|
weaviate:
|
||||||
image: semitechnologies/weaviate:1.19.0
|
image: semitechnologies/weaviate:1.27.0
|
||||||
profiles:
|
profiles:
|
||||||
- ""
|
- ""
|
||||||
- weaviate
|
- weaviate
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import { useTranslation } from 'react-i18next'
|
|||||||
import { useBoolean } from 'ahooks'
|
import { useBoolean } from 'ahooks'
|
||||||
import TracingIcon from './tracing-icon'
|
import TracingIcon from './tracing-icon'
|
||||||
import ProviderPanel from './provider-panel'
|
import ProviderPanel from './provider-panel'
|
||||||
import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, WeaveConfig } from './type'
|
import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
|
||||||
import { TracingProvider } from './type'
|
import { TracingProvider } from './type'
|
||||||
import ProviderConfigModal from './provider-config-modal'
|
import ProviderConfigModal from './provider-config-modal'
|
||||||
import Indicator from '@/app/components/header/indicator'
|
import Indicator from '@/app/components/header/indicator'
|
||||||
@ -30,7 +30,8 @@ export type PopupProps = {
|
|||||||
opikConfig: OpikConfig | null
|
opikConfig: OpikConfig | null
|
||||||
weaveConfig: WeaveConfig | null
|
weaveConfig: WeaveConfig | null
|
||||||
aliyunConfig: AliyunConfig | null
|
aliyunConfig: AliyunConfig | null
|
||||||
onConfigUpdated: (provider: TracingProvider, payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig) => void
|
tencentConfig: TencentConfig | null
|
||||||
|
onConfigUpdated: (provider: TracingProvider, payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig) => void
|
||||||
onConfigRemoved: (provider: TracingProvider) => void
|
onConfigRemoved: (provider: TracingProvider) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,6 +49,7 @@ const ConfigPopup: FC<PopupProps> = ({
|
|||||||
opikConfig,
|
opikConfig,
|
||||||
weaveConfig,
|
weaveConfig,
|
||||||
aliyunConfig,
|
aliyunConfig,
|
||||||
|
tencentConfig,
|
||||||
onConfigUpdated,
|
onConfigUpdated,
|
||||||
onConfigRemoved,
|
onConfigRemoved,
|
||||||
}) => {
|
}) => {
|
||||||
@ -81,8 +83,8 @@ const ConfigPopup: FC<PopupProps> = ({
|
|||||||
hideConfigModal()
|
hideConfigModal()
|
||||||
}, [currentProvider, hideConfigModal, onConfigRemoved])
|
}, [currentProvider, hideConfigModal, onConfigRemoved])
|
||||||
|
|
||||||
const providerAllConfigured = arizeConfig && phoenixConfig && langSmithConfig && langFuseConfig && opikConfig && weaveConfig && aliyunConfig
|
const providerAllConfigured = arizeConfig && phoenixConfig && langSmithConfig && langFuseConfig && opikConfig && weaveConfig && aliyunConfig && tencentConfig
|
||||||
const providerAllNotConfigured = !arizeConfig && !phoenixConfig && !langSmithConfig && !langFuseConfig && !opikConfig && !weaveConfig && !aliyunConfig
|
const providerAllNotConfigured = !arizeConfig && !phoenixConfig && !langSmithConfig && !langFuseConfig && !opikConfig && !weaveConfig && !aliyunConfig && !tencentConfig
|
||||||
|
|
||||||
const switchContent = (
|
const switchContent = (
|
||||||
<Switch
|
<Switch
|
||||||
@ -182,6 +184,19 @@ const ConfigPopup: FC<PopupProps> = ({
|
|||||||
key="aliyun-provider-panel"
|
key="aliyun-provider-panel"
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const tencentPanel = (
|
||||||
|
<ProviderPanel
|
||||||
|
type={TracingProvider.tencent}
|
||||||
|
readOnly={readOnly}
|
||||||
|
config={tencentConfig}
|
||||||
|
hasConfigured={!!tencentConfig}
|
||||||
|
onConfig={handleOnConfig(TracingProvider.tencent)}
|
||||||
|
isChosen={chosenProvider === TracingProvider.tencent}
|
||||||
|
onChoose={handleOnChoose(TracingProvider.tencent)}
|
||||||
|
key="tencent-provider-panel"
|
||||||
|
/>
|
||||||
|
)
|
||||||
const configuredProviderPanel = () => {
|
const configuredProviderPanel = () => {
|
||||||
const configuredPanels: JSX.Element[] = []
|
const configuredPanels: JSX.Element[] = []
|
||||||
|
|
||||||
@ -206,6 +221,9 @@ const ConfigPopup: FC<PopupProps> = ({
|
|||||||
if (aliyunConfig)
|
if (aliyunConfig)
|
||||||
configuredPanels.push(aliyunPanel)
|
configuredPanels.push(aliyunPanel)
|
||||||
|
|
||||||
|
if (tencentConfig)
|
||||||
|
configuredPanels.push(tencentPanel)
|
||||||
|
|
||||||
return configuredPanels
|
return configuredPanels
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,6 +251,9 @@ const ConfigPopup: FC<PopupProps> = ({
|
|||||||
if (!aliyunConfig)
|
if (!aliyunConfig)
|
||||||
notConfiguredPanels.push(aliyunPanel)
|
notConfiguredPanels.push(aliyunPanel)
|
||||||
|
|
||||||
|
if (!tencentConfig)
|
||||||
|
notConfiguredPanels.push(tencentPanel)
|
||||||
|
|
||||||
return notConfiguredPanels
|
return notConfiguredPanels
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -249,6 +270,8 @@ const ConfigPopup: FC<PopupProps> = ({
|
|||||||
return opikConfig
|
return opikConfig
|
||||||
if (currentProvider === TracingProvider.aliyun)
|
if (currentProvider === TracingProvider.aliyun)
|
||||||
return aliyunConfig
|
return aliyunConfig
|
||||||
|
if (currentProvider === TracingProvider.tencent)
|
||||||
|
return tencentConfig
|
||||||
return weaveConfig
|
return weaveConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -297,6 +320,7 @@ const ConfigPopup: FC<PopupProps> = ({
|
|||||||
{arizePanel}
|
{arizePanel}
|
||||||
{phoenixPanel}
|
{phoenixPanel}
|
||||||
{aliyunPanel}
|
{aliyunPanel}
|
||||||
|
{tencentPanel}
|
||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
)
|
)
|
||||||
|
|||||||
@ -8,4 +8,5 @@ export const docURL = {
|
|||||||
[TracingProvider.opik]: 'https://www.comet.com/docs/opik/tracing/integrations/dify#setup-instructions',
|
[TracingProvider.opik]: 'https://www.comet.com/docs/opik/tracing/integrations/dify#setup-instructions',
|
||||||
[TracingProvider.weave]: 'https://weave-docs.wandb.ai/',
|
[TracingProvider.weave]: 'https://weave-docs.wandb.ai/',
|
||||||
[TracingProvider.aliyun]: 'https://help.aliyun.com/zh/arms/tracing-analysis/untitled-document-1750672984680',
|
[TracingProvider.aliyun]: 'https://help.aliyun.com/zh/arms/tracing-analysis/untitled-document-1750672984680',
|
||||||
|
[TracingProvider.tencent]: 'https://cloud.tencent.com/document/product/248/116531',
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,12 +8,12 @@ import {
|
|||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { usePathname } from 'next/navigation'
|
import { usePathname } from 'next/navigation'
|
||||||
import { useBoolean } from 'ahooks'
|
import { useBoolean } from 'ahooks'
|
||||||
import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, WeaveConfig } from './type'
|
import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
|
||||||
import { TracingProvider } from './type'
|
import { TracingProvider } from './type'
|
||||||
import TracingIcon from './tracing-icon'
|
import TracingIcon from './tracing-icon'
|
||||||
import ConfigButton from './config-button'
|
import ConfigButton from './config-button'
|
||||||
import cn from '@/utils/classnames'
|
import cn from '@/utils/classnames'
|
||||||
import { AliyunIcon, ArizeIcon, LangfuseIcon, LangsmithIcon, OpikIcon, PhoenixIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing'
|
import { AliyunIcon, ArizeIcon, LangfuseIcon, LangsmithIcon, OpikIcon, PhoenixIcon, TencentIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing'
|
||||||
import Indicator from '@/app/components/header/indicator'
|
import Indicator from '@/app/components/header/indicator'
|
||||||
import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps'
|
import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps'
|
||||||
import type { TracingStatus } from '@/models/app'
|
import type { TracingStatus } from '@/models/app'
|
||||||
@ -71,6 +71,7 @@ const Panel: FC = () => {
|
|||||||
[TracingProvider.opik]: OpikIcon,
|
[TracingProvider.opik]: OpikIcon,
|
||||||
[TracingProvider.weave]: WeaveIcon,
|
[TracingProvider.weave]: WeaveIcon,
|
||||||
[TracingProvider.aliyun]: AliyunIcon,
|
[TracingProvider.aliyun]: AliyunIcon,
|
||||||
|
[TracingProvider.tencent]: TencentIcon,
|
||||||
}
|
}
|
||||||
const InUseProviderIcon = inUseTracingProvider ? providerIconMap[inUseTracingProvider] : undefined
|
const InUseProviderIcon = inUseTracingProvider ? providerIconMap[inUseTracingProvider] : undefined
|
||||||
|
|
||||||
@ -81,7 +82,8 @@ const Panel: FC = () => {
|
|||||||
const [opikConfig, setOpikConfig] = useState<OpikConfig | null>(null)
|
const [opikConfig, setOpikConfig] = useState<OpikConfig | null>(null)
|
||||||
const [weaveConfig, setWeaveConfig] = useState<WeaveConfig | null>(null)
|
const [weaveConfig, setWeaveConfig] = useState<WeaveConfig | null>(null)
|
||||||
const [aliyunConfig, setAliyunConfig] = useState<AliyunConfig | null>(null)
|
const [aliyunConfig, setAliyunConfig] = useState<AliyunConfig | null>(null)
|
||||||
const hasConfiguredTracing = !!(langSmithConfig || langFuseConfig || opikConfig || weaveConfig || arizeConfig || phoenixConfig || aliyunConfig)
|
const [tencentConfig, setTencentConfig] = useState<TencentConfig | null>(null)
|
||||||
|
const hasConfiguredTracing = !!(langSmithConfig || langFuseConfig || opikConfig || weaveConfig || arizeConfig || phoenixConfig || aliyunConfig || tencentConfig)
|
||||||
|
|
||||||
const fetchTracingConfig = async () => {
|
const fetchTracingConfig = async () => {
|
||||||
const getArizeConfig = async () => {
|
const getArizeConfig = async () => {
|
||||||
@ -119,6 +121,11 @@ const Panel: FC = () => {
|
|||||||
if (!aliyunHasNotConfig)
|
if (!aliyunHasNotConfig)
|
||||||
setAliyunConfig(aliyunConfig as AliyunConfig)
|
setAliyunConfig(aliyunConfig as AliyunConfig)
|
||||||
}
|
}
|
||||||
|
const getTencentConfig = async () => {
|
||||||
|
const { tracing_config: tencentConfig, has_not_configured: tencentHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.tencent })
|
||||||
|
if (!tencentHasNotConfig)
|
||||||
|
setTencentConfig(tencentConfig as TencentConfig)
|
||||||
|
}
|
||||||
Promise.all([
|
Promise.all([
|
||||||
getArizeConfig(),
|
getArizeConfig(),
|
||||||
getPhoenixConfig(),
|
getPhoenixConfig(),
|
||||||
@ -127,6 +134,7 @@ const Panel: FC = () => {
|
|||||||
getOpikConfig(),
|
getOpikConfig(),
|
||||||
getWeaveConfig(),
|
getWeaveConfig(),
|
||||||
getAliyunConfig(),
|
getAliyunConfig(),
|
||||||
|
getTencentConfig(),
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -147,6 +155,8 @@ const Panel: FC = () => {
|
|||||||
setWeaveConfig(tracing_config as WeaveConfig)
|
setWeaveConfig(tracing_config as WeaveConfig)
|
||||||
else if (provider === TracingProvider.aliyun)
|
else if (provider === TracingProvider.aliyun)
|
||||||
setAliyunConfig(tracing_config as AliyunConfig)
|
setAliyunConfig(tracing_config as AliyunConfig)
|
||||||
|
else if (provider === TracingProvider.tencent)
|
||||||
|
setTencentConfig(tracing_config as TencentConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleTracingConfigRemoved = (provider: TracingProvider) => {
|
const handleTracingConfigRemoved = (provider: TracingProvider) => {
|
||||||
@ -164,6 +174,8 @@ const Panel: FC = () => {
|
|||||||
setWeaveConfig(null)
|
setWeaveConfig(null)
|
||||||
else if (provider === TracingProvider.aliyun)
|
else if (provider === TracingProvider.aliyun)
|
||||||
setAliyunConfig(null)
|
setAliyunConfig(null)
|
||||||
|
else if (provider === TracingProvider.tencent)
|
||||||
|
setTencentConfig(null)
|
||||||
if (provider === inUseTracingProvider) {
|
if (provider === inUseTracingProvider) {
|
||||||
handleTracingStatusChange({
|
handleTracingStatusChange({
|
||||||
enabled: false,
|
enabled: false,
|
||||||
@ -209,6 +221,7 @@ const Panel: FC = () => {
|
|||||||
opikConfig={opikConfig}
|
opikConfig={opikConfig}
|
||||||
weaveConfig={weaveConfig}
|
weaveConfig={weaveConfig}
|
||||||
aliyunConfig={aliyunConfig}
|
aliyunConfig={aliyunConfig}
|
||||||
|
tencentConfig={tencentConfig}
|
||||||
onConfigUpdated={handleTracingConfigUpdated}
|
onConfigUpdated={handleTracingConfigUpdated}
|
||||||
onConfigRemoved={handleTracingConfigRemoved}
|
onConfigRemoved={handleTracingConfigRemoved}
|
||||||
>
|
>
|
||||||
@ -245,6 +258,7 @@ const Panel: FC = () => {
|
|||||||
opikConfig={opikConfig}
|
opikConfig={opikConfig}
|
||||||
weaveConfig={weaveConfig}
|
weaveConfig={weaveConfig}
|
||||||
aliyunConfig={aliyunConfig}
|
aliyunConfig={aliyunConfig}
|
||||||
|
tencentConfig={tencentConfig}
|
||||||
onConfigUpdated={handleTracingConfigUpdated}
|
onConfigUpdated={handleTracingConfigUpdated}
|
||||||
onConfigRemoved={handleTracingConfigRemoved}
|
onConfigRemoved={handleTracingConfigRemoved}
|
||||||
>
|
>
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import React, { useCallback, useState } from 'react'
|
|||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { useBoolean } from 'ahooks'
|
import { useBoolean } from 'ahooks'
|
||||||
import Field from './field'
|
import Field from './field'
|
||||||
import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, WeaveConfig } from './type'
|
import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
|
||||||
import { TracingProvider } from './type'
|
import { TracingProvider } from './type'
|
||||||
import { docURL } from './config'
|
import { docURL } from './config'
|
||||||
import {
|
import {
|
||||||
@ -22,10 +22,10 @@ import Divider from '@/app/components/base/divider'
|
|||||||
type Props = {
|
type Props = {
|
||||||
appId: string
|
appId: string
|
||||||
type: TracingProvider
|
type: TracingProvider
|
||||||
payload?: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | null
|
payload?: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig | null
|
||||||
onRemoved: () => void
|
onRemoved: () => void
|
||||||
onCancel: () => void
|
onCancel: () => void
|
||||||
onSaved: (payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig) => void
|
onSaved: (payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig) => void
|
||||||
onChosen: (provider: TracingProvider) => void
|
onChosen: (provider: TracingProvider) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,6 +77,12 @@ const aliyunConfigTemplate = {
|
|||||||
endpoint: '',
|
endpoint: '',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const tencentConfigTemplate = {
|
||||||
|
token: '',
|
||||||
|
endpoint: '',
|
||||||
|
service_name: '',
|
||||||
|
}
|
||||||
|
|
||||||
const ProviderConfigModal: FC<Props> = ({
|
const ProviderConfigModal: FC<Props> = ({
|
||||||
appId,
|
appId,
|
||||||
type,
|
type,
|
||||||
@ -90,7 +96,7 @@ const ProviderConfigModal: FC<Props> = ({
|
|||||||
const isEdit = !!payload
|
const isEdit = !!payload
|
||||||
const isAdd = !isEdit
|
const isAdd = !isEdit
|
||||||
const [isSaving, setIsSaving] = useState(false)
|
const [isSaving, setIsSaving] = useState(false)
|
||||||
const [config, setConfig] = useState<ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig>((() => {
|
const [config, setConfig] = useState<ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig>((() => {
|
||||||
if (isEdit)
|
if (isEdit)
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
@ -112,6 +118,9 @@ const ProviderConfigModal: FC<Props> = ({
|
|||||||
else if (type === TracingProvider.aliyun)
|
else if (type === TracingProvider.aliyun)
|
||||||
return aliyunConfigTemplate
|
return aliyunConfigTemplate
|
||||||
|
|
||||||
|
else if (type === TracingProvider.tencent)
|
||||||
|
return tencentConfigTemplate
|
||||||
|
|
||||||
return weaveConfigTemplate
|
return weaveConfigTemplate
|
||||||
})())
|
})())
|
||||||
const [isShowRemoveConfirm, {
|
const [isShowRemoveConfirm, {
|
||||||
@ -202,6 +211,16 @@ const ProviderConfigModal: FC<Props> = ({
|
|||||||
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Endpoint' })
|
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Endpoint' })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (type === TracingProvider.tencent) {
|
||||||
|
const postData = config as TencentConfig
|
||||||
|
if (!errorMessage && !postData.token)
|
||||||
|
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Token' })
|
||||||
|
if (!errorMessage && !postData.endpoint)
|
||||||
|
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Endpoint' })
|
||||||
|
if (!errorMessage && !postData.service_name)
|
||||||
|
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Service Name' })
|
||||||
|
}
|
||||||
|
|
||||||
return errorMessage
|
return errorMessage
|
||||||
}, [config, t, type])
|
}, [config, t, type])
|
||||||
const handleSave = useCallback(async () => {
|
const handleSave = useCallback(async () => {
|
||||||
@ -338,6 +357,34 @@ const ProviderConfigModal: FC<Props> = ({
|
|||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
{type === TracingProvider.tencent && (
|
||||||
|
<>
|
||||||
|
<Field
|
||||||
|
label='Token'
|
||||||
|
labelClassName='!text-sm'
|
||||||
|
isRequired
|
||||||
|
value={(config as TencentConfig).token}
|
||||||
|
onChange={handleConfigChange('token')}
|
||||||
|
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'Token' })!}
|
||||||
|
/>
|
||||||
|
<Field
|
||||||
|
label='Endpoint'
|
||||||
|
labelClassName='!text-sm'
|
||||||
|
isRequired
|
||||||
|
value={(config as TencentConfig).endpoint}
|
||||||
|
onChange={handleConfigChange('endpoint')}
|
||||||
|
placeholder='https://your-region.cls.tencentcs.com'
|
||||||
|
/>
|
||||||
|
<Field
|
||||||
|
label='Service Name'
|
||||||
|
labelClassName='!text-sm'
|
||||||
|
isRequired
|
||||||
|
value={(config as TencentConfig).service_name}
|
||||||
|
onChange={handleConfigChange('service_name')}
|
||||||
|
placeholder='dify_app'
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
{type === TracingProvider.weave && (
|
{type === TracingProvider.weave && (
|
||||||
<>
|
<>
|
||||||
<Field
|
<Field
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import {
|
|||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { TracingProvider } from './type'
|
import { TracingProvider } from './type'
|
||||||
import cn from '@/utils/classnames'
|
import cn from '@/utils/classnames'
|
||||||
import { AliyunIconBig, ArizeIconBig, LangfuseIconBig, LangsmithIconBig, OpikIconBig, PhoenixIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing'
|
import { AliyunIconBig, ArizeIconBig, LangfuseIconBig, LangsmithIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing'
|
||||||
import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general'
|
import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general'
|
||||||
|
|
||||||
const I18N_PREFIX = 'app.tracing'
|
const I18N_PREFIX = 'app.tracing'
|
||||||
@ -31,6 +31,7 @@ const getIcon = (type: TracingProvider) => {
|
|||||||
[TracingProvider.opik]: OpikIconBig,
|
[TracingProvider.opik]: OpikIconBig,
|
||||||
[TracingProvider.weave]: WeaveIconBig,
|
[TracingProvider.weave]: WeaveIconBig,
|
||||||
[TracingProvider.aliyun]: AliyunIconBig,
|
[TracingProvider.aliyun]: AliyunIconBig,
|
||||||
|
[TracingProvider.tencent]: TencentIconBig,
|
||||||
})[type]
|
})[type]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@ export enum TracingProvider {
|
|||||||
opik = 'opik',
|
opik = 'opik',
|
||||||
weave = 'weave',
|
weave = 'weave',
|
||||||
aliyun = 'aliyun',
|
aliyun = 'aliyun',
|
||||||
|
tencent = 'tencent',
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ArizeConfig = {
|
export type ArizeConfig = {
|
||||||
@ -53,3 +54,9 @@ export type AliyunConfig = {
|
|||||||
license_key: string
|
license_key: string
|
||||||
endpoint: string
|
endpoint: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type TencentConfig = {
|
||||||
|
token: string
|
||||||
|
endpoint: string
|
||||||
|
service_name: string
|
||||||
|
}
|
||||||
|
|||||||
@ -53,9 +53,6 @@ const ChatWrapper = () => {
|
|||||||
initUserVariables,
|
initUserVariables,
|
||||||
} = useChatWithHistoryContext()
|
} = useChatWithHistoryContext()
|
||||||
|
|
||||||
// Semantic variable for better code readability
|
|
||||||
const isHistoryConversation = !!currentConversationId
|
|
||||||
|
|
||||||
const appConfig = useMemo(() => {
|
const appConfig = useMemo(() => {
|
||||||
const config = appParams || {}
|
const config = appParams || {}
|
||||||
|
|
||||||
@ -66,9 +63,9 @@ const ChatWrapper = () => {
|
|||||||
fileUploadConfig: (config as any).system_parameters,
|
fileUploadConfig: (config as any).system_parameters,
|
||||||
},
|
},
|
||||||
supportFeedback: true,
|
supportFeedback: true,
|
||||||
opening_statement: isHistoryConversation ? currentConversationItem?.introduction : (config as any).opening_statement,
|
opening_statement: currentConversationItem?.introduction || (config as any).opening_statement,
|
||||||
} as ChatConfig
|
} as ChatConfig
|
||||||
}, [appParams, currentConversationItem?.introduction, isHistoryConversation])
|
}, [appParams, currentConversationItem?.introduction])
|
||||||
const {
|
const {
|
||||||
chatList,
|
chatList,
|
||||||
setTargetMessageId,
|
setTargetMessageId,
|
||||||
@ -79,7 +76,7 @@ const ChatWrapper = () => {
|
|||||||
} = useChat(
|
} = useChat(
|
||||||
appConfig,
|
appConfig,
|
||||||
{
|
{
|
||||||
inputs: (isHistoryConversation ? currentConversationInputs : newConversationInputs) as any,
|
inputs: (currentConversationId ? currentConversationInputs : newConversationInputs) as any,
|
||||||
inputsForm: inputsForms,
|
inputsForm: inputsForms,
|
||||||
},
|
},
|
||||||
appPrevChatTree,
|
appPrevChatTree,
|
||||||
@ -87,7 +84,7 @@ const ChatWrapper = () => {
|
|||||||
clearChatList,
|
clearChatList,
|
||||||
setClearChatList,
|
setClearChatList,
|
||||||
)
|
)
|
||||||
const inputsFormValue = isHistoryConversation ? currentConversationInputs : newConversationInputsRef?.current
|
const inputsFormValue = currentConversationId ? currentConversationInputs : newConversationInputsRef?.current
|
||||||
const inputDisabled = useMemo(() => {
|
const inputDisabled = useMemo(() => {
|
||||||
if (allInputsHidden)
|
if (allInputsHidden)
|
||||||
return false
|
return false
|
||||||
@ -136,7 +133,7 @@ const ChatWrapper = () => {
|
|||||||
const data: any = {
|
const data: any = {
|
||||||
query: message,
|
query: message,
|
||||||
files,
|
files,
|
||||||
inputs: formatBooleanInputs(inputsForms, isHistoryConversation ? currentConversationInputs : newConversationInputs),
|
inputs: formatBooleanInputs(inputsForms, currentConversationId ? currentConversationInputs : newConversationInputs),
|
||||||
conversation_id: currentConversationId,
|
conversation_id: currentConversationId,
|
||||||
parent_message_id: (isRegenerate ? parentAnswer?.id : getLastAnswer(chatList)?.id) || null,
|
parent_message_id: (isRegenerate ? parentAnswer?.id : getLastAnswer(chatList)?.id) || null,
|
||||||
}
|
}
|
||||||
@ -146,11 +143,11 @@ const ChatWrapper = () => {
|
|||||||
data,
|
data,
|
||||||
{
|
{
|
||||||
onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, isInstalledApp, appId),
|
onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, isInstalledApp, appId),
|
||||||
onConversationComplete: isHistoryConversation ? undefined : handleNewConversationCompleted,
|
onConversationComplete: currentConversationId ? undefined : handleNewConversationCompleted,
|
||||||
isPublicAPI: !isInstalledApp,
|
isPublicAPI: !isInstalledApp,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}, [chatList, handleNewConversationCompleted, handleSend, isHistoryConversation, currentConversationInputs, newConversationInputs, isInstalledApp, appId])
|
}, [chatList, handleNewConversationCompleted, handleSend, currentConversationId, currentConversationInputs, newConversationInputs, isInstalledApp, appId])
|
||||||
|
|
||||||
const doRegenerate = useCallback((chatItem: ChatItemInTree, editedQuestion?: { message: string, files?: FileEntity[] }) => {
|
const doRegenerate = useCallback((chatItem: ChatItemInTree, editedQuestion?: { message: string, files?: FileEntity[] }) => {
|
||||||
const question = editedQuestion ? chatItem : chatList.find(item => item.id === chatItem.parentMessageId)!
|
const question = editedQuestion ? chatItem : chatList.find(item => item.id === chatItem.parentMessageId)!
|
||||||
@ -163,30 +160,38 @@ const ChatWrapper = () => {
|
|||||||
}, [chatList, doSend])
|
}, [chatList, doSend])
|
||||||
|
|
||||||
const messageList = useMemo(() => {
|
const messageList = useMemo(() => {
|
||||||
// Always filter out opening statement from message list as it's handled separately in welcome component
|
if (currentConversationId || chatList.length > 1)
|
||||||
|
return chatList
|
||||||
|
// Without messages we are in the welcome screen, so hide the opening statement from chatlist
|
||||||
return chatList.filter(item => !item.isOpeningStatement)
|
return chatList.filter(item => !item.isOpeningStatement)
|
||||||
}, [chatList])
|
}, [chatList])
|
||||||
|
|
||||||
const [collapsed, setCollapsed] = useState(isHistoryConversation)
|
const [collapsed, setCollapsed] = useState(!!currentConversationId)
|
||||||
|
|
||||||
const chatNode = useMemo(() => {
|
const chatNode = useMemo(() => {
|
||||||
if (allInputsHidden || !inputsForms.length)
|
if (allInputsHidden || !inputsForms.length)
|
||||||
return null
|
return null
|
||||||
if (isMobile) {
|
if (isMobile) {
|
||||||
if (!isHistoryConversation)
|
if (!currentConversationId)
|
||||||
return <InputsForm collapsed={collapsed} setCollapsed={setCollapsed} />
|
return <InputsForm collapsed={collapsed} setCollapsed={setCollapsed} />
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
return <InputsForm collapsed={collapsed} setCollapsed={setCollapsed} />
|
return <InputsForm collapsed={collapsed} setCollapsed={setCollapsed} />
|
||||||
}
|
}
|
||||||
}, [inputsForms.length, isMobile, isHistoryConversation, collapsed, allInputsHidden])
|
},
|
||||||
|
[
|
||||||
|
inputsForms.length,
|
||||||
|
isMobile,
|
||||||
|
currentConversationId,
|
||||||
|
collapsed, allInputsHidden,
|
||||||
|
])
|
||||||
|
|
||||||
const welcome = useMemo(() => {
|
const welcome = useMemo(() => {
|
||||||
const welcomeMessage = chatList.find(item => item.isOpeningStatement)
|
const welcomeMessage = chatList.find(item => item.isOpeningStatement)
|
||||||
if (respondingState)
|
if (respondingState)
|
||||||
return null
|
return null
|
||||||
if (isHistoryConversation)
|
if (currentConversationId)
|
||||||
return null
|
return null
|
||||||
if (!welcomeMessage)
|
if (!welcomeMessage)
|
||||||
return null
|
return null
|
||||||
@ -227,7 +232,18 @@ const ChatWrapper = () => {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}, [appData?.site.icon, appData?.site.icon_background, appData?.site.icon_type, appData?.site.icon_url, chatList, collapsed, isHistoryConversation, inputsForms.length, respondingState, allInputsHidden])
|
},
|
||||||
|
[
|
||||||
|
appData?.site.icon,
|
||||||
|
appData?.site.icon_background,
|
||||||
|
appData?.site.icon_type,
|
||||||
|
appData?.site.icon_url,
|
||||||
|
chatList, collapsed,
|
||||||
|
currentConversationId,
|
||||||
|
inputsForms.length,
|
||||||
|
respondingState,
|
||||||
|
allInputsHidden,
|
||||||
|
])
|
||||||
|
|
||||||
const answerIcon = (appData?.site && appData.site.use_icon_as_answer_icon)
|
const answerIcon = (appData?.site && appData.site.use_icon_as_answer_icon)
|
||||||
? <AnswerIcon
|
? <AnswerIcon
|
||||||
@ -251,7 +267,7 @@ const ChatWrapper = () => {
|
|||||||
chatFooterClassName='pb-4'
|
chatFooterClassName='pb-4'
|
||||||
chatFooterInnerClassName={`mx-auto w-full max-w-[768px] ${isMobile ? 'px-2' : 'px-4'}`}
|
chatFooterInnerClassName={`mx-auto w-full max-w-[768px] ${isMobile ? 'px-2' : 'px-4'}`}
|
||||||
onSend={doSend}
|
onSend={doSend}
|
||||||
inputs={isHistoryConversation ? currentConversationInputs as any : newConversationInputs}
|
inputs={currentConversationId ? currentConversationInputs as any : newConversationInputs}
|
||||||
inputsForm={inputsForms}
|
inputsForm={inputsForms}
|
||||||
onRegenerate={doRegenerate}
|
onRegenerate={doRegenerate}
|
||||||
onStopResponding={handleStop}
|
onStopResponding={handleStop}
|
||||||
|
|||||||
@ -111,6 +111,8 @@ const Answer: FC<AnswerProps> = ({
|
|||||||
}
|
}
|
||||||
}, [switchSibling, item.prevSibling, item.nextSibling])
|
}, [switchSibling, item.prevSibling, item.nextSibling])
|
||||||
|
|
||||||
|
const contentIsEmpty = content.trim() === ''
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className='mb-2 flex last:mb-0'>
|
<div className='mb-2 flex last:mb-0'>
|
||||||
<div className='relative h-10 w-10 shrink-0'>
|
<div className='relative h-10 w-10 shrink-0'>
|
||||||
@ -153,14 +155,14 @@ const Answer: FC<AnswerProps> = ({
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
responding && !content && !hasAgentThoughts && (
|
responding && contentIsEmpty && !hasAgentThoughts && (
|
||||||
<div className='flex h-5 w-6 items-center justify-center'>
|
<div className='flex h-5 w-6 items-center justify-center'>
|
||||||
<LoadingAnim type='text' />
|
<LoadingAnim type='text' />
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
content && !hasAgentThoughts && (
|
!contentIsEmpty && !hasAgentThoughts && (
|
||||||
<BasicContent item={item} />
|
<BasicContent item={item} />
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -83,6 +83,15 @@ const ChatInputArea = ({
|
|||||||
const historyRef = useRef([''])
|
const historyRef = useRef([''])
|
||||||
const [currentIndex, setCurrentIndex] = useState(-1)
|
const [currentIndex, setCurrentIndex] = useState(-1)
|
||||||
const isComposingRef = useRef(false)
|
const isComposingRef = useRef(false)
|
||||||
|
|
||||||
|
const handleQueryChange = useCallback(
|
||||||
|
(value: string) => {
|
||||||
|
setQuery(value)
|
||||||
|
setTimeout(handleTextareaResize, 0)
|
||||||
|
},
|
||||||
|
[handleTextareaResize],
|
||||||
|
)
|
||||||
|
|
||||||
const handleSend = () => {
|
const handleSend = () => {
|
||||||
if (isResponding) {
|
if (isResponding) {
|
||||||
notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') })
|
notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') })
|
||||||
@ -101,7 +110,7 @@ const ChatInputArea = ({
|
|||||||
}
|
}
|
||||||
if (checkInputsForm(inputs, inputsForm)) {
|
if (checkInputsForm(inputs, inputsForm)) {
|
||||||
onSend(query, files)
|
onSend(query, files)
|
||||||
setQuery('')
|
handleQueryChange('')
|
||||||
setFiles([])
|
setFiles([])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -131,19 +140,19 @@ const ChatInputArea = ({
|
|||||||
// When the cmd + up key is pressed, output the previous element
|
// When the cmd + up key is pressed, output the previous element
|
||||||
if (currentIndex > 0) {
|
if (currentIndex > 0) {
|
||||||
setCurrentIndex(currentIndex - 1)
|
setCurrentIndex(currentIndex - 1)
|
||||||
setQuery(historyRef.current[currentIndex - 1])
|
handleQueryChange(historyRef.current[currentIndex - 1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (e.key === 'ArrowDown' && !e.shiftKey && !e.nativeEvent.isComposing && e.metaKey) {
|
else if (e.key === 'ArrowDown' && !e.shiftKey && !e.nativeEvent.isComposing && e.metaKey) {
|
||||||
// When the cmd + down key is pressed, output the next element
|
// When the cmd + down key is pressed, output the next element
|
||||||
if (currentIndex < historyRef.current.length - 1) {
|
if (currentIndex < historyRef.current.length - 1) {
|
||||||
setCurrentIndex(currentIndex + 1)
|
setCurrentIndex(currentIndex + 1)
|
||||||
setQuery(historyRef.current[currentIndex + 1])
|
handleQueryChange(historyRef.current[currentIndex + 1])
|
||||||
}
|
}
|
||||||
else if (currentIndex === historyRef.current.length - 1) {
|
else if (currentIndex === historyRef.current.length - 1) {
|
||||||
// If it is the last element, clear the input box
|
// If it is the last element, clear the input box
|
||||||
setCurrentIndex(historyRef.current.length)
|
setCurrentIndex(historyRef.current.length)
|
||||||
setQuery('')
|
handleQueryChange('')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -171,7 +180,7 @@ const ChatInputArea = ({
|
|||||||
<>
|
<>
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
'relative z-10 rounded-xl border border-components-chat-input-border bg-components-panel-bg-blur pb-[9px] shadow-md',
|
'relative z-10 overflow-hidden rounded-xl border border-components-chat-input-border bg-components-panel-bg-blur pb-[9px] shadow-md',
|
||||||
isDragActive && 'border border-dashed border-components-option-card-option-selected-border',
|
isDragActive && 'border border-dashed border-components-option-card-option-selected-border',
|
||||||
disabled && 'pointer-events-none border-components-panel-border opacity-50 shadow-none',
|
disabled && 'pointer-events-none border-components-panel-border opacity-50 shadow-none',
|
||||||
)}
|
)}
|
||||||
@ -197,12 +206,8 @@ const ChatInputArea = ({
|
|||||||
placeholder={t('common.chat.inputPlaceholder', { botName }) || ''}
|
placeholder={t('common.chat.inputPlaceholder', { botName }) || ''}
|
||||||
autoFocus
|
autoFocus
|
||||||
minRows={1}
|
minRows={1}
|
||||||
onResize={handleTextareaResize}
|
|
||||||
value={query}
|
value={query}
|
||||||
onChange={(e) => {
|
onChange={e => handleQueryChange(e.target.value)}
|
||||||
setQuery(e.target.value)
|
|
||||||
setTimeout(handleTextareaResize, 0)
|
|
||||||
}}
|
|
||||||
onKeyDown={handleKeyDown}
|
onKeyDown={handleKeyDown}
|
||||||
onCompositionStart={handleCompositionStart}
|
onCompositionStart={handleCompositionStart}
|
||||||
onCompositionEnd={handleCompositionEnd}
|
onCompositionEnd={handleCompositionEnd}
|
||||||
@ -221,7 +226,7 @@ const ChatInputArea = ({
|
|||||||
showVoiceInput && (
|
showVoiceInput && (
|
||||||
<VoiceInput
|
<VoiceInput
|
||||||
onCancel={() => setShowVoiceInput(false)}
|
onCancel={() => setShowVoiceInput(false)}
|
||||||
onConverted={text => setQuery(text)}
|
onConverted={text => handleQueryChange(text)}
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import type { FC, Ref } from 'react'
|
||||||
import { memo } from 'react'
|
import { memo } from 'react'
|
||||||
import {
|
import {
|
||||||
RiMicLine,
|
RiMicLine,
|
||||||
@ -18,20 +19,17 @@ type OperationProps = {
|
|||||||
speechToTextConfig?: EnableType
|
speechToTextConfig?: EnableType
|
||||||
onShowVoiceInput?: () => void
|
onShowVoiceInput?: () => void
|
||||||
onSend: () => void
|
onSend: () => void
|
||||||
theme?: Theme | null
|
theme?: Theme | null,
|
||||||
|
ref?: Ref<HTMLDivElement>;
|
||||||
}
|
}
|
||||||
const Operation = (
|
const Operation: FC<OperationProps> = ({
|
||||||
{
|
ref,
|
||||||
ref,
|
fileConfig,
|
||||||
fileConfig,
|
speechToTextConfig,
|
||||||
speechToTextConfig,
|
onShowVoiceInput,
|
||||||
onShowVoiceInput,
|
onSend,
|
||||||
onSend,
|
theme,
|
||||||
theme,
|
}) => {
|
||||||
}: OperationProps & {
|
|
||||||
ref: React.RefObject<HTMLDivElement>;
|
|
||||||
},
|
|
||||||
) => {
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
|
|||||||
@ -62,9 +62,9 @@ const ChatWrapper = () => {
|
|||||||
fileUploadConfig: (config as any).system_parameters,
|
fileUploadConfig: (config as any).system_parameters,
|
||||||
},
|
},
|
||||||
supportFeedback: true,
|
supportFeedback: true,
|
||||||
opening_statement: currentConversationId ? currentConversationItem?.introduction : (config as any).opening_statement,
|
opening_statement: currentConversationItem?.introduction || (config as any).opening_statement,
|
||||||
} as ChatConfig
|
} as ChatConfig
|
||||||
}, [appParams, currentConversationItem?.introduction, currentConversationId])
|
}, [appParams, currentConversationItem?.introduction])
|
||||||
const {
|
const {
|
||||||
chatList,
|
chatList,
|
||||||
setTargetMessageId,
|
setTargetMessageId,
|
||||||
@ -158,8 +158,9 @@ const ChatWrapper = () => {
|
|||||||
}, [chatList, doSend])
|
}, [chatList, doSend])
|
||||||
|
|
||||||
const messageList = useMemo(() => {
|
const messageList = useMemo(() => {
|
||||||
if (currentConversationId)
|
if (currentConversationId || chatList.length > 1)
|
||||||
return chatList
|
return chatList
|
||||||
|
// Without messages we are in the welcome screen, so hide the opening statement from chatlist
|
||||||
return chatList.filter(item => !item.isOpeningStatement)
|
return chatList.filter(item => !item.isOpeningStatement)
|
||||||
}, [chatList, currentConversationId])
|
}, [chatList, currentConversationId])
|
||||||
|
|
||||||
@ -240,7 +241,7 @@ const ChatWrapper = () => {
|
|||||||
config={appConfig}
|
config={appConfig}
|
||||||
chatList={messageList}
|
chatList={messageList}
|
||||||
isResponding={respondingState}
|
isResponding={respondingState}
|
||||||
chatContainerInnerClassName={cn('mx-auto w-full max-w-full pt-4 tablet:px-4', isMobile && 'px-4')}
|
chatContainerInnerClassName={cn('mx-auto w-full max-w-full px-4', messageList.length && 'pt-4')}
|
||||||
chatFooterClassName={cn('pb-4', !isMobile && 'rounded-b-2xl')}
|
chatFooterClassName={cn('pb-4', !isMobile && 'rounded-b-2xl')}
|
||||||
chatFooterInnerClassName={cn('mx-auto w-full max-w-full px-4', isMobile && 'px-2')}
|
chatFooterInnerClassName={cn('mx-auto w-full max-w-full px-4', isMobile && 'px-2')}
|
||||||
onSend={doSend}
|
onSend={doSend}
|
||||||
|
|||||||
@ -36,6 +36,7 @@ const Header: FC<IHeaderProps> = ({
|
|||||||
appData,
|
appData,
|
||||||
currentConversationId,
|
currentConversationId,
|
||||||
inputsForms,
|
inputsForms,
|
||||||
|
allInputsHidden,
|
||||||
} = useEmbeddedChatbotContext()
|
} = useEmbeddedChatbotContext()
|
||||||
|
|
||||||
const isClient = typeof window !== 'undefined'
|
const isClient = typeof window !== 'undefined'
|
||||||
@ -124,7 +125,7 @@ const Header: FC<IHeaderProps> = ({
|
|||||||
</ActionButton>
|
</ActionButton>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
)}
|
)}
|
||||||
{currentConversationId && inputsForms.length > 0 && (
|
{currentConversationId && inputsForms.length > 0 && !allInputsHidden && (
|
||||||
<ViewFormDropdown />
|
<ViewFormDropdown />
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
@ -135,7 +136,7 @@ const Header: FC<IHeaderProps> = ({
|
|||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={cn('flex h-14 shrink-0 items-center justify-between rounded-t-2xl px-3')}
|
className={cn('flex h-14 shrink-0 items-center justify-between rounded-t-2xl px-3')}
|
||||||
style={Object.assign({}, CssTransform(theme?.backgroundHeaderColorStyle ?? ''), CssTransform(theme?.headerBorderBottomStyle ?? ''))}
|
style={CssTransform(theme?.headerBorderBottomStyle ?? '')}
|
||||||
>
|
>
|
||||||
<div className="flex grow items-center space-x-3">
|
<div className="flex grow items-center space-x-3">
|
||||||
{customerIcon}
|
{customerIcon}
|
||||||
@ -171,7 +172,7 @@ const Header: FC<IHeaderProps> = ({
|
|||||||
</ActionButton>
|
</ActionButton>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
)}
|
)}
|
||||||
{currentConversationId && inputsForms.length > 0 && (
|
{currentConversationId && inputsForms.length > 0 && !allInputsHidden && (
|
||||||
<ViewFormDropdown iconColor={theme?.colorPathOnHeader} />
|
<ViewFormDropdown iconColor={theme?.colorPathOnHeader} />
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@ -49,8 +49,8 @@ const Chatbot = () => {
|
|||||||
<div className='relative'>
|
<div className='relative'>
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
'flex flex-col rounded-2xl border border-components-panel-border-subtle',
|
'flex flex-col rounded-2xl',
|
||||||
isMobile ? 'h-[calc(100vh_-_60px)] border-[0.5px] border-components-panel-border shadow-xs' : 'h-[100vh] bg-chatbot-bg',
|
isMobile ? 'h-[calc(100vh_-_60px)] shadow-xs' : 'h-[100vh] bg-chatbot-bg',
|
||||||
)}
|
)}
|
||||||
style={isMobile ? Object.assign({}, CssTransform(themeBuilder?.theme?.backgroundHeaderColorStyle ?? '')) : {}}
|
style={isMobile ? Object.assign({}, CssTransform(themeBuilder?.theme?.backgroundHeaderColorStyle ?? '')) : {}}
|
||||||
>
|
>
|
||||||
@ -62,7 +62,7 @@ const Chatbot = () => {
|
|||||||
theme={themeBuilder?.theme}
|
theme={themeBuilder?.theme}
|
||||||
onCreateNewChat={handleNewConversation}
|
onCreateNewChat={handleNewConversation}
|
||||||
/>
|
/>
|
||||||
<div className={cn('flex grow flex-col overflow-y-auto', isMobile && '!h-[calc(100vh_-_3rem)] rounded-2xl bg-chatbot-bg')}>
|
<div className={cn('flex grow flex-col overflow-y-auto', isMobile && 'm-[0.5px] !h-[calc(100vh_-_3rem)] rounded-2xl bg-chatbot-bg')}>
|
||||||
{appChatListDataLoading && (
|
{appChatListDataLoading && (
|
||||||
<Loading type='app' />
|
<Loading type='app' />
|
||||||
)}
|
)}
|
||||||
|
|||||||
@ -1,13 +0,0 @@
|
|||||||
export { default as Chunk } from './Chunk'
|
|
||||||
export { default as Collapse } from './Collapse'
|
|
||||||
export { default as Divider } from './Divider'
|
|
||||||
export { default as File } from './File'
|
|
||||||
export { default as GeneralType } from './GeneralType'
|
|
||||||
export { default as LayoutRight2LineMod } from './LayoutRight2LineMod'
|
|
||||||
export { default as OptionCardEffectBlueLight } from './OptionCardEffectBlueLight'
|
|
||||||
export { default as OptionCardEffectBlue } from './OptionCardEffectBlue'
|
|
||||||
export { default as OptionCardEffectOrange } from './OptionCardEffectOrange'
|
|
||||||
export { default as OptionCardEffectPurple } from './OptionCardEffectPurple'
|
|
||||||
export { default as ParentChildType } from './ParentChildType'
|
|
||||||
export { default as SelectionMod } from './SelectionMod'
|
|
||||||
export { default as Watercrawl } from './Watercrawl'
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
.wrapper {
|
|
||||||
display: inline-flex;
|
|
||||||
background: url(~@/app/components/base/icons/assets/image/llm/baichuan-text-cn.png) center center no-repeat;
|
|
||||||
background-size: contain;
|
|
||||||
}
|
|
||||||
@ -1,20 +0,0 @@
|
|||||||
// GENERATE BY script
|
|
||||||
// DON NOT EDIT IT MANUALLY
|
|
||||||
|
|
||||||
import * as React from 'react'
|
|
||||||
import cn from '@/utils/classnames'
|
|
||||||
import s from './BaichuanTextCn.module.css'
|
|
||||||
|
|
||||||
const Icon = (
|
|
||||||
{
|
|
||||||
ref,
|
|
||||||
className,
|
|
||||||
...restProps
|
|
||||||
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
|
|
||||||
ref?: React.RefObject<HTMLSpanElement>;
|
|
||||||
},
|
|
||||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
|
|
||||||
|
|
||||||
Icon.displayName = 'BaichuanTextCn'
|
|
||||||
|
|
||||||
export default Icon
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
.wrapper {
|
|
||||||
display: inline-flex;
|
|
||||||
background: url(~@/app/components/base/icons/assets/image/llm/minimax.png) center center no-repeat;
|
|
||||||
background-size: contain;
|
|
||||||
}
|
|
||||||
@ -1,20 +0,0 @@
|
|||||||
// GENERATE BY script
|
|
||||||
// DON NOT EDIT IT MANUALLY
|
|
||||||
|
|
||||||
import * as React from 'react'
|
|
||||||
import cn from '@/utils/classnames'
|
|
||||||
import s from './Minimax.module.css'
|
|
||||||
|
|
||||||
const Icon = (
|
|
||||||
{
|
|
||||||
ref,
|
|
||||||
className,
|
|
||||||
...restProps
|
|
||||||
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
|
|
||||||
ref?: React.RefObject<HTMLSpanElement>;
|
|
||||||
},
|
|
||||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
|
|
||||||
|
|
||||||
Icon.displayName = 'Minimax'
|
|
||||||
|
|
||||||
export default Icon
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
.wrapper {
|
|
||||||
display: inline-flex;
|
|
||||||
background: url(~@/app/components/base/icons/assets/image/llm/minimax-text.png) center center no-repeat;
|
|
||||||
background-size: contain;
|
|
||||||
}
|
|
||||||
@ -1,20 +0,0 @@
|
|||||||
// GENERATE BY script
|
|
||||||
// DON NOT EDIT IT MANUALLY
|
|
||||||
|
|
||||||
import * as React from 'react'
|
|
||||||
import cn from '@/utils/classnames'
|
|
||||||
import s from './MinimaxText.module.css'
|
|
||||||
|
|
||||||
const Icon = (
|
|
||||||
{
|
|
||||||
ref,
|
|
||||||
className,
|
|
||||||
...restProps
|
|
||||||
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
|
|
||||||
ref?: React.RefObject<HTMLSpanElement>;
|
|
||||||
},
|
|
||||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
|
|
||||||
|
|
||||||
Icon.displayName = 'MinimaxText'
|
|
||||||
|
|
||||||
export default Icon
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
.wrapper {
|
|
||||||
display: inline-flex;
|
|
||||||
background: url(~@/app/components/base/icons/assets/image/llm/tongyi.png) center center no-repeat;
|
|
||||||
background-size: contain;
|
|
||||||
}
|
|
||||||
@ -1,20 +0,0 @@
|
|||||||
// GENERATE BY script
|
|
||||||
// DON NOT EDIT IT MANUALLY
|
|
||||||
|
|
||||||
import * as React from 'react'
|
|
||||||
import cn from '@/utils/classnames'
|
|
||||||
import s from './Tongyi.module.css'
|
|
||||||
|
|
||||||
const Icon = (
|
|
||||||
{
|
|
||||||
ref,
|
|
||||||
className,
|
|
||||||
...restProps
|
|
||||||
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
|
|
||||||
ref?: React.RefObject<HTMLSpanElement>;
|
|
||||||
},
|
|
||||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
|
|
||||||
|
|
||||||
Icon.displayName = 'Tongyi'
|
|
||||||
|
|
||||||
export default Icon
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
.wrapper {
|
|
||||||
display: inline-flex;
|
|
||||||
background: url(~@/app/components/base/icons/assets/image/llm/tongyi-text.png) center center no-repeat;
|
|
||||||
background-size: contain;
|
|
||||||
}
|
|
||||||
@ -1,20 +0,0 @@
|
|||||||
// GENERATE BY script
|
|
||||||
// DON NOT EDIT IT MANUALLY
|
|
||||||
|
|
||||||
import * as React from 'react'
|
|
||||||
import cn from '@/utils/classnames'
|
|
||||||
import s from './TongyiText.module.css'
|
|
||||||
|
|
||||||
const Icon = (
|
|
||||||
{
|
|
||||||
ref,
|
|
||||||
className,
|
|
||||||
...restProps
|
|
||||||
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
|
|
||||||
ref?: React.RefObject<HTMLSpanElement>;
|
|
||||||
},
|
|
||||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
|
|
||||||
|
|
||||||
Icon.displayName = 'TongyiText'
|
|
||||||
|
|
||||||
export default Icon
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
.wrapper {
|
|
||||||
display: inline-flex;
|
|
||||||
background: url(~@/app/components/base/icons/assets/image/llm/tongyi-text-cn.png) center center no-repeat;
|
|
||||||
background-size: contain;
|
|
||||||
}
|
|
||||||
@ -1,20 +0,0 @@
|
|||||||
// GENERATE BY script
|
|
||||||
// DON NOT EDIT IT MANUALLY
|
|
||||||
|
|
||||||
import * as React from 'react'
|
|
||||||
import cn from '@/utils/classnames'
|
|
||||||
import s from './TongyiTextCn.module.css'
|
|
||||||
|
|
||||||
const Icon = (
|
|
||||||
{
|
|
||||||
ref,
|
|
||||||
className,
|
|
||||||
...restProps
|
|
||||||
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
|
|
||||||
ref?: React.RefObject<HTMLSpanElement>;
|
|
||||||
},
|
|
||||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
|
|
||||||
|
|
||||||
Icon.displayName = 'TongyiTextCn'
|
|
||||||
|
|
||||||
export default Icon
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
.wrapper {
|
|
||||||
display: inline-flex;
|
|
||||||
background: url(~@/app/components/base/icons/assets/image/llm/wxyy.png) center center no-repeat;
|
|
||||||
background-size: contain;
|
|
||||||
}
|
|
||||||
@ -1,20 +0,0 @@
|
|||||||
// GENERATE BY script
|
|
||||||
// DON NOT EDIT IT MANUALLY
|
|
||||||
|
|
||||||
import * as React from 'react'
|
|
||||||
import cn from '@/utils/classnames'
|
|
||||||
import s from './Wxyy.module.css'
|
|
||||||
|
|
||||||
const Icon = (
|
|
||||||
{
|
|
||||||
ref,
|
|
||||||
className,
|
|
||||||
...restProps
|
|
||||||
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
|
|
||||||
ref?: React.RefObject<HTMLSpanElement>;
|
|
||||||
},
|
|
||||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
|
|
||||||
|
|
||||||
Icon.displayName = 'Wxyy'
|
|
||||||
|
|
||||||
export default Icon
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
.wrapper {
|
|
||||||
display: inline-flex;
|
|
||||||
background: url(~@/app/components/base/icons/assets/image/llm/wxyy-text.png) center center no-repeat;
|
|
||||||
background-size: contain;
|
|
||||||
}
|
|
||||||
@ -1,20 +0,0 @@
|
|||||||
// GENERATE BY script
|
|
||||||
// DON NOT EDIT IT MANUALLY
|
|
||||||
|
|
||||||
import * as React from 'react'
|
|
||||||
import cn from '@/utils/classnames'
|
|
||||||
import s from './WxyyText.module.css'
|
|
||||||
|
|
||||||
const Icon = (
|
|
||||||
{
|
|
||||||
ref,
|
|
||||||
className,
|
|
||||||
...restProps
|
|
||||||
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
|
|
||||||
ref?: React.RefObject<HTMLSpanElement>;
|
|
||||||
},
|
|
||||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
|
|
||||||
|
|
||||||
Icon.displayName = 'WxyyText'
|
|
||||||
|
|
||||||
export default Icon
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
.wrapper {
|
|
||||||
display: inline-flex;
|
|
||||||
background: url(~@/app/components/base/icons/assets/image/llm/wxyy-text-cn.png) center center no-repeat;
|
|
||||||
background-size: contain;
|
|
||||||
}
|
|
||||||
@ -1,20 +0,0 @@
|
|||||||
// GENERATE BY script
|
|
||||||
// DON NOT EDIT IT MANUALLY
|
|
||||||
|
|
||||||
import * as React from 'react'
|
|
||||||
import cn from '@/utils/classnames'
|
|
||||||
import s from './WxyyTextCn.module.css'
|
|
||||||
|
|
||||||
const Icon = (
|
|
||||||
{
|
|
||||||
ref,
|
|
||||||
className,
|
|
||||||
...restProps
|
|
||||||
}: React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement> & {
|
|
||||||
ref?: React.RefObject<HTMLSpanElement>;
|
|
||||||
},
|
|
||||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />
|
|
||||||
|
|
||||||
Icon.displayName = 'WxyyTextCn'
|
|
||||||
|
|
||||||
export default Icon
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user