Compare commits

..

10 Commits

1337 changed files with 8499 additions and 27854 deletions

View File

@ -42,7 +42,11 @@ jobs:
- name: Run Unit tests - name: Run Unit tests
run: | run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Run ty check
run: |
cd api
uv add --dev ty
uv run ty check || true
- name: Run pyrefly check - name: Run pyrefly check
run: | run: |
cd api cd api
@ -62,6 +66,15 @@ jobs:
- name: Run dify config tests - name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py run: uv run --project api dev/pytest/pytest_config_tests.py
- name: MyPy Cache
uses: actions/cache@v4
with:
path: api/.mypy_cache
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/uv.lock') }}
- name: Run MyPy Checks
run: dev/mypy-check
- name: Set up dotenvs - name: Set up dotenvs
run: | run: |
cp docker/.env.example docker/.env cp docker/.env.example docker/.env

View File

@ -20,7 +20,7 @@ jobs:
cd api cd api
uv sync --dev uv sync --dev
# Fix lint errors # Fix lint errors
uv run ruff check --fix . uv run ruff check --fix-only .
# Format code # Format code
uv run ruff format . uv run ruff format .
- name: ast-grep - name: ast-grep

View File

@ -19,23 +19,11 @@ jobs:
github.event.workflow_run.head_branch == 'deploy/enterprise' github.event.workflow_run.head_branch == 'deploy/enterprise'
steps: steps:
- name: trigger deployments - name: Deploy to server
env: uses: appleboy/ssh-action@v0.1.8
DEV_ENV_ADDRS: ${{ vars.DEV_ENV_ADDRS }} with:
DEPLOY_SECRET: ${{ secrets.DEPLOY_SECRET }} host: ${{ secrets.ENTERPRISE_SSH_HOST }}
run: | username: ${{ secrets.ENTERPRISE_SSH_USER }}
IFS=',' read -ra ENDPOINTS <<< "${DEV_ENV_ADDRS:-}" password: ${{ secrets.ENTERPRISE_SSH_PASSWORD }}
BODY='{"project":"dify-api","tag":"deploy-enterprise"}' script: |
${{ vars.ENTERPRISE_SSH_SCRIPT || secrets.ENTERPRISE_SSH_SCRIPT }}
for ENDPOINT in "${ENDPOINTS[@]}"; do
ENDPOINT="$(echo "$ENDPOINT" | xargs)"
[ -z "$ENDPOINT" ] && continue
API_SIGNATURE=$(printf '%s' "$BODY" | openssl dgst -sha256 -hmac "$DEPLOY_SECRET" | awk '{print "sha256="$2}')
curl -sSf -X POST \
-H "Content-Type: application/json" \
-H "X-Hub-Signature-256: $API_SIGNATURE" \
-d "$BODY" \
"$ENDPOINT"
done

View File

@ -44,14 +44,6 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: uv sync --project api --dev run: uv sync --project api --dev
- name: Run Basedpyright Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: dev/basedpyright-check
- name: Run Mypy Type Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
- name: Dotenv check - name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
@ -97,9 +89,7 @@ jobs:
- name: Web style check - name: Web style check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web working-directory: ./web
run: | run: pnpm run lint
pnpm run lint
pnpm run eslint
docker-compose-template: docker-compose-template:
name: Docker Compose Template name: Docker Compose Template

View File

@ -67,22 +67,12 @@ jobs:
working-directory: ./web working-directory: ./web
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }} run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
- name: Generate i18n type definitions
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm run gen:i18n-types
- name: Create Pull Request - name: Create Pull Request
if: env.FILES_CHANGED == 'true' if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6 uses: peter-evans/create-pull-request@v6
with: with:
token: ${{ secrets.GITHUB_TOKEN }} token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Update i18n files and type definitions based on en-US changes commit-message: Update i18n files based on en-US changes
title: 'chore: translate i18n files and update type definitions' title: 'chore: translate i18n files'
body: | body: This PR was automatically created to update i18n files based on changes in en-US locale.
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
**Changes included:**
- Updated translation files for all locales
- Regenerated TypeScript type definitions for type safety
branch: chore/automated-i18n-updates branch: chore/automated-i18n-updates

View File

@ -47,11 +47,6 @@ jobs:
working-directory: ./web working-directory: ./web
run: pnpm install --frozen-lockfile run: pnpm install --frozen-lockfile
- name: Check i18n types synchronization
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run check:i18n-types
- name: Run tests - name: Run tests
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web working-directory: ./web

13
.gitignore vendored
View File

@ -123,12 +123,10 @@ venv.bak/
# mkdocs documentation # mkdocs documentation
/site /site
# type checking # mypy
.mypy_cache/ .mypy_cache/
.dmypy.json .dmypy.json
dmypy.json dmypy.json
pyrightconfig.json
!api/pyrightconfig.json
# Pyre type checker # Pyre type checker
.pyre/ .pyre/
@ -197,8 +195,8 @@ sdks/python-client/dify_client.egg-info
.vscode/* .vscode/*
!.vscode/launch.json.template !.vscode/launch.json.template
!.vscode/README.md !.vscode/README.md
pyrightconfig.json
api/.vscode api/.vscode
web/.vscode
# vscode Code History Extension # vscode Code History Extension
.history .history
@ -216,13 +214,6 @@ mise.toml
# Next.js build output # Next.js build output
.next/ .next/
# PWA generated files
web/public/sw.js
web/public/sw.js.map
web/public/workbox-*.js
web/public/workbox-*.js.map
web/public/fallback-*.js
# AI Assistant # AI Assistant
.roo/ .roo/
api/.env.backup api/.env.backup

View File

@ -32,7 +32,7 @@ uv run --project api pytest tests/integration_tests/ # Integration tests
./dev/reformat # Run all formatters and linters ./dev/reformat # Run all formatters and linters
uv run --project api ruff check --fix ./ # Fix linting issues uv run --project api ruff check --fix ./ # Fix linting issues
uv run --project api ruff format ./ # Format code uv run --project api ruff format ./ # Format code
uv run --directory api basedpyright # Type checking uv run --project api mypy . # Type checking
``` ```
### Frontend (Web) ### Frontend (Web)

View File

@ -4,48 +4,6 @@ WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
API_IMAGE=$(DOCKER_REGISTRY)/dify-api API_IMAGE=$(DOCKER_REGISTRY)/dify-api
VERSION=latest VERSION=latest
# Backend Development Environment Setup
.PHONY: dev-setup prepare-docker prepare-web prepare-api
# Default dev setup target
dev-setup: prepare-docker prepare-web prepare-api
@echo "✅ Backend development environment setup complete!"
# Step 1: Prepare Docker middleware
prepare-docker:
@echo "🐳 Setting up Docker middleware..."
@cp -n docker/middleware.env.example docker/middleware.env 2>/dev/null || echo "Docker middleware.env already exists"
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev up -d
@echo "✅ Docker middleware started"
# Step 2: Prepare web environment
prepare-web:
@echo "🌐 Setting up web environment..."
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
@cd web && pnpm install
@cd web && pnpm build
@echo "✅ Web environment prepared (not started)"
# Step 3: Prepare API environment
prepare-api:
@echo "🔧 Setting up API environment..."
@cp -n api/.env.example api/.env 2>/dev/null || echo "API .env already exists"
@cd api && uv sync --dev
@cd api && uv run flask db upgrade
@echo "✅ API environment prepared (not started)"
# Clean dev environment
dev-clean:
@echo "⚠️ Stopping Docker containers..."
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev down
@echo "🗑️ Removing volumes..."
@rm -rf docker/volumes/db
@rm -rf docker/volumes/redis
@rm -rf docker/volumes/plugin_daemon
@rm -rf docker/volumes/weaviate
@rm -rf api/storage
@echo "✅ Cleanup complete"
# Build Docker images # Build Docker images
build-web: build-web:
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..." @echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
@ -81,21 +39,5 @@ build-push-web: build-web push-web
build-push-all: build-all push-all build-push-all: build-all push-all
@echo "All Docker images have been built and pushed." @echo "All Docker images have been built and pushed."
# Help target
help:
@echo "Development Setup Targets:"
@echo " make dev-setup - Run all setup steps for backend dev environment"
@echo " make prepare-docker - Set up Docker middleware"
@echo " make prepare-web - Set up web environment"
@echo " make prepare-api - Set up API environment"
@echo " make dev-clean - Stop Docker middleware containers"
@echo ""
@echo "Docker Build Targets:"
@echo " make build-web - Build web Docker image"
@echo " make build-api - Build API Docker image"
@echo " make build-all - Build all Docker images"
@echo " make push-all - Push all Docker images"
@echo " make build-push-all - Build and push all Docker images"
# Phony targets # Phony targets
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help .PHONY: build-web build-api push-web push-api build-all push-all build-push-all

View File

@ -75,7 +75,6 @@ DB_PASSWORD=difyai123456
DB_HOST=localhost DB_HOST=localhost
DB_PORT=5432 DB_PORT=5432
DB_DATABASE=dify DB_DATABASE=dify
SQLALCHEMY_POOL_PRE_PING=true
# Storage configuration # Storage configuration
# use for store upload files, private keys... # use for store upload files, private keys...
@ -157,7 +156,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
# Vector database configuration # Vector database configuration
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`. # Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `pinecone`.
VECTOR_STORE=weaviate VECTOR_STORE=weaviate
# Prefix used to create collection name in vector database # Prefix used to create collection name in vector database
VECTOR_INDEX_NAME_PREFIX=Vector_index VECTOR_INDEX_NAME_PREFIX=Vector_index
@ -362,6 +361,17 @@ PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024 CODE_GENERATION_MAX_TOKENS=1024
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
# Pinecone configuration, only available when VECTOR_STORE is `pinecone`
PINECONE_API_KEY=your-pinecone-api-key
PINECONE_ENVIRONMENT=your-pinecone-environment
PINECONE_INDEX_NAME=dify-index
PINECONE_CLIENT_TIMEOUT=30
PINECONE_BATCH_SIZE=100
PINECONE_METRIC=cosine
PINECONE_PODS=1
PINECONE_POD_TYPE=s1
# Mail configuration, support: resend, smtp, sendgrid # Mail configuration, support: resend, smtp, sendgrid
MAIL_TYPE= MAIL_TYPE=
# If using SendGrid, use the 'from' field for authentication if necessary. # If using SendGrid, use the 'from' field for authentication if necessary.
@ -569,7 +579,3 @@ QUEUE_MONITOR_INTERVAL=30
# Swagger UI configuration # Swagger UI configuration
SWAGGER_UI_ENABLED=true SWAGGER_UI_ENABLED=true
SWAGGER_UI_PATH=/swagger-ui.html SWAGGER_UI_PATH=/swagger-ui.html
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
# Set to false to export dataset IDs as plain text for easier cross-environment import
DSL_EXPORT_ENCRYPT_DATASET_ID=true

View File

@ -45,7 +45,6 @@ select = [
"G001", # don't use str format to logging messages "G001", # don't use str format to logging messages
"G003", # don't use + in logging messages "G003", # don't use + in logging messages
"G004", # don't use f-strings to format logging messages "G004", # don't use f-strings to format logging messages
"UP042", # use StrEnum
] ]
ignore = [ ignore = [

View File

@ -108,5 +108,5 @@ uv run celery -A app.celery beat
../dev/reformat # Run all formatters and linters ../dev/reformat # Run all formatters and linters
uv run ruff check --fix ./ # Fix linting issues uv run ruff check --fix ./ # Fix linting issues
uv run ruff format ./ # Format code uv run ruff format ./ # Format code
uv run basedpyright . # Type checking uv run mypy . # Type checking
``` ```

View File

@ -25,9 +25,6 @@ def create_flask_app_with_configs() -> DifyApp:
# add an unique identifier to each request # add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles() RecyclableContextVar.increment_thread_recycles()
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
return dify_app return dify_app

11
api/child_class.py Normal file
View File

@ -0,0 +1,11 @@
from tests.integration_tests.utils.parent_class import ParentClass
class ChildClass(ParentClass):
"""Test child class for module import helper tests"""
def __init__(self, name):
super().__init__(name)
def get_name(self):
return f"Child: {self.name}"

View File

@ -212,9 +212,7 @@ def migrate_annotation_vector_database():
if not dataset_collection_binding: if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}") click.echo(f"App annotation collection binding not found: {app.id}")
continue continue
annotations = db.session.scalars( annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all()
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
).all()
dataset = Dataset( dataset = Dataset(
id=app.id, id=app.id,
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
@ -369,25 +367,29 @@ def migrate_knowledge_vector_database():
) )
raise e raise e
dataset_documents = db.session.scalars( dataset_documents = (
select(DatasetDocument).where( db.session.query(DatasetDocument)
.where(
DatasetDocument.dataset_id == dataset.id, DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed", DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
DatasetDocument.archived == False, DatasetDocument.archived == False,
) )
).all() .all()
)
documents = [] documents = []
segments_count = 0 segments_count = 0
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
segments = db.session.scalars( segments = (
select(DocumentSegment).where( db.session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id, DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed", DocumentSegment.status == "completed",
DocumentSegment.enabled == True, DocumentSegment.enabled == True,
) )
).all() .all()
)
for segment in segments: for segment in segments:
document = Document( document = Document(
@ -509,7 +511,7 @@ def add_qdrant_index(field: str):
from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType from qdrant_client.http.models import PayloadSchemaType
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
for binding in bindings: for binding in bindings:
if dify_config.QDRANT_URL is None: if dify_config.QDRANT_URL is None:
@ -523,21 +525,7 @@ def add_qdrant_index(field: str):
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
) )
try: try:
params = qdrant_config.to_qdrant_params() client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
# Check the type before using
if isinstance(params, PathQdrantParams):
# PathQdrantParams case
client = qdrant_client.QdrantClient(path=params.path)
else:
# UrlQdrantParams case - params is UrlQdrantParams
client = qdrant_client.QdrantClient(
url=params.url,
api_key=params.api_key,
timeout=int(params.timeout),
verify=params.verify,
grpc_port=params.grpc_port,
prefer_grpc=params.prefer_grpc,
)
# create payload index # create payload index
client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD) client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
create_count += 1 create_count += 1
@ -583,7 +571,7 @@ def old_metadata_migration():
for document in documents: for document in documents:
if document.doc_metadata: if document.doc_metadata:
doc_metadata = document.doc_metadata doc_metadata = document.doc_metadata
for key in doc_metadata: for key, value in doc_metadata.items():
for field in BuiltInField: for field in BuiltInField:
if field.value == key: if field.value == key:
break break

View File

@ -796,11 +796,6 @@ class DataSetConfig(BaseSettings):
default=30, default=30,
) )
DSL_EXPORT_ENCRYPT_DATASET_ID: bool = Field(
description="Enable or disable dataset ID encryption when exporting DSL files",
default=True,
)
class WorkspaceConfig(BaseSettings): class WorkspaceConfig(BaseSettings):
""" """

View File

@ -35,6 +35,7 @@ from .vdb.opensearch_config import OpenSearchConfig
from .vdb.oracle_config import OracleConfig from .vdb.oracle_config import OracleConfig
from .vdb.pgvector_config import PGVectorConfig from .vdb.pgvector_config import PGVectorConfig
from .vdb.pgvectors_config import PGVectoRSConfig from .vdb.pgvectors_config import PGVectoRSConfig
from .vdb.pinecone_config import PineconeConfig
from .vdb.qdrant_config import QdrantConfig from .vdb.qdrant_config import QdrantConfig
from .vdb.relyt_config import RelytConfig from .vdb.relyt_config import RelytConfig
from .vdb.tablestore_config import TableStoreConfig from .vdb.tablestore_config import TableStoreConfig
@ -300,7 +301,8 @@ class DatasetQueueMonitorConfig(BaseSettings):
class MiddlewareConfig( class MiddlewareConfig(
# place the configs in alphabet order # place the configs in alphabet order
CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig CeleryConfig,
DatabaseConfig,
KeywordStoreConfig, KeywordStoreConfig,
RedisConfig, RedisConfig,
# configs of storage and storage providers # configs of storage and storage providers
@ -330,6 +332,7 @@ class MiddlewareConfig(
PGVectorConfig, PGVectorConfig,
VastbaseVectorConfig, VastbaseVectorConfig,
PGVectoRSConfig, PGVectoRSConfig,
PineconeConfig,
QdrantConfig, QdrantConfig,
RelytConfig, RelytConfig,
TencentVectorDBConfig, TencentVectorDBConfig,

View File

@ -1,10 +1,9 @@
from typing import Optional from typing import Optional
from pydantic import Field from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
class ClickzettaConfig(BaseSettings): class ClickzettaConfig(BaseModel):
""" """
Clickzetta Lakehouse vector database configuration Clickzetta Lakehouse vector database configuration
""" """

View File

@ -1,8 +1,7 @@
from pydantic import Field from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
class MatrixoneConfig(BaseSettings): class MatrixoneConfig(BaseModel):
"""Matrixone vector database configuration.""" """Matrixone vector database configuration."""
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server") MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")

View File

@ -0,0 +1,41 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class PineconeConfig(BaseSettings):
"""
Configuration settings for Pinecone vector database
"""
PINECONE_API_KEY: Optional[str] = Field(
description="API key for authenticating with Pinecone service",
default=None,
)
PINECONE_ENVIRONMENT: Optional[str] = Field(
description="Pinecone environment (e.g., 'us-west1-gcp', 'us-east-1-aws')",
default=None,
)
PINECONE_INDEX_NAME: Optional[str] = Field(
description="Default Pinecone index name",
default=None,
)
PINECONE_CLIENT_TIMEOUT: PositiveInt = Field(
description="Timeout in seconds for Pinecone client operations (default is 30 seconds)",
default=30,
)
PINECONE_BATCH_SIZE: PositiveInt = Field(
description="Batch size for Pinecone operations (default is 100)",
default=100,
)
PINECONE_METRIC: str = Field(
description="Distance metric for Pinecone index (cosine, euclidean, dotproduct)",
default="cosine",
)

View File

@ -1,6 +1,6 @@
from pydantic import Field from pydantic import Field
from configs.packaging.pyproject import PyProjectTomlConfig from configs.packaging.pyproject import PyProjectConfig, PyProjectTomlConfig
class PackagingInfo(PyProjectTomlConfig): class PackagingInfo(PyProjectTomlConfig):

View File

@ -4,9 +4,8 @@ import logging
import os import os
import threading import threading
import time import time
from collections.abc import Callable, Mapping from collections.abc import Mapping
from pathlib import Path from pathlib import Path
from typing import Any
from .python_3x import http_request, makedirs_wrapper from .python_3x import http_request, makedirs_wrapper
from .utils import ( from .utils import (
@ -26,13 +25,13 @@ logger = logging.getLogger(__name__)
class ApolloClient: class ApolloClient:
def __init__( def __init__(
self, self,
config_url: str, config_url,
app_id: str, app_id,
cluster: str = "default", cluster="default",
secret: str = "", secret="",
start_hot_update: bool = True, start_hot_update=True,
change_listener: Callable[[str, str, str, Any], None] | None = None, change_listener=None,
_notification_map: dict[str, int] | None = None, _notification_map=None,
): ):
# Core routing parameters # Core routing parameters
self.config_url = config_url self.config_url = config_url
@ -48,17 +47,17 @@ class ApolloClient:
# Private control variables # Private control variables
self._cycle_time = 5 self._cycle_time = 5
self._stopping = False self._stopping = False
self._cache: dict[str, dict[str, Any]] = {} self._cache = {}
self._no_key: dict[str, str] = {} self._no_key = {}
self._hash: dict[str, str] = {} self._hash = {}
self._pull_timeout = 75 self._pull_timeout = 75
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/" self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
self._long_poll_thread: threading.Thread | None = None self._long_poll_thread = None
self._change_listener = change_listener # "add" "delete" "update" self._change_listener = change_listener # "add" "delete" "update"
if _notification_map is None: if _notification_map is None:
_notification_map = {"application": -1} _notification_map = {"application": -1}
self._notification_map = _notification_map self._notification_map = _notification_map
self.last_release_key: str | None = None self.last_release_key = None
# Private startup method # Private startup method
self._path_checker() self._path_checker()
if start_hot_update: if start_hot_update:
@ -69,7 +68,7 @@ class ApolloClient:
heartbeat.daemon = True heartbeat.daemon = True
heartbeat.start() heartbeat.start()
def get_json_from_net(self, namespace: str = "application") -> dict[str, Any] | None: def get_json_from_net(self, namespace="application"):
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format( url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
self.config_url, self.app_id, self.cluster, namespace, "", self.ip self.config_url, self.app_id, self.cluster, namespace, "", self.ip
) )
@ -89,7 +88,7 @@ class ApolloClient:
logger.exception("an error occurred in get_json_from_net") logger.exception("an error occurred in get_json_from_net")
return None return None
def get_value(self, key: str, default_val: Any = None, namespace: str = "application") -> Any: def get_value(self, key, default_val=None, namespace="application"):
try: try:
# read memory configuration # read memory configuration
namespace_cache = self._cache.get(namespace) namespace_cache = self._cache.get(namespace)
@ -105,8 +104,7 @@ class ApolloClient:
namespace_data = self.get_json_from_net(namespace) namespace_data = self.get_json_from_net(namespace)
val = get_value_from_dict(namespace_data, key) val = get_value_from_dict(namespace_data, key)
if val is not None: if val is not None:
if namespace_data is not None: self._update_cache_and_file(namespace_data, namespace)
self._update_cache_and_file(namespace_data, namespace)
return val return val
# read the file configuration # read the file configuration
@ -128,23 +126,23 @@ class ApolloClient:
# to ensure the real-time correctness of the function call. # to ensure the real-time correctness of the function call.
# If the user does not have the same default val twice # If the user does not have the same default val twice
# and the default val is used here, there may be a problem. # and the default val is used here, there may be a problem.
def _set_local_cache_none(self, namespace: str, key: str) -> None: def _set_local_cache_none(self, namespace, key):
no_key = no_key_cache_key(namespace, key) no_key = no_key_cache_key(namespace, key)
self._no_key[no_key] = key self._no_key[no_key] = key
def _start_hot_update(self) -> None: def _start_hot_update(self):
self._long_poll_thread = threading.Thread(target=self._listener) self._long_poll_thread = threading.Thread(target=self._listener)
# When the asynchronous thread is started, the daemon thread will automatically exit # When the asynchronous thread is started, the daemon thread will automatically exit
# when the main thread is launched. # when the main thread is launched.
self._long_poll_thread.daemon = True self._long_poll_thread.daemon = True
self._long_poll_thread.start() self._long_poll_thread.start()
def stop(self) -> None: def stop(self):
self._stopping = True self._stopping = True
logger.info("Stopping listener...") logger.info("Stopping listener...")
# Call the set callback function, and if it is abnormal, try it out # Call the set callback function, and if it is abnormal, try it out
def _call_listener(self, namespace: str, old_kv: dict[str, Any] | None, new_kv: dict[str, Any] | None) -> None: def _call_listener(self, namespace, old_kv, new_kv):
if self._change_listener is None: if self._change_listener is None:
return return
if old_kv is None: if old_kv is None:
@ -170,12 +168,12 @@ class ApolloClient:
except BaseException as e: except BaseException as e:
logger.warning(str(e)) logger.warning(str(e))
def _path_checker(self) -> None: def _path_checker(self):
if not os.path.isdir(self._cache_file_path): if not os.path.isdir(self._cache_file_path):
makedirs_wrapper(self._cache_file_path) makedirs_wrapper(self._cache_file_path)
# update the local cache and file cache # update the local cache and file cache
def _update_cache_and_file(self, namespace_data: dict[str, Any], namespace: str = "application") -> None: def _update_cache_and_file(self, namespace_data, namespace="application"):
# update the local cache # update the local cache
self._cache[namespace] = namespace_data self._cache[namespace] = namespace_data
# update the file cache # update the file cache
@ -189,7 +187,7 @@ class ApolloClient:
self._hash[namespace] = new_hash self._hash[namespace] = new_hash
# get the configuration from the local file # get the configuration from the local file
def _get_local_cache(self, namespace: str = "application") -> dict[str, Any]: def _get_local_cache(self, namespace="application"):
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt") cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
if os.path.isfile(cache_file_path): if os.path.isfile(cache_file_path):
with open(cache_file_path) as f: with open(cache_file_path) as f:
@ -197,8 +195,8 @@ class ApolloClient:
return result return result
return {} return {}
def _long_poll(self) -> None: def _long_poll(self):
notifications: list[dict[str, Any]] = [] notifications = []
for key in self._cache: for key in self._cache:
namespace_data = self._cache[key] namespace_data = self._cache[key]
notification_id = -1 notification_id = -1
@ -238,7 +236,7 @@ class ApolloClient:
except Exception as e: except Exception as e:
logger.warning(str(e)) logger.warning(str(e))
def _get_net_and_set_local(self, namespace: str, n_id: int, call_change: bool = False) -> None: def _get_net_and_set_local(self, namespace, n_id, call_change=False):
namespace_data = self.get_json_from_net(namespace) namespace_data = self.get_json_from_net(namespace)
if not namespace_data: if not namespace_data:
return return
@ -250,7 +248,7 @@ class ApolloClient:
new_kv = namespace_data.get(CONFIGURATIONS) new_kv = namespace_data.get(CONFIGURATIONS)
self._call_listener(namespace, old_kv, new_kv) self._call_listener(namespace, old_kv, new_kv)
def _listener(self) -> None: def _listener(self):
logger.info("start long_poll") logger.info("start long_poll")
while not self._stopping: while not self._stopping:
self._long_poll() self._long_poll()
@ -268,13 +266,13 @@ class ApolloClient:
headers["Timestamp"] = time_unix_now headers["Timestamp"] = time_unix_now
return headers return headers
def _heart_beat(self) -> None: def _heart_beat(self):
while not self._stopping: while not self._stopping:
for namespace in self._notification_map: for namespace in self._notification_map:
self._do_heart_beat(namespace) self._do_heart_beat(namespace)
time.sleep(60 * 10) # 10 minutes time.sleep(60 * 10) # 10 minutes
def _do_heart_beat(self, namespace: str) -> None: def _do_heart_beat(self, namespace):
url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}" url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}"
try: try:
code, body = http_request(url, timeout=3, headers=self._sign_headers(url)) code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
@ -294,7 +292,7 @@ class ApolloClient:
logger.exception("an error occurred in _do_heart_beat") logger.exception("an error occurred in _do_heart_beat")
return None return None
def get_all_dicts(self, namespace: str) -> dict[str, Any] | None: def get_all_dicts(self, namespace):
namespace_data = self._cache.get(namespace) namespace_data = self._cache.get(namespace)
if namespace_data is None: if namespace_data is None:
net_namespace_data = self.get_json_from_net(namespace) net_namespace_data = self.get_json_from_net(namespace)

View File

@ -2,8 +2,6 @@ import logging
import os import os
import ssl import ssl
import urllib.request import urllib.request
from collections.abc import Mapping
from typing import Any
from urllib import parse from urllib import parse
from urllib.error import HTTPError from urllib.error import HTTPError
@ -21,9 +19,9 @@ urllib.request.install_opener(opener)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def http_request(url: str, timeout: int | float, headers: Mapping[str, str] = {}) -> tuple[int, str | None]: def http_request(url, timeout, headers={}):
try: try:
request = urllib.request.Request(url, headers=dict(headers)) request = urllib.request.Request(url, headers=headers)
res = urllib.request.urlopen(request, timeout=timeout) res = urllib.request.urlopen(request, timeout=timeout)
body = res.read().decode("utf-8") body = res.read().decode("utf-8")
return res.code, body return res.code, body
@ -35,9 +33,9 @@ def http_request(url: str, timeout: int | float, headers: Mapping[str, str] = {}
raise e raise e
def url_encode(params: dict[str, Any]) -> str: def url_encode(params):
return parse.urlencode(params) return parse.urlencode(params)
def makedirs_wrapper(path: str) -> None: def makedirs_wrapper(path):
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)

View File

@ -1,6 +1,5 @@
import hashlib import hashlib
import socket import socket
from typing import Any
from .python_3x import url_encode from .python_3x import url_encode
@ -11,7 +10,7 @@ NAMESPACE_NAME = "namespaceName"
# add timestamps uris and keys # add timestamps uris and keys
def signature(timestamp: str, uri: str, secret: str) -> str: def signature(timestamp, uri, secret):
import base64 import base64
import hmac import hmac
@ -20,16 +19,16 @@ def signature(timestamp: str, uri: str, secret: str) -> str:
return base64.b64encode(hmac_code).decode() return base64.b64encode(hmac_code).decode()
def url_encode_wrapper(params: dict[str, Any]) -> str: def url_encode_wrapper(params):
return url_encode(params) return url_encode(params)
def no_key_cache_key(namespace: str, key: str) -> str: def no_key_cache_key(namespace, key):
return f"{namespace}{len(namespace)}{key}" return f"{namespace}{len(namespace)}{key}"
# Returns whether the obtained value is obtained, and None if it does not # Returns whether the obtained value is obtained, and None if it does not
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None: def get_value_from_dict(namespace_cache, key):
if namespace_cache: if namespace_cache:
kv_data = namespace_cache.get(CONFIGURATIONS) kv_data = namespace_cache.get(CONFIGURATIONS)
if kv_data is None: if kv_data is None:
@ -39,7 +38,7 @@ def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any
return None return None
def init_ip() -> str: def init_ip():
ip = "" ip = ""
s = None s = None
try: try:

View File

@ -11,5 +11,5 @@ class RemoteSettingsSource:
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError raise NotImplementedError
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool): def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
return value return value

View File

@ -11,16 +11,16 @@ logger = logging.getLogger(__name__)
from configs.remote_settings_sources.base import RemoteSettingsSource from configs.remote_settings_sources.base import RemoteSettingsSource
from .utils import parse_config from .utils import _parse_config
class NacosSettingsSource(RemoteSettingsSource): class NacosSettingsSource(RemoteSettingsSource):
def __init__(self, configs: Mapping[str, Any]): def __init__(self, configs: Mapping[str, Any]):
self.configs = configs self.configs = configs
self.remote_configs: dict[str, str] = {} self.remote_configs: dict[str, Any] = {}
self.async_init() self.async_init()
def async_init(self) -> None: def async_init(self):
data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties") data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties")
group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify") group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify")
tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "") tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "")
@ -29,19 +29,22 @@ class NacosSettingsSource(RemoteSettingsSource):
try: try:
content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params) content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params)
self.remote_configs = self._parse_config(content) self.remote_configs = self._parse_config(content)
except Exception: except Exception as e:
logger.exception("[get-access-token] exception occurred") logger.exception("[get-access-token] exception occurred")
raise raise
def _parse_config(self, content: str) -> dict[str, str]: def _parse_config(self, content: str) -> dict:
if not content: if not content:
return {} return {}
try: try:
return parse_config(content) return _parse_config(self, content)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to parse config: {e}") raise RuntimeError(f"Failed to parse config: {e}")
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
if not isinstance(self.remote_configs, dict):
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
field_value = self.remote_configs.get(field_name) field_value = self.remote_configs.get(field_name)
if field_value is None: if field_value is None:
return None, field_name, False return None, field_name, False

View File

@ -17,26 +17,20 @@ class NacosHttpClient:
self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY") self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY")
self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY") self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY")
self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848") self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848")
self.token: str | None = None self.token = None
self.token_ttl = 18000 self.token_ttl = 18000
self.token_expire_time: float = 0 self.token_expire_time: float = 0
def http_request( def http_request(self, url, method="GET", headers=None, params=None):
self, url: str, method: str = "GET", headers: dict[str, str] | None = None, params: dict[str, str] | None = None
) -> str:
if headers is None:
headers = {}
if params is None:
params = {}
try: try:
self._inject_auth_info(headers, params) self._inject_auth_info(headers, params)
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params) response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
response.raise_for_status() response.raise_for_status()
return response.text return response.text
except requests.RequestException as e: except requests.exceptions.RequestException as e:
return f"Request to Nacos failed: {e}" return f"Request to Nacos failed: {e}"
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None: def _inject_auth_info(self, headers, params, module="config"):
headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"}) headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"})
if module == "login": if module == "login":
@ -51,17 +45,16 @@ class NacosHttpClient:
headers["timeStamp"] = ts headers["timeStamp"] = ts
if self.username and self.password: if self.username and self.password:
self.get_access_token(force_refresh=False) self.get_access_token(force_refresh=False)
if self.token is not None: params["accessToken"] = self.token
params["accessToken"] = self.token
def __do_sign(self, sign_str: str, sk: str) -> str: def __do_sign(self, sign_str, sk):
return ( return (
base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest()) base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest())
.decode() .decode()
.strip() .strip()
) )
def get_sign_str(self, group: str, tenant: str, ts: str) -> str: def get_sign_str(self, group, tenant, ts):
sign_str = "" sign_str = ""
if tenant: if tenant:
sign_str = tenant + "+" sign_str = tenant + "+"
@ -70,7 +63,7 @@ class NacosHttpClient:
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it. sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
return sign_str return sign_str
def get_access_token(self, force_refresh: bool = False) -> str | None: def get_access_token(self, force_refresh=False):
current_time = time.time() current_time = time.time()
if self.token and not force_refresh and self.token_expire_time > current_time: if self.token and not force_refresh and self.token_expire_time > current_time:
return self.token return self.token
@ -84,7 +77,6 @@ class NacosHttpClient:
self.token = response_data.get("accessToken") self.token = response_data.get("accessToken")
self.token_ttl = response_data.get("tokenTtl", 18000) self.token_ttl = response_data.get("tokenTtl", 18000)
self.token_expire_time = current_time + self.token_ttl - 10 self.token_expire_time = current_time + self.token_ttl - 10
return self.token except Exception as e:
except Exception:
logger.exception("[get-access-token] exception occur") logger.exception("[get-access-token] exception occur")
raise raise

View File

@ -1,4 +1,4 @@
def parse_config(content: str) -> dict[str, str]: def _parse_config(self, content: str) -> dict[str, str]:
config: dict[str, str] = {} config: dict[str, str] = {}
if not content: if not content:
return config return config

View File

@ -16,14 +16,14 @@ AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS]) AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
_doc_extensions: list[str]
if dify_config.ETL_TYPE == "Unstructured": if dify_config.ETL_TYPE == "Unstructured":
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
if dify_config.UNSTRUCTURED_API_URL: if dify_config.UNSTRUCTURED_API_URL:
_doc_extensions.append("ppt") DOCUMENT_EXTENSIONS.append("ppt")
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
else: else:
_doc_extensions = [ DOCUMENT_EXTENSIONS = [
"txt", "txt",
"markdown", "markdown",
"md", "md",
@ -38,4 +38,4 @@ else:
"vtt", "vtt",
"properties", "properties",
] ]
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions] DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])

View File

@ -19,7 +19,6 @@ language_timezone_mapping = {
"fa-IR": "Asia/Tehran", "fa-IR": "Asia/Tehran",
"sl-SI": "Europe/Ljubljana", "sl-SI": "Europe/Ljubljana",
"th-TH": "Asia/Bangkok", "th-TH": "Asia/Bangkok",
"id-ID": "Asia/Jakarta",
} }
languages = list(language_timezone_mapping.keys()) languages = list(language_timezone_mapping.keys())

View File

@ -8,6 +8,7 @@ if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
from core.workflow.entities.variable_pool import VariablePool
""" """

View File

@ -1,5 +1,4 @@
from flask import Blueprint from flask import Blueprint
from flask_restx import Namespace
from libs.external_api import ExternalApi from libs.external_api import ExternalApi
@ -27,16 +26,7 @@ from .files import FileApi, FilePreviewApi, FileSupportTypeApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
bp = Blueprint("console", __name__, url_prefix="/console/api") bp = Blueprint("console", __name__, url_prefix="/console/api")
api = ExternalApi(bp)
api = ExternalApi(
bp,
version="1.0",
title="Console API",
description="Console management APIs for app configuration, monitoring, and administration",
)
# Create namespace
console_ns = Namespace("console", description="Console management API operations", path="/")
# File # File
api.add_resource(FileApi, "/files/upload") api.add_resource(FileApi, "/files/upload")
@ -53,90 +43,56 @@ api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm"
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies") api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
# Import other controllers # Import other controllers
from . import ( from . import admin, apikey, extension, feature, ping, setup, version
admin, # pyright: ignore[reportUnusedImport]
apikey, # pyright: ignore[reportUnusedImport]
extension, # pyright: ignore[reportUnusedImport]
feature, # pyright: ignore[reportUnusedImport]
init_validate, # pyright: ignore[reportUnusedImport]
ping, # pyright: ignore[reportUnusedImport]
setup, # pyright: ignore[reportUnusedImport]
version, # pyright: ignore[reportUnusedImport]
)
# Import app controllers # Import app controllers
from .app import ( from .app import (
advanced_prompt_template, # pyright: ignore[reportUnusedImport] advanced_prompt_template,
agent, # pyright: ignore[reportUnusedImport] agent,
annotation, # pyright: ignore[reportUnusedImport] annotation,
app, # pyright: ignore[reportUnusedImport] app,
audio, # pyright: ignore[reportUnusedImport] audio,
completion, # pyright: ignore[reportUnusedImport] completion,
conversation, # pyright: ignore[reportUnusedImport] conversation,
conversation_variables, # pyright: ignore[reportUnusedImport] conversation_variables,
generator, # pyright: ignore[reportUnusedImport] generator,
mcp_server, # pyright: ignore[reportUnusedImport] mcp_server,
message, # pyright: ignore[reportUnusedImport] message,
model_config, # pyright: ignore[reportUnusedImport] model_config,
ops_trace, # pyright: ignore[reportUnusedImport] ops_trace,
site, # pyright: ignore[reportUnusedImport] site,
statistic, # pyright: ignore[reportUnusedImport] statistic,
workflow, # pyright: ignore[reportUnusedImport] workflow,
workflow_app_log, # pyright: ignore[reportUnusedImport] workflow_app_log,
workflow_draft_variable, # pyright: ignore[reportUnusedImport] workflow_draft_variable,
workflow_run, # pyright: ignore[reportUnusedImport] workflow_run,
workflow_statistic, # pyright: ignore[reportUnusedImport] workflow_statistic,
) )
# Import auth controllers # Import auth controllers
from .auth import ( from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server
activate, # pyright: ignore[reportUnusedImport]
data_source_bearer_auth, # pyright: ignore[reportUnusedImport]
data_source_oauth, # pyright: ignore[reportUnusedImport]
forgot_password, # pyright: ignore[reportUnusedImport]
login, # pyright: ignore[reportUnusedImport]
oauth, # pyright: ignore[reportUnusedImport]
oauth_server, # pyright: ignore[reportUnusedImport]
)
# Import billing controllers # Import billing controllers
from .billing import billing, compliance # pyright: ignore[reportUnusedImport] from .billing import billing, compliance
# Import datasets controllers # Import datasets controllers
from .datasets import ( from .datasets import (
data_source, # pyright: ignore[reportUnusedImport] data_source,
datasets, # pyright: ignore[reportUnusedImport] datasets,
datasets_document, # pyright: ignore[reportUnusedImport] datasets_document,
datasets_segments, # pyright: ignore[reportUnusedImport] datasets_segments,
external, # pyright: ignore[reportUnusedImport] external,
hit_testing, # pyright: ignore[reportUnusedImport] hit_testing,
metadata, # pyright: ignore[reportUnusedImport] metadata,
website, # pyright: ignore[reportUnusedImport] website,
) )
# Import explore controllers # Import explore controllers
from .explore import ( from .explore import (
installed_app, # pyright: ignore[reportUnusedImport] installed_app,
parameter, # pyright: ignore[reportUnusedImport] parameter,
recommended_app, # pyright: ignore[reportUnusedImport] recommended_app,
saved_message, # pyright: ignore[reportUnusedImport] saved_message,
)
# Import tag controllers
from .tag import tags # pyright: ignore[reportUnusedImport]
# Import workspace controllers
from .workspace import (
account, # pyright: ignore[reportUnusedImport]
agent_providers, # pyright: ignore[reportUnusedImport]
endpoint, # pyright: ignore[reportUnusedImport]
load_balancing_config, # pyright: ignore[reportUnusedImport]
members, # pyright: ignore[reportUnusedImport]
model_providers, # pyright: ignore[reportUnusedImport]
models, # pyright: ignore[reportUnusedImport]
plugin, # pyright: ignore[reportUnusedImport]
tool_providers, # pyright: ignore[reportUnusedImport]
workspace, # pyright: ignore[reportUnusedImport]
) )
# Explore Audio # Explore Audio
@ -210,4 +166,19 @@ api.add_resource(
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop" InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
) )
api.add_namespace(console_ns) # Import tag controllers
from .tag import tags
# Import workspace controllers
from .workspace import (
account,
agent_providers,
endpoint,
load_balancing_config,
members,
model_providers,
models,
plugin,
tool_providers,
workspace,
)

View File

@ -1,26 +1,22 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, reqparse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api, console_ns from controllers.console import api
from controllers.console.wraps import only_edition_cloud from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, InstalledApp, RecommendedApp from models.model import App, InstalledApp, RecommendedApp
def admin_required(view: Callable[P, R]): def admin_required(view):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
if not dify_config.ADMIN_API_KEY: if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.") raise Unauthorized("API key is invalid.")
@ -45,28 +41,7 @@ def admin_required(view: Callable[P, R]):
return decorated return decorated
@console_ns.route("/admin/insert-explore-apps")
class InsertExploreAppListApi(Resource): class InsertExploreAppListApi(Resource):
@api.doc("insert_explore_app")
@api.doc(description="Insert or update an app in the explore list")
@api.expect(
api.model(
"InsertExploreAppRequest",
{
"app_id": fields.String(required=True, description="Application ID"),
"desc": fields.String(description="App description"),
"copyright": fields.String(description="Copyright information"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"language": fields.String(required=True, description="Language code"),
"category": fields.String(required=True, description="App category"),
"position": fields.Integer(required=True, description="Display position"),
},
)
)
@api.response(200, "App updated successfully")
@api.response(201, "App inserted successfully")
@api.response(404, "App not found")
@only_edition_cloud @only_edition_cloud
@admin_required @admin_required
def post(self): def post(self):
@ -136,12 +111,7 @@ class InsertExploreAppListApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/admin/insert-explore-apps/<uuid:app_id>")
class InsertExploreAppApi(Resource): class InsertExploreAppApi(Resource):
@api.doc("delete_explore_app")
@api.doc(description="Remove an app from the explore list")
@api.doc(params={"app_id": "Application ID to remove"})
@api.response(204, "App removed successfully")
@only_edition_cloud @only_edition_cloud
@admin_required @admin_required
def delete(self, app_id): def delete(self, app_id):
@ -160,21 +130,21 @@ class InsertExploreAppApi(Resource):
app.is_public = False app.is_public = False
with Session(db.engine) as session: with Session(db.engine) as session:
installed_apps = ( installed_apps = session.execute(
session.execute( select(InstalledApp).where(
select(InstalledApp).where( InstalledApp.app_id == recommended_app.app_id,
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
) )
.scalars() ).all()
.all()
)
for installed_app in installed_apps: for installed_app in installed_apps:
session.delete(installed_app) db.session.delete(installed_app)
db.session.delete(recommended_app) db.session.delete(recommended_app)
db.session.commit() db.session.commit()
return {"result": "success"}, 204 return {"result": "success"}, 204
api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps")
api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/<uuid:app_id>")

View File

@ -1,9 +1,8 @@
from typing import Optional from typing import Any, Optional
import flask_restx import flask_restx
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -14,7 +13,7 @@ from libs.login import login_required
from models.dataset import Dataset from models.dataset import Dataset
from models.model import ApiToken, App from models.model import ApiToken, App
from . import api, console_ns from . import api
from .wraps import account_initialization_required, setup_required from .wraps import account_initialization_required, setup_required
api_key_fields = { api_key_fields = {
@ -41,7 +40,7 @@ def _get_resource(resource_id, tenant_id, resource_model):
).scalar_one_or_none() ).scalar_one_or_none()
if resource is None: if resource is None:
flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.") flask_restx.abort(404, message=f"{resource_model.__name__} not found.")
return resource return resource
@ -50,7 +49,7 @@ class BaseApiKeyListResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required] method_decorators = [account_initialization_required, login_required, setup_required]
resource_type: str | None = None resource_type: str | None = None
resource_model: Optional[type] = None resource_model: Optional[Any] = None
resource_id_field: str | None = None resource_id_field: str | None = None
token_prefix: str | None = None token_prefix: str | None = None
max_keys = 10 max_keys = 10
@ -60,11 +59,11 @@ class BaseApiKeyListResource(Resource):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = db.session.scalars( keys = (
select(ApiToken).where( db.session.query(ApiToken)
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
) .all()
).all() )
return {"items": keys} return {"items": keys}
@marshal_with(api_key_fields) @marshal_with(api_key_fields)
@ -83,12 +82,12 @@ class BaseApiKeyListResource(Resource):
if current_key_count >= self.max_keys: if current_key_count >= self.max_keys:
flask_restx.abort( flask_restx.abort(
HTTPStatus.BAD_REQUEST, 400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.", message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
custom="max_keys_exceeded", code="max_keys_exceeded",
) )
key = ApiToken.generate_api_key(self.token_prefix or "", 24) key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken() api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id) setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id api_token.tenant_id = current_user.current_tenant_id
@ -103,7 +102,7 @@ class BaseApiKeyResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required] method_decorators = [account_initialization_required, login_required, setup_required]
resource_type: str | None = None resource_type: str | None = None
resource_model: Optional[type] = None resource_model: Optional[Any] = None
resource_id_field: str | None = None resource_id_field: str | None = None
def delete(self, resource_id, api_key_id): def delete(self, resource_id, api_key_id):
@ -127,7 +126,7 @@ class BaseApiKeyResource(Resource):
) )
if key is None: if key is None:
flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found") flask_restx.abort(404, message="API key not found")
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit() db.session.commit()
@ -135,25 +134,7 @@ class BaseApiKeyResource(Resource):
return {"result": "success"}, 204 return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
class AppApiKeyListResource(BaseApiKeyListResource): class AppApiKeyListResource(BaseApiKeyListResource):
@api.doc("get_app_api_keys")
@api.doc(description="Get all API keys for an app")
@api.doc(params={"resource_id": "App ID"})
@api.response(200, "Success", api_key_list)
def get(self, resource_id):
"""Get all API keys for an app"""
return super().get(resource_id)
@api.doc("create_app_api_key")
@api.doc(description="Create a new API key for an app")
@api.doc(params={"resource_id": "App ID"})
@api.response(201, "API key created successfully", api_key_fields)
@api.response(400, "Maximum keys exceeded")
def post(self, resource_id):
"""Create a new API key for an app"""
return super().post(resource_id)
def after_request(self, resp): def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true" resp.headers["Access-Control-Allow-Credentials"] = "true"
@ -165,16 +146,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
token_prefix = "app-" token_prefix = "app-"
@console_ns.route("/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
class AppApiKeyResource(BaseApiKeyResource): class AppApiKeyResource(BaseApiKeyResource):
@api.doc("delete_app_api_key")
@api.doc(description="Delete an API key for an app")
@api.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
@api.response(204, "API key deleted successfully")
def delete(self, resource_id, api_key_id):
"""Delete an API key for an app"""
return super().delete(resource_id, api_key_id)
def after_request(self, resp): def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true" resp.headers["Access-Control-Allow-Credentials"] = "true"
@ -185,25 +157,7 @@ class AppApiKeyResource(BaseApiKeyResource):
resource_id_field = "app_id" resource_id_field = "app_id"
@console_ns.route("/datasets/<uuid:resource_id>/api-keys")
class DatasetApiKeyListResource(BaseApiKeyListResource): class DatasetApiKeyListResource(BaseApiKeyListResource):
@api.doc("get_dataset_api_keys")
@api.doc(description="Get all API keys for a dataset")
@api.doc(params={"resource_id": "Dataset ID"})
@api.response(200, "Success", api_key_list)
def get(self, resource_id):
"""Get all API keys for a dataset"""
return super().get(resource_id)
@api.doc("create_dataset_api_key")
@api.doc(description="Create a new API key for a dataset")
@api.doc(params={"resource_id": "Dataset ID"})
@api.response(201, "API key created successfully", api_key_fields)
@api.response(400, "Maximum keys exceeded")
def post(self, resource_id):
"""Create a new API key for a dataset"""
return super().post(resource_id)
def after_request(self, resp): def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true" resp.headers["Access-Control-Allow-Credentials"] = "true"
@ -215,16 +169,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
token_prefix = "ds-" token_prefix = "ds-"
@console_ns.route("/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
class DatasetApiKeyResource(BaseApiKeyResource): class DatasetApiKeyResource(BaseApiKeyResource):
@api.doc("delete_dataset_api_key")
@api.doc(description="Delete an API key for a dataset")
@api.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
@api.response(204, "API key deleted successfully")
def delete(self, resource_id, api_key_id):
"""Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id)
def after_request(self, resp): def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true" resp.headers["Access-Control-Allow-Credentials"] = "true"
@ -233,3 +178,9 @@ class DatasetApiKeyResource(BaseApiKeyResource):
resource_type = "dataset" resource_type = "dataset"
resource_model = Dataset resource_model = Dataset
resource_id_field = "dataset_id" resource_id_field = "dataset_id"
api.add_resource(AppApiKeyListResource, "/apps/<uuid:resource_id>/api-keys")
api.add_resource(AppApiKeyResource, "/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
api.add_resource(DatasetApiKeyListResource, "/datasets/<uuid:resource_id>/api-keys")
api.add_resource(DatasetApiKeyResource, "/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")

View File

@ -115,10 +115,6 @@ class AppListApi(Resource):
raise BadRequest("mode is required") raise BadRequest("mode is required")
app_service = AppService() app_service = AppService()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
if current_user.current_tenant_id is None:
raise ValueError("current_user.current_tenant_id cannot be None")
app = app_service.create_app(current_user.current_tenant_id, args, current_user) app = app_service.create_app(current_user.current_tenant_id, args, current_user)
return app, 201 return app, 201
@ -165,26 +161,14 @@ class AppApi(Resource):
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
# Construct ArgsDict from parsed arguments app_model = app_service.update_app(app_model, args)
from services.app_service import AppService as AppServiceType
args_dict: AppServiceType.ArgsDict = {
"name": args["name"],
"description": args.get("description", ""),
"icon_type": args.get("icon_type", ""),
"icon": args.get("icon", ""),
"icon_background": args.get("icon_background", ""),
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
"max_active_requests": args.get("max_active_requests", 0),
}
app_model = app_service.update_app(app_model, args_dict)
return app_model return app_model
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def delete(self, app_model): def delete(self, app_model):
"""Delete app""" """Delete app"""
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
@ -240,10 +224,10 @@ class AppCopyApi(Resource):
class AppExportApi(Resource): class AppExportApi(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
"""Export app""" """Export app"""
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
@ -253,14 +237,9 @@ class AppExportApi(Resource):
# Add include_secret params # Add include_secret params
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args") parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
parser.add_argument("workflow_id", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
return { return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
"data": AppDslService.export_dsl(
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
)
}
class AppNameApi(Resource): class AppNameApi(Resource):
@ -279,7 +258,7 @@ class AppNameApi(Resource):
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_name(app_model, args["name"]) app_model = app_service.update_app_name(app_model, args.get("name"))
return app_model return app_model
@ -301,7 +280,7 @@ class AppIconApi(Resource):
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "") app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background"))
return app_model return app_model
@ -322,7 +301,7 @@ class AppSiteStatus(Resource):
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args["enable_site"]) app_model = app_service.update_app_site_status(app_model, args.get("enable_site"))
return app_model return app_model
@ -343,7 +322,7 @@ class AppApiStatus(Resource):
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args["enable_api"]) app_model = app_service.update_app_api_status(app_model, args.get("enable_api"))
return app_model return app_model

View File

@ -77,10 +77,10 @@ class ChatMessageAudioApi(Resource):
class ChatMessageTextApi(Resource): class ChatMessageTextApi(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def post(self, app_model: App): def post(self, app_model: App):
try: try:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -125,10 +125,10 @@ class ChatMessageTextApi(Resource):
class TextModesApi(Resource): class TextModesApi(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
try: try:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()

View File

@ -1,8 +1,9 @@
import logging import logging
import flask_login
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.console import api from controllers.console import api
@ -28,8 +29,7 @@ from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import current_user, login_required from libs.login import login_required
from models import Account
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
@ -56,11 +56,11 @@ class CompletionMessageApi(Resource):
streaming = args["response_mode"] != "blocking" streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False args["auto_generate_name"] = False
account = flask_login.current_user
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account or EndUser instance")
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)
@ -92,9 +92,9 @@ class CompletionMessageStopApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model, task_id): def post(self, app_model, task_id):
if not isinstance(current_user, Account): account = flask_login.current_user
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -105,12 +105,6 @@ class ChatMessageApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
def post(self, app_model): def post(self, app_model):
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json") parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json") parser.add_argument("query", type=str, required=True, location="json")
@ -129,11 +123,11 @@ class ChatMessageApi(Resource):
if external_trace_id: if external_trace_id:
args["external_trace_id"] = external_trace_id args["external_trace_id"] = external_trace_id
account = flask_login.current_user
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account or EndUser instance")
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)
@ -167,9 +161,9 @@ class ChatMessageStopApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def post(self, app_model, task_id): def post(self, app_model, task_id):
if not isinstance(current_user, Account): account = flask_login.current_user
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -22,7 +22,7 @@ from fields.conversation_fields import (
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import DatetimeString from libs.helper import DatetimeString
from libs.login import login_required from libs.login import login_required
from models import Account, Conversation, EndUser, Message, MessageAnnotation from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode from models.model import AppMode
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
@ -117,15 +117,13 @@ class CompletionConversationDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def delete(self, app_model, conversation_id): def delete(self, app_model, conversation_id):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user) ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -284,8 +282,6 @@ class ChatConversationDetailApi(Resource):
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user) ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")

View File

@ -207,7 +207,7 @@ class InstructionGenerationTemplateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self) -> dict:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, default=False, location="json") parser.add_argument("type", type=str, required=True, default=False, location="json")
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,5 +1,6 @@
import logging import logging
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from sqlalchemy import exists, select from sqlalchemy import exists, select
@ -26,8 +27,7 @@ from extensions.ext_database import db
from fields.conversation_fields import annotation_fields, message_detail_fields from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_user, login_required from libs.login import login_required
from models.account import Account
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
@ -118,14 +118,11 @@ class ChatMessageListApi(Resource):
class MessageFeedbackApi(Resource): class MessageFeedbackApi(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def post(self, app_model): def post(self, app_model):
if current_user is None:
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("message_id", required=True, type=uuid_value, location="json") parser.add_argument("message_id", required=True, type=uuid_value, location="json")
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
@ -170,9 +167,7 @@ class MessageAnnotationApi(Resource):
@get_app_model @get_app_model
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
def post(self, app_model): def post(self, app_model):
if not isinstance(current_user, Account): if not current_user.is_editor:
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -187,10 +182,10 @@ class MessageAnnotationApi(Resource):
class MessageAnnotationCountApi(Resource): class MessageAnnotationCountApi(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()

View File

@ -2,8 +2,8 @@ import json
from typing import cast from typing import cast
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource from flask_restx import Resource
from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import api
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -13,8 +13,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated from events.app_event import app_model_config_was_updated
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_user, login_required from libs.login import login_required
from models.account import Account
from models.model import AppMode, AppModelConfig from models.model import AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService from services.app_model_config_service import AppModelConfigService
@ -26,13 +25,6 @@ class ModelConfigResource(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model): def post(self, app_model):
"""Modify app model config""" """Modify app model config"""
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
# validate config # validate config
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,

View File

@ -10,7 +10,7 @@ from extensions.ext_database import db
from fields.app_fields import app_site_fields from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import login_required
from models import Account, Site from models import Site
def parse_app_site_args(): def parse_app_site_args():
@ -75,8 +75,6 @@ class AppSite(Resource):
if value is not None: if value is not None:
setattr(site, attr_name, value) setattr(site, attr_name, value)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = naive_utc_now() site.updated_at = naive_utc_now()
db.session.commit() db.session.commit()
@ -101,8 +99,6 @@ class AppSiteAccessTokenReset(Resource):
raise NotFound raise NotFound
site.code = Site.generate_code(16) site.code = Site.generate_code(16)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = naive_utc_now() site.updated_at = naive_utc_now()
db.session.commit() db.session.commit()

View File

@ -18,10 +18,10 @@ from models import AppMode, Message
class DailyMessageStatistic(Resource): class DailyMessageStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user
@ -75,10 +75,10 @@ WHERE
class DailyConversationStatistic(Resource): class DailyConversationStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user
@ -127,10 +127,10 @@ class DailyConversationStatistic(Resource):
class DailyTerminalsStatistic(Resource): class DailyTerminalsStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user
@ -184,10 +184,10 @@ WHERE
class DailyTokenCostStatistic(Resource): class DailyTokenCostStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user
@ -320,10 +320,10 @@ ORDER BY
class UserSatisfactionRateStatistic(Resource): class UserSatisfactionRateStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user
@ -443,10 +443,10 @@ WHERE
class TokensPerSecondStatistic(Resource): class TokensPerSecondStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user

View File

@ -11,7 +11,11 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import api
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.error import (
ConversationCompletedError,
DraftWorkflowNotExist,
DraftWorkflowNotSync,
)
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
@ -69,7 +73,7 @@ class DraftWorkflowApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
# fetch draft workflow by app_model # fetch draft workflow by app_model
@ -92,7 +96,7 @@ class DraftWorkflowApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
@ -170,7 +174,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
@ -220,7 +224,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -256,7 +260,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -293,7 +297,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -330,7 +334,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -367,7 +371,7 @@ class DraftWorkflowRunApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -406,7 +410,7 @@ class WorkflowTaskStopApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
@ -428,7 +432,7 @@ class DraftWorkflowNodeRunApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -476,7 +480,7 @@ class PublishedWorkflowApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
# fetch published workflow by app_model # fetch published workflow by app_model
@ -497,7 +501,7 @@ class PublishedWorkflowApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -522,7 +526,7 @@ class PublishedWorkflowApi(Resource):
) )
app_model.workflow_id = workflow.id app_model.workflow_id = workflow.id
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id db.session.commit()
workflow_created_at = TimestampField().format(workflow.created_at) workflow_created_at = TimestampField().format(workflow.created_at)
@ -547,7 +551,7 @@ class DefaultBlockConfigsApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
# Get default block configs # Get default block configs
@ -567,7 +571,7 @@ class DefaultBlockConfigApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -602,7 +606,7 @@ class ConvertToWorkflowApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
if request.data: if request.data:
@ -651,7 +655,7 @@ class PublishedAllWorkflowApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -702,7 +706,7 @@ class WorkflowByIdApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# Check permission # Check permission
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -758,7 +762,7 @@ class WorkflowByIdApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# Check permission # Check permission
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
workflow_service = WorkflowService() workflow_service = WorkflowService()

View File

@ -27,9 +27,7 @@ class WorkflowAppLogApi(Resource):
""" """
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") parser.add_argument("keyword", type=str, location="args")
parser.add_argument( parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
parser.add_argument( parser.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp" "created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
) )

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import NoReturn from typing import Any, NoReturn
from flask import Response from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
@ -29,7 +29,7 @@ from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _convert_values_to_json_serializable_object(value: Segment): def _convert_values_to_json_serializable_object(value: Segment) -> Any:
if isinstance(value, FileSegment): if isinstance(value, FileSegment):
return value.value.model_dump() return value.value.model_dump()
elif isinstance(value, ArrayFileSegment): elif isinstance(value, ArrayFileSegment):
@ -40,7 +40,7 @@ def _convert_values_to_json_serializable_object(value: Segment):
return value.value return value.value
def _serialize_var_value(variable: WorkflowDraftVariable): def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
value = variable.get_value() value = variable.get_value()
# create a copy of the value to avoid affecting the model cache. # create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True) value = value.model_copy(deep=True)
@ -137,7 +137,7 @@ def _api_prerequisite(f):
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
return f(*args, **kwargs) return f(*args, **kwargs)

View File

@ -18,10 +18,10 @@ from models.model import AppMode
class WorkflowDailyRunsStatistic(Resource): class WorkflowDailyRunsStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user
@ -80,10 +80,10 @@ WHERE
class WorkflowDailyTerminalsStatistic(Resource): class WorkflowDailyTerminalsStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user
@ -142,10 +142,10 @@ WHERE
class WorkflowDailyTokenCostStatistic(Resource): class WorkflowDailyTokenCostStatistic(Resource):
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user

View File

@ -1,6 +1,6 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Optional, ParamSpec, TypeVar, Union from typing import Optional, Union
from controllers.console.app.error import AppNotFoundError from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db from extensions.ext_database import db
@ -8,9 +8,6 @@ from libs.login import current_user
from models import App, AppMode from models import App, AppMode
from models.account import Account from models.account import Account
P = ParamSpec("P")
R = TypeVar("R")
def _load_app_model(app_id: str) -> Optional[App]: def _load_app_model(app_id: str) -> Optional[App]:
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
@ -22,10 +19,10 @@ def _load_app_model(app_id: str) -> Optional[App]:
return app_model return app_model
def get_app_model(view: Optional[Callable[P, R]] = None, *, mode: Union[AppMode, list[AppMode], None] = None): def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]): def decorator(view_func):
@wraps(view_func) @wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs): def decorated_view(*args, **kwargs):
if not kwargs.get("app_id"): if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters") raise ValueError("missing app_id in path parameters")

View File

@ -1,8 +1,8 @@
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, reqparse
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api, console_ns from controllers.console import api
from controllers.console.error import AlreadyActivateError from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
@ -10,36 +10,14 @@ from libs.helper import StrLen, email, extract_remote_ip, timezone
from models.account import AccountStatus from models.account import AccountStatus
from services.account_service import AccountService, RegisterService from services.account_service import AccountService, RegisterService
active_check_parser = reqparse.RequestParser()
active_check_parser.add_argument(
"workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID"
)
active_check_parser.add_argument(
"email", type=email, required=False, nullable=True, location="args", help="Email address"
)
active_check_parser.add_argument(
"token", type=str, required=True, nullable=False, location="args", help="Activation token"
)
@console_ns.route("/activate/check")
class ActivateCheckApi(Resource): class ActivateCheckApi(Resource):
@api.doc("check_activation_token")
@api.doc(description="Check if activation token is valid")
@api.expect(active_check_parser)
@api.response(
200,
"Success",
api.model(
"ActivationCheckResponse",
{
"is_valid": fields.Boolean(description="Whether token is valid"),
"data": fields.Raw(description="Activation data if valid"),
},
),
)
def get(self): def get(self):
args = active_check_parser.parse_args() parser = reqparse.RequestParser()
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args")
parser.add_argument("email", type=email, required=False, nullable=True, location="args")
parser.add_argument("token", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
workspaceId = args["workspace_id"] workspaceId = args["workspace_id"]
reg_email = args["email"] reg_email = args["email"]
@ -60,36 +38,18 @@ class ActivateCheckApi(Resource):
return {"is_valid": False} return {"is_valid": False}
active_parser = reqparse.RequestParser()
active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
active_parser.add_argument("email", type=email, required=False, nullable=True, location="json")
active_parser.add_argument("token", type=str, required=True, nullable=False, location="json")
active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
active_parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json"
)
active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
@console_ns.route("/activate")
class ActivateApi(Resource): class ActivateApi(Resource):
@api.doc("activate_account")
@api.doc(description="Activate account with invitation token")
@api.expect(active_parser)
@api.response(
200,
"Account activated successfully",
api.model(
"ActivationResponse",
{
"result": fields.String(description="Operation result"),
"data": fields.Raw(description="Login token data"),
},
),
)
@api.response(400, "Already activated or invalid token")
def post(self): def post(self):
args = active_parser.parse_args() parser = reqparse.RequestParser()
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json"
)
parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
args = parser.parse_args()
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
if invitation is None: if invitation is None:
@ -110,3 +70,7 @@ class ActivateApi(Resource):
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
return {"result": "success", "data": token_pair.model_dump()} return {"result": "success", "data": token_pair.model_dump()}
api.add_resource(ActivateCheckApi, "/activate/check")
api.add_resource(ActivateApi, "/activate")

View File

@ -3,11 +3,11 @@ import logging
import requests import requests
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields from flask_restx import Resource
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api, console_ns from controllers.console import api
from libs.login import login_required from libs.login import login_required
from libs.oauth_data_source import NotionOAuth from libs.oauth_data_source import NotionOAuth
@ -28,21 +28,7 @@ def get_oauth_providers():
return OAUTH_PROVIDERS return OAUTH_PROVIDERS
@console_ns.route("/oauth/data-source/<string:provider>")
class OAuthDataSource(Resource): class OAuthDataSource(Resource):
@api.doc("oauth_data_source")
@api.doc(description="Get OAuth authorization URL for data source provider")
@api.doc(params={"provider": "Data source provider name (notion)"})
@api.response(
200,
"Authorization URL or internal setup success",
api.model(
"OAuthDataSourceResponse",
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
),
)
@api.response(400, "Invalid provider")
@api.response(403, "Admin privileges required")
def get(self, provider: str): def get(self, provider: str):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
@ -63,19 +49,7 @@ class OAuthDataSource(Resource):
return {"data": auth_url}, 200 return {"data": auth_url}, 200
@console_ns.route("/oauth/data-source/callback/<string:provider>")
class OAuthDataSourceCallback(Resource): class OAuthDataSourceCallback(Resource):
@api.doc("oauth_data_source_callback")
@api.doc(description="Handle OAuth callback from data source provider")
@api.doc(
params={
"provider": "Data source provider name (notion)",
"code": "Authorization code from OAuth provider",
"error": "Error message from OAuth provider",
}
)
@api.response(302, "Redirect to console with result")
@api.response(400, "Invalid provider")
def get(self, provider: str): def get(self, provider: str):
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context(): with current_app.app_context():
@ -94,19 +68,7 @@ class OAuthDataSourceCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied") return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied")
@console_ns.route("/oauth/data-source/binding/<string:provider>")
class OAuthDataSourceBinding(Resource): class OAuthDataSourceBinding(Resource):
@api.doc("oauth_data_source_binding")
@api.doc(description="Bind OAuth data source with authorization code")
@api.doc(
params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"}
)
@api.response(
200,
"Data source binding success",
api.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
)
@api.response(400, "Invalid provider or code")
def get(self, provider: str): def get(self, provider: str):
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context(): with current_app.app_context():
@ -119,7 +81,7 @@ class OAuthDataSourceBinding(Resource):
return {"error": "Invalid code"}, 400 return {"error": "Invalid code"}, 400
try: try:
oauth_provider.get_access_token(code) oauth_provider.get_access_token(code)
except requests.HTTPError as e: except requests.exceptions.HTTPError as e:
logger.exception( logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
) )
@ -128,17 +90,7 @@ class OAuthDataSourceBinding(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
class OAuthDataSourceSync(Resource): class OAuthDataSourceSync(Resource):
@api.doc("oauth_data_source_sync")
@api.doc(description="Sync data from OAuth data source")
@api.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"})
@api.response(
200,
"Data source sync success",
api.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
)
@api.response(400, "Invalid provider or sync failed")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -152,10 +104,16 @@ class OAuthDataSourceSync(Resource):
return {"error": "Invalid provider"}, 400 return {"error": "Invalid provider"}, 400
try: try:
oauth_provider.sync_data_source(binding_id) oauth_provider.sync_data_source(binding_id)
except requests.HTTPError as e: except requests.exceptions.HTTPError as e:
logger.exception( logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
) )
return {"error": "OAuth data source process failed"}, 400 return {"error": "OAuth data source process failed"}, 400
return {"result": "success"}, 200 return {"result": "success"}, 200
api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")
api.add_resource(OAuthDataSourceSync, "/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")

View File

@ -2,12 +2,12 @@ import base64
import secrets import secrets
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, reqparse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from constants.languages import languages from constants.languages import languages
from controllers.console import api, console_ns from controllers.console import api
from controllers.console.auth.error import ( from controllers.console.auth.error import (
EmailCodeError, EmailCodeError,
EmailPasswordResetLimitError, EmailPasswordResetLimitError,
@ -28,32 +28,7 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces
from services.feature_service import FeatureService from services.feature_service import FeatureService
@console_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordSendEmailApi(Resource):
@api.doc("send_forgot_password_email")
@api.doc(description="Send password reset email")
@api.expect(
api.model(
"ForgotPasswordEmailRequest",
{
"email": fields.String(required=True, description="Email address"),
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
},
)
)
@api.response(
200,
"Email sent successfully",
api.model(
"ForgotPasswordEmailResponse",
{
"result": fields.String(description="Operation result"),
"data": fields.String(description="Reset token"),
"code": fields.String(description="Error code if account not found"),
},
),
)
@api.response(400, "Invalid email or rate limit exceeded")
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
@ -86,33 +61,7 @@ class ForgotPasswordSendEmailApi(Resource):
return {"result": "success", "data": token} return {"result": "success", "data": token}
@console_ns.route("/forgot-password/validity")
class ForgotPasswordCheckApi(Resource): class ForgotPasswordCheckApi(Resource):
@api.doc("check_forgot_password_code")
@api.doc(description="Verify password reset code")
@api.expect(
api.model(
"ForgotPasswordCheckRequest",
{
"email": fields.String(required=True, description="Email address"),
"code": fields.String(required=True, description="Verification code"),
"token": fields.String(required=True, description="Reset token"),
},
)
)
@api.response(
200,
"Code verified successfully",
api.model(
"ForgotPasswordCheckResponse",
{
"is_valid": fields.Boolean(description="Whether code is valid"),
"email": fields.String(description="Email address"),
"token": fields.String(description="New reset token"),
},
),
)
@api.response(400, "Invalid code or token")
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
@ -151,26 +100,7 @@ class ForgotPasswordCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/forgot-password/resets")
class ForgotPasswordResetApi(Resource): class ForgotPasswordResetApi(Resource):
@api.doc("reset_password")
@api.doc(description="Reset password with verification token")
@api.expect(
api.model(
"ForgotPasswordResetRequest",
{
"token": fields.String(required=True, description="Verification token"),
"new_password": fields.String(required=True, description="New password"),
"password_confirm": fields.String(required=True, description="Password confirmation"),
},
)
)
@api.response(
200,
"Password reset successfully",
api.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
)
@api.response(400, "Invalid token or password mismatch")
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
@ -242,3 +172,8 @@ class ForgotPasswordResetApi(Resource):
pass pass
except AccountRegisterError: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")

View File

@ -130,7 +130,7 @@ class ResetPasswordSendEmailApi(Resource):
language = "en-US" language = "en-US"
try: try:
account = AccountService.get_user_through_email(args["email"]) account = AccountService.get_user_through_email(args["email"])
except AccountRegisterError: except AccountRegisterError as are:
raise AccountInFreezeError() raise AccountInFreezeError()
if account is None: if account is None:
@ -162,7 +162,7 @@ class EmailCodeLoginSendEmailApi(Resource):
language = "en-US" language = "en-US"
try: try:
account = AccountService.get_user_through_email(args["email"]) account = AccountService.get_user_through_email(args["email"])
except AccountRegisterError: except AccountRegisterError as are:
raise AccountInFreezeError() raise AccountInFreezeError()
if account is None: if account is None:
@ -200,7 +200,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args["token"]) AccountService.revoke_email_code_login_token(args["token"])
try: try:
account = AccountService.get_user_through_email(user_email) account = AccountService.get_user_through_email(user_email)
except AccountRegisterError: except AccountRegisterError as are:
raise AccountInFreezeError() raise AccountInFreezeError()
if account: if account:
tenants = TenantService.get_join_tenants(account) tenants = TenantService.get_join_tenants(account)
@ -223,7 +223,7 @@ class EmailCodeLoginApi(Resource):
) )
except WorkSpaceNotAllowedCreateError: except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace() raise NotAllowedCreateWorkspace()
except AccountRegisterError: except AccountRegisterError as are:
raise AccountInFreezeError() raise AccountInFreezeError()
except WorkspacesLimitExceededError: except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded() raise WorkspacesLimitExceeded()

View File

@ -22,7 +22,7 @@ from services.errors.account import AccountNotFoundError, AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
from services.feature_service import FeatureService from services.feature_service import FeatureService
from .. import api, console_ns from .. import api
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -50,13 +50,7 @@ def get_oauth_providers():
return OAUTH_PROVIDERS return OAUTH_PROVIDERS
@console_ns.route("/oauth/login/<provider>")
class OAuthLogin(Resource): class OAuthLogin(Resource):
@api.doc("oauth_login")
@api.doc(description="Initiate OAuth login process")
@api.doc(params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"})
@api.response(302, "Redirect to OAuth authorization URL")
@api.response(400, "Invalid provider")
def get(self, provider: str): def get(self, provider: str):
invite_token = request.args.get("invite_token") or None invite_token = request.args.get("invite_token") or None
OAUTH_PROVIDERS = get_oauth_providers() OAUTH_PROVIDERS = get_oauth_providers()
@ -69,19 +63,7 @@ class OAuthLogin(Resource):
return redirect(auth_url) return redirect(auth_url)
@console_ns.route("/oauth/authorize/<provider>")
class OAuthCallback(Resource): class OAuthCallback(Resource):
@api.doc("oauth_callback")
@api.doc(description="Handle OAuth callback and complete login process")
@api.doc(
params={
"provider": "OAuth provider name (github/google)",
"code": "Authorization code from OAuth provider",
"state": "Optional state parameter (used for invite token)",
}
)
@api.response(302, "Redirect to console with access token")
@api.response(400, "OAuth process failed")
def get(self, provider: str): def get(self, provider: str):
OAUTH_PROVIDERS = get_oauth_providers() OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context(): with current_app.app_context():
@ -95,19 +77,16 @@ class OAuthCallback(Resource):
if state: if state:
invite_token = state invite_token = state
if not code:
return {"error": "Authorization code is required"}, 400
try: try:
token = oauth_provider.get_access_token(code) token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token) user_info = oauth_provider.get_user_info(token)
except requests.RequestException as e: except requests.exceptions.RequestException as e:
error_text = e.response.text if e.response else str(e) error_text = e.response.text if e.response else str(e)
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400 return {"error": "OAuth process failed"}, 400
if invite_token and RegisterService.is_valid_invite_token(invite_token): if invite_token and RegisterService.is_valid_invite_token(invite_token):
invitation = RegisterService.get_invitation_by_token(token=invite_token) invitation = RegisterService._get_invitation_by_token(token=invite_token)
if invitation: if invitation:
invitation_email = invitation.get("email", None) invitation_email = invitation.get("email", None)
if invitation_email != user_info.email: if invitation_email != user_info.email:
@ -202,3 +181,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
AccountService.link_account_integrate(provider, user_info.id, account) AccountService.link_account_integrate(provider, user_info.id, account)
return account return account
api.add_resource(OAuthLogin, "/oauth/login/<provider>")
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")

View File

@ -1,9 +1,8 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar, cast from typing import cast
import flask_login import flask_login
from flask import jsonify, request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest, NotFound from werkzeug.exceptions import BadRequest, NotFound
@ -16,14 +15,10 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
from .. import api from .. import api
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def oauth_server_client_id_required(view):
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view) @wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("client_id", type=str, required=True, location="json") parser.add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args() parsed_args = parser.parse_args()
@ -35,53 +30,43 @@ def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderA
if not oauth_provider_app: if not oauth_provider_app:
raise NotFound("client_id is invalid") raise NotFound("client_id is invalid")
return view(self, oauth_provider_app, *args, **kwargs) kwargs["oauth_provider_app"] = oauth_provider_app
return view(*args, **kwargs)
return decorated return decorated
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]): def oauth_server_access_token_required(view):
@wraps(view) @wraps(view)
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
if not isinstance(oauth_provider_app, OAuthProviderApp): oauth_provider_app = kwargs.get("oauth_provider_app")
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app") raise BadRequest("Invalid oauth_provider_app")
authorization_header = request.headers.get("Authorization") authorization_header = request.headers.get("Authorization")
if not authorization_header: if not authorization_header:
response = jsonify({"error": "Authorization header is required"}) raise BadRequest("Authorization header is required")
response.status_code = 401
response.headers["WWW-Authenticate"] = "Bearer"
return response
parts = authorization_header.strip().split(None, 1) parts = authorization_header.strip().split(" ")
if len(parts) != 2: if len(parts) != 2:
response = jsonify({"error": "Invalid Authorization header format"}) raise BadRequest("Invalid Authorization header format")
response.status_code = 401
response.headers["WWW-Authenticate"] = "Bearer"
return response
token_type = parts[0].strip() token_type = parts[0].strip()
if token_type.lower() != "bearer": if token_type.lower() != "bearer":
response = jsonify({"error": "token_type is invalid"}) raise BadRequest("token_type is invalid")
response.status_code = 401
response.headers["WWW-Authenticate"] = "Bearer"
return response
access_token = parts[1].strip() access_token = parts[1].strip()
if not access_token: if not access_token:
response = jsonify({"error": "access_token is required"}) raise BadRequest("access_token is required")
response.status_code = 401
response.headers["WWW-Authenticate"] = "Bearer"
return response
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token) account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
if not account: if not account:
response = jsonify({"error": "access_token or client_id is invalid"}) raise BadRequest("access_token or client_id is invalid")
response.status_code = 401
response.headers["WWW-Authenticate"] = "Bearer"
return response
return view(self, oauth_provider_app, account, *args, **kwargs) kwargs["account"] = account
return view(*args, **kwargs)
return decorated return decorated

View File

@ -1,9 +1,9 @@
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from controllers.console import api from controllers.console import api
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_user, login_required from libs.login import login_required
from models.model import Account
from services.billing_service import BillingService from services.billing_service import BillingService
@ -17,10 +17,9 @@ class Subscription(Resource):
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
args = parser.parse_args() args = parser.parse_args()
assert isinstance(current_user, Account)
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_subscription( return BillingService.get_subscription(
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
) )
@ -32,9 +31,7 @@ class Invoices(Resource):
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
assert isinstance(current_user, Account)
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)

View File

@ -10,7 +10,6 @@ from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db from extensions.ext_database import db
@ -29,12 +28,14 @@ class DataSourceApi(Resource):
@marshal_with(integrate_list_fields) @marshal_with(integrate_list_fields)
def get(self): def get(self):
# get workspace data source integrates # get workspace data source integrates
data_source_integrates = db.session.scalars( data_source_integrates = (
select(DataSourceOauthBinding).where( db.session.query(DataSourceOauthBinding)
.where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
) )
).all() .all()
)
base_url = request.url_root.rstrip("/") base_url = request.url_root.rstrip("/")
data_source_oauth_base_path = "/console/api/oauth/data-source" data_source_oauth_base_path = "/console/api/oauth/data-source"
@ -213,7 +214,7 @@ class DataSourceNotionApi(Resource):
workspace_id = notion_info["workspace_id"] workspace_id = notion_info["workspace_id"]
for page in notion_info["pages"]: for page in notion_info["pages"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type="notion_import",
notion_info={ notion_info={
"notion_workspace_id": workspace_id, "notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"], "notion_obj_id": page["page_id"],
@ -247,7 +248,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
documents = DocumentService.get_document_by_dataset_id(dataset_id_str) documents = DocumentService.get_document_by_dataset_id(dataset_id_str)
for document in documents: for document in documents:
document_indexing_sync_task.delay(dataset_id_str, document.id) document_indexing_sync_task.delay(dataset_id_str, document.id)
return {"result": "success"}, 200 return 200
class DataSourceNotionDocumentSyncApi(Resource): class DataSourceNotionDocumentSyncApi(Resource):
@ -265,7 +266,7 @@ class DataSourceNotionDocumentSyncApi(Resource):
if document is None: if document is None:
raise NotFound("Document not found.") raise NotFound("Document not found.")
document_indexing_sync_task.delay(dataset_id_str, document_id_str) document_indexing_sync_task.delay(dataset_id_str, document_id_str)
return {"result": "success"}, 200 return 200
api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>") api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>")

View File

@ -2,7 +2,6 @@ import flask_restx
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, marshal, marshal_with, reqparse from flask_restx import Resource, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
@ -23,7 +22,6 @@ from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
@ -412,11 +410,11 @@ class DatasetIndexingEstimateApi(Resource):
extract_settings = [] extract_settings = []
if args["info_list"]["data_source_type"] == "upload_file": if args["info_list"]["data_source_type"] == "upload_file":
file_ids = args["info_list"]["file_info_list"]["file_ids"] file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = db.session.scalars( file_details = (
select(UploadFile).where( db.session.query(UploadFile)
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids) .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
) .all()
).all() )
if file_details is None: if file_details is None:
raise NotFound("File not found.") raise NotFound("File not found.")
@ -424,9 +422,7 @@ class DatasetIndexingEstimateApi(Resource):
if file_details: if file_details:
for file_detail in file_details: for file_detail in file_details:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
upload_file=file_detail,
document_model=args["doc_form"],
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
elif args["info_list"]["data_source_type"] == "notion_import": elif args["info_list"]["data_source_type"] == "notion_import":
@ -435,7 +431,7 @@ class DatasetIndexingEstimateApi(Resource):
workspace_id = notion_info["workspace_id"] workspace_id = notion_info["workspace_id"]
for page in notion_info["pages"]: for page in notion_info["pages"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type="notion_import",
notion_info={ notion_info={
"notion_workspace_id": workspace_id, "notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"], "notion_obj_id": page["page_id"],
@ -449,7 +445,7 @@ class DatasetIndexingEstimateApi(Resource):
website_info_list = args["info_list"]["website_info_list"] website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]: for url in website_info_list["urls"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value, datasource_type="website_crawl",
website_info={ website_info={
"provider": website_info_list["provider"], "provider": website_info_list["provider"],
"job_id": website_info_list["job_id"], "job_id": website_info_list["job_id"],
@ -519,11 +515,11 @@ class DatasetIndexingStatusApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
documents = db.session.scalars( documents = (
select(Document).where( db.session.query(Document)
Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id .where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
) .all()
).all() )
documents_status = [] documents_status = []
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
@ -570,11 +566,11 @@ class DatasetApiKeyApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(api_key_list) @marshal_with(api_key_list)
def get(self): def get(self):
keys = db.session.scalars( keys = (
select(ApiToken).where( db.session.query(ApiToken)
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
) .all()
).all() )
return {"items": keys} return {"items": keys}
@setup_required @setup_required
@ -664,6 +660,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.BAIDU | VectorType.BAIDU
| VectorType.VIKINGDB | VectorType.VIKINGDB
| VectorType.UPSTASH | VectorType.UPSTASH
| VectorType.PINECONE
): ):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case ( case (
@ -715,6 +712,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.BAIDU | VectorType.BAIDU
| VectorType.VIKINGDB | VectorType.VIKINGDB
| VectorType.UPSTASH | VectorType.UPSTASH
| VectorType.PINECONE
): ):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case ( case (

View File

@ -1,6 +1,5 @@
import logging import logging
from argparse import ArgumentTypeError from argparse import ArgumentTypeError
from collections.abc import Sequence
from typing import Literal, cast from typing import Literal, cast
from flask import request from flask import request
@ -41,7 +40,6 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db from extensions.ext_database import db
from fields.document_fields import ( from fields.document_fields import (
@ -80,7 +78,7 @@ class DocumentResource(Resource):
return document return document
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]: def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -356,6 +354,9 @@ class DatasetInitApi(Resource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig(**args)
if knowledge_config.indexing_technique == "high_quality": if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
@ -427,7 +428,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.") raise NotFound("File not found.")
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form datasource_type="upload_file", upload_file=file, document_model=document.doc_form
) )
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
@ -476,8 +477,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
if document.data_source_type == "upload_file": if document.data_source_type == "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
file_detail = ( file_detail = (
db.session.query(UploadFile) db.session.query(UploadFile)
@ -489,15 +488,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.") raise NotFound("File not found.")
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
elif document.data_source_type == "notion_import": elif document.data_source_type == "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type="notion_import",
notion_info={ notion_info={
"notion_workspace_id": data_source_info["notion_workspace_id"], "notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"], "notion_obj_id": data_source_info["notion_page_id"],
@ -508,10 +505,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl": elif document.data_source_type == "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value, datasource_type="website_crawl",
website_info={ website_info={
"provider": data_source_info["provider"], "provider": data_source_info["provider"],
"job_id": data_source_info["job_id"], "job_id": data_source_info["job_id"],

View File

@ -113,7 +113,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
MetadataService.enable_built_in_field(dataset) MetadataService.enable_built_in_field(dataset)
elif action == "disable": elif action == "disable":
MetadataService.disable_built_in_field(dataset) MetadataService.disable_built_in_field(dataset)
return {"result": "success"}, 200 return 200
class DocumentMetadataEditApi(Resource): class DocumentMetadataEditApi(Resource):
@ -135,7 +135,7 @@ class DocumentMetadataEditApi(Resource):
MetadataService.update_documents_metadata(dataset, metadata_args) MetadataService.update_documents_metadata(dataset, metadata_args)
return {"result": "success"}, 200 return 200
api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata") api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata")

View File

@ -1,5 +1,6 @@
import logging import logging
from flask_login import current_user
from flask_restx import reqparse from flask_restx import reqparse
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
@ -27,8 +28,6 @@ from extensions.ext_database import db
from libs import helper from libs import helper
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
@ -58,8 +57,6 @@ class CompletionApi(InstalledAppResource):
db.session.commit() db.session.commit()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
) )
@ -93,8 +90,6 @@ class CompletionStopApi(InstalledAppResource):
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -122,8 +117,6 @@ class ChatApi(InstalledAppResource):
db.session.commit() db.session.commit()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
) )
@ -160,8 +153,6 @@ class ChatStopApi(InstalledAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -1,3 +1,4 @@
from flask_login import current_user
from flask_restx import marshal_with, reqparse from flask_restx import marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -9,8 +10,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.model import AppMode from models.model import AppMode
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
@ -36,8 +35,6 @@ class ConversationListApi(InstalledAppResource):
pinned = args["pinned"] == "true" pinned = args["pinned"] == "true"
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
with Session(db.engine) as session: with Session(db.engine) as session:
return WebConversationService.pagination_by_last_id( return WebConversationService.pagination_by_last_id(
session=session, session=session,
@ -61,11 +58,10 @@ class ConversationApi(InstalledAppResource):
conversation_id = str(c_id) conversation_id = str(c_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user) ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"}, 204 return {"result": "success"}, 204
@ -86,8 +82,6 @@ class ConversationRenameApi(InstalledAppResource):
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return ConversationService.rename( return ConversationService.rename(
app_model, conversation_id, current_user, args["name"], args["auto_generate"] app_model, conversation_id, current_user, args["name"], args["auto_generate"]
) )
@ -105,8 +99,6 @@ class ConversationPinApi(InstalledAppResource):
conversation_id = str(c_id) conversation_id = str(c_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
WebConversationService.pin(app_model, conversation_id, current_user) WebConversationService.pin(app_model, conversation_id, current_user)
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -122,8 +114,6 @@ class ConversationUnPinApi(InstalledAppResource):
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
WebConversationService.unpin(app_model, conversation_id, current_user) WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"} return {"result": "success"}

View File

@ -2,8 +2,9 @@ import logging
from typing import Any from typing import Any
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, inputs, marshal_with, reqparse from flask_restx import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_, select from sqlalchemy import and_
from werkzeug.exceptions import BadRequest, Forbidden, NotFound from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.console import api from controllers.console import api
@ -12,8 +13,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
from extensions.ext_database import db from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields from fields.installed_app_fields import installed_app_list_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import current_user, login_required from libs.login import login_required
from models import Account, App, InstalledApp, RecommendedApp from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService from services.account_service import TenantService
from services.app_service import AppService from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
@ -28,23 +29,17 @@ class InstalledAppsListApi(Resource):
@marshal_with(installed_app_list_fields) @marshal_with(installed_app_list_fields)
def get(self): def get(self):
app_id = request.args.get("app_id", default=None, type=str) app_id = request.args.get("app_id", default=None, type=str)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id current_tenant_id = current_user.current_tenant_id
if app_id: if app_id:
installed_apps = db.session.scalars( installed_apps = (
select(InstalledApp).where( db.session.query(InstalledApp)
and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id) .where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
) .all()
).all() )
else: else:
installed_apps = db.session.scalars( installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all()
select(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id)
).all()
if current_user.current_tenant is None:
raise ValueError("current_user.current_tenant must not be None")
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
installed_app_list: list[dict[str, Any]] = [ installed_app_list: list[dict[str, Any]] = [
{ {
@ -120,8 +115,6 @@ class InstalledAppsListApi(Resource):
if recommended_app is None: if recommended_app is None:
raise NotFound("App not found") raise NotFound("App not found")
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id current_tenant_id = current_user.current_tenant_id
app = db.session.query(App).where(App.id == args["app_id"]).first() app = db.session.query(App).where(App.id == args["app_id"]).first()
@ -161,8 +154,6 @@ class InstalledAppApi(InstalledAppResource):
""" """
def delete(self, installed_app): def delete(self, installed_app):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
if installed_app.app_owner_tenant_id == current_user.current_tenant_id: if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant") raise BadRequest("You can't uninstall an app owned by the current tenant")

View File

@ -1,5 +1,6 @@
import logging import logging
from flask_login import current_user
from flask_restx import marshal_with, reqparse from flask_restx import marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
@ -23,8 +24,6 @@ from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper from libs import helper
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError from services.errors.app import MoreLikeThisDisabledError
@ -55,8 +54,6 @@ class MessageListApi(InstalledAppResource):
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return MessageService.pagination_by_first_id( return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
) )
@ -78,8 +75,6 @@ class MessageFeedbackApi(InstalledAppResource):
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
MessageService.create_feedback( MessageService.create_feedback(
app_model=app_model, app_model=app_model,
message_id=message_id, message_id=message_id,
@ -110,8 +105,6 @@ class MessageMoreLikeThisApi(InstalledAppResource):
streaming = args["response_mode"] == "streaming" streaming = args["response_mode"] == "streaming"
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate_more_like_this( response = AppGenerateService.generate_more_like_this(
app_model=app_model, app_model=app_model,
user=current_user, user=current_user,
@ -149,8 +142,6 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
message_id = str(message_id) message_id = str(message_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer( questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
) )

View File

@ -43,8 +43,6 @@ class ExploreAppMetaApi(InstalledAppResource):
def get(self, installed_app: InstalledApp): def get(self, installed_app: InstalledApp):
"""Get app meta""" """Get app meta"""
app_model = installed_app.app app_model = installed_app.app
if not app_model:
raise ValueError("App not found")
return AppService().get_app_meta(app_model) return AppService().get_app_meta(app_model)

View File

@ -1,10 +1,11 @@
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from constants.languages import languages from constants.languages import languages
from controllers.console import api from controllers.console import api
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField from libs.helper import AppIconUrlField
from libs.login import current_user, login_required from libs.login import login_required
from services.recommended_app_service import RecommendedAppService from services.recommended_app_service import RecommendedAppService
app_fields = { app_fields = {
@ -45,9 +46,8 @@ class RecommendedAppListApi(Resource):
parser.add_argument("language", type=str, location="args") parser.add_argument("language", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
language = args.get("language") if args.get("language") and args.get("language") in languages:
if language and language in languages: language_prefix = args.get("language")
language_prefix = language
elif current_user and current_user.interface_language: elif current_user and current_user.interface_language:
language_prefix = current_user.interface_language language_prefix = current_user.interface_language
else: else:

View File

@ -1,3 +1,4 @@
from flask_login import current_user
from flask_restx import fields, marshal_with, reqparse from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -7,8 +8,6 @@ from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from libs.login import current_user
from models import Account
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService from services.saved_message_service import SavedMessageService
@ -43,8 +42,6 @@ class SavedMessageListApi(InstalledAppResource):
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
def post(self, installed_app): def post(self, installed_app):
@ -57,8 +54,6 @@ class SavedMessageListApi(InstalledAppResource):
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.save(app_model, current_user, args["message_id"]) SavedMessageService.save(app_model, current_user, args["message_id"])
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@ -75,8 +70,6 @@ class SavedMessageApi(InstalledAppResource):
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.delete(app_model, current_user, message_id) SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}, 204 return {"result": "success"}, 204

View File

@ -35,8 +35,6 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
Run workflow Run workflow
""" """
app_model = installed_app.app app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
@ -75,8 +73,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
Stop workflow task Stop workflow task
""" """
app_model = installed_app.app app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()

View File

@ -1,6 +1,4 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Concatenate, Optional, ParamSpec, TypeVar
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource from flask_restx import Resource
@ -15,15 +13,19 @@ from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def installed_app_required(view=None):
def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): def decorator(view):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view) @wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
if not kwargs.get("installed_app_id"):
raise ValueError("missing installed_app_id in path parameters")
installed_app_id = kwargs.get("installed_app_id")
installed_app_id = str(installed_app_id)
del kwargs["installed_app_id"]
installed_app = ( installed_app = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.where( .where(
@ -50,10 +52,10 @@ def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P],
return decorator return decorator
def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): def user_allowed_to_access_app(view=None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]): def decorator(view):
@wraps(view) @wraps(view)
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs): def decorated(installed_app: InstalledApp, *args, **kwargs):
feature = FeatureService.get_system_features() feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled: if feature.webapp_auth.enabled:
app_id = installed_app.app_id app_id = installed_app.app_id

View File

@ -1,8 +1,8 @@
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from controllers.console import api, console_ns from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import login_required from libs.login import login_required
@ -11,21 +11,7 @@ from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService
@console_ns.route("/code-based-extension")
class CodeBasedExtensionAPI(Resource): class CodeBasedExtensionAPI(Resource):
@api.doc("get_code_based_extension")
@api.doc(description="Get code-based extension data by module name")
@api.expect(
api.parser().add_argument("module", type=str, required=True, location="args", help="Extension module name")
)
@api.response(
200,
"Success",
api.model(
"CodeBasedExtensionResponse",
{"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")},
),
)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -37,11 +23,7 @@ class CodeBasedExtensionAPI(Resource):
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
@console_ns.route("/api-based-extension")
class APIBasedExtensionAPI(Resource): class APIBasedExtensionAPI(Resource):
@api.doc("get_api_based_extensions")
@api.doc(description="Get all API-based extensions for current tenant")
@api.response(200, "Success", fields.List(fields.Nested(api_based_extension_fields)))
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -50,19 +32,6 @@ class APIBasedExtensionAPI(Resource):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
@api.doc("create_api_based_extension")
@api.doc(description="Create a new API-based extension")
@api.expect(
api.model(
"CreateAPIBasedExtensionRequest",
{
"name": fields.String(required=True, description="Extension name"),
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
"api_key": fields.String(required=True, description="API key for authentication"),
},
)
)
@api.response(201, "Extension created successfully", api_based_extension_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -84,12 +53,7 @@ class APIBasedExtensionAPI(Resource):
return APIBasedExtensionService.save(extension_data) return APIBasedExtensionService.save(extension_data)
@console_ns.route("/api-based-extension/<uuid:id>")
class APIBasedExtensionDetailAPI(Resource): class APIBasedExtensionDetailAPI(Resource):
@api.doc("get_api_based_extension")
@api.doc(description="Get API-based extension by ID")
@api.doc(params={"id": "Extension ID"})
@api.response(200, "Success", api_based_extension_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -100,20 +64,6 @@ class APIBasedExtensionDetailAPI(Resource):
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
@api.doc("update_api_based_extension")
@api.doc(description="Update API-based extension")
@api.doc(params={"id": "Extension ID"})
@api.expect(
api.model(
"UpdateAPIBasedExtensionRequest",
{
"name": fields.String(required=True, description="Extension name"),
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
"api_key": fields.String(required=True, description="API key for authentication"),
},
)
)
@api.response(200, "Extension updated successfully", api_based_extension_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -138,10 +88,6 @@ class APIBasedExtensionDetailAPI(Resource):
return APIBasedExtensionService.save(extension_data_from_db) return APIBasedExtensionService.save(extension_data_from_db)
@api.doc("delete_api_based_extension")
@api.doc(description="Delete API-based extension")
@api.doc(params={"id": "Extension ID"})
@api.response(204, "Extension deleted successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -154,3 +100,9 @@ class APIBasedExtensionDetailAPI(Resource):
APIBasedExtensionService.delete(extension_data_from_db) APIBasedExtensionService.delete(extension_data_from_db)
return {"result": "success"}, 204 return {"result": "success"}, 204
api.add_resource(CodeBasedExtensionAPI, "/code-based-extension")
api.add_resource(APIBasedExtensionAPI, "/api-based-extension")
api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/<uuid:id>")

View File

@ -1,40 +1,26 @@
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields from flask_restx import Resource
from libs.login import login_required from libs.login import login_required
from services.feature_service import FeatureService from services.feature_service import FeatureService
from . import api, console_ns from . import api
from .wraps import account_initialization_required, cloud_utm_record, setup_required from .wraps import account_initialization_required, cloud_utm_record, setup_required
@console_ns.route("/features")
class FeatureApi(Resource): class FeatureApi(Resource):
@api.doc("get_tenant_features")
@api.doc(description="Get feature configuration for current tenant")
@api.response(
200,
"Success",
api.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}),
)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_utm_record @cloud_utm_record
def get(self): def get(self):
"""Get feature configuration for current tenant"""
return FeatureService.get_features(current_user.current_tenant_id).model_dump() return FeatureService.get_features(current_user.current_tenant_id).model_dump()
@console_ns.route("/system-features")
class SystemFeatureApi(Resource): class SystemFeatureApi(Resource):
@api.doc("get_system_features")
@api.doc(description="Get system-wide feature configuration")
@api.response(
200,
"Success",
api.model("SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}),
)
def get(self): def get(self):
"""Get system-wide feature configuration"""
return FeatureService.get_system_features().model_dump() return FeatureService.get_system_features().model_dump()
api.add_resource(FeatureApi, "/features")
api.add_resource(SystemFeatureApi, "/system-features")

View File

@ -22,7 +22,6 @@ from controllers.console.wraps import (
) )
from fields.file_fields import file_fields, upload_config_fields from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required from libs.login import login_required
from models import Account
from services.file_service import FileService from services.file_service import FileService
PREVIEW_WORDS_LIMIT = 3000 PREVIEW_WORDS_LIMIT = 3000
@ -69,8 +68,6 @@ class FileApi(Resource):
source = None source = None
try: try:
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
upload_file = FileService.upload_file( upload_file = FileService.upload_file(
filename=file.filename, filename=file.filename,
content=file.read(), content=file.read(),

View File

@ -1,7 +1,7 @@
import os import os
from flask import session from flask import session
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, reqparse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -11,47 +11,20 @@ from libs.helper import StrLen
from models.model import DifySetup from models.model import DifySetup
from services.account_service import TenantService from services.account_service import TenantService
from . import api, console_ns from . import api
from .error import AlreadySetupError, InitValidateFailedError from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted from .wraps import only_edition_self_hosted
@console_ns.route("/init")
class InitValidateAPI(Resource): class InitValidateAPI(Resource):
@api.doc("get_init_status")
@api.doc(description="Get initialization validation status")
@api.response(
200,
"Success",
model=api.model(
"InitStatusResponse",
{"status": fields.String(description="Initialization status", enum=["finished", "not_started"])},
),
)
def get(self): def get(self):
"""Get initialization validation status"""
init_status = get_init_validate_status() init_status = get_init_validate_status()
if init_status: if init_status:
return {"status": "finished"} return {"status": "finished"}
return {"status": "not_started"} return {"status": "not_started"}
@api.doc("validate_init_password")
@api.doc(description="Validate initialization password for self-hosted edition")
@api.expect(
api.model(
"InitValidateRequest",
{"password": fields.String(required=True, description="Initialization password", max_length=30)},
)
)
@api.response(
201,
"Success",
model=api.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
)
@api.response(400, "Already setup or validation failed")
@only_edition_self_hosted @only_edition_self_hosted
def post(self): def post(self):
"""Validate initialization password"""
# is tenant created # is tenant created
tenant_count = TenantService.get_tenant_count() tenant_count = TenantService.get_tenant_count()
if tenant_count > 0: if tenant_count > 0:
@ -79,3 +52,6 @@ def get_init_validate_status():
return db_session.execute(select(DifySetup)).scalar_one_or_none() return db_session.execute(select(DifySetup)).scalar_one_or_none()
return True return True
api.add_resource(InitValidateAPI, "/init")

View File

@ -1,17 +1,14 @@
from flask_restx import Resource, fields from flask_restx import Resource
from . import api, console_ns from controllers.console import api
@console_ns.route("/ping")
class PingApi(Resource): class PingApi(Resource):
@api.doc("health_check")
@api.doc(description="Health check endpoint for connection testing")
@api.response(
200,
"Success",
api.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
)
def get(self): def get(self):
"""Health check endpoint for connection testing""" """
For connection health check
"""
return {"result": "pong"} return {"result": "pong"}
api.add_resource(PingApi, "/ping")

View File

@ -1,5 +1,5 @@
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, reqparse
from configs import dify_config from configs import dify_config
from libs.helper import StrLen, email, extract_remote_ip from libs.helper import StrLen, email, extract_remote_ip
@ -7,56 +7,23 @@ from libs.password import valid_password
from models.model import DifySetup, db from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService
from . import api, console_ns from . import api
from .error import AlreadySetupError, NotInitValidateError from .error import AlreadySetupError, NotInitValidateError
from .init_validate import get_init_validate_status from .init_validate import get_init_validate_status
from .wraps import only_edition_self_hosted from .wraps import only_edition_self_hosted
@console_ns.route("/setup")
class SetupApi(Resource): class SetupApi(Resource):
@api.doc("get_setup_status")
@api.doc(description="Get system setup status")
@api.response(
200,
"Success",
api.model(
"SetupStatusResponse",
{
"step": fields.String(description="Setup step status", enum=["not_started", "finished"]),
"setup_at": fields.String(description="Setup completion time (ISO format)", required=False),
},
),
)
def get(self): def get(self):
"""Get system setup status"""
if dify_config.EDITION == "SELF_HOSTED": if dify_config.EDITION == "SELF_HOSTED":
setup_status = get_setup_status() setup_status = get_setup_status()
# Check if setup_status is a DifySetup object rather than a bool if setup_status:
if setup_status and not isinstance(setup_status, bool):
return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()} return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()}
elif setup_status:
return {"step": "finished"}
return {"step": "not_started"} return {"step": "not_started"}
return {"step": "finished"} return {"step": "finished"}
@api.doc("setup_system")
@api.doc(description="Initialize system setup with admin account")
@api.expect(
api.model(
"SetupRequest",
{
"email": fields.String(required=True, description="Admin email address"),
"name": fields.String(required=True, description="Admin name (max 30 characters)"),
"password": fields.String(required=True, description="Admin password"),
},
)
)
@api.response(201, "Success", api.model("SetupResponse", {"result": fields.String(description="Setup result")}))
@api.response(400, "Already setup or validation failed")
@only_edition_self_hosted @only_edition_self_hosted
def post(self): def post(self):
"""Initialize system setup with admin account"""
# is set up # is set up
if get_setup_status(): if get_setup_status():
raise AlreadySetupError() raise AlreadySetupError()
@ -88,3 +55,6 @@ def get_setup_status():
return db.session.query(DifySetup).first() return db.session.query(DifySetup).first()
else: else:
return True return True
api.add_resource(SetupApi, "/setup")

View File

@ -111,7 +111,7 @@ class TagBindingCreateApi(Resource):
args = parser.parse_args() args = parser.parse_args()
TagService.save_tag_binding(args) TagService.save_tag_binding(args)
return {"result": "success"}, 200 return 200
class TagBindingDeleteApi(Resource): class TagBindingDeleteApi(Resource):
@ -132,7 +132,7 @@ class TagBindingDeleteApi(Resource):
args = parser.parse_args() args = parser.parse_args()
TagService.delete_tag_binding(args) TagService.delete_tag_binding(args)
return {"result": "success"}, 200 return 200
api.add_resource(TagListApi, "/tags") api.add_resource(TagListApi, "/tags")

View File

@ -2,41 +2,18 @@ import json
import logging import logging
import requests import requests
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, reqparse
from packaging import version from packaging import version
from configs import dify_config from configs import dify_config
from . import api, console_ns from . import api
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route("/version")
class VersionApi(Resource): class VersionApi(Resource):
@api.doc("check_version_update")
@api.doc(description="Check for application version updates")
@api.expect(
api.parser().add_argument(
"current_version", type=str, required=True, location="args", help="Current application version"
)
)
@api.response(
200,
"Success",
api.model(
"VersionResponse",
{
"version": fields.String(description="Latest version number"),
"release_date": fields.String(description="Release date of latest version"),
"release_notes": fields.String(description="Release notes for latest version"),
"can_auto_update": fields.Boolean(description="Whether auto-update is supported"),
"features": fields.Raw(description="Feature flags and capabilities"),
},
),
)
def get(self): def get(self):
"""Check for application version updates"""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("current_version", type=str, required=True, location="args") parser.add_argument("current_version", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
@ -57,14 +34,14 @@ class VersionApi(Resource):
return result return result
try: try:
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10)) response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10))
except Exception as error: except Exception as error:
logger.warning("Check update version error: %s.", str(error)) logger.warning("Check update version error: %s.", str(error))
result["version"] = args["current_version"] result["version"] = args.get("current_version")
return result return result
content = json.loads(response.content) content = json.loads(response.content)
if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"): if _has_new_version(latest_version=content["version"], current_version=f"{args.get('current_version')}"):
result["version"] = content["version"] result["version"] = content["version"]
result["release_date"] = content["releaseDate"] result["release_date"] = content["releaseDate"]
result["release_notes"] = content["releaseNotes"] result["release_notes"] = content["releaseNotes"]
@ -82,3 +59,6 @@ def _has_new_version(*, latest_version: str, current_version: str) -> bool:
except version.InvalidVersion: except version.InvalidVersion:
logger.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version) logger.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version)
return False return False
api.add_resource(VersionApi, "/version")

View File

@ -1,6 +1,4 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask_login import current_user from flask_login import current_user
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -9,17 +7,14 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db from extensions.ext_database import db
from models.account import TenantPluginPermission from models.account import TenantPluginPermission
P = ParamSpec("P")
R = TypeVar("R")
def plugin_permission_required( def plugin_permission_required(
install_required: bool = False, install_required: bool = False,
debug_required: bool = False, debug_required: bool = False,
): ):
def interceptor(view: Callable[P, R]): def interceptor(view):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
user = current_user user = current_user
tenant_id = user.current_tenant_id tenant_id = user.current_tenant_id

View File

@ -49,8 +49,6 @@ class AccountInitApi(Resource):
@setup_required @setup_required
@login_required @login_required
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user account = current_user
if account.status == "active": if account.status == "active":
@ -104,8 +102,6 @@ class AccountProfileApi(Resource):
@marshal_with(account_fields) @marshal_with(account_fields)
@enterprise_license_required @enterprise_license_required
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
return current_user return current_user
@ -115,8 +111,6 @@ class AccountNameApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -136,8 +130,6 @@ class AccountAvatarApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("avatar", type=str, required=True, location="json") parser.add_argument("avatar", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -153,8 +145,6 @@ class AccountInterfaceLanguageApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("interface_language", type=supported_language, required=True, location="json") parser.add_argument("interface_language", type=supported_language, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -170,8 +160,6 @@ class AccountInterfaceThemeApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -187,8 +175,6 @@ class AccountTimezoneApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("timezone", type=str, required=True, location="json") parser.add_argument("timezone", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -208,8 +194,6 @@ class AccountPasswordApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("password", type=str, required=False, location="json") parser.add_argument("password", type=str, required=False, location="json")
parser.add_argument("new_password", type=str, required=True, location="json") parser.add_argument("new_password", type=str, required=True, location="json")
@ -244,13 +228,9 @@ class AccountIntegrateApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(integrate_list_fields) @marshal_with(integrate_list_fields)
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user account = current_user
account_integrates = db.session.scalars( account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all()
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
).all()
base_url = request.url_root.rstrip("/") base_url = request.url_root.rstrip("/")
oauth_base_path = "/console/api/oauth/login" oauth_base_path = "/console/api/oauth/login"
@ -288,8 +268,6 @@ class AccountDeleteVerifyApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user account = current_user
token, code = AccountService.generate_account_deletion_verification_code(account) token, code = AccountService.generate_account_deletion_verification_code(account)
@ -303,8 +281,6 @@ class AccountDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -345,8 +321,6 @@ class EducationVerifyApi(Resource):
@cloud_edition_billing_enabled @cloud_edition_billing_enabled
@marshal_with(verify_fields) @marshal_with(verify_fields)
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user account = current_user
return BillingService.EducationIdentity.verify(account.id, account.email) return BillingService.EducationIdentity.verify(account.id, account.email)
@ -366,8 +340,6 @@ class EducationApi(Resource):
@only_edition_cloud @only_edition_cloud
@cloud_edition_billing_enabled @cloud_edition_billing_enabled
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -385,8 +357,6 @@ class EducationApi(Resource):
@cloud_edition_billing_enabled @cloud_edition_billing_enabled
@marshal_with(status_fields) @marshal_with(status_fields)
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user account = current_user
res = BillingService.EducationIdentity.status(account.id) res = BillingService.EducationIdentity.status(account.id)
@ -451,8 +421,6 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError() raise InvalidTokenError()
user_email = reset_data.get("email", "") user_email = reset_data.get("email", "")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if user_email != current_user.email: if user_email != current_user.email:
raise InvalidEmailError() raise InvalidEmailError()
else: else:
@ -533,8 +501,6 @@ class ChangeEmailResetApi(Resource):
AccountService.revoke_change_email_token(args["token"]) AccountService.revoke_change_email_token(args["token"])
old_email = reset_data.get("old_email", "") old_email = reset_data.get("old_email", "")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if current_user.email != old_email: if current_user.email != old_email:
raise AccountNotFound() raise AccountNotFound()

View File

@ -1,22 +1,14 @@
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields from flask_restx import Resource
from controllers.console import api, console_ns from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required from libs.login import login_required
from services.agent_service import AgentService from services.agent_service import AgentService
@console_ns.route("/workspaces/current/agent-providers")
class AgentProviderListApi(Resource): class AgentProviderListApi(Resource):
@api.doc("list_agent_providers")
@api.doc(description="Get list of available agent providers")
@api.response(
200,
"Success",
fields.List(fields.Raw(description="Agent provider information")),
)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -29,16 +21,7 @@ class AgentProviderListApi(Resource):
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id)) return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
class AgentProviderApi(Resource): class AgentProviderApi(Resource):
@api.doc("get_agent_provider")
@api.doc(description="Get specific agent provider details")
@api.doc(params={"provider_name": "Agent provider name"})
@api.response(
200,
"Success",
fields.Raw(description="Agent provider details"),
)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -47,3 +30,7 @@ class AgentProviderApi(Resource):
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name)) return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers")
api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/<path:provider_name>")

View File

@ -1,8 +1,8 @@
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api, console_ns from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginPermissionDeniedError from core.plugin.impl.exc import PluginPermissionDeniedError
@ -10,26 +10,7 @@ from libs.login import login_required
from services.plugin.endpoint_service import EndpointService from services.plugin.endpoint_service import EndpointService
@console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource): class EndpointCreateApi(Resource):
@api.doc("create_endpoint")
@api.doc(description="Create a new plugin endpoint")
@api.expect(
api.model(
"EndpointCreateRequest",
{
"plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"),
"settings": fields.Raw(required=True, description="Endpoint settings"),
"name": fields.String(required=True, description="Endpoint name"),
},
)
)
@api.response(
200,
"Endpoint created successfully",
api.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
)
@api.response(403, "Admin privileges required")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -62,20 +43,7 @@ class EndpointCreateApi(Resource):
raise ValueError(e.description) from e raise ValueError(e.description) from e
@console_ns.route("/workspaces/current/endpoints/list")
class EndpointListApi(Resource): class EndpointListApi(Resource):
@api.doc("list_endpoints")
@api.doc(description="List plugin endpoints with pagination")
@api.expect(
api.parser()
.add_argument("page", type=int, required=True, location="args", help="Page number")
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
)
@api.response(
200,
"Success",
api.model("EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}),
)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -102,23 +70,7 @@ class EndpointListApi(Resource):
) )
@console_ns.route("/workspaces/current/endpoints/list/plugin")
class EndpointListForSinglePluginApi(Resource): class EndpointListForSinglePluginApi(Resource):
@api.doc("list_plugin_endpoints")
@api.doc(description="List endpoints for a specific plugin")
@api.expect(
api.parser()
.add_argument("page", type=int, required=True, location="args", help="Page number")
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
.add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID")
)
@api.response(
200,
"Success",
api.model(
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
),
)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -148,19 +100,7 @@ class EndpointListForSinglePluginApi(Resource):
) )
@console_ns.route("/workspaces/current/endpoints/delete")
class EndpointDeleteApi(Resource): class EndpointDeleteApi(Resource):
@api.doc("delete_endpoint")
@api.doc(description="Delete a plugin endpoint")
@api.expect(
api.model("EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
)
@api.response(
200,
"Endpoint deleted successfully",
api.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
)
@api.response(403, "Admin privileges required")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -183,26 +123,7 @@ class EndpointDeleteApi(Resource):
} }
@console_ns.route("/workspaces/current/endpoints/update")
class EndpointUpdateApi(Resource): class EndpointUpdateApi(Resource):
@api.doc("update_endpoint")
@api.doc(description="Update a plugin endpoint")
@api.expect(
api.model(
"EndpointUpdateRequest",
{
"endpoint_id": fields.String(required=True, description="Endpoint ID"),
"settings": fields.Raw(required=True, description="Updated settings"),
"name": fields.String(required=True, description="Updated name"),
},
)
)
@api.response(
200,
"Endpoint updated successfully",
api.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
)
@api.response(403, "Admin privileges required")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -233,19 +154,7 @@ class EndpointUpdateApi(Resource):
} }
@console_ns.route("/workspaces/current/endpoints/enable")
class EndpointEnableApi(Resource): class EndpointEnableApi(Resource):
@api.doc("enable_endpoint")
@api.doc(description="Enable a plugin endpoint")
@api.expect(
api.model("EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
)
@api.response(
200,
"Endpoint enabled successfully",
api.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
)
@api.response(403, "Admin privileges required")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -268,19 +177,7 @@ class EndpointEnableApi(Resource):
} }
@console_ns.route("/workspaces/current/endpoints/disable")
class EndpointDisableApi(Resource): class EndpointDisableApi(Resource):
@api.doc("disable_endpoint")
@api.doc(description="Disable a plugin endpoint")
@api.expect(
api.model("EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
)
@api.response(
200,
"Endpoint disabled successfully",
api.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
)
@api.response(403, "Admin privileges required")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -301,3 +198,12 @@ class EndpointDisableApi(Resource):
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
) )
} }
api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create")
api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list")
api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin")
api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete")
api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update")
api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable")
api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable")

View File

@ -1,8 +1,8 @@
from urllib import parse from urllib import parse
from flask import abort, request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, abort, marshal_with, reqparse
import services import services
from configs import dify_config from configs import dify_config
@ -41,10 +41,6 @@ class MemberListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_with_role_list_fields) @marshal_with(account_with_role_list_fields)
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant) members = TenantService.get_tenant_members(current_user.current_tenant)
return {"result": "success", "accounts": members}, 200 return {"result": "success", "accounts": members}, 200
@ -69,11 +65,7 @@ class MemberInviteEmailApi(Resource):
if not TenantAccountRole.is_non_owner_role(invitee_role): if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400 return {"code": "invalid-role", "message": "Invalid role"}, 400
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
inviter = current_user inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
invitation_results = [] invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL console_web_url = dify_config.CONSOLE_WEB_URL
@ -84,8 +76,6 @@ class MemberInviteEmailApi(Resource):
for invitee_email in invitee_emails: for invitee_email in invitee_emails:
try: try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member( token = RegisterService.invite_new_member(
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
) )
@ -107,7 +97,7 @@ class MemberInviteEmailApi(Resource):
return { return {
"result": "success", "result": "success",
"invitation_results": invitation_results, "invitation_results": invitation_results,
"tenant_id": str(inviter.current_tenant.id) if inviter.current_tenant else "", "tenant_id": str(current_user.current_tenant.id),
}, 201 }, 201
@ -118,10 +108,6 @@ class MemberCancelInviteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, member_id): def delete(self, member_id):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.query(Account).where(Account.id == str(member_id)).first() member = db.session.query(Account).where(Account.id == str(member_id)).first()
if member is None: if member is None:
abort(404) abort(404)
@ -137,10 +123,7 @@ class MemberCancelInviteApi(Resource):
except Exception as e: except Exception as e:
raise ValueError(str(e)) raise ValueError(str(e))
return { return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200
"result": "success",
"tenant_id": str(current_user.current_tenant.id) if current_user.current_tenant else "",
}, 200
class MemberUpdateRoleApi(Resource): class MemberUpdateRoleApi(Resource):
@ -158,10 +141,6 @@ class MemberUpdateRoleApi(Resource):
if not TenantAccountRole.is_valid_role(new_role): if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400 return {"code": "invalid-role", "message": "Invalid role"}, 400
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.get(Account, str(member_id)) member = db.session.get(Account, str(member_id))
if not member: if not member:
abort(404) abort(404)
@ -185,10 +164,6 @@ class DatasetOperatorMemberListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_with_role_list_fields) @marshal_with(account_with_role_list_fields)
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant) members = TenantService.get_dataset_operator_members(current_user.current_tenant)
return {"result": "success", "accounts": members}, 200 return {"result": "success", "accounts": members}, 200
@ -209,10 +184,6 @@ class SendOwnerTransferEmailApi(Resource):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
# check if the current user is the owner of the workspace # check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant): if not TenantService.is_owner(current_user, current_user.current_tenant):
raise NotOwnerError() raise NotOwnerError()
@ -227,7 +198,7 @@ class SendOwnerTransferEmailApi(Resource):
account=current_user, account=current_user,
email=email, email=email,
language=language, language=language,
workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", workspace_name=current_user.current_tenant.name,
) )
return {"result": "success", "data": token} return {"result": "success", "data": token}
@ -244,10 +215,6 @@ class OwnerTransferCheckApi(Resource):
parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
# check if the current user is the owner of the workspace # check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant): if not TenantService.is_owner(current_user, current_user.current_tenant):
raise NotOwnerError() raise NotOwnerError()
@ -289,10 +256,6 @@ class OwnerTransfer(Resource):
args = parser.parse_args() args = parser.parse_args()
# check if the current user is the owner of the workspace # check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant): if not TenantService.is_owner(current_user, current_user.current_tenant):
raise NotOwnerError() raise NotOwnerError()
@ -311,11 +274,9 @@ class OwnerTransfer(Resource):
member = db.session.get(Account, str(member_id)) member = db.session.get(Account, str(member_id))
if not member: if not member:
abort(404) abort(404)
return # Never reached, but helps type checker else:
member_account = member
if not current_user.current_tenant: if not TenantService.is_member(member_account, current_user.current_tenant):
raise ValueError("No current tenant")
if not TenantService.is_member(member, current_user.current_tenant):
raise MemberNotInTenantError() raise MemberNotInTenantError()
try: try:
@ -325,13 +286,13 @@ class OwnerTransfer(Resource):
AccountService.send_new_owner_transfer_notify_email( AccountService.send_new_owner_transfer_notify_email(
account=member, account=member,
email=member.email, email=member.email,
workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", workspace_name=current_user.current_tenant.name,
) )
AccountService.send_old_owner_transfer_notify_email( AccountService.send_old_owner_transfer_notify_email(
account=current_user, account=current_user,
email=current_user.email, email=current_user.email,
workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", workspace_name=current_user.current_tenant.name,
new_owner_email=member.email, new_owner_email=member.email,
) )

View File

@ -12,7 +12,6 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value from libs.helper import StrLen, uuid_value
from libs.login import login_required from libs.login import login_required
from models.account import Account
from services.billing_service import BillingService from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService from services.model_provider_service import ModelProviderService
@ -22,10 +21,6 @@ class ModelProviderListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -50,10 +45,6 @@ class ModelProviderCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider: str): def get(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
# if credential_id is not provided, return current used credential # if credential_id is not provided, return current used credential
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -71,20 +62,16 @@ class ModelProviderCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try: try:
model_provider_service.create_provider_credential( model_provider_service.create_provider_credential(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
@ -101,21 +88,17 @@ class ModelProviderCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def put(self, provider: str): def put(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try: try:
model_provider_service.update_provider_credential( model_provider_service.update_provider_credential(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
@ -133,16 +116,12 @@ class ModelProviderCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider: str): def delete(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential( model_provider_service.remove_provider_credential(
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
@ -156,16 +135,12 @@ class ModelProviderCredentialSwitchApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
service = ModelProviderService() service = ModelProviderService()
service.switch_active_provider_credential( service.switch_active_provider_credential(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
@ -175,35 +150,15 @@ class ModelProviderCredentialSwitchApi(Resource):
return {"result": "success"} return {"result": "success"}
class ModelProviderCredentialCancelApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
if not current_user.is_admin_or_owner:
raise Forbidden()
service = ModelProviderService()
service.cancel_provider_credential(
tenant_id=current_user.current_tenant_id,
provider=provider,
)
return {"result": "success"}
class ModelProviderValidateApi(Resource): class ModelProviderValidateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -250,13 +205,9 @@ class PreferredProviderTypeUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -285,11 +236,7 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
def get(self, provider: str): def get(self, provider: str):
if provider != "anthropic": if provider != "anthropic":
raise ValueError(f"provider name {provider} is invalid") raise ValueError(f"provider name {provider} is invalid")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
data = BillingService.get_model_provider_payment_link( data = BillingService.get_model_provider_payment_link(
provider_name=provider, provider_name=provider,
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
@ -305,9 +252,6 @@ api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-provider
api.add_resource( api.add_resource(
ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch" ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
) )
api.add_resource(
ModelProviderCredentialCancelApi, "/workspaces/current/model-providers/<path:provider>/credentials/cancel"
)
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate") api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
api.add_resource( api.add_resource(

View File

@ -219,11 +219,7 @@ class ModelProviderModelCredentialApi(Resource):
model_load_balancing_service = ModelLoadBalancingService() model_load_balancing_service = ModelLoadBalancingService()
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id, tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
provider=provider,
model=args["model"],
model_type=args["model_type"],
config_from=args.get("config_from", ""),
) )
if args.get("config_from", "") == "predefined-model": if args.get("config_from", "") == "predefined-model":
@ -267,7 +263,7 @@ class ModelProviderModelCredentialApi(Resource):
choices=[mt.value for mt in ModelType], choices=[mt.value for mt in ModelType],
location="json", location="json",
) )
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -313,7 +309,7 @@ class ModelProviderModelCredentialApi(Resource):
) )
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()

View File

@ -865,7 +865,6 @@ class ToolProviderMCPApi(Resource):
parser.add_argument( parser.add_argument(
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300 "sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
) )
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
args = parser.parse_args() args = parser.parse_args()
user = current_user user = current_user
if not is_valid_url(args["server_url"]): if not is_valid_url(args["server_url"]):
@ -882,7 +881,6 @@ class ToolProviderMCPApi(Resource):
server_identifier=args["server_identifier"], server_identifier=args["server_identifier"],
timeout=args["timeout"], timeout=args["timeout"],
sse_read_timeout=args["sse_read_timeout"], sse_read_timeout=args["sse_read_timeout"],
headers=args["headers"],
) )
) )
@ -900,7 +898,6 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json") parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json") parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
if not is_valid_url(args["server_url"]): if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]: if "[__HIDDEN__]" in args["server_url"]:
@ -918,7 +915,6 @@ class ToolProviderMCPApi(Resource):
server_identifier=args["server_identifier"], server_identifier=args["server_identifier"],
timeout=args.get("timeout"), timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"), sse_read_timeout=args.get("sse_read_timeout"),
headers=args.get("headers"),
) )
return {"result": "success"} return {"result": "success"}
@ -955,9 +951,6 @@ class ToolMCPAuthApi(Resource):
authed=False, authed=False,
authorization_code=args["authorization_code"], authorization_code=args["authorization_code"],
for_list=True, for_list=True,
headers=provider.decrypted_headers,
timeout=provider.timeout,
sse_read_timeout=provider.sse_read_timeout,
): ):
MCPToolManageService.update_mcp_provider_credentials( MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=provider, mcp_provider=provider,

View File

@ -25,7 +25,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField from libs.helper import TimestampField
from libs.login import login_required from libs.login import login_required
from models.account import Account, Tenant, TenantStatus from models.account import Tenant, TenantStatus
from services.account_service import TenantService from services.account_service import TenantService
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.file_service import FileService from services.file_service import FileService
@ -70,8 +70,6 @@ class TenantListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
tenants = TenantService.get_join_tenants(current_user) tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = [] tenant_dicts = []
@ -85,7 +83,7 @@ class TenantListApi(Resource):
"status": tenant.status, "status": tenant.status,
"created_at": tenant.created_at, "created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox", "plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
"current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False, "current": tenant.id == current_user.current_tenant_id,
} }
tenant_dicts.append(tenant_dict) tenant_dicts.append(tenant_dict)
@ -127,11 +125,7 @@ class TenantApi(Resource):
if request.path == "/info": if request.path == "/info":
logger.warning("Deprecated URL /info was used.") logger.warning("Deprecated URL /info was used.")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
tenant = current_user.current_tenant tenant = current_user.current_tenant
if not tenant:
raise ValueError("No current tenant")
if tenant.status == TenantStatus.ARCHIVE: if tenant.status == TenantStatus.ARCHIVE:
tenants = TenantService.get_join_tenants(current_user) tenants = TenantService.get_join_tenants(current_user)
@ -143,8 +137,6 @@ class TenantApi(Resource):
else: else:
raise Unauthorized("workspace is archived") raise Unauthorized("workspace is archived")
if not tenant:
raise ValueError("No tenant available")
return WorkspaceService.get_tenant_info(tenant), 200 return WorkspaceService.get_tenant_info(tenant), 200
@ -153,8 +145,6 @@ class SwitchWorkspaceApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json") parser.add_argument("tenant_id", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -178,15 +168,11 @@ class CustomConfigWorkspaceApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom") @cloud_edition_billing_resource_check("workspace_custom")
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("remove_webapp_brand", type=bool, location="json") parser.add_argument("remove_webapp_brand", type=bool, location="json")
parser.add_argument("replace_webapp_logo", type=str, location="json") parser.add_argument("replace_webapp_logo", type=str, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_user.current_tenant_id) tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
custom_config_dict = { custom_config_dict = {
@ -208,8 +194,6 @@ class WebappLogoWorkspaceApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom") @cloud_edition_billing_resource_check("workspace_custom")
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
# check file # check file
if "file" not in request.files: if "file" not in request.files:
raise NoFileUploadedError() raise NoFileUploadedError()
@ -248,14 +232,10 @@ class WorkspaceInfoApi(Resource):
@account_initialization_required @account_initialization_required
# Change workspace name # Change workspace name
def post(self): def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_user.current_tenant_id) tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
tenant.name = args["name"] tenant.name = args["name"]
db.session.commit() db.session.commit()

View File

@ -2,9 +2,7 @@ import contextlib
import json import json
import os import os
import time import time
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request from flask import abort, request
from flask_login import current_user from flask_login import current_user
@ -21,13 +19,10 @@ from services.operation_service import OperationService
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
P = ParamSpec("P")
R = TypeVar("R")
def account_initialization_required(view):
def account_initialization_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
# check account initialization # check account initialization
account = current_user account = current_user
@ -39,9 +34,9 @@ def account_initialization_required(view: Callable[P, R]):
return decorated return decorated
def only_edition_cloud(view: Callable[P, R]): def only_edition_cloud(view):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
if dify_config.EDITION != "CLOUD": if dify_config.EDITION != "CLOUD":
abort(404) abort(404)
@ -50,9 +45,9 @@ def only_edition_cloud(view: Callable[P, R]):
return decorated return decorated
def only_edition_enterprise(view: Callable[P, R]): def only_edition_enterprise(view):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
if not dify_config.ENTERPRISE_ENABLED: if not dify_config.ENTERPRISE_ENABLED:
abort(404) abort(404)
@ -61,9 +56,9 @@ def only_edition_enterprise(view: Callable[P, R]):
return decorated return decorated
def only_edition_self_hosted(view: Callable[P, R]): def only_edition_self_hosted(view):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
if dify_config.EDITION != "SELF_HOSTED": if dify_config.EDITION != "SELF_HOSTED":
abort(404) abort(404)
@ -72,9 +67,9 @@ def only_edition_self_hosted(view: Callable[P, R]):
return decorated return decorated
def cloud_edition_billing_enabled(view: Callable[P, R]): def cloud_edition_billing_enabled(view):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled: if not features.billing.enabled:
abort(403, "Billing feature is not enabled.") abort(403, "Billing feature is not enabled.")
@ -84,9 +79,9 @@ def cloud_edition_billing_enabled(view: Callable[P, R]):
def cloud_edition_billing_resource_check(resource: str): def cloud_edition_billing_resource_check(resource: str):
def interceptor(view: Callable[P, R]): def interceptor(view):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled: if features.billing.enabled:
members = features.members members = features.members
@ -125,9 +120,9 @@ def cloud_edition_billing_resource_check(resource: str):
def cloud_edition_billing_knowledge_limit_check(resource: str): def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view: Callable[P, R]): def interceptor(view):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled: if features.billing.enabled:
if resource == "add_segment": if resource == "add_segment":
@ -147,9 +142,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
def cloud_edition_billing_rate_limit_check(resource: str): def cloud_edition_billing_rate_limit_check(resource: str):
def interceptor(view: Callable[P, R]): def interceptor(view):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
if resource == "knowledge": if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id) knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
if knowledge_rate_limit.enabled: if knowledge_rate_limit.enabled:
@ -181,9 +176,9 @@ def cloud_edition_billing_rate_limit_check(resource: str):
return interceptor return interceptor
def cloud_utm_record(view: Callable[P, R]): def cloud_utm_record(view):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
@ -199,9 +194,9 @@ def cloud_utm_record(view: Callable[P, R]):
return decorated return decorated
def setup_required(view: Callable[P, R]): def setup_required(view):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
# check setup # check setup
if ( if (
dify_config.EDITION == "SELF_HOSTED" dify_config.EDITION == "SELF_HOSTED"
@ -217,9 +212,9 @@ def setup_required(view: Callable[P, R]):
return decorated return decorated
def enterprise_license_required(view: Callable[P, R]): def enterprise_license_required(view):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
settings = FeatureService.get_system_features() settings = FeatureService.get_system_features()
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]: if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.") raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
@ -229,9 +224,9 @@ def enterprise_license_required(view: Callable[P, R]):
return decorated return decorated
def email_password_login_enabled(view: Callable[P, R]): def email_password_login_enabled(view):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
features = FeatureService.get_system_features() features = FeatureService.get_system_features()
if features.enable_email_password_login: if features.enable_email_password_login:
return view(*args, **kwargs) return view(*args, **kwargs)
@ -242,9 +237,9 @@ def email_password_login_enabled(view: Callable[P, R]):
return decorated return decorated
def enable_change_email(view: Callable[P, R]): def enable_change_email(view):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
features = FeatureService.get_system_features() features = FeatureService.get_system_features()
if features.enable_change_email: if features.enable_change_email:
return view(*args, **kwargs) return view(*args, **kwargs)
@ -255,9 +250,9 @@ def enable_change_email(view: Callable[P, R]):
return decorated return decorated
def is_allow_transfer_owner(view: Callable[P, R]): def is_allow_transfer_owner(view):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.is_allow_transfer_workspace: if features.is_allow_transfer_workspace:
return view(*args, **kwargs) return view(*args, **kwargs)

View File

@ -10,10 +10,11 @@ api = ExternalApi(
version="1.0", version="1.0",
title="Files API", title="Files API",
description="API for file operations including upload and preview", description="API for file operations including upload and preview",
doc="/docs", # Enable Swagger UI at /files/docs
) )
files_ns = Namespace("files", description="File operations", path="/") files_ns = Namespace("files", description="File operations", path="/")
from . import image_preview, tool_files, upload # pyright: ignore[reportUnusedImport] from . import image_preview, tool_files, upload
api.add_namespace(files_ns) api.add_namespace(files_ns)

View File

@ -10,13 +10,14 @@ api = ExternalApi(
version="1.0", version="1.0",
title="Inner API", title="Inner API",
description="Internal APIs for enterprise features, billing, and plugin communication", description="Internal APIs for enterprise features, billing, and plugin communication",
doc="/docs", # Enable Swagger UI at /inner/api/docs
) )
# Create namespace # Create namespace
inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/") inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/")
from . import mail as _mail # pyright: ignore[reportUnusedImport] from . import mail
from .plugin import plugin as _plugin # pyright: ignore[reportUnusedImport] from .plugin import plugin
from .workspace import workspace as _workspace # pyright: ignore[reportUnusedImport] from .workspace import workspace
api.add_namespace(inner_api_ns) api.add_namespace(inner_api_ns)

View File

@ -37,9 +37,9 @@ from models.model import EndUser
@inner_api_ns.route("/invoke/llm") @inner_api_ns.route("/invoke/llm")
class PluginInvokeLLMApi(Resource): class PluginInvokeLLMApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeLLM) @plugin_data(payload_type=RequestInvokeLLM)
@inner_api_ns.doc("plugin_invoke_llm") @inner_api_ns.doc("plugin_invoke_llm")
@inner_api_ns.doc(description="Invoke LLM models through plugin interface") @inner_api_ns.doc(description="Invoke LLM models through plugin interface")
@ -60,9 +60,9 @@ class PluginInvokeLLMApi(Resource):
@inner_api_ns.route("/invoke/llm/structured-output") @inner_api_ns.route("/invoke/llm/structured-output")
class PluginInvokeLLMWithStructuredOutputApi(Resource): class PluginInvokeLLMWithStructuredOutputApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput) @plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput)
@inner_api_ns.doc("plugin_invoke_llm_structured") @inner_api_ns.doc("plugin_invoke_llm_structured")
@inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface") @inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface")
@ -85,9 +85,9 @@ class PluginInvokeLLMWithStructuredOutputApi(Resource):
@inner_api_ns.route("/invoke/text-embedding") @inner_api_ns.route("/invoke/text-embedding")
class PluginInvokeTextEmbeddingApi(Resource): class PluginInvokeTextEmbeddingApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTextEmbedding) @plugin_data(payload_type=RequestInvokeTextEmbedding)
@inner_api_ns.doc("plugin_invoke_text_embedding") @inner_api_ns.doc("plugin_invoke_text_embedding")
@inner_api_ns.doc(description="Invoke text embedding models through plugin interface") @inner_api_ns.doc(description="Invoke text embedding models through plugin interface")
@ -115,9 +115,9 @@ class PluginInvokeTextEmbeddingApi(Resource):
@inner_api_ns.route("/invoke/rerank") @inner_api_ns.route("/invoke/rerank")
class PluginInvokeRerankApi(Resource): class PluginInvokeRerankApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeRerank) @plugin_data(payload_type=RequestInvokeRerank)
@inner_api_ns.doc("plugin_invoke_rerank") @inner_api_ns.doc("plugin_invoke_rerank")
@inner_api_ns.doc(description="Invoke rerank models through plugin interface") @inner_api_ns.doc(description="Invoke rerank models through plugin interface")
@ -141,9 +141,9 @@ class PluginInvokeRerankApi(Resource):
@inner_api_ns.route("/invoke/tts") @inner_api_ns.route("/invoke/tts")
class PluginInvokeTTSApi(Resource): class PluginInvokeTTSApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTTS) @plugin_data(payload_type=RequestInvokeTTS)
@inner_api_ns.doc("plugin_invoke_tts") @inner_api_ns.doc("plugin_invoke_tts")
@inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface") @inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface")
@ -168,9 +168,9 @@ class PluginInvokeTTSApi(Resource):
@inner_api_ns.route("/invoke/speech2text") @inner_api_ns.route("/invoke/speech2text")
class PluginInvokeSpeech2TextApi(Resource): class PluginInvokeSpeech2TextApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeSpeech2Text) @plugin_data(payload_type=RequestInvokeSpeech2Text)
@inner_api_ns.doc("plugin_invoke_speech2text") @inner_api_ns.doc("plugin_invoke_speech2text")
@inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface") @inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface")
@ -194,9 +194,9 @@ class PluginInvokeSpeech2TextApi(Resource):
@inner_api_ns.route("/invoke/moderation") @inner_api_ns.route("/invoke/moderation")
class PluginInvokeModerationApi(Resource): class PluginInvokeModerationApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeModeration) @plugin_data(payload_type=RequestInvokeModeration)
@inner_api_ns.doc("plugin_invoke_moderation") @inner_api_ns.doc("plugin_invoke_moderation")
@inner_api_ns.doc(description="Invoke moderation models through plugin interface") @inner_api_ns.doc(description="Invoke moderation models through plugin interface")
@ -220,9 +220,9 @@ class PluginInvokeModerationApi(Resource):
@inner_api_ns.route("/invoke/tool") @inner_api_ns.route("/invoke/tool")
class PluginInvokeToolApi(Resource): class PluginInvokeToolApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTool) @plugin_data(payload_type=RequestInvokeTool)
@inner_api_ns.doc("plugin_invoke_tool") @inner_api_ns.doc("plugin_invoke_tool")
@inner_api_ns.doc(description="Invoke tools through plugin interface") @inner_api_ns.doc(description="Invoke tools through plugin interface")
@ -252,9 +252,9 @@ class PluginInvokeToolApi(Resource):
@inner_api_ns.route("/invoke/parameter-extractor") @inner_api_ns.route("/invoke/parameter-extractor")
class PluginInvokeParameterExtractorNodeApi(Resource): class PluginInvokeParameterExtractorNodeApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeParameterExtractorNode) @plugin_data(payload_type=RequestInvokeParameterExtractorNode)
@inner_api_ns.doc("plugin_invoke_parameter_extractor") @inner_api_ns.doc("plugin_invoke_parameter_extractor")
@inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface") @inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface")
@ -285,9 +285,9 @@ class PluginInvokeParameterExtractorNodeApi(Resource):
@inner_api_ns.route("/invoke/question-classifier") @inner_api_ns.route("/invoke/question-classifier")
class PluginInvokeQuestionClassifierNodeApi(Resource): class PluginInvokeQuestionClassifierNodeApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode) @plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
@inner_api_ns.doc("plugin_invoke_question_classifier") @inner_api_ns.doc("plugin_invoke_question_classifier")
@inner_api_ns.doc(description="Invoke question classifier node through plugin interface") @inner_api_ns.doc(description="Invoke question classifier node through plugin interface")
@ -318,9 +318,9 @@ class PluginInvokeQuestionClassifierNodeApi(Resource):
@inner_api_ns.route("/invoke/app") @inner_api_ns.route("/invoke/app")
class PluginInvokeAppApi(Resource): class PluginInvokeAppApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeApp) @plugin_data(payload_type=RequestInvokeApp)
@inner_api_ns.doc("plugin_invoke_app") @inner_api_ns.doc("plugin_invoke_app")
@inner_api_ns.doc(description="Invoke application through plugin interface") @inner_api_ns.doc(description="Invoke application through plugin interface")
@ -348,9 +348,9 @@ class PluginInvokeAppApi(Resource):
@inner_api_ns.route("/invoke/encrypt") @inner_api_ns.route("/invoke/encrypt")
class PluginInvokeEncryptApi(Resource): class PluginInvokeEncryptApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeEncrypt) @plugin_data(payload_type=RequestInvokeEncrypt)
@inner_api_ns.doc("plugin_invoke_encrypt") @inner_api_ns.doc("plugin_invoke_encrypt")
@inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface") @inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface")
@ -375,9 +375,9 @@ class PluginInvokeEncryptApi(Resource):
@inner_api_ns.route("/invoke/summary") @inner_api_ns.route("/invoke/summary")
class PluginInvokeSummaryApi(Resource): class PluginInvokeSummaryApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeSummary) @plugin_data(payload_type=RequestInvokeSummary)
@inner_api_ns.doc("plugin_invoke_summary") @inner_api_ns.doc("plugin_invoke_summary")
@inner_api_ns.doc(description="Invoke summary functionality through plugin interface") @inner_api_ns.doc(description="Invoke summary functionality through plugin interface")
@ -405,9 +405,9 @@ class PluginInvokeSummaryApi(Resource):
@inner_api_ns.route("/upload/file/request") @inner_api_ns.route("/upload/file/request")
class PluginUploadFileRequestApi(Resource): class PluginUploadFileRequestApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestRequestUploadFile) @plugin_data(payload_type=RequestRequestUploadFile)
@inner_api_ns.doc("plugin_upload_file_request") @inner_api_ns.doc("plugin_upload_file_request")
@inner_api_ns.doc(description="Request signed URL for file upload through plugin interface") @inner_api_ns.doc(description="Request signed URL for file upload through plugin interface")
@ -426,9 +426,9 @@ class PluginUploadFileRequestApi(Resource):
@inner_api_ns.route("/fetch/app/info") @inner_api_ns.route("/fetch/app/info")
class PluginFetchAppInfoApi(Resource): class PluginFetchAppInfoApi(Resource):
@get_user_tenant
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestFetchAppInfo) @plugin_data(payload_type=RequestFetchAppInfo)
@inner_api_ns.doc("plugin_fetch_app_info") @inner_api_ns.doc("plugin_fetch_app_info")
@inner_api_ns.doc(description="Fetch application information through plugin interface") @inner_api_ns.doc(description="Fetch application information through plugin interface")

View File

@ -1,6 +1,6 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Optional, ParamSpec, TypeVar, cast from typing import Optional
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
@ -8,72 +8,65 @@ from flask_restx import reqparse
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.file.constants import DEFAULT_SERVICE_API_USER_ID
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_user from libs.login import _get_user
from models.account import Tenant from models.account import Account, Tenant
from models.model import EndUser from models.model import EndUser
from services.account_service import AccountService
P = ParamSpec("P")
R = TypeVar("R")
def get_user(tenant_id: str, user_id: str | None) -> EndUser: def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
"""
Get current user
NOTE: user_id is not trusted, it could be maliciously set to any value.
As a result, it could only be considered as an end user id.
"""
try: try:
with Session(db.engine) as session: with Session(db.engine) as session:
if not user_id: if not user_id:
user_id = DEFAULT_SERVICE_API_USER_ID user_id = "DEFAULT-USER"
user_model = (
session.query(EndUser)
.where(
EndUser.session_id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
)
if not user_model:
user_model = EndUser(
tenant_id=tenant_id,
type="service_api",
is_anonymous=user_id == DEFAULT_SERVICE_API_USER_ID,
session_id=user_id,
)
session.add(user_model)
session.commit()
session.refresh(user_model)
if user_id == "DEFAULT-USER":
user_model = session.query(EndUser).where(EndUser.session_id == "DEFAULT-USER").first()
if not user_model:
user_model = EndUser(
tenant_id=tenant_id,
type="service_api",
is_anonymous=True if user_id == "DEFAULT-USER" else False,
session_id=user_id,
)
session.add(user_model)
session.commit()
session.refresh(user_model)
else:
user_model = AccountService.load_user(user_id)
if not user_model:
user_model = session.query(EndUser).where(EndUser.id == user_id).first()
if not user_model:
raise ValueError("user not found")
except Exception: except Exception:
raise ValueError("user not found") raise ValueError("user not found")
return user_model return user_model
def get_user_tenant(view: Optional[Callable[P, R]] = None): def get_user_tenant(view: Optional[Callable] = None):
def decorator(view_func: Callable[P, R]): def decorator(view_func):
@wraps(view_func) @wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs): def decorated_view(*args, **kwargs):
# fetch json body # fetch json body
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json") parser.add_argument("tenant_id", type=str, required=True, location="json")
parser.add_argument("user_id", type=str, required=True, location="json") parser.add_argument("user_id", type=str, required=True, location="json")
p = parser.parse_args() kwargs = parser.parse_args()
user_id = cast(str, p.get("user_id")) user_id = kwargs.get("user_id")
tenant_id = cast(str, p.get("tenant_id")) tenant_id = kwargs.get("tenant_id")
if not tenant_id: if not tenant_id:
raise ValueError("tenant_id is required") raise ValueError("tenant_id is required")
if not user_id: if not user_id:
user_id = DEFAULT_SERVICE_API_USER_ID user_id = "DEFAULT-USER"
del kwargs["tenant_id"]
del kwargs["user_id"]
try: try:
tenant_model = ( tenant_model = (
@ -95,7 +88,7 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None):
kwargs["user_model"] = user kwargs["user_model"] = user
current_app.login_manager._update_request_context_with_user(user) # type: ignore current_app.login_manager._update_request_context_with_user(user) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
return view_func(*args, **kwargs) return view_func(*args, **kwargs)
@ -107,9 +100,9 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None):
return decorator(view) return decorator(view)
def plugin_data(view: Optional[Callable[P, R]] = None, *, payload_type: type[BaseModel]): def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
def decorator(view_func: Callable[P, R]): def decorator(view_func):
def decorated_view(*args: P.args, **kwargs: P.kwargs): def decorated_view(*args, **kwargs):
try: try:
data = request.get_json() data = request.get_json()
except Exception: except Exception:

View File

@ -46,9 +46,9 @@ def enterprise_inner_api_only(view: Callable[P, R]):
return decorated return decorated
def enterprise_inner_api_user_auth(view: Callable[P, R]): def enterprise_inner_api_user_auth(view):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args, **kwargs):
if not dify_config.INNER_API: if not dify_config.INNER_API:
return view(*args, **kwargs) return view(*args, **kwargs)

View File

@ -10,10 +10,11 @@ api = ExternalApi(
version="1.0", version="1.0",
title="MCP API", title="MCP API",
description="API for Model Context Protocol operations", description="API for Model Context Protocol operations",
doc="/docs", # Enable Swagger UI at /mcp/docs
) )
mcp_ns = Namespace("mcp", description="MCP operations", path="/") mcp_ns = Namespace("mcp", description="MCP operations", path="/")
from . import mcp # pyright: ignore[reportUnusedImport] from . import mcp
api.add_namespace(mcp_ns) api.add_namespace(mcp_ns)

View File

@ -99,7 +99,7 @@ class MCPAppApi(Resource):
return mcp_server, app return mcp_server, app
def _validate_server_status(self, mcp_server: AppMCPServer): def _validate_server_status(self, mcp_server: AppMCPServer) -> None:
"""Validate MCP server status""" """Validate MCP server status"""
if mcp_server.status != AppMCPServerStatus.ACTIVE: if mcp_server.status != AppMCPServerStatus.ACTIVE:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active") raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")

View File

@ -10,31 +10,14 @@ api = ExternalApi(
version="1.0", version="1.0",
title="Service API", title="Service API",
description="API for application services", description="API for application services",
doc="/docs", # Enable Swagger UI at /v1/docs
) )
service_api_ns = Namespace("service_api", description="Service operations", path="/") service_api_ns = Namespace("service_api", description="Service operations", path="/")
from . import index # pyright: ignore[reportUnusedImport] from . import index
from .app import ( from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow
annotation, # pyright: ignore[reportUnusedImport] from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
app, # pyright: ignore[reportUnusedImport] from .workspace import models
audio, # pyright: ignore[reportUnusedImport]
completion, # pyright: ignore[reportUnusedImport]
conversation, # pyright: ignore[reportUnusedImport]
file, # pyright: ignore[reportUnusedImport]
file_preview, # pyright: ignore[reportUnusedImport]
message, # pyright: ignore[reportUnusedImport]
site, # pyright: ignore[reportUnusedImport]
workflow, # pyright: ignore[reportUnusedImport]
)
from .dataset import (
dataset, # pyright: ignore[reportUnusedImport]
document, # pyright: ignore[reportUnusedImport]
hit_testing, # pyright: ignore[reportUnusedImport]
metadata, # pyright: ignore[reportUnusedImport]
segment, # pyright: ignore[reportUnusedImport]
upload_file, # pyright: ignore[reportUnusedImport]
)
from .workspace import models # pyright: ignore[reportUnusedImport]
api.add_namespace(service_api_ns) api.add_namespace(service_api_ns)

View File

@ -165,7 +165,7 @@ class AnnotationUpdateDeleteApi(Resource):
def put(self, app_model: App, annotation_id): def put(self, app_model: App, annotation_id):
"""Update an existing annotation.""" """Update an existing annotation."""
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
annotation_id = str(annotation_id) annotation_id = str(annotation_id)
@ -189,7 +189,7 @@ class AnnotationUpdateDeleteApi(Resource):
"""Delete an annotation.""" """Delete an annotation."""
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
annotation_id = str(annotation_id) annotation_id = str(annotation_id)

View File

@ -55,7 +55,7 @@ class AudioApi(Resource):
file = request.files["file"] file = request.files["file"]
try: try:
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.id) response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
return response return response
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:

View File

@ -1,5 +1,4 @@
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from flask_restx._http import HTTPStatus
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound from werkzeug.exceptions import BadRequest, NotFound
@ -122,7 +121,7 @@ class ConversationDetailApi(Resource):
} }
) )
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT) @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=204)
def delete(self, app_model: App, end_user: EndUser, c_id): def delete(self, app_model: App, end_user: EndUser, c_id):
"""Delete a specific conversation.""" """Delete a specific conversation."""
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)

View File

@ -59,7 +59,7 @@ class FilePreviewApi(Resource):
args = file_preview_parser.parse_args() args = file_preview_parser.parse_args()
# Validate file ownership and get file objects # Validate file ownership and get file objects
_, upload_file = self._validate_file_ownership(file_id, app_model.id) message_file, upload_file = self._validate_file_ownership(file_id, app_model.id)
# Get file content generator # Get file content generator
try: try:

View File

@ -559,7 +559,7 @@ class DatasetTagsApi(DatasetApiResource):
def post(self, _, dataset_id): def post(self, _, dataset_id):
"""Add a knowledge type tag.""" """Add a knowledge type tag."""
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
args = tag_create_parser.parse_args() args = tag_create_parser.parse_args()
@ -583,7 +583,7 @@ class DatasetTagsApi(DatasetApiResource):
@validate_dataset_token @validate_dataset_token
def patch(self, _, dataset_id): def patch(self, _, dataset_id):
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
args = tag_update_parser.parse_args() args = tag_update_parser.parse_args()
@ -610,7 +610,7 @@ class DatasetTagsApi(DatasetApiResource):
def delete(self, _, dataset_id): def delete(self, _, dataset_id):
"""Delete a knowledge type tag.""" """Delete a knowledge type tag."""
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
if not current_user.has_edit_permission: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
args = tag_delete_parser.parse_args() args = tag_delete_parser.parse_args()
TagService.delete_tag(args.get("tag_id")) TagService.delete_tag(args.get("tag_id"))
@ -634,7 +634,7 @@ class DatasetTagBindingApi(DatasetApiResource):
def post(self, _, dataset_id): def post(self, _, dataset_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
args = tag_binding_parser.parse_args() args = tag_binding_parser.parse_args()
@ -660,7 +660,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
def post(self, _, dataset_id): def post(self, _, dataset_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
args = tag_unbinding_parser.parse_args() args = tag_unbinding_parser.parse_args()

View File

@ -30,7 +30,6 @@ from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from models.model import EndUser
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService from services.file_service import FileService
@ -299,9 +298,6 @@ class DocumentAddByFileApi(DatasetApiResource):
if not file.filename: if not file.filename:
raise FilenameNotExistsError raise FilenameNotExistsError
if not isinstance(current_user, EndUser):
raise ValueError("Invalid user account")
upload_file = FileService.upload_file( upload_file = FileService.upload_file(
filename=file.filename, filename=file.filename,
content=file.read(), content=file.read(),
@ -391,8 +387,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
raise FilenameNotExistsError raise FilenameNotExistsError
try: try:
if not isinstance(current_user, EndUser):
raise ValueError("Invalid user account")
upload_file = FileService.upload_file( upload_file = FileService.upload_file(
filename=file.filename, filename=file.filename,
content=file.read(), content=file.read(),
@ -416,7 +410,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
DocumentService.document_create_args_validate(knowledge_config) DocumentService.document_create_args_validate(knowledge_config)
try: try:
documents, _ = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
knowledge_config=knowledge_config, knowledge_config=knowledge_config,
account=dataset.created_by_account, account=dataset.created_by_account,

Some files were not shown because too many files have changed in this diff Show More