Merge branch 'main' into jzh

This commit is contained in:
JzoNg
2026-04-13 13:35:13 +08:00
43 changed files with 1109 additions and 963 deletions

100
.github/dependabot.yml vendored
View File

@ -1,106 +1,6 @@
version: 2
updates:
- package-ecosystem: "pip"
directory: "/api"
open-pull-requests-limit: 10
schedule:
interval: "weekly"
groups:
flask:
patterns:
- "flask"
- "flask-*"
- "werkzeug"
- "gunicorn"
google:
patterns:
- "google-*"
- "googleapis-*"
opentelemetry:
patterns:
- "opentelemetry-*"
pydantic:
patterns:
- "pydantic"
- "pydantic-*"
llm:
patterns:
- "langfuse"
- "langsmith"
- "litellm"
- "mlflow*"
- "opik"
- "weave*"
- "arize*"
- "tiktoken"
- "transformers"
database:
patterns:
- "sqlalchemy"
- "psycopg2*"
- "psycogreen"
- "redis*"
- "alembic*"
storage:
patterns:
- "boto3*"
- "botocore*"
- "azure-*"
- "bce-*"
- "cos-python-*"
- "esdk-obs-*"
- "google-cloud-storage"
- "opendal"
- "oss2"
- "supabase*"
- "tos*"
vdb:
patterns:
- "alibabacloud*"
- "chromadb"
- "clickhouse-*"
- "clickzetta-*"
- "couchbase"
- "elasticsearch"
- "opensearch-py"
- "oracledb"
- "pgvect*"
- "pymilvus"
- "pymochow"
- "pyobvector"
- "qdrant-client"
- "intersystems-*"
- "tablestore"
- "tcvectordb"
- "tidb-vector"
- "upstash-*"
- "volcengine-*"
- "weaviate-*"
- "xinference-*"
- "mo-vector"
- "mysql-connector-*"
dev:
patterns:
- "coverage"
- "dotenv-linter"
- "faker"
- "lxml-stubs"
- "basedpyright"
- "ruff"
- "pytest*"
- "types-*"
- "boto3-stubs"
- "hypothesis"
- "pandas-stubs"
- "scipy-stubs"
- "import-linter"
- "celery-types"
- "mypy*"
- "pyrefly"
python-packages:
patterns:
- "*"
- package-ecosystem: "uv"
directory: "/api"
open-pull-requests-limit: 10

View File

@ -18,7 +18,7 @@
## Checklist
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [x] I've updated the documentation accordingly.
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
- [ ] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [ ] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [ ] I've updated the documentation accordingly.
- [ ] I ran `make lint && make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods

View File

@ -54,7 +54,7 @@ jobs:
run: uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Upload unit coverage data
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: api-coverage-unit
path: coverage-unit
@ -129,7 +129,7 @@ jobs:
api/tests/test_containers_integration_tests
- name: Upload integration coverage data
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: api-coverage-integration
path: coverage-integration

View File

@ -81,7 +81,7 @@ jobs:
- name: Build Docker image
id: build
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
with:
context: ${{ matrix.build_context }}
file: ${{ matrix.file }}
@ -101,7 +101,7 @@ jobs:
touch "/tmp/digests/${sanitized_digest}"
- name: Upload digest
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: digests-${{ matrix.artifact_context }}-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*

View File

@ -50,7 +50,7 @@ jobs:
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Build Docker Image
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
with:
push: false
context: ${{ matrix.context }}

View File

@ -21,7 +21,7 @@ jobs:
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
steps:
- name: Download pyrefly diff artifact
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
@ -49,7 +49,7 @@ jobs:
run: unzip -o pyrefly_diff.zip
- name: Post comment
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |

View File

@ -66,7 +66,7 @@ jobs:
echo ${{ github.event.pull_request.number }} > pr_number.txt
- name: Upload pyrefly diff
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: pyrefly_diff
path: |
@ -75,7 +75,7 @@ jobs:
- name: Comment PR with pyrefly diff
if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }}
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |

View File

@ -32,7 +32,7 @@ jobs:
run: uv sync --project api --dev
- name: Download type coverage artifact
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
@ -73,7 +73,7 @@ jobs:
} > /tmp/type_coverage_comment.md
- name: Post comment
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |

View File

@ -71,7 +71,7 @@ jobs:
echo ${{ github.event.pull_request.number }} > pr_number.txt
- name: Upload type coverage artifact
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: pyrefly_type_coverage
path: |
@ -81,7 +81,7 @@ jobs:
- name: Comment PR with type coverage
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |

View File

@ -23,8 +23,8 @@ jobs:
days-before-issue-stale: 15
days-before-issue-close: 3
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
stale-pr-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
stale-issue-message: "Closed due to inactivity. If you have any questions, you can reopen it."
stale-pr-message: "Closed due to inactivity. If you have any questions, you can reopen it."
stale-issue-label: 'no-issue-activity'
stale-pr-label: 'no-pr-activity'
any-of-labels: 'duplicate,question,invalid,wontfix,no-issue-activity,no-pr-activity,enhancement,cant-reproduce,help-wanted'
any-of-labels: '🌚 invalid,🙋‍♂️ question,wont-fix,no-issue-activity,no-pr-activity,💪 enhancement,🤔 cant-reproduce,🙏 help wanted'

View File

@ -158,7 +158,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@6e2bd52842c65e914eba5c8badd17560bd26b5de # v1.0.89
uses: anthropics/claude-code-action@b47fd721da662d48c5680e154ad16a73ed74d2e0 # v1.0.93
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -56,7 +56,7 @@ jobs:
- name: Trigger i18n sync workflow
if: steps.detect.outputs.has_changes == 'true'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
env:
BASE_SHA: ${{ steps.detect.outputs.base_sha }}
HEAD_SHA: ${{ steps.detect.outputs.head_sha }}

View File

@ -53,7 +53,7 @@ jobs:
- name: Upload Cucumber report
if: ${{ !cancelled() }}
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: cucumber-report
path: e2e/cucumber-report
@ -61,7 +61,7 @@ jobs:
- name: Upload E2E logs
if: ${{ !cancelled() }}
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: e2e-logs
path: e2e/.logs

View File

@ -43,7 +43,7 @@ jobs:
- name: Upload blob report
if: ${{ !cancelled() }}
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: blob-report-${{ matrix.shardIndex }}
path: web/.vitest-reports/*

View File

@ -1,12 +1,16 @@
from datetime import datetime
import flask_restx
from flask_restx import Resource, fields, marshal_with
from flask_restx import Resource
from flask_restx._http import HTTPStatus
from pydantic import field_validator
from sqlalchemy import delete, func, select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from extensions.ext_database import db
from libs.helper import TimestampField
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.enums import ApiTokenType
@ -16,21 +20,31 @@ from services.api_token_service import ApiTokenCache
from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required
api_key_fields = {
"id": fields.String,
"type": fields.String,
"token": fields.String,
"last_used_at": TimestampField,
"created_at": TimestampField,
}
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
api_key_list_model = console_ns.model(
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
)
class ApiKeyItem(ResponseModel):
id: str
type: str
token: str
last_used_at: int | None = None
created_at: int | None = None
@field_validator("last_used_at", "created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class ApiKeyList(ResponseModel):
data: list[ApiKeyItem]
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
def _get_resource(resource_id, tenant_id, resource_model):
@ -54,7 +68,6 @@ class BaseApiKeyListResource(Resource):
token_prefix: str | None = None
max_keys = 10
@marshal_with(api_key_list_model)
def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
@ -66,9 +79,8 @@ class BaseApiKeyListResource(Resource):
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
)
).all()
return {"items": keys}
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
@marshal_with(api_key_item_model)
@edit_permission_required
def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
@ -100,7 +112,7 @@ class BaseApiKeyListResource(Resource):
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return api_token, 201
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
class BaseApiKeyResource(Resource):
@ -147,7 +159,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("get_app_api_keys")
@console_ns.doc(description="Get all API keys for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(200, "Success", api_key_list_model)
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
def get(self, resource_id): # type: ignore
"""Get all API keys for an app"""
return super().get(resource_id)
@ -155,7 +167,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("create_app_api_key")
@console_ns.doc(description="Create a new API key for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(201, "API key created successfully", api_key_item_model)
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id): # type: ignore
"""Create a new API key for an app"""
@ -187,7 +199,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get all API keys for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(200, "Success", api_key_list_model)
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
def get(self, resource_id): # type: ignore
"""Get all API keys for a dataset"""
return super().get(resource_id)
@ -195,7 +207,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("create_dataset_api_key")
@console_ns.doc(description="Create a new API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(201, "API key created successfully", api_key_item_model)
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id): # type: ignore
"""Create a new API key for a dataset"""

View File

@ -25,7 +25,13 @@ from fields.annotation_fields import (
)
from libs.helper import uuid_value
from libs.login import login_required
from services.annotation_service import AppAnnotationService
from services.annotation_service import (
AppAnnotationService,
EnableAnnotationArgs,
UpdateAnnotationArgs,
UpdateAnnotationSettingArgs,
UpsertAnnotationArgs,
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -120,7 +126,12 @@ class AnnotationReplyActionApi(Resource):
args = AnnotationReplyPayload.model_validate(console_ns.payload)
match action:
case "enable":
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
enable_args: EnableAnnotationArgs = {
"score_threshold": args.score_threshold,
"embedding_provider_name": args.embedding_provider_name,
"embedding_model_name": args.embedding_model_name,
}
result = AppAnnotationService.enable_app_annotation(enable_args, app_id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200
@ -161,7 +172,8 @@ class AppAnnotationSettingUpdateApi(Resource):
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args)
return result, 200
@ -237,8 +249,16 @@ class AnnotationApi(Resource):
def post(self, app_id):
app_id = str(app_id)
args = CreateAnnotationPayload.model_validate(console_ns.payload)
data = args.model_dump(exclude_none=True)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
upsert_args: UpsertAnnotationArgs = {}
if args.answer is not None:
upsert_args["answer"] = args.answer
if args.content is not None:
upsert_args["content"] = args.content
if args.message_id is not None:
upsert_args["message_id"] = args.message_id
if args.question is not None:
upsert_args["question"] = args.question
annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required
@ -315,9 +335,12 @@ class AnnotationUpdateDeleteApi(Resource):
app_id = str(app_id)
annotation_id = str(annotation_id)
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
update_args: UpdateAnnotationArgs = {}
if args.answer is not None:
update_args["answer"] = args.answer
if args.question is not None:
update_args["question"] = args.question
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required

View File

@ -1,7 +1,8 @@
from flask_restx import Resource, fields, marshal_with
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
@ -10,35 +11,15 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import (
app_import_check_dependencies_fields,
app_import_fields,
leaked_dependency_fields,
)
from libs.login import current_account_with_tenant, login_required
from models.model import App
from services.app_dsl_service import AppDslService
from services.app_dsl_service import AppDslService, Import
from services.enterprise.enterprise_service import EnterpriseService
from services.entities.dsl_entities import ImportStatus
from services.entities.dsl_entities import CheckDependenciesResult, ImportStatus
from services.feature_service import FeatureService
from .. import console_ns
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first
leaked_dependency_model = console_ns.model("LeakedDependency", leaked_dependency_fields)
app_import_model = console_ns.model("AppImport", app_import_fields)
# For nested models, need to replace nested dict with registered model
app_import_check_dependencies_fields_copy = app_import_check_dependencies_fields.copy()
app_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(fields.Nested(leaked_dependency_model))
app_import_check_dependencies_model = console_ns.model(
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppImportPayload(BaseModel):
mode: str = Field(..., description="Import mode")
@ -52,18 +33,18 @@ class AppImportPayload(BaseModel):
app_id: str | None = Field(None)
console_ns.schema_model(
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
register_schema_models(console_ns, AppImportPayload, Import, CheckDependenciesResult)
@console_ns.route("/apps/imports")
class AppImportApi(Resource):
@console_ns.expect(console_ns.models[AppImportPayload.__name__])
@console_ns.response(200, "Import completed", console_ns.models[Import.__name__])
@console_ns.response(202, "Import pending confirmation", console_ns.models[Import.__name__])
@console_ns.response(400, "Import failed", console_ns.models[Import.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_model)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
@ -104,10 +85,11 @@ class AppImportApi(Resource):
@console_ns.route("/apps/imports/<string:import_id>/confirm")
class AppImportConfirmApi(Resource):
@console_ns.response(200, "Import confirmed", console_ns.models[Import.__name__])
@console_ns.response(400, "Import failed", console_ns.models[Import.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_model)
@edit_permission_required
def post(self, import_id):
# Check user role first
@ -128,11 +110,11 @@ class AppImportConfirmApi(Resource):
@console_ns.route("/apps/imports/<string:app_id>/check-dependencies")
class AppImportCheckDependenciesApi(Resource):
@console_ns.response(200, "Dependencies checked", console_ns.models[CheckDependenciesResult.__name__])
@setup_required
@login_required
@get_app_model
@account_initialization_required
@marshal_with(app_import_check_dependencies_model)
@edit_permission_required
def get(self, app_model: App):
with sessionmaker(db.engine).begin() as session:

View File

@ -1,23 +1,27 @@
import json
from datetime import datetime
from typing import Any
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db
from fields.app_fields import app_server_fields
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models.enums import AppMCPServerStatus
from models.model import AppMCPServer
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
# Register model for flask_restx to avoid dict type issues in Swagger
app_server_model = console_ns.model("AppServer", app_server_fields)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class MCPServerCreatePayload(BaseModel):
@ -32,8 +36,33 @@ class MCPServerUpdatePayload(BaseModel):
status: str | None = Field(default=None, description="Server status")
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class AppMCPServerResponse(ResponseModel):
id: str
name: str
server_code: str
description: str
status: str
parameters: dict[str, Any] | list[Any] | str
created_at: int | None = None
updated_at: int | None = None
@field_validator("parameters", mode="before")
@classmethod
def _parse_json_string(cls, value: Any) -> Any:
if isinstance(value, str):
try:
return json.loads(value)
except (json.JSONDecodeError, TypeError):
return value
return value
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
register_schema_models(console_ns, MCPServerCreatePayload, MCPServerUpdatePayload, AppMCPServerResponse)
@console_ns.route("/apps/<uuid:app_id>/server")
@ -41,27 +70,27 @@ class AppMCPServerController(Resource):
@console_ns.doc("get_app_mcp_server")
@console_ns.doc(description="Get MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "MCP server configuration retrieved successfully", app_server_model)
@console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__])
@login_required
@account_initialization_required
@setup_required
@get_app_model
@marshal_with(app_server_model)
def get(self, app_model):
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
return server
if server is None:
return {}
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
@console_ns.doc("create_app_mcp_server")
@console_ns.doc(description="Create MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
@console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(403, "Insufficient permissions")
@account_initialization_required
@get_app_model
@login_required
@setup_required
@marshal_with(app_server_model)
@edit_permission_required
def post(self, app_model):
_, current_tenant_id = current_account_with_tenant()
@ -82,20 +111,19 @@ class AppMCPServerController(Resource):
)
db.session.add(server)
db.session.commit()
return server
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
@console_ns.doc("update_app_mcp_server")
@console_ns.doc(description="Update MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
@console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@get_app_model
@login_required
@setup_required
@account_initialization_required
@marshal_with(app_server_model)
@edit_permission_required
def put(self, app_model):
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
@ -118,7 +146,7 @@ class AppMCPServerController(Resource):
except ValueError:
raise ValueError("Invalid status")
db.session.commit()
return server
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
@console_ns.route("/apps/<uuid:server_id>/server/refresh")
@ -126,13 +154,12 @@ class AppMCPServerRefreshController(Resource):
@console_ns.doc("refresh_app_mcp_server")
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
@console_ns.doc(params={"server_id": "Server ID"})
@console_ns.response(200, "MCP server refreshed successfully", app_server_model)
@console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_server_model)
@edit_permission_required
def get(self, server_id):
_, current_tenant_id = current_account_with_tenant()
@ -145,4 +172,4 @@ class AppMCPServerRefreshController(Resource):
raise NotFound()
server.server_code = AppMCPServer.generate_server_code(16)
db.session.commit()
return server
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")

View File

@ -1,11 +1,12 @@
from typing import Literal
from flask_restx import Resource, marshal_with
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
@ -15,13 +16,11 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_site_fields
from fields.base import ResponseModel
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import Site
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppSiteUpdatePayload(BaseModel):
title: str | None = Field(default=None)
@ -49,13 +48,26 @@ class AppSiteUpdatePayload(BaseModel):
return supported_language(value)
console_ns.schema_model(
AppSiteUpdatePayload.__name__,
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
class AppSiteResponse(ResponseModel):
app_id: str
access_token: str | None = Field(default=None, validation_alias="code")
code: str | None = None
title: str
icon: str | None = None
icon_background: str | None = None
description: str | None = None
default_language: str
customize_domain: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
customize_token_strategy: str
prompt_public: bool
show_workflow_steps: bool
use_icon_as_answer_icon: bool
# Register model for flask_restx to avoid dict type issues in Swagger
app_site_model = console_ns.model("AppSite", app_site_fields)
register_schema_models(console_ns, AppSiteUpdatePayload, AppSiteResponse)
@console_ns.route("/apps/<uuid:app_id>/site")
@ -64,7 +76,7 @@ class AppSite(Resource):
@console_ns.doc(description="Update application site configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
@console_ns.response(200, "Site configuration updated successfully", console_ns.models[AppSiteResponse.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "App not found")
@setup_required
@ -72,7 +84,6 @@ class AppSite(Resource):
@edit_permission_required
@account_initialization_required
@get_app_model
@marshal_with(app_site_model)
def post(self, app_model):
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
@ -106,7 +117,7 @@ class AppSite(Resource):
site.updated_at = naive_utc_now()
db.session.commit()
return site
return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/site/access-token-reset")
@ -114,7 +125,7 @@ class AppSiteAccessTokenReset(Resource):
@console_ns.doc("reset_app_site_access_token")
@console_ns.doc(description="Reset access token for application site")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Access token reset successfully", app_site_model)
@console_ns.response(200, "Access token reset successfully", console_ns.models[AppSiteResponse.__name__])
@console_ns.response(403, "Insufficient permissions (admin/owner required)")
@console_ns.response(404, "App or site not found")
@setup_required
@ -122,7 +133,6 @@ class AppSiteAccessTokenReset(Resource):
@is_admin_or_owner_required
@account_initialization_required
@get_app_model
@marshal_with(app_site_model)
def post(self, app_model):
current_user, _ = current_account_with_tenant()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
@ -135,4 +145,4 @@ class AppSiteAccessTokenReset(Resource):
site.updated_at = naive_utc_now()
db.session.commit()
return site
return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json")

View File

@ -1,8 +1,9 @@
from flask import request
from flask_restx import Resource, fields
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
@ -11,8 +12,6 @@ from libs.helper import EmailStr, timezone
from models import AccountStatus
from services.account_service import RegisterService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ActivateCheckQuery(BaseModel):
workspace_id: str | None = Field(default=None)
@ -39,8 +38,16 @@ class ActivatePayload(BaseModel):
return timezone(value)
for model in (ActivateCheckQuery, ActivatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class ActivationCheckResponse(BaseModel):
is_valid: bool = Field(description="Whether token is valid")
data: dict | None = Field(default=None, description="Activation data if valid")
class ActivationResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(console_ns, ActivateCheckQuery, ActivatePayload, ActivationCheckResponse, ActivationResponse)
@console_ns.route("/activate/check")
@ -51,13 +58,7 @@ class ActivateCheckApi(Resource):
@console_ns.response(
200,
"Success",
console_ns.model(
"ActivationCheckResponse",
{
"is_valid": fields.Boolean(description="Whether token is valid"),
"data": fields.Raw(description="Activation data if valid"),
},
),
console_ns.models[ActivationCheckResponse.__name__],
)
def get(self):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
@ -95,12 +96,7 @@ class ActivateApi(Resource):
@console_ns.response(
200,
"Account activated successfully",
console_ns.model(
"ActivationResponse",
{
"result": fields.String(description="Operation result"),
},
),
console_ns.models[ActivationResponse.__name__],
)
@console_ns.response(400, "Already activated or invalid token")
def post(self):

View File

@ -11,10 +11,7 @@ import services
from configs import dify_config
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import (
api_key_item_model,
api_key_list_model,
)
from controllers.console.apikey import ApiKeyItem, ApiKeyList
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from controllers.console.wraps import (
@ -785,23 +782,23 @@ class DatasetApiKeyApi(Resource):
@console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get dataset API keys")
@console_ns.response(200, "API keys retrieved successfully", api_key_list_model)
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_key_list_model)
def get(self):
_, current_tenant_id = current_account_with_tenant()
keys = db.session.scalars(
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
).all()
return {"items": keys}
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
@console_ns.response(200, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
@marshal_with(api_key_item_model)
def post(self):
_, current_tenant_id = current_account_with_tenant()
@ -828,7 +825,7 @@ class DatasetApiKeyApi(Resource):
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return api_token, 200
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 200
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")

View File

@ -94,10 +94,9 @@ def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
def plugin_data[**P, R](
view: Callable[P, R] | None = None,
*,
payload_type: type[BaseModel],
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
@ -116,7 +115,4 @@ def plugin_data[**P, R](
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
return decorator

View File

@ -12,7 +12,12 @@ from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import Annotation, AnnotationList
from models.model import App
from services.annotation_service import AppAnnotationService
from services.annotation_service import (
AppAnnotationService,
EnableAnnotationArgs,
InsertAnnotationArgs,
UpdateAnnotationArgs,
)
class AnnotationCreatePayload(BaseModel):
@ -46,10 +51,15 @@ class AnnotationReplyActionApi(Resource):
@validate_app_token
def post(self, app_model: App, action: Literal["enable", "disable"]):
"""Enable or disable annotation reply feature."""
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
payload = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {})
match action:
case "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
enable_args: EnableAnnotationArgs = {
"score_threshold": payload.score_threshold,
"embedding_provider_name": payload.embedding_provider_name,
"embedding_model_name": payload.embedding_model_name,
}
result = AppAnnotationService.enable_app_annotation(enable_args, app_model.id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
return result, 200
@ -135,8 +145,9 @@ class AnnotationListApi(Resource):
@validate_app_token
def post(self, app_model: App):
"""Create a new annotation."""
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
insert_args: InsertAnnotationArgs = {"question": payload.question, "answer": payload.answer}
annotation = AppAnnotationService.insert_app_annotation_directly(insert_args, app_model.id)
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json"), HTTPStatus.CREATED
@ -164,8 +175,9 @@ class AnnotationUpdateDeleteApi(Resource):
@edit_permission_required
def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation."""
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id)
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json")

View File

@ -41,7 +41,23 @@ class AbstractVectorFactory(ABC):
class Vector:
def __init__(self, dataset: Dataset, attributes: list | None = None):
if attributes is None:
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
# `is_summary` and `original_chunk_id` are stored on summary vectors
# by `SummaryIndexService` and read back by `RetrievalService` to
# route summary hits through their original parent chunks. They
# must be listed here so vector backends that use this list as an
# explicit return-properties projection (notably Weaviate) actually
# return those fields; without them, summary hits silently
# collapse into `is_summary = False` branches and the summary
# retrieval path is a no-op. See #34884.
attributes = [
"doc_id",
"dataset_id",
"document_id",
"doc_hash",
"doc_type",
"is_summary",
"original_chunk_id",
]
self._dataset = dataset
self._embeddings = self._get_embeddings()
self._attributes = attributes

View File

@ -244,7 +244,7 @@ class DatasetDocumentStore:
return document_segment
def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None):
if multimodel_documents:
if multimodel_documents and self._document_id is not None:
for multimodel_document in multimodel_documents:
binding = SegmentAttachmentBinding(
tenant_id=self._dataset.tenant_id,

View File

@ -17,7 +17,6 @@ def http_status_message(code):
def register_external_error_handlers(api: Api):
@api.errorhandler(HTTPException)
def handle_http_exception(e: HTTPException):
got_request_exception.send(current_app, exception=e)
@ -74,27 +73,18 @@ def register_external_error_handlers(api: Api):
headers["Set-Cookie"] = build_force_logout_cookie_headers()
return data, status_code, headers
_ = handle_http_exception
@api.errorhandler(ValueError)
def handle_value_error(e: ValueError):
got_request_exception.send(current_app, exception=e)
status_code = 400
data = {"code": "invalid_param", "message": str(e), "status": status_code}
return data, status_code
_ = handle_value_error
@api.errorhandler(AppInvokeQuotaExceededError)
def handle_quota_exceeded(e: AppInvokeQuotaExceededError):
got_request_exception.send(current_app, exception=e)
status_code = 429
data = {"code": "too_many_requests", "message": str(e), "status": status_code}
return data, status_code
_ = handle_quota_exceeded
@api.errorhandler(Exception)
def handle_general_exception(e: Exception):
got_request_exception.send(current_app, exception=e)
@ -113,7 +103,10 @@ def register_external_error_handlers(api: Api):
return data, status_code
_ = handle_general_exception
api.errorhandler(HTTPException)(handle_http_exception)
api.errorhandler(ValueError)(handle_value_error)
api.errorhandler(AppInvokeQuotaExceededError)(handle_quota_exceeded)
api.errorhandler(Exception)(handle_general_exception)
class ExternalApi(Api):

View File

@ -1688,7 +1688,7 @@ class PipelineRecommendedPlugin(TypeBase):
)
class SegmentAttachmentBinding(Base):
class SegmentAttachmentBinding(TypeBase):
__tablename__ = "segment_attachment_bindings"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="segment_attachment_binding_pkey"),
@ -1701,13 +1701,17 @@ class SegmentAttachmentBinding(Base):
),
sa.Index("segment_attachment_binding_attachment_idx", "attachment_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
class DocumentSegmentSummary(Base):

View File

@ -838,7 +838,7 @@ class AppModelConfig(TypeBase):
return self
class RecommendedApp(Base): # bug
class RecommendedApp(TypeBase):
__tablename__ = "recommended_apps"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
@ -846,20 +846,37 @@ class RecommendedApp(Base): # bug
sa.Index("recommended_app_is_listed_idx", "is_listed", "language"),
)
id = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()))
app_id = mapped_column(StringUUID, nullable=False)
description = mapped_column(sa.JSON, nullable=False)
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
description: Mapped[Any] = mapped_column(sa.JSON, nullable=False)
copyright: Mapped[str] = mapped_column(String(255), nullable=False)
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False)
custom_disclaimer: Mapped[str] = mapped_column(LongText, default="")
category: Mapped[str] = mapped_column(String(255), nullable=False)
custom_disclaimer: Mapped[str] = mapped_column(LongText, default="")
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
is_listed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'"))
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
language: Mapped[str] = mapped_column(
String(255),
nullable=False,
server_default=sa.text("'en-US'"),
default="en-US",
)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
@property

View File

@ -4,15 +4,15 @@ version = "1.13.3"
requires-python = "~=3.12.0"
dependencies = [
"aliyun-log-python-sdk~=0.9.37",
"aliyun-log-python-sdk~=0.9.44",
"arize-phoenix-otel~=0.15.0",
"azure-identity==1.25.3",
"beautifulsoup4==4.14.3",
"boto3==1.42.88",
"bs4~=0.0.1",
"cachetools~=5.3.0",
"celery~=5.6.2",
"charset-normalizer>=3.4.4",
"cachetools~=7.0.5",
"celery~=5.6.3",
"charset-normalizer>=3.4.7",
"flask~=3.1.3",
"flask-compress>=1.24,<1.25",
"flask-cors~=6.0.2",
@ -20,7 +20,7 @@ dependencies = [
"flask-migrate~=4.1.0",
"flask-orjson~=2.0.0",
"flask-sqlalchemy~=3.1.1",
"gevent~=25.9.1",
"gevent~=26.4.0",
"gmpy2~=2.3.0",
"google-api-core>=2.30.3",
"google-api-python-client==2.194.0",
@ -30,49 +30,49 @@ dependencies = [
"googleapis-common-protos>=1.74.0",
"graphon>=0.1.2",
"gunicorn~=25.3.0",
"httpx[socks]~=0.28.0",
"httpx[socks]~=0.28.1",
"jieba==0.42.1",
"json-repair>=0.55.1",
"json-repair>=0.59.2",
"langfuse>=4.2.0,<5.0.0",
"langsmith~=0.7.30",
"markdown~=3.10.2",
"mlflow-skinny>=3.11.1",
"numpy~=1.26.4",
"numpy~=2.4.4",
"openpyxl~=3.1.5",
"opik~=1.11.2",
"litellm==1.83.0", # Pinned to avoid madoka dependency issue
"opentelemetry-api==1.40.0",
"opentelemetry-distro==0.61b0",
"opentelemetry-exporter-otlp==1.40.0",
"opentelemetry-exporter-otlp-proto-common==1.40.0",
"opentelemetry-exporter-otlp-proto-grpc==1.40.0",
"opentelemetry-exporter-otlp-proto-http==1.40.0",
"opentelemetry-instrumentation==0.61b0",
"opentelemetry-instrumentation-celery==0.61b0",
"opentelemetry-instrumentation-flask==0.61b0",
"opentelemetry-instrumentation-httpx==0.61b0",
"opentelemetry-instrumentation-redis==0.61b0",
"opentelemetry-instrumentation-sqlalchemy==0.61b0",
"opentelemetry-api==1.41.0",
"opentelemetry-distro==0.62b0",
"opentelemetry-exporter-otlp==1.41.0",
"opentelemetry-exporter-otlp-proto-common==1.41.0",
"opentelemetry-exporter-otlp-proto-grpc==1.41.0",
"opentelemetry-exporter-otlp-proto-http==1.41.0",
"opentelemetry-instrumentation==0.62b0",
"opentelemetry-instrumentation-celery==0.62b0",
"opentelemetry-instrumentation-flask==0.62b0",
"opentelemetry-instrumentation-httpx==0.62b0",
"opentelemetry-instrumentation-redis==0.62b0",
"opentelemetry-instrumentation-sqlalchemy==0.62b0",
"opentelemetry-propagator-b3==1.41.0",
"opentelemetry-proto==1.40.0",
"opentelemetry-sdk==1.40.0",
"opentelemetry-semantic-conventions==0.61b0",
"opentelemetry-util-http==0.61b0",
"pandas[excel,output-formatting,performance]~=3.0.1",
"opentelemetry-proto==1.41.0",
"opentelemetry-sdk==1.41.0",
"opentelemetry-semantic-conventions==0.62b0",
"opentelemetry-util-http==0.62b0",
"pandas[excel,output-formatting,performance]~=3.0.2",
"psycogreen~=1.0.2",
"psycopg2-binary~=2.9.11",
"pycryptodome==3.23.0",
"pydantic~=2.12.5",
"pydantic-settings~=2.13.1",
"pyjwt~=2.12.0",
"pyjwt~=2.12.1",
"pypdfium2==5.6.0",
"python-docx~=1.2.0",
"python-dotenv==1.2.2",
"pyyaml~=6.0.1",
"readabilipy~=0.3.0",
"redis[hiredis]~=7.4.0",
"resend~=2.26.0",
"sentry-sdk[flask]~=2.55.0",
"resend~=2.27.0",
"sentry-sdk[flask]~=2.57.0",
"sqlalchemy~=2.0.49",
"starlette==1.0.0",
"tiktoken~=0.12.0",
@ -82,13 +82,13 @@ dependencies = [
"yarl~=1.23.0",
"sseclient-py~=1.9.0",
"httpx-sse~=0.4.0",
"sendgrid~=6.12.3",
"sendgrid~=6.12.5",
"flask-restx~=1.3.2",
"packaging~=23.2",
"croniter>=6.0.0",
"packaging~=26.0",
"croniter>=6.2.2",
"weaviate-client==4.20.5",
"apscheduler>=3.11.0",
"weave>=0.52.16",
"apscheduler>=3.11.2",
"weave>=0.52.36",
"fastopenapi[flask]>=0.7.0",
"bleach~=6.3.0",
]
@ -120,7 +120,7 @@ dev = [
"pytest-cov~=7.1.0",
"pytest-env~=1.6.0",
"pytest-mock~=3.15.1",
"testcontainers~=4.14.1",
"testcontainers~=4.14.2",
"types-aiofiles~=25.1.0",
"types-beautifulsoup4~=4.12.0",
"types-cachetools~=6.2.0",
@ -166,7 +166,7 @@ dev = [
"import-linter>=2.3",
"types-redis>=4.6.0.20241004",
"celery-types>=0.23.0",
"mypy~=1.20.0",
"mypy~=1.20.1",
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"sseclient-py>=1.8.0",
"pytest-timeout>=2.4.0",
@ -225,7 +225,7 @@ vdb = [
"xinference-client~=2.4.0",
"mo-vector~=0.1.13",
"mysql-connector-python>=9.3.0",
"holo-search-sdk>=0.4.1",
"holo-search-sdk>=0.4.2",
]
[tool.pyrefly]

View File

@ -47,7 +47,6 @@
"reportMissingTypeArgument": "hint",
"reportUnnecessaryComparison": "hint",
"reportUnnecessaryIsInstance": "hint",
"reportUntypedFunctionDecorator": "hint",
"reportUnnecessaryTypeIgnoreComment": "hint",
"reportAttributeAccessIssue": "hint",
"pythonVersion": "3.12",

View File

@ -1,11 +1,8 @@
import logging
import uuid
import pandas as pd
logger = logging.getLogger(__name__)
from typing import TypedDict
import pandas as pd
from sqlalchemy import delete, or_, select, update
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@ -24,6 +21,8 @@ from tasks.annotation.disable_annotation_reply_task import disable_annotation_re
from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
logger = logging.getLogger(__name__)
class AnnotationJobStatusDict(TypedDict):
job_id: str
@ -46,9 +45,50 @@ class AnnotationSettingDisabledDict(TypedDict):
enabled: bool
class EnableAnnotationArgs(TypedDict):
"""Expected shape of the args dict passed to enable_app_annotation."""
score_threshold: float
embedding_provider_name: str
embedding_model_name: str
class UpsertAnnotationArgs(TypedDict, total=False):
"""Expected shape of the args dict passed to up_insert_app_annotation_from_message."""
answer: str
content: str
message_id: str
question: str
class InsertAnnotationArgs(TypedDict):
"""Expected shape of the args dict passed to insert_app_annotation_directly."""
question: str
answer: str
class UpdateAnnotationArgs(TypedDict, total=False):
"""Expected shape of the args dict passed to update_app_annotation_directly.
Both fields are optional at the type level; the service validates at runtime
and raises ValueError if either is missing.
"""
answer: str
question: str
class UpdateAnnotationSettingArgs(TypedDict):
"""Expected shape of the args dict passed to update_app_annotation_setting."""
score_threshold: float
class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
def up_insert_app_annotation_from_message(cls, args: UpsertAnnotationArgs, app_id: str) -> MessageAnnotation:
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = db.session.scalar(
@ -62,8 +102,9 @@ class AppAnnotationService:
if answer is None:
raise ValueError("Either 'answer' or 'content' must be provided")
if args.get("message_id"):
message_id = str(args["message_id"])
raw_message_id = args.get("message_id")
if raw_message_id:
message_id = str(raw_message_id)
message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app.id).limit(1)
)
@ -87,9 +128,10 @@ class AppAnnotationService:
account_id=current_user.id,
)
else:
question = args.get("question")
if not question:
maybe_question = args.get("question")
if not maybe_question:
raise ValueError("'question' is required when 'message_id' is not provided")
question = maybe_question
annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id)
db.session.add(annotation)
@ -110,7 +152,7 @@ class AppAnnotationService:
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict:
def enable_app_annotation(cls, args: EnableAnnotationArgs, app_id: str) -> AnnotationJobStatusDict:
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
@ -217,7 +259,7 @@ class AppAnnotationService:
return annotations
@classmethod
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
def insert_app_annotation_directly(cls, args: InsertAnnotationArgs, app_id: str) -> MessageAnnotation:
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = db.session.scalar(
@ -251,7 +293,7 @@ class AppAnnotationService:
return annotation
@classmethod
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = db.session.scalar(
@ -270,7 +312,11 @@ class AppAnnotationService:
if question is None:
raise ValueError("'question' is required")
annotation.content = args["answer"]
answer = args.get("answer")
if answer is None:
raise ValueError("'answer' is required")
annotation.content = answer
annotation.question = question
db.session.commit()
@ -613,7 +659,7 @@ class AppAnnotationService:
@classmethod
def update_app_annotation_setting(
cls, app_id: str, annotation_setting_id: str, args: dict
cls, app_id: str, annotation_setting_id: str, args: UpdateAnnotationSettingArgs
) -> AnnotationSettingDict:
current_user, current_tenant_id = current_account_with_tenant()
# get app info

View File

@ -4,7 +4,7 @@ from datetime import datetime
from graphon.model_runtime.utils.encoders import jsonable_encoder
from sqlalchemy import delete, or_, select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
@ -42,32 +42,43 @@ class WorkflowToolManageService:
labels: list[str] | None = None,
):
# check if the name is unique
existing_workflow_tool_provider = db.session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
existing_workflow_tool_provider: WorkflowToolProvider | None = None
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
# query if the name or app_id exists
existing_workflow_tool_provider = _session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
)
.limit(1)
)
.limit(1)
)
# if the name or app_id exists raise error
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
app: App | None = db.session.scalar(
select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1)
)
# query the app
app: App | None = None
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
app = _session.scalar(select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1))
# if not found raise error
if app is None:
raise ValueError(f"App {workflow_app_id} not found")
# query the workflow
workflow: Workflow | None = app.workflow
# if not found raise error
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_app_id}")
# check if workflow configuration is synced
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict)
# create workflow tool provider
workflow_tool_provider = WorkflowToolProvider(
tenant_id=tenant_id,
user_id=user_id,
@ -87,13 +98,15 @@ class WorkflowToolManageService:
logger.warning(e, exc_info=True)
raise ValueError(str(e))
with Session(db.engine, expire_on_commit=False) as session, session.begin():
session.add(workflow_tool_provider)
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
_session.add(workflow_tool_provider)
# keep the session open to make orm instances in the same session
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
return {"result": "success"}
@classmethod
@ -112,6 +125,7 @@ class WorkflowToolManageService:
):
"""
Update a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_tool_id: workflow tool id
@ -187,28 +201,32 @@ class WorkflowToolManageService:
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
"""
List workflow tools.
:param user_id: the user id
:param tenant_id: the tenant id
:return: the list of tools
"""
db_tools = db.session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()
providers: list[WorkflowToolProvider] = []
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
providers = list(
_session.scalars(select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)).all()
)
# Create a mapping from provider_id to app_id
provider_id_to_app_id = {provider.id: provider.app_id for provider in db_tools}
provider_id_to_app_id = {provider.id: provider.app_id for provider in providers}
tools: list[WorkflowToolProviderController] = []
for provider in db_tools:
for provider in providers:
try:
tools.append(ToolTransformService.workflow_provider_to_controller(provider))
except Exception:
# skip deleted tools
logger.exception("Failed to load workflow tool provider %s", provider.id)
labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)])
labels = ToolLabelManager.get_tools_labels([tool for tool in tools if isinstance(tool, ToolProviderController)])
result = []
result: list[ToolProviderApiEntity] = []
for tool in tools:
workflow_app_id = provider_id_to_app_id.get(tool.provider_id)
@ -233,17 +251,18 @@ class WorkflowToolManageService:
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
"""
Delete a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_tool_id: the workflow tool id
"""
db.session.execute(
delete(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
)
)
db.session.commit()
with sessionmaker(db.engine).begin() as _session:
_ = _session.execute(
delete(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
)
)
return {"result": "success"}
@ -251,47 +270,59 @@ class WorkflowToolManageService:
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
"""
Get a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_tool_id: the workflow tool id
:return: the tool
"""
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.limit(1)
)
return cls._get_workflow_tool(tenant_id, db_tool)
tool_provider: WorkflowToolProvider | None = None
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
tool_provider = _session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.limit(1)
)
return cls._get_workflow_tool(tenant_id, tool_provider)
@classmethod
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str):
"""
Get a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.limit(1)
)
return cls._get_workflow_tool(tenant_id, db_tool)
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
tool_provider: WorkflowToolProvider | None = _session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.limit(1)
)
return cls._get_workflow_tool(tenant_id, tool_provider)
@classmethod
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
"""
Get a workflow tool.
:db_tool: the database tool
:return: the tool
"""
if db_tool is None:
raise ValueError("Tool not found")
workflow_app: App | None = db.session.scalar(
select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1)
)
workflow_app: App | None = None
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
workflow_app = _session.scalar(
select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1)
)
if workflow_app is None:
raise ValueError(f"App {db_tool.app_id} not found")
@ -331,28 +362,32 @@ class WorkflowToolManageService:
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[ToolApiEntity]:
"""
List workflow tool provider tools.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_tool_id: the workflow tool id
:return: the list of tools
"""
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.limit(1)
)
if db_tool is None:
provider: WorkflowToolProvider | None = None
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
provider = _session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.limit(1)
)
if provider is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
tool = ToolTransformService.workflow_provider_to_controller(provider)
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
if len(workflow_tools) == 0:
raise ValueError(f"Tool {workflow_tool_id} not found")
return [
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool.get_tools(db_tool.tenant_id)[0],
tool=tool.get_tools(provider.tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool),
tenant_id=tenant_id,
)

View File

@ -313,6 +313,21 @@ class TestSiteEndpoints:
method = _unwrap(api.post)
site = MagicMock()
site.app_id = "app-1"
site.code = "test-code"
site.title = "My Site"
site.icon = None
site.icon_background = None
site.description = "Test site"
site.default_language = "en-US"
site.customize_domain = None
site.copyright = None
site.privacy_policy = None
site.custom_disclaimer = ""
site.customize_token_strategy = "not_allow"
site.prompt_public = False
site.show_workflow_steps = True
site.use_icon_as_answer_icon = False
monkeypatch.setattr(
site_module.db,
"session",
@ -328,13 +343,29 @@ class TestSiteEndpoints:
with app.test_request_context("/", json={"title": "My Site"}):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result is site
assert isinstance(result, dict)
assert result["title"] == "My Site"
def test_app_site_access_token_reset(self, app, monkeypatch):
api = site_module.AppSiteAccessTokenReset()
method = _unwrap(api.post)
site = MagicMock()
site.app_id = "app-1"
site.code = "old-code"
site.title = "My Site"
site.icon = None
site.icon_background = None
site.description = None
site.default_language = "en-US"
site.customize_domain = None
site.copyright = None
site.privacy_policy = None
site.custom_disclaimer = ""
site.customize_token_strategy = "not_allow"
site.prompt_public = False
site.show_workflow_steps = True
site.use_icon_as_answer_icon = False
monkeypatch.setattr(
site_module.db,
"session",
@ -351,7 +382,8 @@ class TestSiteEndpoints:
with app.test_request_context("/"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result is site
assert isinstance(result, dict)
assert result["access_token"] == "code"
class TestWorkflowEndpoints:

View File

@ -148,14 +148,18 @@ def test_chat_message_list_success(
account.id,
created_at_offset_seconds=1,
)
# Capture IDs before the HTTP request detaches ORM instances from the session
app_id = app.id
conversation_id = conversation.id
second_id = second.id
with patch(
"controllers.console.app.message.attach_message_extra_contents",
side_effect=_attach_message_extra_contents,
):
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/chat-messages",
query_string={"conversation_id": conversation.id, "limit": 1},
f"/console/api/apps/{app_id}/chat-messages",
query_string={"conversation_id": conversation_id, "limit": 1},
headers=authenticate_console_client(test_client_with_containers, account),
)
@ -165,7 +169,7 @@ def test_chat_message_list_success(
assert payload["limit"] == 1
assert payload["has_more"] is True
assert len(payload["data"]) == 1
assert payload["data"][0]["id"] == second.id
assert payload["data"][0]["id"] == second_id
def test_message_feedback_not_found(

View File

@ -1,79 +1,202 @@
# import secrets
"""
Integration tests for Account and Tenant model methods that interact with the database.
# import pytest
# from sqlalchemy import select
# from sqlalchemy.orm import Session
# from sqlalchemy.orm.exc import DetachedInstanceError
Migrated from unit_tests/models/test_account_models.py, replacing
@patch("models.account.db") mock patches with real PostgreSQL operations.
# from libs.datetime_utils import naive_utc_now
# from models.account import Account, Tenant, TenantAccountJoin
Covers:
- Account.current_tenant setter (sets _current_tenant and role from TenantAccountJoin)
- Account.set_tenant_id (resolves tenant + role from real join row)
- Account.get_by_openid (AccountIntegrate lookup then Account fetch)
- Tenant.get_accounts (returns accounts linked via TenantAccountJoin)
"""
from collections.abc import Generator
from uuid import uuid4
import pytest
from sqlalchemy import delete
from sqlalchemy.orm import Session
from models.account import Account, AccountIntegrate, Tenant, TenantAccountJoin, TenantAccountRole
# @pytest.fixture
# def session(db_session_with_containers):
# with Session(db_session_with_containers.get_bind()) as session:
# yield session
def _cleanup_tracked_rows(db_session: Session, tracked: list) -> None:
"""Delete rows tracked during the test so committed state does not leak into the DB.
Rolls back any pending (uncommitted) session state first, then issues DELETE
statements by primary key for each tracked entity (in reverse creation order)
and commits. This cleans up rows created via either flush() or commit().
"""
db_session.rollback()
for entity in reversed(tracked):
db_session.execute(delete(type(entity)).where(type(entity).id == entity.id))
db_session.commit()
# @pytest.fixture
# def account(session):
# account = Account(
# name="test account",
# email=f"test_{secrets.token_hex(8)}@example.com",
# )
# session.add(account)
# session.commit()
# return account
def _build_tenant() -> Tenant:
return Tenant(name=f"Tenant {uuid4()}")
# @pytest.fixture
# def tenant(session):
# tenant = Tenant(name="test tenant")
# session.add(tenant)
# session.commit()
# return tenant
def _build_account(email_prefix: str = "account") -> Account:
return Account(
name=f"Account {uuid4()}",
email=f"{email_prefix}_{uuid4()}@example.com",
password="hashed-password",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
# @pytest.fixture
# def tenant_account_join(session, account, tenant):
# tenant_join = TenantAccountJoin(account_id=account.id, tenant_id=tenant.id)
# session.add(tenant_join)
# session.commit()
# yield tenant_join
# session.delete(tenant_join)
# session.commit()
class _DBTrackingTestBase:
"""Base class providing a tracker list and shared row factories for account/tenant tests."""
_tracked: list
@pytest.fixture(autouse=True)
def _setup_cleanup(self, db_session_with_containers: Session) -> Generator[None, None, None]:
self._tracked = []
yield
_cleanup_tracked_rows(db_session_with_containers, self._tracked)
def _create_tenant(self, db_session: Session) -> Tenant:
tenant = _build_tenant()
db_session.add(tenant)
db_session.flush()
self._tracked.append(tenant)
return tenant
def _create_account(self, db_session: Session, email_prefix: str = "account") -> Account:
account = _build_account(email_prefix)
db_session.add(account)
db_session.flush()
self._tracked.append(account)
return account
def _create_join(
self, db_session: Session, tenant_id: str, account_id: str, role: TenantAccountRole, current: bool = True
) -> TenantAccountJoin:
join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id, role=role, current=current)
db_session.add(join)
db_session.flush()
self._tracked.append(join)
return join
# class TestAccountTenant:
# def test_set_current_tenant_should_reload_tenant(
# self,
# db_session_with_containers,
# account,
# tenant,
# tenant_account_join,
# ):
# with Session(db_session_with_containers.get_bind(), expire_on_commit=True) as session:
# scoped_tenant = session.scalars(select(Tenant).where(Tenant.id == tenant.id)).one()
# account.current_tenant = scoped_tenant
# scoped_tenant.created_at = naive_utc_now()
# # session.commit()
class TestAccountCurrentTenantSetter(_DBTrackingTestBase):
"""Integration tests for Account.current_tenant property setter."""
# # Ensure the tenant used in assignment is detached.
# with pytest.raises(DetachedInstanceError):
# _ = scoped_tenant.name
def test_current_tenant_property_returns_cached_tenant(self, db_session_with_containers: Session) -> None:
"""current_tenant getter returns the in-memory _current_tenant without DB access."""
account = self._create_account(db_session_with_containers)
tenant = self._create_tenant(db_session_with_containers)
account._current_tenant = tenant
# assert account._current_tenant.id == tenant.id
# assert account._current_tenant.id == tenant.id
assert account.current_tenant is tenant
# def test_set_tenant_id_should_load_tenant_as_not_expire(
# self,
# flask_app_with_containers,
# account,
# tenant,
# tenant_account_join,
# ):
# with flask_app_with_containers.test_request_context():
# account.set_tenant_id(tenant.id)
def test_current_tenant_setter_sets_tenant_and_role_when_join_exists(
self, db_session_with_containers: Session
) -> None:
"""Setting current_tenant loads the join row and assigns role when relationship exists."""
tenant = self._create_tenant(db_session_with_containers)
account = self._create_account(db_session_with_containers)
self._create_join(db_session_with_containers, tenant.id, account.id, TenantAccountRole.OWNER)
db_session_with_containers.commit()
# assert account._current_tenant.id == tenant.id
# assert account._current_tenant.id == tenant.id
account.current_tenant = tenant
assert account._current_tenant is not None
assert account._current_tenant.id == tenant.id
assert account.role == TenantAccountRole.OWNER
def test_current_tenant_setter_sets_none_when_no_join_exists(self, db_session_with_containers: Session) -> None:
"""Setting current_tenant results in _current_tenant=None when no join row exists."""
tenant = self._create_tenant(db_session_with_containers)
account = self._create_account(db_session_with_containers)
db_session_with_containers.commit()
account.current_tenant = tenant
assert account._current_tenant is None
class TestAccountSetTenantId(_DBTrackingTestBase):
"""Integration tests for Account.set_tenant_id method."""
def test_set_tenant_id_sets_tenant_and_role_when_relationship_exists(
self, db_session_with_containers: Session
) -> None:
"""set_tenant_id loads the tenant and assigns role when a join row exists."""
tenant = self._create_tenant(db_session_with_containers)
account = self._create_account(db_session_with_containers)
self._create_join(db_session_with_containers, tenant.id, account.id, TenantAccountRole.ADMIN)
db_session_with_containers.commit()
account.set_tenant_id(tenant.id)
assert account._current_tenant is not None
assert account._current_tenant.id == tenant.id
assert account.role == TenantAccountRole.ADMIN
def test_set_tenant_id_does_not_set_tenant_when_no_relationship_exists(
self, db_session_with_containers: Session
) -> None:
"""set_tenant_id does nothing when no join row matches the tenant."""
tenant = self._create_tenant(db_session_with_containers)
account = self._create_account(db_session_with_containers)
db_session_with_containers.commit()
account.set_tenant_id(tenant.id)
assert account._current_tenant is None
class TestAccountGetByOpenId(_DBTrackingTestBase):
"""Integration tests for Account.get_by_openid class method."""
def test_get_by_openid_returns_account_when_integrate_exists(self, db_session_with_containers: Session) -> None:
"""get_by_openid returns the Account when a matching AccountIntegrate row exists."""
account = self._create_account(db_session_with_containers, email_prefix="openid")
provider = "google"
open_id = f"google_{uuid4()}"
integrate = AccountIntegrate(
account_id=account.id,
provider=provider,
open_id=open_id,
encrypted_token="token",
)
db_session_with_containers.add(integrate)
db_session_with_containers.flush()
self._tracked.append(integrate)
result = Account.get_by_openid(provider, open_id)
assert result is not None
assert result.id == account.id
def test_get_by_openid_returns_none_when_no_integrate_exists(self, db_session_with_containers: Session) -> None:
"""get_by_openid returns None when no AccountIntegrate row matches."""
result = Account.get_by_openid("github", f"github_{uuid4()}")
assert result is None
class TestTenantGetAccounts(_DBTrackingTestBase):
"""Integration tests for Tenant.get_accounts method."""
def test_get_accounts_returns_linked_accounts(self, db_session_with_containers: Session) -> None:
"""get_accounts returns all accounts linked to the tenant via TenantAccountJoin."""
tenant = self._create_tenant(db_session_with_containers)
account1 = self._create_account(db_session_with_containers, email_prefix="tenant_member")
account2 = self._create_account(db_session_with_containers, email_prefix="tenant_member")
self._create_join(db_session_with_containers, tenant.id, account1.id, TenantAccountRole.OWNER, current=False)
self._create_join(db_session_with_containers, tenant.id, account2.id, TenantAccountRole.NORMAL, current=False)
accounts = tenant.get_accounts()
assert len(accounts) == 2
account_ids = {a.id for a in accounts}
assert account1.id in account_ids
assert account2.id in account_ids

View File

@ -1,9 +1,13 @@
import json
from collections.abc import Generator
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from services.billing_service import BillingService
@ -363,3 +367,62 @@ class TestBillingServiceGetPlanBulkWithCache:
assert ttl_1_new <= 600
assert ttl_2 > 0
assert ttl_2 <= 600
class TestBillingServiceIsTenantOwnerOrAdmin:
"""
Integration tests for BillingService.is_tenant_owner_or_admin.
Verifies that non-privileged roles (EDITOR, DATASET_OPERATOR) raise ValueError
when checked against real TenantAccountJoin rows in PostgreSQL.
"""
@pytest.fixture(autouse=True)
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
yield
db_session_with_containers.rollback()
def _create_account_with_tenant_role(self, db_session: Session, role: TenantAccountRole) -> tuple[Account, Tenant]:
tenant = Tenant(name=f"Tenant {uuid4()}")
db_session.add(tenant)
db_session.flush()
account = Account(
name=f"Account {uuid4()}",
email=f"billing_{uuid4()}@example.com",
password="hashed-password",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
db_session.add(account)
db_session.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=role,
current=True,
)
db_session.add(join)
db_session.flush()
# Wire up in-memory reference so current_tenant_id resolves
account._current_tenant = tenant
return account, tenant
def test_is_tenant_owner_or_admin_editor_role_raises_error(self, db_session_with_containers: Session) -> None:
"""is_tenant_owner_or_admin raises ValueError for EDITOR role."""
account, _ = self._create_account_with_tenant_role(db_session_with_containers, TenantAccountRole.EDITOR)
with pytest.raises(ValueError, match="Only team owner or team admin can perform this action"):
BillingService.is_tenant_owner_or_admin(account)
def test_is_tenant_owner_or_admin_dataset_operator_raises_error(self, db_session_with_containers: Session) -> None:
"""is_tenant_owner_or_admin raises ValueError for DATASET_OPERATOR role."""
account, _ = self._create_account_with_tenant_role(
db_session_with_containers, TenantAccountRole.DATASET_OPERATOR
)
with pytest.raises(ValueError, match="Only team owner or team admin can perform this action"):
BillingService.is_tenant_owner_or_admin(account)

View File

@ -0,0 +1,70 @@
import datetime
from controllers.console.app.mcp_server import AppMCPServerResponse
class TestAppMCPServerResponse:
def test_parameters_json_string_parsed(self):
data = {
"id": "s1",
"name": "test",
"server_code": "code",
"description": "desc",
"status": "active",
"parameters": '{"key": "value"}',
}
resp = AppMCPServerResponse.model_validate(data)
assert resp.parameters == {"key": "value"}
def test_parameters_invalid_json_returns_original(self):
data = {
"id": "s1",
"name": "test",
"server_code": "code",
"description": "desc",
"status": "active",
"parameters": "not-valid-json",
}
resp = AppMCPServerResponse.model_validate(data)
assert resp.parameters == "not-valid-json"
def test_parameters_dict_passthrough(self):
data = {
"id": "s1",
"name": "test",
"server_code": "code",
"description": "desc",
"status": "active",
"parameters": {"already": "parsed"},
}
resp = AppMCPServerResponse.model_validate(data)
assert resp.parameters == {"already": "parsed"}
def test_timestamps_normalized(self):
dt = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC)
data = {
"id": "s1",
"name": "test",
"server_code": "code",
"description": "desc",
"status": "active",
"parameters": {},
"created_at": dt,
"updated_at": dt,
}
resp = AppMCPServerResponse.model_validate(data)
assert resp.created_at == int(dt.timestamp())
assert resp.updated_at == int(dt.timestamp())
def test_timestamps_none(self):
data = {
"id": "s1",
"name": "test",
"server_code": "code",
"description": "desc",
"status": "active",
"parameters": {},
}
resp = AppMCPServerResponse.model_validate(data)
assert resp.created_at is None
assert resp.updated_at is None

View File

@ -1555,7 +1555,17 @@ class TestDatasetApiKeyApi:
method = unwrap(api.get)
mock_key_1 = MagicMock(spec=ApiToken)
mock_key_1.id = "key-1"
mock_key_1.type = "dataset"
mock_key_1.token = "ds-abc"
mock_key_1.last_used_at = None
mock_key_1.created_at = None
mock_key_2 = MagicMock(spec=ApiToken)
mock_key_2.id = "key-2"
mock_key_2.type = "dataset"
mock_key_2.token = "ds-def"
mock_key_2.last_used_at = None
mock_key_2.created_at = None
with (
app.test_request_context("/"),
@ -1570,13 +1580,26 @@ class TestDatasetApiKeyApi:
):
response = method(api)
assert "items" in response
assert response["items"] == [mock_key_1, mock_key_2]
assert "data" in response
assert len(response["data"]) == 2
assert response["data"][0]["id"] == "key-1"
assert response["data"][0]["token"] == "ds-abc"
assert response["data"][1]["id"] == "key-2"
assert response["data"][1]["token"] == "ds-def"
def test_post_create_api_key_success(self, app):
api = DatasetApiKeyApi()
method = unwrap(api.post)
mock_token = MagicMock()
mock_token.id = "new-key-id"
mock_token.last_used_at = None
mock_token.created_at = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC)
mock_api_token_cls = MagicMock()
mock_api_token_cls.return_value = mock_token
mock_api_token_cls.generate_api_key.return_value = "dataset-abc123"
with (
app.test_request_context("/"),
patch(
@ -1588,8 +1611,8 @@ class TestDatasetApiKeyApi:
return_value=3,
),
patch(
"controllers.console.datasets.datasets.ApiToken.generate_api_key",
return_value="dataset-abc123",
"controllers.console.datasets.datasets.ApiToken",
mock_api_token_cls,
),
patch(
"controllers.console.datasets.datasets.db.session.add",
@ -1603,9 +1626,11 @@ class TestDatasetApiKeyApi:
response, status = method(api)
assert status == 200
assert isinstance(response, ApiToken)
assert response.token == "dataset-abc123"
assert response.type == "dataset"
assert isinstance(response, dict)
assert response["id"] == "new-key-id"
assert response["token"] == "dataset-abc123"
assert response["type"] == "dataset"
assert response["created_at"] is not None
def test_post_exceed_max_keys(self, app):
api = DatasetApiKeyApi()

View File

@ -121,7 +121,18 @@ def test_vector_init_uses_default_and_custom_attributes(vector_factory_module):
default_vector = vector_factory_module.Vector(dataset)
custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"])
assert default_vector._attributes == ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
# `is_summary` and `original_chunk_id` must be in the default return-properties
# projection so summary index retrieval works on backends that honor the list
# as an explicit projection (e.g. Weaviate). See #34884.
assert default_vector._attributes == [
"doc_id",
"dataset_id",
"document_id",
"doc_hash",
"doc_type",
"is_summary",
"original_chunk_id",
]
assert custom_vector._attributes == ["doc_id"]
assert default_vector._embeddings == "embeddings"
assert default_vector._vector_processor == "processor"

View File

@ -721,6 +721,30 @@ class TestDatasetDocumentStoreMultimodelBinding:
mock_db.session.add.assert_not_called()
def test_add_multimodel_documents_binding_with_none_document_id(self):
"""Test that no bindings are added when document_id is None."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_dataset.tenant_id = "tenant-1"
mock_attachment = MagicMock(spec=AttachmentDocument)
mock_attachment.metadata = {"doc_id": "attachment-1"}
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
document_id=None,
)
store.add_multimodel_documents_binding("seg-1", [mock_attachment])
mock_db.session.add.assert_not_called()
class TestDatasetDocumentStoreAddDocumentsUpdateChild:
"""Tests for add_documents when updating existing documents with children."""

View File

@ -12,7 +12,6 @@ This test suite covers:
import base64
import secrets
from datetime import UTC, datetime
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
@ -310,90 +309,6 @@ class TestAccountStatusTransitions:
class TestTenantRelationshipIntegrity:
"""Test suite for tenant relationship integrity."""
@patch("models.account.db")
def test_account_current_tenant_property(self, mock_db):
"""Test the current_tenant property getter."""
# Arrange
account = Account(
name="Test User",
email="test@example.com",
)
account.id = str(uuid4())
tenant = Tenant(name="Test Tenant")
tenant.id = str(uuid4())
account._current_tenant = tenant
# Act
result = account.current_tenant
# Assert
assert result == tenant
@patch("models.account.Session")
@patch("models.account.db")
def test_account_current_tenant_setter_with_valid_tenant(self, mock_db, mock_session_class):
"""Test setting current_tenant with a valid tenant relationship."""
# Arrange
account = Account(
name="Test User",
email="test@example.com",
)
account.id = str(uuid4())
tenant = Tenant(name="Test Tenant")
tenant.id = str(uuid4())
# Mock the session and queries
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock TenantAccountJoin query result
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
)
mock_session.scalar.return_value = tenant_join
# Mock Tenant query result
mock_session.scalars.return_value.one.return_value = tenant
# Act
account.current_tenant = tenant
# Assert
assert account._current_tenant == tenant
assert account.role == TenantAccountRole.OWNER
@patch("models.account.Session")
@patch("models.account.db")
def test_account_current_tenant_setter_without_relationship(self, mock_db, mock_session_class):
"""Test setting current_tenant when no relationship exists."""
# Arrange
account = Account(
name="Test User",
email="test@example.com",
)
account.id = str(uuid4())
tenant = Tenant(name="Test Tenant")
tenant.id = str(uuid4())
# Mock the session and queries
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock no TenantAccountJoin found
mock_session.scalar.return_value = None
# Act
account.current_tenant = tenant
# Assert
assert account._current_tenant is None
def test_account_current_tenant_id_property(self):
"""Test the current_tenant_id property."""
# Arrange
@ -418,61 +333,6 @@ class TestTenantRelationshipIntegrity:
# Assert
assert tenant_id_none is None
@patch("models.account.Session")
@patch("models.account.db")
def test_account_set_tenant_id_method(self, mock_db, mock_session_class):
"""Test the set_tenant_id method."""
# Arrange
account = Account(
name="Test User",
email="test@example.com",
)
account.id = str(uuid4())
tenant = Tenant(name="Test Tenant")
tenant.id = str(uuid4())
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.ADMIN,
)
# Mock the session and queries
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.first.return_value = (tenant, tenant_join)
# Act
account.set_tenant_id(tenant.id)
# Assert
assert account._current_tenant == tenant
assert account.role == TenantAccountRole.ADMIN
@patch("models.account.Session")
@patch("models.account.db")
def test_account_set_tenant_id_with_no_relationship(self, mock_db, mock_session_class):
"""Test set_tenant_id when no relationship exists."""
# Arrange
account = Account(
name="Test User",
email="test@example.com",
)
account.id = str(uuid4())
tenant_id = str(uuid4())
# Mock the session and queries
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.first.return_value = None
# Act
account.set_tenant_id(tenant_id)
# Assert - should not set tenant when no relationship exists
# The method returns early without setting _current_tenant
class TestAccountRolePermissions:
"""Test suite for account role permissions."""
@ -605,51 +465,6 @@ class TestAccountRolePermissions:
assert current_role == TenantAccountRole.EDITOR
class TestAccountGetByOpenId:
"""Test suite for get_by_openid class method."""
@patch("models.account.db")
def test_get_by_openid_success(self, mock_db):
"""Test successful retrieval of account by OpenID."""
# Arrange
provider = "google"
open_id = "google_user_123"
account_id = str(uuid4())
mock_account_integrate = MagicMock()
mock_account_integrate.account_id = account_id
mock_account = Account(name="Test User", email="test@example.com")
mock_account.id = account_id
# Mock db.session.execute().scalar_one_or_none() for AccountIntegrate lookup
mock_db.session.execute.return_value.scalar_one_or_none.return_value = mock_account_integrate
# Mock db.session.scalar() for Account lookup
mock_db.session.scalar.return_value = mock_account
# Act
result = Account.get_by_openid(provider, open_id)
# Assert
assert result == mock_account
@patch("models.account.db")
def test_get_by_openid_not_found(self, mock_db):
"""Test get_by_openid when account integrate doesn't exist."""
# Arrange
provider = "github"
open_id = "github_user_456"
# Mock db.session.execute().scalar_one_or_none() to return None
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
# Act
result = Account.get_by_openid(provider, open_id)
# Assert
assert result is None
class TestTenantAccountJoinModel:
"""Test suite for TenantAccountJoin model."""
@ -760,31 +575,6 @@ class TestTenantModel:
# Assert
assert tenant.custom_config == '{"feature1": true, "feature2": "value"}'
@patch("models.account.db")
def test_tenant_get_accounts(self, mock_db):
"""Test getting accounts associated with a tenant."""
# Arrange
tenant = Tenant(name="Test Workspace")
tenant.id = str(uuid4())
account1 = Account(name="User 1", email="user1@example.com")
account1.id = str(uuid4())
account2 = Account(name="User 2", email="user2@example.com")
account2.id = str(uuid4())
# Mock the query chain
mock_scalars = MagicMock()
mock_scalars.all.return_value = [account1, account2]
mock_db.session.scalars.return_value = mock_scalars
# Act
accounts = tenant.get_accounts()
# Assert
assert len(accounts) == 2
assert account1 in accounts
assert account2 in accounts
class TestTenantStatusEnum:
"""Test suite for TenantStatus enum."""

View File

@ -1117,42 +1117,6 @@ class TestBillingServiceEdgeCases:
# Assert
assert result["history_id"] == history_id
def test_is_tenant_owner_or_admin_editor_role_raises_error(self):
"""Test tenant owner/admin check raises error for editor role."""
# Arrange
current_user = MagicMock(spec=Account)
current_user.id = "account-123"
current_user.current_tenant_id = "tenant-456"
mock_join = MagicMock(spec=TenantAccountJoin)
mock_join.role = TenantAccountRole.EDITOR # Editor is not privileged
with patch("services.billing_service.db.session") as mock_session:
mock_session.scalar.return_value = mock_join
# Act & Assert
with pytest.raises(ValueError) as exc_info:
BillingService.is_tenant_owner_or_admin(current_user)
assert "Only team owner or team admin can perform this action" in str(exc_info.value)
def test_is_tenant_owner_or_admin_dataset_operator_raises_error(self):
"""Test tenant owner/admin check raises error for dataset operator role."""
# Arrange
current_user = MagicMock(spec=Account)
current_user.id = "account-123"
current_user.current_tenant_id = "tenant-456"
mock_join = MagicMock(spec=TenantAccountJoin)
mock_join.role = TenantAccountRole.DATASET_OPERATOR # Dataset operator is not privileged
with patch("services.billing_service.db.session") as mock_session:
mock_session.scalar.return_value = mock_join
# Act & Assert
with pytest.raises(ValueError) as exc_info:
BillingService.is_tenant_owner_or_admin(current_user)
assert "Only team owner or team admin can perform this action" in str(exc_info.value)
class TestBillingServiceSubscriptionOperations:
"""Unit tests for subscription operations in BillingService.

485
api/uv.lock generated

File diff suppressed because it is too large Load Diff