mirror of
https://github.com/langgenius/dify.git
synced 2026-04-24 21:05:48 +08:00
Merge branch 'main' into jzh
This commit is contained in:
100
.github/dependabot.yml
vendored
100
.github/dependabot.yml
vendored
@ -1,106 +1,6 @@
|
||||
version: 2
|
||||
|
||||
updates:
|
||||
- package-ecosystem: "pip"
|
||||
directory: "/api"
|
||||
open-pull-requests-limit: 10
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
flask:
|
||||
patterns:
|
||||
- "flask"
|
||||
- "flask-*"
|
||||
- "werkzeug"
|
||||
- "gunicorn"
|
||||
google:
|
||||
patterns:
|
||||
- "google-*"
|
||||
- "googleapis-*"
|
||||
opentelemetry:
|
||||
patterns:
|
||||
- "opentelemetry-*"
|
||||
pydantic:
|
||||
patterns:
|
||||
- "pydantic"
|
||||
- "pydantic-*"
|
||||
llm:
|
||||
patterns:
|
||||
- "langfuse"
|
||||
- "langsmith"
|
||||
- "litellm"
|
||||
- "mlflow*"
|
||||
- "opik"
|
||||
- "weave*"
|
||||
- "arize*"
|
||||
- "tiktoken"
|
||||
- "transformers"
|
||||
database:
|
||||
patterns:
|
||||
- "sqlalchemy"
|
||||
- "psycopg2*"
|
||||
- "psycogreen"
|
||||
- "redis*"
|
||||
- "alembic*"
|
||||
storage:
|
||||
patterns:
|
||||
- "boto3*"
|
||||
- "botocore*"
|
||||
- "azure-*"
|
||||
- "bce-*"
|
||||
- "cos-python-*"
|
||||
- "esdk-obs-*"
|
||||
- "google-cloud-storage"
|
||||
- "opendal"
|
||||
- "oss2"
|
||||
- "supabase*"
|
||||
- "tos*"
|
||||
vdb:
|
||||
patterns:
|
||||
- "alibabacloud*"
|
||||
- "chromadb"
|
||||
- "clickhouse-*"
|
||||
- "clickzetta-*"
|
||||
- "couchbase"
|
||||
- "elasticsearch"
|
||||
- "opensearch-py"
|
||||
- "oracledb"
|
||||
- "pgvect*"
|
||||
- "pymilvus"
|
||||
- "pymochow"
|
||||
- "pyobvector"
|
||||
- "qdrant-client"
|
||||
- "intersystems-*"
|
||||
- "tablestore"
|
||||
- "tcvectordb"
|
||||
- "tidb-vector"
|
||||
- "upstash-*"
|
||||
- "volcengine-*"
|
||||
- "weaviate-*"
|
||||
- "xinference-*"
|
||||
- "mo-vector"
|
||||
- "mysql-connector-*"
|
||||
dev:
|
||||
patterns:
|
||||
- "coverage"
|
||||
- "dotenv-linter"
|
||||
- "faker"
|
||||
- "lxml-stubs"
|
||||
- "basedpyright"
|
||||
- "ruff"
|
||||
- "pytest*"
|
||||
- "types-*"
|
||||
- "boto3-stubs"
|
||||
- "hypothesis"
|
||||
- "pandas-stubs"
|
||||
- "scipy-stubs"
|
||||
- "import-linter"
|
||||
- "celery-types"
|
||||
- "mypy*"
|
||||
- "pyrefly"
|
||||
python-packages:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/api"
|
||||
open-pull-requests-limit: 10
|
||||
|
||||
8
.github/pull_request_template.md
vendored
8
.github/pull_request_template.md
vendored
@ -18,7 +18,7 @@
|
||||
## Checklist
|
||||
|
||||
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
|
||||
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [x] I've updated the documentation accordingly.
|
||||
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
|
||||
- [ ] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [ ] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [ ] I've updated the documentation accordingly.
|
||||
- [ ] I ran `make lint && make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
|
||||
|
||||
4
.github/workflows/api-tests.yml
vendored
4
.github/workflows/api-tests.yml
vendored
@ -54,7 +54,7 @@ jobs:
|
||||
run: uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Upload unit coverage data
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: api-coverage-unit
|
||||
path: coverage-unit
|
||||
@ -129,7 +129,7 @@ jobs:
|
||||
api/tests/test_containers_integration_tests
|
||||
|
||||
- name: Upload integration coverage data
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: api-coverage-integration
|
||||
path: coverage-integration
|
||||
|
||||
4
.github/workflows/build-push.yml
vendored
4
.github/workflows/build-push.yml
vendored
@ -81,7 +81,7 @@ jobs:
|
||||
|
||||
- name: Build Docker image
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
with:
|
||||
context: ${{ matrix.build_context }}
|
||||
file: ${{ matrix.file }}
|
||||
@ -101,7 +101,7 @@ jobs:
|
||||
touch "/tmp/digests/${sanitized_digest}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: digests-${{ matrix.artifact_context }}-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
|
||||
2
.github/workflows/docker-build.yml
vendored
2
.github/workflows/docker-build.yml
vendored
@ -50,7 +50,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
with:
|
||||
push: false
|
||||
context: ${{ matrix.context }}
|
||||
|
||||
4
.github/workflows/pyrefly-diff-comment.yml
vendored
4
.github/workflows/pyrefly-diff-comment.yml
vendored
@ -21,7 +21,7 @@ jobs:
|
||||
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
|
||||
steps:
|
||||
- name: Download pyrefly diff artifact
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
@ -49,7 +49,7 @@ jobs:
|
||||
run: unzip -o pyrefly_diff.zip
|
||||
|
||||
- name: Post comment
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
|
||||
4
.github/workflows/pyrefly-diff.yml
vendored
4
.github/workflows/pyrefly-diff.yml
vendored
@ -66,7 +66,7 @@ jobs:
|
||||
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||
|
||||
- name: Upload pyrefly diff
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: pyrefly_diff
|
||||
path: |
|
||||
@ -75,7 +75,7 @@ jobs:
|
||||
|
||||
- name: Comment PR with pyrefly diff
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }}
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
|
||||
@ -32,7 +32,7 @@ jobs:
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Download type coverage artifact
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
@ -73,7 +73,7 @@ jobs:
|
||||
} > /tmp/type_coverage_comment.md
|
||||
|
||||
- name: Post comment
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
|
||||
4
.github/workflows/pyrefly-type-coverage.yml
vendored
4
.github/workflows/pyrefly-type-coverage.yml
vendored
@ -71,7 +71,7 @@ jobs:
|
||||
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||
|
||||
- name: Upload type coverage artifact
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: pyrefly_type_coverage
|
||||
path: |
|
||||
@ -81,7 +81,7 @@ jobs:
|
||||
|
||||
- name: Comment PR with type coverage
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
|
||||
6
.github/workflows/stale.yml
vendored
6
.github/workflows/stale.yml
vendored
@ -23,8 +23,8 @@ jobs:
|
||||
days-before-issue-stale: 15
|
||||
days-before-issue-close: 3
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
|
||||
stale-pr-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
|
||||
stale-issue-message: "Closed due to inactivity. If you have any questions, you can reopen it."
|
||||
stale-pr-message: "Closed due to inactivity. If you have any questions, you can reopen it."
|
||||
stale-issue-label: 'no-issue-activity'
|
||||
stale-pr-label: 'no-pr-activity'
|
||||
any-of-labels: 'duplicate,question,invalid,wontfix,no-issue-activity,no-pr-activity,enhancement,cant-reproduce,help-wanted'
|
||||
any-of-labels: '🌚 invalid,🙋♂️ question,wont-fix,no-issue-activity,no-pr-activity,💪 enhancement,🤔 cant-reproduce,🙏 help wanted'
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@ -158,7 +158,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@6e2bd52842c65e914eba5c8badd17560bd26b5de # v1.0.89
|
||||
uses: anthropics/claude-code-action@b47fd721da662d48c5680e154ad16a73ed74d2e0 # v1.0.93
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/trigger-i18n-sync.yml
vendored
2
.github/workflows/trigger-i18n-sync.yml
vendored
@ -56,7 +56,7 @@ jobs:
|
||||
|
||||
- name: Trigger i18n sync workflow
|
||||
if: steps.detect.outputs.has_changes == 'true'
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
env:
|
||||
BASE_SHA: ${{ steps.detect.outputs.base_sha }}
|
||||
HEAD_SHA: ${{ steps.detect.outputs.head_sha }}
|
||||
|
||||
4
.github/workflows/web-e2e.yml
vendored
4
.github/workflows/web-e2e.yml
vendored
@ -53,7 +53,7 @@ jobs:
|
||||
|
||||
- name: Upload Cucumber report
|
||||
if: ${{ !cancelled() }}
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: cucumber-report
|
||||
path: e2e/cucumber-report
|
||||
@ -61,7 +61,7 @@ jobs:
|
||||
|
||||
- name: Upload E2E logs
|
||||
if: ${{ !cancelled() }}
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: e2e-logs
|
||||
path: e2e/.logs
|
||||
|
||||
2
.github/workflows/web-tests.yml
vendored
2
.github/workflows/web-tests.yml
vendored
@ -43,7 +43,7 @@ jobs:
|
||||
|
||||
- name: Upload blob report
|
||||
if: ${{ !cancelled() }}
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: blob-report-${{ matrix.shardIndex }}
|
||||
path: web/.vitest-reports/*
|
||||
|
||||
@ -1,12 +1,16 @@
|
||||
from datetime import datetime
|
||||
|
||||
import flask_restx
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource
|
||||
from flask_restx._http import HTTPStatus
|
||||
from pydantic import field_validator
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
@ -16,21 +20,31 @@ from services.api_token_service import ApiTokenCache
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
api_key_fields = {
|
||||
"id": fields.String,
|
||||
"type": fields.String,
|
||||
"token": fields.String,
|
||||
"last_used_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
|
||||
|
||||
api_key_list_model = console_ns.model(
|
||||
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
|
||||
)
|
||||
class ApiKeyItem(ResponseModel):
|
||||
id: str
|
||||
type: str
|
||||
token: str
|
||||
last_used_at: int | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("last_used_at", "created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class ApiKeyList(ResponseModel):
|
||||
data: list[ApiKeyItem]
|
||||
|
||||
|
||||
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
@ -54,7 +68,6 @@ class BaseApiKeyListResource(Resource):
|
||||
token_prefix: str | None = None
|
||||
max_keys = 10
|
||||
|
||||
@marshal_with(api_key_list_model)
|
||||
def get(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
@ -66,9 +79,8 @@ class BaseApiKeyListResource(Resource):
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
)
|
||||
).all()
|
||||
return {"items": keys}
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@marshal_with(api_key_item_model)
|
||||
@edit_permission_required
|
||||
def post(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
@ -100,7 +112,7 @@ class BaseApiKeyListResource(Resource):
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
db.session.commit()
|
||||
return api_token, 201
|
||||
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
|
||||
|
||||
|
||||
class BaseApiKeyResource(Resource):
|
||||
@ -147,7 +159,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("get_app_api_keys")
|
||||
@console_ns.doc(description="Get all API keys for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(200, "Success", api_key_list_model)
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
def get(self, resource_id): # type: ignore
|
||||
"""Get all API keys for an app"""
|
||||
return super().get(resource_id)
|
||||
@ -155,7 +167,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("create_app_api_key")
|
||||
@console_ns.doc(description="Create a new API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(201, "API key created successfully", api_key_item_model)
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id): # type: ignore
|
||||
"""Create a new API key for an app"""
|
||||
@ -187,7 +199,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get all API keys for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Success", api_key_list_model)
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
def get(self, resource_id): # type: ignore
|
||||
"""Get all API keys for a dataset"""
|
||||
return super().get(resource_id)
|
||||
@ -195,7 +207,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("create_dataset_api_key")
|
||||
@console_ns.doc(description="Create a new API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(201, "API key created successfully", api_key_item_model)
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id): # type: ignore
|
||||
"""Create a new API key for a dataset"""
|
||||
|
||||
@ -25,7 +25,13 @@ from fields.annotation_fields import (
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.annotation_service import (
|
||||
AppAnnotationService,
|
||||
EnableAnnotationArgs,
|
||||
UpdateAnnotationArgs,
|
||||
UpdateAnnotationSettingArgs,
|
||||
UpsertAnnotationArgs,
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
@ -120,7 +126,12 @@ class AnnotationReplyActionApi(Resource):
|
||||
args = AnnotationReplyPayload.model_validate(console_ns.payload)
|
||||
match action:
|
||||
case "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||
enable_args: EnableAnnotationArgs = {
|
||||
"score_threshold": args.score_threshold,
|
||||
"embedding_provider_name": args.embedding_provider_name,
|
||||
"embedding_model_name": args.embedding_model_name,
|
||||
}
|
||||
result = AppAnnotationService.enable_app_annotation(enable_args, app_id)
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
return result, 200
|
||||
@ -161,7 +172,8 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
|
||||
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
|
||||
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args)
|
||||
return result, 200
|
||||
|
||||
|
||||
@ -237,8 +249,16 @@ class AnnotationApi(Resource):
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
args = CreateAnnotationPayload.model_validate(console_ns.payload)
|
||||
data = args.model_dump(exclude_none=True)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
|
||||
upsert_args: UpsertAnnotationArgs = {}
|
||||
if args.answer is not None:
|
||||
upsert_args["answer"] = args.answer
|
||||
if args.content is not None:
|
||||
upsert_args["content"] = args.content
|
||||
if args.message_id is not None:
|
||||
upsert_args["message_id"] = args.message_id
|
||||
if args.question is not None:
|
||||
upsert_args["question"] = args.question
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id)
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@ -315,9 +335,12 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
args.model_dump(exclude_none=True), app_id, annotation_id
|
||||
)
|
||||
update_args: UpdateAnnotationArgs = {}
|
||||
if args.answer is not None:
|
||||
update_args["answer"] = args.answer
|
||||
if args.question is not None:
|
||||
update_args["question"] = args.question
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id)
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -10,35 +11,15 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (
|
||||
app_import_check_dependencies_fields,
|
||||
app_import_fields,
|
||||
leaked_dependency_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_dsl_service import AppDslService, Import
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportStatus
|
||||
from services.entities.dsl_entities import CheckDependenciesResult, ImportStatus
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
leaked_dependency_model = console_ns.model("LeakedDependency", leaked_dependency_fields)
|
||||
|
||||
app_import_model = console_ns.model("AppImport", app_import_fields)
|
||||
|
||||
# For nested models, need to replace nested dict with registered model
|
||||
app_import_check_dependencies_fields_copy = app_import_check_dependencies_fields.copy()
|
||||
app_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(fields.Nested(leaked_dependency_model))
|
||||
app_import_check_dependencies_model = console_ns.model(
|
||||
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppImportPayload(BaseModel):
|
||||
mode: str = Field(..., description="Import mode")
|
||||
@ -52,18 +33,18 @@ class AppImportPayload(BaseModel):
|
||||
app_id: str | None = Field(None)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
register_schema_models(console_ns, AppImportPayload, Import, CheckDependenciesResult)
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports")
|
||||
class AppImportApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AppImportPayload.__name__])
|
||||
@console_ns.response(200, "Import completed", console_ns.models[Import.__name__])
|
||||
@console_ns.response(202, "Import pending confirmation", console_ns.models[Import.__name__])
|
||||
@console_ns.response(400, "Import failed", console_ns.models[Import.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_model)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@ -104,10 +85,11 @@ class AppImportApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/imports/<string:import_id>/confirm")
|
||||
class AppImportConfirmApi(Resource):
|
||||
@console_ns.response(200, "Import confirmed", console_ns.models[Import.__name__])
|
||||
@console_ns.response(400, "Import failed", console_ns.models[Import.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_model)
|
||||
@edit_permission_required
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
@ -128,11 +110,11 @@ class AppImportConfirmApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/imports/<string:app_id>/check-dependencies")
|
||||
class AppImportCheckDependenciesApi(Resource):
|
||||
@console_ns.response(200, "Dependencies checked", console_ns.models[CheckDependenciesResult.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_check_dependencies_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
|
||||
@ -1,23 +1,27 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_server_fields
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import AppMCPServerStatus
|
||||
from models.model import AppMCPServer
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_server_model = console_ns.model("AppServer", app_server_fields)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
@ -32,8 +36,33 @@ class MCPServerUpdatePayload(BaseModel):
|
||||
status: str | None = Field(default=None, description="Server status")
|
||||
|
||||
|
||||
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class AppMCPServerResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
server_code: str
|
||||
description: str
|
||||
status: str
|
||||
parameters: dict[str, Any] | list[Any] | str
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def _parse_json_string(cls, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return value
|
||||
return value
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
register_schema_models(console_ns, MCPServerCreatePayload, MCPServerUpdatePayload, AppMCPServerResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/server")
|
||||
@ -41,27 +70,27 @@ class AppMCPServerController(Resource):
|
||||
@console_ns.doc("get_app_mcp_server")
|
||||
@console_ns.doc(description="Get MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "MCP server configuration retrieved successfully", app_server_model)
|
||||
@console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_model)
|
||||
def get(self, app_model):
|
||||
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
|
||||
return server
|
||||
if server is None:
|
||||
return {}
|
||||
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("create_app_mcp_server")
|
||||
@console_ns.doc(description="Create MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
|
||||
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
|
||||
@console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@login_required
|
||||
@setup_required
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@ -82,20 +111,19 @@ class AppMCPServerController(Resource):
|
||||
)
|
||||
db.session.add(server)
|
||||
db.session.commit()
|
||||
return server
|
||||
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("update_app_mcp_server")
|
||||
@console_ns.doc(description="Update MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
|
||||
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
|
||||
@console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@get_app_model
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||
@ -118,7 +146,7 @@ class AppMCPServerController(Resource):
|
||||
except ValueError:
|
||||
raise ValueError("Invalid status")
|
||||
db.session.commit()
|
||||
return server
|
||||
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:server_id>/server/refresh")
|
||||
@ -126,13 +154,12 @@ class AppMCPServerRefreshController(Resource):
|
||||
@console_ns.doc("refresh_app_mcp_server")
|
||||
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
|
||||
@console_ns.doc(params={"server_id": "Server ID"})
|
||||
@console_ns.response(200, "MCP server refreshed successfully", app_server_model)
|
||||
@console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def get(self, server_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@ -145,4 +172,4 @@ class AppMCPServerRefreshController(Resource):
|
||||
raise NotFound()
|
||||
server.server_code = AppMCPServer.generate_server_code(16)
|
||||
db.session.commit()
|
||||
return server
|
||||
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
@ -15,13 +16,11 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from fields.base import ResponseModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Site
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppSiteUpdatePayload(BaseModel):
|
||||
title: str | None = Field(default=None)
|
||||
@ -49,13 +48,26 @@ class AppSiteUpdatePayload(BaseModel):
|
||||
return supported_language(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AppSiteUpdatePayload.__name__,
|
||||
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
class AppSiteResponse(ResponseModel):
|
||||
app_id: str
|
||||
access_token: str | None = Field(default=None, validation_alias="code")
|
||||
code: str | None = None
|
||||
title: str
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
description: str | None = None
|
||||
default_language: str
|
||||
customize_domain: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
customize_token_strategy: str
|
||||
prompt_public: bool
|
||||
show_workflow_steps: bool
|
||||
use_icon_as_answer_icon: bool
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_site_model = console_ns.model("AppSite", app_site_fields)
|
||||
|
||||
register_schema_models(console_ns, AppSiteUpdatePayload, AppSiteResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site")
|
||||
@ -64,7 +76,7 @@ class AppSite(Resource):
|
||||
@console_ns.doc(description="Update application site configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
|
||||
@console_ns.response(200, "Site configuration updated successfully", console_ns.models[AppSiteResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@ -72,7 +84,6 @@ class AppSite(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -106,7 +117,7 @@ class AppSite(Resource):
|
||||
site.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site/access-token-reset")
|
||||
@ -114,7 +125,7 @@ class AppSiteAccessTokenReset(Resource):
|
||||
@console_ns.doc("reset_app_site_access_token")
|
||||
@console_ns.doc(description="Reset access token for application site")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Access token reset successfully", app_site_model)
|
||||
@console_ns.response(200, "Access token reset successfully", console_ns.models[AppSiteResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions (admin/owner required)")
|
||||
@console_ns.response(404, "App or site not found")
|
||||
@setup_required
|
||||
@ -122,7 +133,6 @@ class AppSiteAccessTokenReset(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
@ -135,4 +145,4 @@ class AppSiteAccessTokenReset(Resource):
|
||||
site.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
@ -11,8 +12,6 @@ from libs.helper import EmailStr, timezone
|
||||
from models import AccountStatus
|
||||
from services.account_service import RegisterService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ActivateCheckQuery(BaseModel):
|
||||
workspace_id: str | None = Field(default=None)
|
||||
@ -39,8 +38,16 @@ class ActivatePayload(BaseModel):
|
||||
return timezone(value)
|
||||
|
||||
|
||||
for model in (ActivateCheckQuery, ActivatePayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class ActivationCheckResponse(BaseModel):
|
||||
is_valid: bool = Field(description="Whether token is valid")
|
||||
data: dict | None = Field(default=None, description="Activation data if valid")
|
||||
|
||||
|
||||
class ActivationResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
register_schema_models(console_ns, ActivateCheckQuery, ActivatePayload, ActivationCheckResponse, ActivationResponse)
|
||||
|
||||
|
||||
@console_ns.route("/activate/check")
|
||||
@ -51,13 +58,7 @@ class ActivateCheckApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"ActivationCheckResponse",
|
||||
{
|
||||
"is_valid": fields.Boolean(description="Whether token is valid"),
|
||||
"data": fields.Raw(description="Activation data if valid"),
|
||||
},
|
||||
),
|
||||
console_ns.models[ActivationCheckResponse.__name__],
|
||||
)
|
||||
def get(self):
|
||||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
@ -95,12 +96,7 @@ class ActivateApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Account activated successfully",
|
||||
console_ns.model(
|
||||
"ActivationResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
},
|
||||
),
|
||||
console_ns.models[ActivationResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Already activated or invalid token")
|
||||
def post(self):
|
||||
|
||||
@ -11,10 +11,7 @@ import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.apikey import (
|
||||
api_key_item_model,
|
||||
api_key_list_model,
|
||||
)
|
||||
from controllers.console.apikey import ApiKeyItem, ApiKeyList
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
|
||||
from controllers.console.wraps import (
|
||||
@ -785,23 +782,23 @@ class DatasetApiKeyApi(Resource):
|
||||
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get dataset API keys")
|
||||
@console_ns.response(200, "API keys retrieved successfully", api_key_list_model)
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_list_model)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||
).all()
|
||||
return {"items": keys}
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@console_ns.response(200, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_item_model)
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -828,7 +825,7 @@ class DatasetApiKeyApi(Resource):
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
db.session.commit()
|
||||
return api_token, 200
|
||||
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")
|
||||
|
||||
@ -94,10 +94,9 @@ def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
|
||||
|
||||
def plugin_data[**P, R](
|
||||
view: Callable[P, R] | None = None,
|
||||
*,
|
||||
payload_type: type[BaseModel],
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
@ -116,7 +115,4 @@ def plugin_data[**P, R](
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
@ -12,7 +12,12 @@ from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import Annotation, AnnotationList
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.annotation_service import (
|
||||
AppAnnotationService,
|
||||
EnableAnnotationArgs,
|
||||
InsertAnnotationArgs,
|
||||
UpdateAnnotationArgs,
|
||||
)
|
||||
|
||||
|
||||
class AnnotationCreatePayload(BaseModel):
|
||||
@ -46,10 +51,15 @@ class AnnotationReplyActionApi(Resource):
|
||||
@validate_app_token
|
||||
def post(self, app_model: App, action: Literal["enable", "disable"]):
|
||||
"""Enable or disable annotation reply feature."""
|
||||
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
payload = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {})
|
||||
match action:
|
||||
case "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
|
||||
enable_args: EnableAnnotationArgs = {
|
||||
"score_threshold": payload.score_threshold,
|
||||
"embedding_provider_name": payload.embedding_provider_name,
|
||||
"embedding_model_name": payload.embedding_model_name,
|
||||
}
|
||||
result = AppAnnotationService.enable_app_annotation(enable_args, app_model.id)
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_model.id)
|
||||
return result, 200
|
||||
@ -135,8 +145,9 @@ class AnnotationListApi(Resource):
|
||||
@validate_app_token
|
||||
def post(self, app_model: App):
|
||||
"""Create a new annotation."""
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
|
||||
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
insert_args: InsertAnnotationArgs = {"question": payload.question, "answer": payload.answer}
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(insert_args, app_model.id)
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json"), HTTPStatus.CREATED
|
||||
|
||||
@ -164,8 +175,9 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, annotation_id: str):
|
||||
"""Update an existing annotation."""
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
||||
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id)
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
|
||||
@ -41,7 +41,23 @@ class AbstractVectorFactory(ABC):
|
||||
class Vector:
|
||||
def __init__(self, dataset: Dataset, attributes: list | None = None):
|
||||
if attributes is None:
|
||||
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
|
||||
# `is_summary` and `original_chunk_id` are stored on summary vectors
|
||||
# by `SummaryIndexService` and read back by `RetrievalService` to
|
||||
# route summary hits through their original parent chunks. They
|
||||
# must be listed here so vector backends that use this list as an
|
||||
# explicit return-properties projection (notably Weaviate) actually
|
||||
# return those fields; without them, summary hits silently
|
||||
# collapse into `is_summary = False` branches and the summary
|
||||
# retrieval path is a no-op. See #34884.
|
||||
attributes = [
|
||||
"doc_id",
|
||||
"dataset_id",
|
||||
"document_id",
|
||||
"doc_hash",
|
||||
"doc_type",
|
||||
"is_summary",
|
||||
"original_chunk_id",
|
||||
]
|
||||
self._dataset = dataset
|
||||
self._embeddings = self._get_embeddings()
|
||||
self._attributes = attributes
|
||||
|
||||
@ -244,7 +244,7 @@ class DatasetDocumentStore:
|
||||
return document_segment
|
||||
|
||||
def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None):
|
||||
if multimodel_documents:
|
||||
if multimodel_documents and self._document_id is not None:
|
||||
for multimodel_document in multimodel_documents:
|
||||
binding = SegmentAttachmentBinding(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
|
||||
@ -17,7 +17,6 @@ def http_status_message(code):
|
||||
|
||||
|
||||
def register_external_error_handlers(api: Api):
|
||||
@api.errorhandler(HTTPException)
|
||||
def handle_http_exception(e: HTTPException):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
|
||||
@ -74,27 +73,18 @@ def register_external_error_handlers(api: Api):
|
||||
headers["Set-Cookie"] = build_force_logout_cookie_headers()
|
||||
return data, status_code, headers
|
||||
|
||||
_ = handle_http_exception
|
||||
|
||||
@api.errorhandler(ValueError)
|
||||
def handle_value_error(e: ValueError):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
status_code = 400
|
||||
data = {"code": "invalid_param", "message": str(e), "status": status_code}
|
||||
return data, status_code
|
||||
|
||||
_ = handle_value_error
|
||||
|
||||
@api.errorhandler(AppInvokeQuotaExceededError)
|
||||
def handle_quota_exceeded(e: AppInvokeQuotaExceededError):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
status_code = 429
|
||||
data = {"code": "too_many_requests", "message": str(e), "status": status_code}
|
||||
return data, status_code
|
||||
|
||||
_ = handle_quota_exceeded
|
||||
|
||||
@api.errorhandler(Exception)
|
||||
def handle_general_exception(e: Exception):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
|
||||
@ -113,7 +103,10 @@ def register_external_error_handlers(api: Api):
|
||||
|
||||
return data, status_code
|
||||
|
||||
_ = handle_general_exception
|
||||
api.errorhandler(HTTPException)(handle_http_exception)
|
||||
api.errorhandler(ValueError)(handle_value_error)
|
||||
api.errorhandler(AppInvokeQuotaExceededError)(handle_quota_exceeded)
|
||||
api.errorhandler(Exception)(handle_general_exception)
|
||||
|
||||
|
||||
class ExternalApi(Api):
|
||||
|
||||
@ -1688,7 +1688,7 @@ class PipelineRecommendedPlugin(TypeBase):
|
||||
)
|
||||
|
||||
|
||||
class SegmentAttachmentBinding(Base):
|
||||
class SegmentAttachmentBinding(TypeBase):
|
||||
__tablename__ = "segment_attachment_bindings"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="segment_attachment_binding_pkey"),
|
||||
@ -1701,13 +1701,17 @@ class SegmentAttachmentBinding(Base):
|
||||
),
|
||||
sa.Index("segment_attachment_binding_attachment_idx", "attachment_id"),
|
||||
)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
|
||||
class DocumentSegmentSummary(Base):
|
||||
|
||||
@ -838,7 +838,7 @@ class AppModelConfig(TypeBase):
|
||||
return self
|
||||
|
||||
|
||||
class RecommendedApp(Base): # bug
|
||||
class RecommendedApp(TypeBase):
|
||||
__tablename__ = "recommended_apps"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
|
||||
@ -846,20 +846,37 @@ class RecommendedApp(Base): # bug
|
||||
sa.Index("recommended_app_is_listed_idx", "is_listed", "language"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
description = mapped_column(sa.JSON, nullable=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
insert_default=lambda: str(uuid4()),
|
||||
default_factory=lambda: str(uuid4()),
|
||||
init=False,
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
description: Mapped[Any] = mapped_column(sa.JSON, nullable=False)
|
||||
copyright: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
custom_disclaimer: Mapped[str] = mapped_column(LongText, default="")
|
||||
category: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
custom_disclaimer: Mapped[str] = mapped_column(LongText, default="")
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
is_listed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
|
||||
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'"))
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
language: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'en-US'"),
|
||||
default="en-US",
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@ -4,15 +4,15 @@ version = "1.13.3"
|
||||
requires-python = "~=3.12.0"
|
||||
|
||||
dependencies = [
|
||||
"aliyun-log-python-sdk~=0.9.37",
|
||||
"aliyun-log-python-sdk~=0.9.44",
|
||||
"arize-phoenix-otel~=0.15.0",
|
||||
"azure-identity==1.25.3",
|
||||
"beautifulsoup4==4.14.3",
|
||||
"boto3==1.42.88",
|
||||
"bs4~=0.0.1",
|
||||
"cachetools~=5.3.0",
|
||||
"celery~=5.6.2",
|
||||
"charset-normalizer>=3.4.4",
|
||||
"cachetools~=7.0.5",
|
||||
"celery~=5.6.3",
|
||||
"charset-normalizer>=3.4.7",
|
||||
"flask~=3.1.3",
|
||||
"flask-compress>=1.24,<1.25",
|
||||
"flask-cors~=6.0.2",
|
||||
@ -20,7 +20,7 @@ dependencies = [
|
||||
"flask-migrate~=4.1.0",
|
||||
"flask-orjson~=2.0.0",
|
||||
"flask-sqlalchemy~=3.1.1",
|
||||
"gevent~=25.9.1",
|
||||
"gevent~=26.4.0",
|
||||
"gmpy2~=2.3.0",
|
||||
"google-api-core>=2.30.3",
|
||||
"google-api-python-client==2.194.0",
|
||||
@ -30,49 +30,49 @@ dependencies = [
|
||||
"googleapis-common-protos>=1.74.0",
|
||||
"graphon>=0.1.2",
|
||||
"gunicorn~=25.3.0",
|
||||
"httpx[socks]~=0.28.0",
|
||||
"httpx[socks]~=0.28.1",
|
||||
"jieba==0.42.1",
|
||||
"json-repair>=0.55.1",
|
||||
"json-repair>=0.59.2",
|
||||
"langfuse>=4.2.0,<5.0.0",
|
||||
"langsmith~=0.7.30",
|
||||
"markdown~=3.10.2",
|
||||
"mlflow-skinny>=3.11.1",
|
||||
"numpy~=1.26.4",
|
||||
"numpy~=2.4.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.11.2",
|
||||
"litellm==1.83.0", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.40.0",
|
||||
"opentelemetry-distro==0.61b0",
|
||||
"opentelemetry-exporter-otlp==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-common==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-grpc==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-http==1.40.0",
|
||||
"opentelemetry-instrumentation==0.61b0",
|
||||
"opentelemetry-instrumentation-celery==0.61b0",
|
||||
"opentelemetry-instrumentation-flask==0.61b0",
|
||||
"opentelemetry-instrumentation-httpx==0.61b0",
|
||||
"opentelemetry-instrumentation-redis==0.61b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.61b0",
|
||||
"opentelemetry-api==1.41.0",
|
||||
"opentelemetry-distro==0.62b0",
|
||||
"opentelemetry-exporter-otlp==1.41.0",
|
||||
"opentelemetry-exporter-otlp-proto-common==1.41.0",
|
||||
"opentelemetry-exporter-otlp-proto-grpc==1.41.0",
|
||||
"opentelemetry-exporter-otlp-proto-http==1.41.0",
|
||||
"opentelemetry-instrumentation==0.62b0",
|
||||
"opentelemetry-instrumentation-celery==0.62b0",
|
||||
"opentelemetry-instrumentation-flask==0.62b0",
|
||||
"opentelemetry-instrumentation-httpx==0.62b0",
|
||||
"opentelemetry-instrumentation-redis==0.62b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.62b0",
|
||||
"opentelemetry-propagator-b3==1.41.0",
|
||||
"opentelemetry-proto==1.40.0",
|
||||
"opentelemetry-sdk==1.40.0",
|
||||
"opentelemetry-semantic-conventions==0.61b0",
|
||||
"opentelemetry-util-http==0.61b0",
|
||||
"pandas[excel,output-formatting,performance]~=3.0.1",
|
||||
"opentelemetry-proto==1.41.0",
|
||||
"opentelemetry-sdk==1.41.0",
|
||||
"opentelemetry-semantic-conventions==0.62b0",
|
||||
"opentelemetry-util-http==0.62b0",
|
||||
"pandas[excel,output-formatting,performance]~=3.0.2",
|
||||
"psycogreen~=1.0.2",
|
||||
"psycopg2-binary~=2.9.11",
|
||||
"pycryptodome==3.23.0",
|
||||
"pydantic~=2.12.5",
|
||||
"pydantic-settings~=2.13.1",
|
||||
"pyjwt~=2.12.0",
|
||||
"pyjwt~=2.12.1",
|
||||
"pypdfium2==5.6.0",
|
||||
"python-docx~=1.2.0",
|
||||
"python-dotenv==1.2.2",
|
||||
"pyyaml~=6.0.1",
|
||||
"readabilipy~=0.3.0",
|
||||
"redis[hiredis]~=7.4.0",
|
||||
"resend~=2.26.0",
|
||||
"sentry-sdk[flask]~=2.55.0",
|
||||
"resend~=2.27.0",
|
||||
"sentry-sdk[flask]~=2.57.0",
|
||||
"sqlalchemy~=2.0.49",
|
||||
"starlette==1.0.0",
|
||||
"tiktoken~=0.12.0",
|
||||
@ -82,13 +82,13 @@ dependencies = [
|
||||
"yarl~=1.23.0",
|
||||
"sseclient-py~=1.9.0",
|
||||
"httpx-sse~=0.4.0",
|
||||
"sendgrid~=6.12.3",
|
||||
"sendgrid~=6.12.5",
|
||||
"flask-restx~=1.3.2",
|
||||
"packaging~=23.2",
|
||||
"croniter>=6.0.0",
|
||||
"packaging~=26.0",
|
||||
"croniter>=6.2.2",
|
||||
"weaviate-client==4.20.5",
|
||||
"apscheduler>=3.11.0",
|
||||
"weave>=0.52.16",
|
||||
"apscheduler>=3.11.2",
|
||||
"weave>=0.52.36",
|
||||
"fastopenapi[flask]>=0.7.0",
|
||||
"bleach~=6.3.0",
|
||||
]
|
||||
@ -120,7 +120,7 @@ dev = [
|
||||
"pytest-cov~=7.1.0",
|
||||
"pytest-env~=1.6.0",
|
||||
"pytest-mock~=3.15.1",
|
||||
"testcontainers~=4.14.1",
|
||||
"testcontainers~=4.14.2",
|
||||
"types-aiofiles~=25.1.0",
|
||||
"types-beautifulsoup4~=4.12.0",
|
||||
"types-cachetools~=6.2.0",
|
||||
@ -166,7 +166,7 @@ dev = [
|
||||
"import-linter>=2.3",
|
||||
"types-redis>=4.6.0.20241004",
|
||||
"celery-types>=0.23.0",
|
||||
"mypy~=1.20.0",
|
||||
"mypy~=1.20.1",
|
||||
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
|
||||
"sseclient-py>=1.8.0",
|
||||
"pytest-timeout>=2.4.0",
|
||||
@ -225,7 +225,7 @@ vdb = [
|
||||
"xinference-client~=2.4.0",
|
||||
"mo-vector~=0.1.13",
|
||||
"mysql-connector-python>=9.3.0",
|
||||
"holo-search-sdk>=0.4.1",
|
||||
"holo-search-sdk>=0.4.2",
|
||||
]
|
||||
|
||||
[tool.pyrefly]
|
||||
|
||||
@ -47,7 +47,6 @@
|
||||
"reportMissingTypeArgument": "hint",
|
||||
"reportUnnecessaryComparison": "hint",
|
||||
"reportUnnecessaryIsInstance": "hint",
|
||||
"reportUntypedFunctionDecorator": "hint",
|
||||
"reportUnnecessaryTypeIgnoreComment": "hint",
|
||||
"reportAttributeAccessIssue": "hint",
|
||||
"pythonVersion": "3.12",
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from typing import TypedDict
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import delete, or_, select, update
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -24,6 +21,8 @@ from tasks.annotation.disable_annotation_reply_task import disable_annotation_re
|
||||
from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task
|
||||
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnnotationJobStatusDict(TypedDict):
|
||||
job_id: str
|
||||
@ -46,9 +45,50 @@ class AnnotationSettingDisabledDict(TypedDict):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class EnableAnnotationArgs(TypedDict):
|
||||
"""Expected shape of the args dict passed to enable_app_annotation."""
|
||||
|
||||
score_threshold: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class UpsertAnnotationArgs(TypedDict, total=False):
|
||||
"""Expected shape of the args dict passed to up_insert_app_annotation_from_message."""
|
||||
|
||||
answer: str
|
||||
content: str
|
||||
message_id: str
|
||||
question: str
|
||||
|
||||
|
||||
class InsertAnnotationArgs(TypedDict):
|
||||
"""Expected shape of the args dict passed to insert_app_annotation_directly."""
|
||||
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class UpdateAnnotationArgs(TypedDict, total=False):
|
||||
"""Expected shape of the args dict passed to update_app_annotation_directly.
|
||||
|
||||
Both fields are optional at the type level; the service validates at runtime
|
||||
and raises ValueError if either is missing.
|
||||
"""
|
||||
|
||||
answer: str
|
||||
question: str
|
||||
|
||||
|
||||
class UpdateAnnotationSettingArgs(TypedDict):
|
||||
"""Expected shape of the args dict passed to update_app_annotation_setting."""
|
||||
|
||||
score_threshold: float
|
||||
|
||||
|
||||
class AppAnnotationService:
|
||||
@classmethod
|
||||
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
def up_insert_app_annotation_from_message(cls, args: UpsertAnnotationArgs, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app = db.session.scalar(
|
||||
@ -62,8 +102,9 @@ class AppAnnotationService:
|
||||
if answer is None:
|
||||
raise ValueError("Either 'answer' or 'content' must be provided")
|
||||
|
||||
if args.get("message_id"):
|
||||
message_id = str(args["message_id"])
|
||||
raw_message_id = args.get("message_id")
|
||||
if raw_message_id:
|
||||
message_id = str(raw_message_id)
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_id, Message.app_id == app.id).limit(1)
|
||||
)
|
||||
@ -87,9 +128,10 @@ class AppAnnotationService:
|
||||
account_id=current_user.id,
|
||||
)
|
||||
else:
|
||||
question = args.get("question")
|
||||
if not question:
|
||||
maybe_question = args.get("question")
|
||||
if not maybe_question:
|
||||
raise ValueError("'question' is required when 'message_id' is not provided")
|
||||
question = maybe_question
|
||||
|
||||
annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id)
|
||||
db.session.add(annotation)
|
||||
@ -110,7 +152,7 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict:
|
||||
def enable_app_annotation(cls, args: EnableAnnotationArgs, app_id: str) -> AnnotationJobStatusDict:
|
||||
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(enable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
@ -217,7 +259,7 @@ class AppAnnotationService:
|
||||
return annotations
|
||||
|
||||
@classmethod
|
||||
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
def insert_app_annotation_directly(cls, args: InsertAnnotationArgs, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app = db.session.scalar(
|
||||
@ -251,7 +293,7 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
|
||||
def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = db.session.scalar(
|
||||
@ -270,7 +312,11 @@ class AppAnnotationService:
|
||||
if question is None:
|
||||
raise ValueError("'question' is required")
|
||||
|
||||
annotation.content = args["answer"]
|
||||
answer = args.get("answer")
|
||||
if answer is None:
|
||||
raise ValueError("'answer' is required")
|
||||
|
||||
annotation.content = answer
|
||||
annotation.question = question
|
||||
|
||||
db.session.commit()
|
||||
@ -613,7 +659,7 @@ class AppAnnotationService:
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_setting(
|
||||
cls, app_id: str, annotation_setting_id: str, args: dict
|
||||
cls, app_id: str, annotation_setting_id: str, args: UpdateAnnotationSettingArgs
|
||||
) -> AnnotationSettingDict:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# get app info
|
||||
|
||||
@ -4,7 +4,7 @@ from datetime import datetime
|
||||
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from sqlalchemy import delete, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
@ -42,32 +42,43 @@ class WorkflowToolManageService:
|
||||
labels: list[str] | None = None,
|
||||
):
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
# name or app_id
|
||||
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
|
||||
existing_workflow_tool_provider: WorkflowToolProvider | None = None
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
# query if the name or app_id exists
|
||||
existing_workflow_tool_provider = _session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
# name or app_id
|
||||
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# if the name or app_id exists raise error
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
|
||||
|
||||
app: App | None = db.session.scalar(
|
||||
select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
# query the app
|
||||
app: App | None = None
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
app = _session.scalar(select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1))
|
||||
|
||||
# if not found raise error
|
||||
if app is None:
|
||||
raise ValueError(f"App {workflow_app_id} not found")
|
||||
|
||||
# query the workflow
|
||||
workflow: Workflow | None = app.workflow
|
||||
|
||||
# if not found raise error
|
||||
if workflow is None:
|
||||
raise ValueError(f"Workflow not found for app {workflow_app_id}")
|
||||
|
||||
# check if workflow configuration is synced
|
||||
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict)
|
||||
|
||||
# create workflow tool provider
|
||||
workflow_tool_provider = WorkflowToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
@ -87,13 +98,15 @@ class WorkflowToolManageService:
|
||||
logger.warning(e, exc_info=True)
|
||||
raise ValueError(str(e))
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
session.add(workflow_tool_provider)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
_session.add(workflow_tool_provider)
|
||||
|
||||
# keep the session open to make orm instances in the same session
|
||||
if labels is not None:
|
||||
ToolLabelManager.update_tool_labels(
|
||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
@ -112,6 +125,7 @@ class WorkflowToolManageService:
|
||||
):
|
||||
"""
|
||||
Update a workflow tool.
|
||||
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param workflow_tool_id: workflow tool id
|
||||
@ -187,28 +201,32 @@ class WorkflowToolManageService:
|
||||
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
List workflow tools.
|
||||
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tools = db.session.scalars(
|
||||
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
providers: list[WorkflowToolProvider] = []
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
providers = list(
|
||||
_session.scalars(select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)).all()
|
||||
)
|
||||
|
||||
# Create a mapping from provider_id to app_id
|
||||
provider_id_to_app_id = {provider.id: provider.app_id for provider in db_tools}
|
||||
provider_id_to_app_id = {provider.id: provider.app_id for provider in providers}
|
||||
|
||||
tools: list[WorkflowToolProviderController] = []
|
||||
for provider in db_tools:
|
||||
for provider in providers:
|
||||
try:
|
||||
tools.append(ToolTransformService.workflow_provider_to_controller(provider))
|
||||
except Exception:
|
||||
# skip deleted tools
|
||||
logger.exception("Failed to load workflow tool provider %s", provider.id)
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)])
|
||||
labels = ToolLabelManager.get_tools_labels([tool for tool in tools if isinstance(tool, ToolProviderController)])
|
||||
|
||||
result = []
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
for tool in tools:
|
||||
workflow_app_id = provider_id_to_app_id.get(tool.provider_id)
|
||||
@ -233,17 +251,18 @@ class WorkflowToolManageService:
|
||||
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
|
||||
"""
|
||||
Delete a workflow tool.
|
||||
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param workflow_tool_id: the workflow tool id
|
||||
"""
|
||||
db.session.execute(
|
||||
delete(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
|
||||
)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
with sessionmaker(db.engine).begin() as _session:
|
||||
_ = _session.execute(
|
||||
delete(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
|
||||
)
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -251,47 +270,59 @@ class WorkflowToolManageService:
|
||||
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param workflow_tool_id: the workflow tool id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider | None = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.limit(1)
|
||||
)
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
tool_provider: WorkflowToolProvider | None = None
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
tool_provider = _session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
return cls._get_workflow_tool(tenant_id, tool_provider)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider | None = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
|
||||
.limit(1)
|
||||
)
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
tool_provider: WorkflowToolProvider | None = _session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
return cls._get_workflow_tool(tenant_id, tool_provider)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
|
||||
:db_tool: the database tool
|
||||
:return: the tool
|
||||
"""
|
||||
if db_tool is None:
|
||||
raise ValueError("Tool not found")
|
||||
|
||||
workflow_app: App | None = db.session.scalar(
|
||||
select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1)
|
||||
)
|
||||
workflow_app: App | None = None
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
workflow_app = _session.scalar(
|
||||
select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1)
|
||||
)
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f"App {db_tool.app_id} not found")
|
||||
@ -331,28 +362,32 @@ class WorkflowToolManageService:
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
List workflow tool provider tools.
|
||||
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param workflow_tool_id: the workflow tool id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tool: WorkflowToolProvider | None = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if db_tool is None:
|
||||
provider: WorkflowToolProvider | None = None
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
provider = _session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
tool = ToolTransformService.workflow_provider_to_controller(provider)
|
||||
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
|
||||
if len(workflow_tools) == 0:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
return [
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool=tool.get_tools(db_tool.tenant_id)[0],
|
||||
tool=tool.get_tools(provider.tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
@ -313,6 +313,21 @@ class TestSiteEndpoints:
|
||||
method = _unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
site.app_id = "app-1"
|
||||
site.code = "test-code"
|
||||
site.title = "My Site"
|
||||
site.icon = None
|
||||
site.icon_background = None
|
||||
site.description = "Test site"
|
||||
site.default_language = "en-US"
|
||||
site.customize_domain = None
|
||||
site.copyright = None
|
||||
site.privacy_policy = None
|
||||
site.custom_disclaimer = ""
|
||||
site.customize_token_strategy = "not_allow"
|
||||
site.prompt_public = False
|
||||
site.show_workflow_steps = True
|
||||
site.use_icon_as_answer_icon = False
|
||||
monkeypatch.setattr(
|
||||
site_module.db,
|
||||
"session",
|
||||
@ -328,13 +343,29 @@ class TestSiteEndpoints:
|
||||
with app.test_request_context("/", json={"title": "My Site"}):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result is site
|
||||
assert isinstance(result, dict)
|
||||
assert result["title"] == "My Site"
|
||||
|
||||
def test_app_site_access_token_reset(self, app, monkeypatch):
|
||||
api = site_module.AppSiteAccessTokenReset()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
site.app_id = "app-1"
|
||||
site.code = "old-code"
|
||||
site.title = "My Site"
|
||||
site.icon = None
|
||||
site.icon_background = None
|
||||
site.description = None
|
||||
site.default_language = "en-US"
|
||||
site.customize_domain = None
|
||||
site.copyright = None
|
||||
site.privacy_policy = None
|
||||
site.custom_disclaimer = ""
|
||||
site.customize_token_strategy = "not_allow"
|
||||
site.prompt_public = False
|
||||
site.show_workflow_steps = True
|
||||
site.use_icon_as_answer_icon = False
|
||||
monkeypatch.setattr(
|
||||
site_module.db,
|
||||
"session",
|
||||
@ -351,7 +382,8 @@ class TestSiteEndpoints:
|
||||
with app.test_request_context("/"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result is site
|
||||
assert isinstance(result, dict)
|
||||
assert result["access_token"] == "code"
|
||||
|
||||
|
||||
class TestWorkflowEndpoints:
|
||||
|
||||
@ -148,14 +148,18 @@ def test_chat_message_list_success(
|
||||
account.id,
|
||||
created_at_offset_seconds=1,
|
||||
)
|
||||
# Capture IDs before the HTTP request detaches ORM instances from the session
|
||||
app_id = app.id
|
||||
conversation_id = conversation.id
|
||||
second_id = second.id
|
||||
|
||||
with patch(
|
||||
"controllers.console.app.message.attach_message_extra_contents",
|
||||
side_effect=_attach_message_extra_contents,
|
||||
):
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/chat-messages",
|
||||
query_string={"conversation_id": conversation.id, "limit": 1},
|
||||
f"/console/api/apps/{app_id}/chat-messages",
|
||||
query_string={"conversation_id": conversation_id, "limit": 1},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
@ -165,7 +169,7 @@ def test_chat_message_list_success(
|
||||
assert payload["limit"] == 1
|
||||
assert payload["has_more"] is True
|
||||
assert len(payload["data"]) == 1
|
||||
assert payload["data"][0]["id"] == second.id
|
||||
assert payload["data"][0]["id"] == second_id
|
||||
|
||||
|
||||
def test_message_feedback_not_found(
|
||||
|
||||
@ -1,79 +1,202 @@
|
||||
# import secrets
|
||||
"""
|
||||
Integration tests for Account and Tenant model methods that interact with the database.
|
||||
|
||||
# import pytest
|
||||
# from sqlalchemy import select
|
||||
# from sqlalchemy.orm import Session
|
||||
# from sqlalchemy.orm.exc import DetachedInstanceError
|
||||
Migrated from unit_tests/models/test_account_models.py, replacing
|
||||
@patch("models.account.db") mock patches with real PostgreSQL operations.
|
||||
|
||||
# from libs.datetime_utils import naive_utc_now
|
||||
# from models.account import Account, Tenant, TenantAccountJoin
|
||||
Covers:
|
||||
- Account.current_tenant setter (sets _current_tenant and role from TenantAccountJoin)
|
||||
- Account.set_tenant_id (resolves tenant + role from real join row)
|
||||
- Account.get_by_openid (AccountIntegrate lookup then Account fetch)
|
||||
- Tenant.get_accounts (returns accounts linked via TenantAccountJoin)
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, AccountIntegrate, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def session(db_session_with_containers):
|
||||
# with Session(db_session_with_containers.get_bind()) as session:
|
||||
# yield session
|
||||
def _cleanup_tracked_rows(db_session: Session, tracked: list) -> None:
|
||||
"""Delete rows tracked during the test so committed state does not leak into the DB.
|
||||
|
||||
Rolls back any pending (uncommitted) session state first, then issues DELETE
|
||||
statements by primary key for each tracked entity (in reverse creation order)
|
||||
and commits. This cleans up rows created via either flush() or commit().
|
||||
"""
|
||||
db_session.rollback()
|
||||
for entity in reversed(tracked):
|
||||
db_session.execute(delete(type(entity)).where(type(entity).id == entity.id))
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def account(session):
|
||||
# account = Account(
|
||||
# name="test account",
|
||||
# email=f"test_{secrets.token_hex(8)}@example.com",
|
||||
# )
|
||||
# session.add(account)
|
||||
# session.commit()
|
||||
# return account
|
||||
def _build_tenant() -> Tenant:
|
||||
return Tenant(name=f"Tenant {uuid4()}")
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def tenant(session):
|
||||
# tenant = Tenant(name="test tenant")
|
||||
# session.add(tenant)
|
||||
# session.commit()
|
||||
# return tenant
|
||||
def _build_account(email_prefix: str = "account") -> Account:
|
||||
return Account(
|
||||
name=f"Account {uuid4()}",
|
||||
email=f"{email_prefix}_{uuid4()}@example.com",
|
||||
password="hashed-password",
|
||||
password_salt="salt",
|
||||
interface_language="en-US",
|
||||
timezone="UTC",
|
||||
)
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def tenant_account_join(session, account, tenant):
|
||||
# tenant_join = TenantAccountJoin(account_id=account.id, tenant_id=tenant.id)
|
||||
# session.add(tenant_join)
|
||||
# session.commit()
|
||||
# yield tenant_join
|
||||
# session.delete(tenant_join)
|
||||
# session.commit()
|
||||
class _DBTrackingTestBase:
|
||||
"""Base class providing a tracker list and shared row factories for account/tenant tests."""
|
||||
|
||||
_tracked: list
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_cleanup(self, db_session_with_containers: Session) -> Generator[None, None, None]:
|
||||
self._tracked = []
|
||||
yield
|
||||
_cleanup_tracked_rows(db_session_with_containers, self._tracked)
|
||||
|
||||
def _create_tenant(self, db_session: Session) -> Tenant:
|
||||
tenant = _build_tenant()
|
||||
db_session.add(tenant)
|
||||
db_session.flush()
|
||||
self._tracked.append(tenant)
|
||||
return tenant
|
||||
|
||||
def _create_account(self, db_session: Session, email_prefix: str = "account") -> Account:
|
||||
account = _build_account(email_prefix)
|
||||
db_session.add(account)
|
||||
db_session.flush()
|
||||
self._tracked.append(account)
|
||||
return account
|
||||
|
||||
def _create_join(
|
||||
self, db_session: Session, tenant_id: str, account_id: str, role: TenantAccountRole, current: bool = True
|
||||
) -> TenantAccountJoin:
|
||||
join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id, role=role, current=current)
|
||||
db_session.add(join)
|
||||
db_session.flush()
|
||||
self._tracked.append(join)
|
||||
return join
|
||||
|
||||
|
||||
# class TestAccountTenant:
|
||||
# def test_set_current_tenant_should_reload_tenant(
|
||||
# self,
|
||||
# db_session_with_containers,
|
||||
# account,
|
||||
# tenant,
|
||||
# tenant_account_join,
|
||||
# ):
|
||||
# with Session(db_session_with_containers.get_bind(), expire_on_commit=True) as session:
|
||||
# scoped_tenant = session.scalars(select(Tenant).where(Tenant.id == tenant.id)).one()
|
||||
# account.current_tenant = scoped_tenant
|
||||
# scoped_tenant.created_at = naive_utc_now()
|
||||
# # session.commit()
|
||||
class TestAccountCurrentTenantSetter(_DBTrackingTestBase):
|
||||
"""Integration tests for Account.current_tenant property setter."""
|
||||
|
||||
# # Ensure the tenant used in assignment is detached.
|
||||
# with pytest.raises(DetachedInstanceError):
|
||||
# _ = scoped_tenant.name
|
||||
def test_current_tenant_property_returns_cached_tenant(self, db_session_with_containers: Session) -> None:
|
||||
"""current_tenant getter returns the in-memory _current_tenant without DB access."""
|
||||
account = self._create_account(db_session_with_containers)
|
||||
tenant = self._create_tenant(db_session_with_containers)
|
||||
account._current_tenant = tenant
|
||||
|
||||
# assert account._current_tenant.id == tenant.id
|
||||
# assert account._current_tenant.id == tenant.id
|
||||
assert account.current_tenant is tenant
|
||||
|
||||
# def test_set_tenant_id_should_load_tenant_as_not_expire(
|
||||
# self,
|
||||
# flask_app_with_containers,
|
||||
# account,
|
||||
# tenant,
|
||||
# tenant_account_join,
|
||||
# ):
|
||||
# with flask_app_with_containers.test_request_context():
|
||||
# account.set_tenant_id(tenant.id)
|
||||
def test_current_tenant_setter_sets_tenant_and_role_when_join_exists(
|
||||
self, db_session_with_containers: Session
|
||||
) -> None:
|
||||
"""Setting current_tenant loads the join row and assigns role when relationship exists."""
|
||||
tenant = self._create_tenant(db_session_with_containers)
|
||||
account = self._create_account(db_session_with_containers)
|
||||
self._create_join(db_session_with_containers, tenant.id, account.id, TenantAccountRole.OWNER)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# assert account._current_tenant.id == tenant.id
|
||||
# assert account._current_tenant.id == tenant.id
|
||||
account.current_tenant = tenant
|
||||
|
||||
assert account._current_tenant is not None
|
||||
assert account._current_tenant.id == tenant.id
|
||||
assert account.role == TenantAccountRole.OWNER
|
||||
|
||||
def test_current_tenant_setter_sets_none_when_no_join_exists(self, db_session_with_containers: Session) -> None:
|
||||
"""Setting current_tenant results in _current_tenant=None when no join row exists."""
|
||||
tenant = self._create_tenant(db_session_with_containers)
|
||||
account = self._create_account(db_session_with_containers)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.current_tenant = tenant
|
||||
|
||||
assert account._current_tenant is None
|
||||
|
||||
|
||||
class TestAccountSetTenantId(_DBTrackingTestBase):
|
||||
"""Integration tests for Account.set_tenant_id method."""
|
||||
|
||||
def test_set_tenant_id_sets_tenant_and_role_when_relationship_exists(
|
||||
self, db_session_with_containers: Session
|
||||
) -> None:
|
||||
"""set_tenant_id loads the tenant and assigns role when a join row exists."""
|
||||
tenant = self._create_tenant(db_session_with_containers)
|
||||
account = self._create_account(db_session_with_containers)
|
||||
self._create_join(db_session_with_containers, tenant.id, account.id, TenantAccountRole.ADMIN)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.set_tenant_id(tenant.id)
|
||||
|
||||
assert account._current_tenant is not None
|
||||
assert account._current_tenant.id == tenant.id
|
||||
assert account.role == TenantAccountRole.ADMIN
|
||||
|
||||
def test_set_tenant_id_does_not_set_tenant_when_no_relationship_exists(
|
||||
self, db_session_with_containers: Session
|
||||
) -> None:
|
||||
"""set_tenant_id does nothing when no join row matches the tenant."""
|
||||
tenant = self._create_tenant(db_session_with_containers)
|
||||
account = self._create_account(db_session_with_containers)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.set_tenant_id(tenant.id)
|
||||
|
||||
assert account._current_tenant is None
|
||||
|
||||
|
||||
class TestAccountGetByOpenId(_DBTrackingTestBase):
|
||||
"""Integration tests for Account.get_by_openid class method."""
|
||||
|
||||
def test_get_by_openid_returns_account_when_integrate_exists(self, db_session_with_containers: Session) -> None:
|
||||
"""get_by_openid returns the Account when a matching AccountIntegrate row exists."""
|
||||
account = self._create_account(db_session_with_containers, email_prefix="openid")
|
||||
provider = "google"
|
||||
open_id = f"google_{uuid4()}"
|
||||
|
||||
integrate = AccountIntegrate(
|
||||
account_id=account.id,
|
||||
provider=provider,
|
||||
open_id=open_id,
|
||||
encrypted_token="token",
|
||||
)
|
||||
db_session_with_containers.add(integrate)
|
||||
db_session_with_containers.flush()
|
||||
self._tracked.append(integrate)
|
||||
|
||||
result = Account.get_by_openid(provider, open_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == account.id
|
||||
|
||||
def test_get_by_openid_returns_none_when_no_integrate_exists(self, db_session_with_containers: Session) -> None:
|
||||
"""get_by_openid returns None when no AccountIntegrate row matches."""
|
||||
result = Account.get_by_openid("github", f"github_{uuid4()}")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestTenantGetAccounts(_DBTrackingTestBase):
|
||||
"""Integration tests for Tenant.get_accounts method."""
|
||||
|
||||
def test_get_accounts_returns_linked_accounts(self, db_session_with_containers: Session) -> None:
|
||||
"""get_accounts returns all accounts linked to the tenant via TenantAccountJoin."""
|
||||
tenant = self._create_tenant(db_session_with_containers)
|
||||
account1 = self._create_account(db_session_with_containers, email_prefix="tenant_member")
|
||||
account2 = self._create_account(db_session_with_containers, email_prefix="tenant_member")
|
||||
self._create_join(db_session_with_containers, tenant.id, account1.id, TenantAccountRole.OWNER, current=False)
|
||||
self._create_join(db_session_with_containers, tenant.id, account2.id, TenantAccountRole.NORMAL, current=False)
|
||||
|
||||
accounts = tenant.get_accounts()
|
||||
|
||||
assert len(accounts) == 2
|
||||
account_ids = {a.id for a in accounts}
|
||||
assert account1.id in account_ids
|
||||
assert account2.id in account_ids
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@ -363,3 +367,62 @@ class TestBillingServiceGetPlanBulkWithCache:
|
||||
assert ttl_1_new <= 600
|
||||
assert ttl_2 > 0
|
||||
assert ttl_2 <= 600
|
||||
|
||||
|
||||
class TestBillingServiceIsTenantOwnerOrAdmin:
|
||||
"""
|
||||
Integration tests for BillingService.is_tenant_owner_or_admin.
|
||||
|
||||
Verifies that non-privileged roles (EDITOR, DATASET_OPERATOR) raise ValueError
|
||||
when checked against real TenantAccountJoin rows in PostgreSQL.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
|
||||
yield
|
||||
db_session_with_containers.rollback()
|
||||
|
||||
def _create_account_with_tenant_role(self, db_session: Session, role: TenantAccountRole) -> tuple[Account, Tenant]:
|
||||
tenant = Tenant(name=f"Tenant {uuid4()}")
|
||||
db_session.add(tenant)
|
||||
db_session.flush()
|
||||
|
||||
account = Account(
|
||||
name=f"Account {uuid4()}",
|
||||
email=f"billing_{uuid4()}@example.com",
|
||||
password="hashed-password",
|
||||
password_salt="salt",
|
||||
interface_language="en-US",
|
||||
timezone="UTC",
|
||||
)
|
||||
db_session.add(account)
|
||||
db_session.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=role,
|
||||
current=True,
|
||||
)
|
||||
db_session.add(join)
|
||||
db_session.flush()
|
||||
|
||||
# Wire up in-memory reference so current_tenant_id resolves
|
||||
account._current_tenant = tenant
|
||||
return account, tenant
|
||||
|
||||
def test_is_tenant_owner_or_admin_editor_role_raises_error(self, db_session_with_containers: Session) -> None:
|
||||
"""is_tenant_owner_or_admin raises ValueError for EDITOR role."""
|
||||
account, _ = self._create_account_with_tenant_role(db_session_with_containers, TenantAccountRole.EDITOR)
|
||||
|
||||
with pytest.raises(ValueError, match="Only team owner or team admin can perform this action"):
|
||||
BillingService.is_tenant_owner_or_admin(account)
|
||||
|
||||
def test_is_tenant_owner_or_admin_dataset_operator_raises_error(self, db_session_with_containers: Session) -> None:
|
||||
"""is_tenant_owner_or_admin raises ValueError for DATASET_OPERATOR role."""
|
||||
account, _ = self._create_account_with_tenant_role(
|
||||
db_session_with_containers, TenantAccountRole.DATASET_OPERATOR
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Only team owner or team admin can perform this action"):
|
||||
BillingService.is_tenant_owner_or_admin(account)
|
||||
|
||||
@ -0,0 +1,70 @@
|
||||
import datetime
|
||||
|
||||
from controllers.console.app.mcp_server import AppMCPServerResponse
|
||||
|
||||
|
||||
class TestAppMCPServerResponse:
|
||||
def test_parameters_json_string_parsed(self):
|
||||
data = {
|
||||
"id": "s1",
|
||||
"name": "test",
|
||||
"server_code": "code",
|
||||
"description": "desc",
|
||||
"status": "active",
|
||||
"parameters": '{"key": "value"}',
|
||||
}
|
||||
resp = AppMCPServerResponse.model_validate(data)
|
||||
assert resp.parameters == {"key": "value"}
|
||||
|
||||
def test_parameters_invalid_json_returns_original(self):
|
||||
data = {
|
||||
"id": "s1",
|
||||
"name": "test",
|
||||
"server_code": "code",
|
||||
"description": "desc",
|
||||
"status": "active",
|
||||
"parameters": "not-valid-json",
|
||||
}
|
||||
resp = AppMCPServerResponse.model_validate(data)
|
||||
assert resp.parameters == "not-valid-json"
|
||||
|
||||
def test_parameters_dict_passthrough(self):
|
||||
data = {
|
||||
"id": "s1",
|
||||
"name": "test",
|
||||
"server_code": "code",
|
||||
"description": "desc",
|
||||
"status": "active",
|
||||
"parameters": {"already": "parsed"},
|
||||
}
|
||||
resp = AppMCPServerResponse.model_validate(data)
|
||||
assert resp.parameters == {"already": "parsed"}
|
||||
|
||||
def test_timestamps_normalized(self):
|
||||
dt = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC)
|
||||
data = {
|
||||
"id": "s1",
|
||||
"name": "test",
|
||||
"server_code": "code",
|
||||
"description": "desc",
|
||||
"status": "active",
|
||||
"parameters": {},
|
||||
"created_at": dt,
|
||||
"updated_at": dt,
|
||||
}
|
||||
resp = AppMCPServerResponse.model_validate(data)
|
||||
assert resp.created_at == int(dt.timestamp())
|
||||
assert resp.updated_at == int(dt.timestamp())
|
||||
|
||||
def test_timestamps_none(self):
|
||||
data = {
|
||||
"id": "s1",
|
||||
"name": "test",
|
||||
"server_code": "code",
|
||||
"description": "desc",
|
||||
"status": "active",
|
||||
"parameters": {},
|
||||
}
|
||||
resp = AppMCPServerResponse.model_validate(data)
|
||||
assert resp.created_at is None
|
||||
assert resp.updated_at is None
|
||||
@ -1555,7 +1555,17 @@ class TestDatasetApiKeyApi:
|
||||
method = unwrap(api.get)
|
||||
|
||||
mock_key_1 = MagicMock(spec=ApiToken)
|
||||
mock_key_1.id = "key-1"
|
||||
mock_key_1.type = "dataset"
|
||||
mock_key_1.token = "ds-abc"
|
||||
mock_key_1.last_used_at = None
|
||||
mock_key_1.created_at = None
|
||||
mock_key_2 = MagicMock(spec=ApiToken)
|
||||
mock_key_2.id = "key-2"
|
||||
mock_key_2.type = "dataset"
|
||||
mock_key_2.token = "ds-def"
|
||||
mock_key_2.last_used_at = None
|
||||
mock_key_2.created_at = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -1570,13 +1580,26 @@ class TestDatasetApiKeyApi:
|
||||
):
|
||||
response = method(api)
|
||||
|
||||
assert "items" in response
|
||||
assert response["items"] == [mock_key_1, mock_key_2]
|
||||
assert "data" in response
|
||||
assert len(response["data"]) == 2
|
||||
assert response["data"][0]["id"] == "key-1"
|
||||
assert response["data"][0]["token"] == "ds-abc"
|
||||
assert response["data"][1]["id"] == "key-2"
|
||||
assert response["data"][1]["token"] == "ds-def"
|
||||
|
||||
def test_post_create_api_key_success(self, app):
|
||||
api = DatasetApiKeyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
mock_token = MagicMock()
|
||||
mock_token.id = "new-key-id"
|
||||
mock_token.last_used_at = None
|
||||
mock_token.created_at = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC)
|
||||
|
||||
mock_api_token_cls = MagicMock()
|
||||
mock_api_token_cls.return_value = mock_token
|
||||
mock_api_token_cls.generate_api_key.return_value = "dataset-abc123"
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
@ -1588,8 +1611,8 @@ class TestDatasetApiKeyApi:
|
||||
return_value=3,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.ApiToken.generate_api_key",
|
||||
return_value="dataset-abc123",
|
||||
"controllers.console.datasets.datasets.ApiToken",
|
||||
mock_api_token_cls,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.add",
|
||||
@ -1603,9 +1626,11 @@ class TestDatasetApiKeyApi:
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert isinstance(response, ApiToken)
|
||||
assert response.token == "dataset-abc123"
|
||||
assert response.type == "dataset"
|
||||
assert isinstance(response, dict)
|
||||
assert response["id"] == "new-key-id"
|
||||
assert response["token"] == "dataset-abc123"
|
||||
assert response["type"] == "dataset"
|
||||
assert response["created_at"] is not None
|
||||
|
||||
def test_post_exceed_max_keys(self, app):
|
||||
api = DatasetApiKeyApi()
|
||||
|
||||
@ -121,7 +121,18 @@ def test_vector_init_uses_default_and_custom_attributes(vector_factory_module):
|
||||
default_vector = vector_factory_module.Vector(dataset)
|
||||
custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"])
|
||||
|
||||
assert default_vector._attributes == ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
|
||||
# `is_summary` and `original_chunk_id` must be in the default return-properties
|
||||
# projection so summary index retrieval works on backends that honor the list
|
||||
# as an explicit projection (e.g. Weaviate). See #34884.
|
||||
assert default_vector._attributes == [
|
||||
"doc_id",
|
||||
"dataset_id",
|
||||
"document_id",
|
||||
"doc_hash",
|
||||
"doc_type",
|
||||
"is_summary",
|
||||
"original_chunk_id",
|
||||
]
|
||||
assert custom_vector._attributes == ["doc_id"]
|
||||
assert default_vector._embeddings == "embeddings"
|
||||
assert default_vector._vector_processor == "processor"
|
||||
|
||||
@ -721,6 +721,30 @@ class TestDatasetDocumentStoreMultimodelBinding:
|
||||
|
||||
mock_db.session.add.assert_not_called()
|
||||
|
||||
def test_add_multimodel_documents_binding_with_none_document_id(self):
|
||||
"""Test that no bindings are added when document_id is None."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
mock_dataset.tenant_id = "tenant-1"
|
||||
|
||||
mock_attachment = MagicMock(spec=AttachmentDocument)
|
||||
mock_attachment.metadata = {"doc_id": "attachment-1"}
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
document_id=None,
|
||||
)
|
||||
|
||||
store.add_multimodel_documents_binding("seg-1", [mock_attachment])
|
||||
|
||||
mock_db.session.add.assert_not_called()
|
||||
|
||||
|
||||
class TestDatasetDocumentStoreAddDocumentsUpdateChild:
|
||||
"""Tests for add_documents when updating existing documents with children."""
|
||||
|
||||
@ -12,7 +12,6 @@ This test suite covers:
|
||||
import base64
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@ -310,90 +309,6 @@ class TestAccountStatusTransitions:
|
||||
class TestTenantRelationshipIntegrity:
|
||||
"""Test suite for tenant relationship integrity."""
|
||||
|
||||
@patch("models.account.db")
|
||||
def test_account_current_tenant_property(self, mock_db):
|
||||
"""Test the current_tenant property getter."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.id = str(uuid4())
|
||||
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid4())
|
||||
|
||||
account._current_tenant = tenant
|
||||
|
||||
# Act
|
||||
result = account.current_tenant
|
||||
|
||||
# Assert
|
||||
assert result == tenant
|
||||
|
||||
@patch("models.account.Session")
|
||||
@patch("models.account.db")
|
||||
def test_account_current_tenant_setter_with_valid_tenant(self, mock_db, mock_session_class):
|
||||
"""Test setting current_tenant with a valid tenant relationship."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.id = str(uuid4())
|
||||
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid4())
|
||||
|
||||
# Mock the session and queries
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock TenantAccountJoin query result
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
)
|
||||
mock_session.scalar.return_value = tenant_join
|
||||
|
||||
# Mock Tenant query result
|
||||
mock_session.scalars.return_value.one.return_value = tenant
|
||||
|
||||
# Act
|
||||
account.current_tenant = tenant
|
||||
|
||||
# Assert
|
||||
assert account._current_tenant == tenant
|
||||
assert account.role == TenantAccountRole.OWNER
|
||||
|
||||
@patch("models.account.Session")
|
||||
@patch("models.account.db")
|
||||
def test_account_current_tenant_setter_without_relationship(self, mock_db, mock_session_class):
|
||||
"""Test setting current_tenant when no relationship exists."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.id = str(uuid4())
|
||||
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid4())
|
||||
|
||||
# Mock the session and queries
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock no TenantAccountJoin found
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
account.current_tenant = tenant
|
||||
|
||||
# Assert
|
||||
assert account._current_tenant is None
|
||||
|
||||
def test_account_current_tenant_id_property(self):
|
||||
"""Test the current_tenant_id property."""
|
||||
# Arrange
|
||||
@ -418,61 +333,6 @@ class TestTenantRelationshipIntegrity:
|
||||
# Assert
|
||||
assert tenant_id_none is None
|
||||
|
||||
@patch("models.account.Session")
|
||||
@patch("models.account.db")
|
||||
def test_account_set_tenant_id_method(self, mock_db, mock_session_class):
|
||||
"""Test the set_tenant_id method."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.id = str(uuid4())
|
||||
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid4())
|
||||
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.ADMIN,
|
||||
)
|
||||
|
||||
# Mock the session and queries
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.first.return_value = (tenant, tenant_join)
|
||||
|
||||
# Act
|
||||
account.set_tenant_id(tenant.id)
|
||||
|
||||
# Assert
|
||||
assert account._current_tenant == tenant
|
||||
assert account.role == TenantAccountRole.ADMIN
|
||||
|
||||
@patch("models.account.Session")
|
||||
@patch("models.account.db")
|
||||
def test_account_set_tenant_id_with_no_relationship(self, mock_db, mock_session_class):
|
||||
"""Test set_tenant_id when no relationship exists."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.id = str(uuid4())
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# Mock the session and queries
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
account.set_tenant_id(tenant_id)
|
||||
|
||||
# Assert - should not set tenant when no relationship exists
|
||||
# The method returns early without setting _current_tenant
|
||||
|
||||
|
||||
class TestAccountRolePermissions:
|
||||
"""Test suite for account role permissions."""
|
||||
@ -605,51 +465,6 @@ class TestAccountRolePermissions:
|
||||
assert current_role == TenantAccountRole.EDITOR
|
||||
|
||||
|
||||
class TestAccountGetByOpenId:
|
||||
"""Test suite for get_by_openid class method."""
|
||||
|
||||
@patch("models.account.db")
|
||||
def test_get_by_openid_success(self, mock_db):
|
||||
"""Test successful retrieval of account by OpenID."""
|
||||
# Arrange
|
||||
provider = "google"
|
||||
open_id = "google_user_123"
|
||||
account_id = str(uuid4())
|
||||
|
||||
mock_account_integrate = MagicMock()
|
||||
mock_account_integrate.account_id = account_id
|
||||
|
||||
mock_account = Account(name="Test User", email="test@example.com")
|
||||
mock_account.id = account_id
|
||||
|
||||
# Mock db.session.execute().scalar_one_or_none() for AccountIntegrate lookup
|
||||
mock_db.session.execute.return_value.scalar_one_or_none.return_value = mock_account_integrate
|
||||
# Mock db.session.scalar() for Account lookup
|
||||
mock_db.session.scalar.return_value = mock_account
|
||||
|
||||
# Act
|
||||
result = Account.get_by_openid(provider, open_id)
|
||||
|
||||
# Assert
|
||||
assert result == mock_account
|
||||
|
||||
@patch("models.account.db")
|
||||
def test_get_by_openid_not_found(self, mock_db):
|
||||
"""Test get_by_openid when account integrate doesn't exist."""
|
||||
# Arrange
|
||||
provider = "github"
|
||||
open_id = "github_user_456"
|
||||
|
||||
# Mock db.session.execute().scalar_one_or_none() to return None
|
||||
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
# Act
|
||||
result = Account.get_by_openid(provider, open_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestTenantAccountJoinModel:
|
||||
"""Test suite for TenantAccountJoin model."""
|
||||
|
||||
@ -760,31 +575,6 @@ class TestTenantModel:
|
||||
# Assert
|
||||
assert tenant.custom_config == '{"feature1": true, "feature2": "value"}'
|
||||
|
||||
@patch("models.account.db")
|
||||
def test_tenant_get_accounts(self, mock_db):
|
||||
"""Test getting accounts associated with a tenant."""
|
||||
# Arrange
|
||||
tenant = Tenant(name="Test Workspace")
|
||||
tenant.id = str(uuid4())
|
||||
|
||||
account1 = Account(name="User 1", email="user1@example.com")
|
||||
account1.id = str(uuid4())
|
||||
account2 = Account(name="User 2", email="user2@example.com")
|
||||
account2.id = str(uuid4())
|
||||
|
||||
# Mock the query chain
|
||||
mock_scalars = MagicMock()
|
||||
mock_scalars.all.return_value = [account1, account2]
|
||||
mock_db.session.scalars.return_value = mock_scalars
|
||||
|
||||
# Act
|
||||
accounts = tenant.get_accounts()
|
||||
|
||||
# Assert
|
||||
assert len(accounts) == 2
|
||||
assert account1 in accounts
|
||||
assert account2 in accounts
|
||||
|
||||
|
||||
class TestTenantStatusEnum:
|
||||
"""Test suite for TenantStatus enum."""
|
||||
|
||||
@ -1117,42 +1117,6 @@ class TestBillingServiceEdgeCases:
|
||||
# Assert
|
||||
assert result["history_id"] == history_id
|
||||
|
||||
def test_is_tenant_owner_or_admin_editor_role_raises_error(self):
|
||||
"""Test tenant owner/admin check raises error for editor role."""
|
||||
# Arrange
|
||||
current_user = MagicMock(spec=Account)
|
||||
current_user.id = "account-123"
|
||||
current_user.current_tenant_id = "tenant-456"
|
||||
|
||||
mock_join = MagicMock(spec=TenantAccountJoin)
|
||||
mock_join.role = TenantAccountRole.EDITOR # Editor is not privileged
|
||||
|
||||
with patch("services.billing_service.db.session") as mock_session:
|
||||
mock_session.scalar.return_value = mock_join
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
assert "Only team owner or team admin can perform this action" in str(exc_info.value)
|
||||
|
||||
def test_is_tenant_owner_or_admin_dataset_operator_raises_error(self):
|
||||
"""Test tenant owner/admin check raises error for dataset operator role."""
|
||||
# Arrange
|
||||
current_user = MagicMock(spec=Account)
|
||||
current_user.id = "account-123"
|
||||
current_user.current_tenant_id = "tenant-456"
|
||||
|
||||
mock_join = MagicMock(spec=TenantAccountJoin)
|
||||
mock_join.role = TenantAccountRole.DATASET_OPERATOR # Dataset operator is not privileged
|
||||
|
||||
with patch("services.billing_service.db.session") as mock_session:
|
||||
mock_session.scalar.return_value = mock_join
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
assert "Only team owner or team admin can perform this action" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestBillingServiceSubscriptionOperations:
|
||||
"""Unit tests for subscription operations in BillingService.
|
||||
|
||||
485
api/uv.lock
generated
485
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user