mirror of
https://github.com/langgenius/dify.git
synced 2026-04-21 03:07:39 +08:00
Compare commits
10 Commits
fix/unsele
...
pinecone
| Author | SHA1 | Date | |
|---|---|---|---|
| 594906c1ff | |||
| 80f8245f2e | |||
| a12b437c16 | |||
| 12de554313 | |||
| 1f36c0c1c5 | |||
| 8b9297563c | |||
| 1cbe9eedb6 | |||
| 90fc5a1f12 | |||
| 41dfdf1ac0 | |||
| dd7de74aa6 |
15
.github/workflows/api-tests.yml
vendored
15
.github/workflows/api-tests.yml
vendored
@ -42,7 +42,11 @@ jobs:
|
|||||||
- name: Run Unit tests
|
- name: Run Unit tests
|
||||||
run: |
|
run: |
|
||||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||||
|
- name: Run ty check
|
||||||
|
run: |
|
||||||
|
cd api
|
||||||
|
uv add --dev ty
|
||||||
|
uv run ty check || true
|
||||||
- name: Run pyrefly check
|
- name: Run pyrefly check
|
||||||
run: |
|
run: |
|
||||||
cd api
|
cd api
|
||||||
@ -62,6 +66,15 @@ jobs:
|
|||||||
- name: Run dify config tests
|
- name: Run dify config tests
|
||||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||||
|
|
||||||
|
- name: MyPy Cache
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: api/.mypy_cache
|
||||||
|
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/uv.lock') }}
|
||||||
|
|
||||||
|
- name: Run MyPy Checks
|
||||||
|
run: dev/mypy-check
|
||||||
|
|
||||||
- name: Set up dotenvs
|
- name: Set up dotenvs
|
||||||
run: |
|
run: |
|
||||||
cp docker/.env.example docker/.env
|
cp docker/.env.example docker/.env
|
||||||
|
|||||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@ -20,7 +20,7 @@ jobs:
|
|||||||
cd api
|
cd api
|
||||||
uv sync --dev
|
uv sync --dev
|
||||||
# Fix lint errors
|
# Fix lint errors
|
||||||
uv run ruff check --fix .
|
uv run ruff check --fix-only .
|
||||||
# Format code
|
# Format code
|
||||||
uv run ruff format .
|
uv run ruff format .
|
||||||
- name: ast-grep
|
- name: ast-grep
|
||||||
|
|||||||
28
.github/workflows/deploy-enterprise.yml
vendored
28
.github/workflows/deploy-enterprise.yml
vendored
@ -19,23 +19,11 @@ jobs:
|
|||||||
github.event.workflow_run.head_branch == 'deploy/enterprise'
|
github.event.workflow_run.head_branch == 'deploy/enterprise'
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: trigger deployments
|
- name: Deploy to server
|
||||||
env:
|
uses: appleboy/ssh-action@v0.1.8
|
||||||
DEV_ENV_ADDRS: ${{ vars.DEV_ENV_ADDRS }}
|
with:
|
||||||
DEPLOY_SECRET: ${{ secrets.DEPLOY_SECRET }}
|
host: ${{ secrets.ENTERPRISE_SSH_HOST }}
|
||||||
run: |
|
username: ${{ secrets.ENTERPRISE_SSH_USER }}
|
||||||
IFS=',' read -ra ENDPOINTS <<< "${DEV_ENV_ADDRS:-}"
|
password: ${{ secrets.ENTERPRISE_SSH_PASSWORD }}
|
||||||
BODY='{"project":"dify-api","tag":"deploy-enterprise"}'
|
script: |
|
||||||
|
${{ vars.ENTERPRISE_SSH_SCRIPT || secrets.ENTERPRISE_SSH_SCRIPT }}
|
||||||
for ENDPOINT in "${ENDPOINTS[@]}"; do
|
|
||||||
ENDPOINT="$(echo "$ENDPOINT" | xargs)"
|
|
||||||
[ -z "$ENDPOINT" ] && continue
|
|
||||||
|
|
||||||
API_SIGNATURE=$(printf '%s' "$BODY" | openssl dgst -sha256 -hmac "$DEPLOY_SECRET" | awk '{print "sha256="$2}')
|
|
||||||
|
|
||||||
curl -sSf -X POST \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-H "X-Hub-Signature-256: $API_SIGNATURE" \
|
|
||||||
-d "$BODY" \
|
|
||||||
"$ENDPOINT"
|
|
||||||
done
|
|
||||||
|
|||||||
12
.github/workflows/style.yml
vendored
12
.github/workflows/style.yml
vendored
@ -44,14 +44,6 @@ jobs:
|
|||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
run: uv sync --project api --dev
|
run: uv sync --project api --dev
|
||||||
|
|
||||||
- name: Run Basedpyright Checks
|
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
|
||||||
run: dev/basedpyright-check
|
|
||||||
|
|
||||||
- name: Run Mypy Type Checks
|
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
|
||||||
run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
|
||||||
|
|
||||||
- name: Dotenv check
|
- name: Dotenv check
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
|
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
|
||||||
@ -97,9 +89,7 @@ jobs:
|
|||||||
- name: Web style check
|
- name: Web style check
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
working-directory: ./web
|
working-directory: ./web
|
||||||
run: |
|
run: pnpm run lint
|
||||||
pnpm run lint
|
|
||||||
pnpm run eslint
|
|
||||||
|
|
||||||
docker-compose-template:
|
docker-compose-template:
|
||||||
name: Docker Compose Template
|
name: Docker Compose Template
|
||||||
|
|||||||
@ -67,22 +67,12 @@ jobs:
|
|||||||
working-directory: ./web
|
working-directory: ./web
|
||||||
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
||||||
|
|
||||||
- name: Generate i18n type definitions
|
|
||||||
if: env.FILES_CHANGED == 'true'
|
|
||||||
working-directory: ./web
|
|
||||||
run: pnpm run gen:i18n-types
|
|
||||||
|
|
||||||
- name: Create Pull Request
|
- name: Create Pull Request
|
||||||
if: env.FILES_CHANGED == 'true'
|
if: env.FILES_CHANGED == 'true'
|
||||||
uses: peter-evans/create-pull-request@v6
|
uses: peter-evans/create-pull-request@v6
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
commit-message: Update i18n files and type definitions based on en-US changes
|
commit-message: Update i18n files based on en-US changes
|
||||||
title: 'chore: translate i18n files and update type definitions'
|
title: 'chore: translate i18n files'
|
||||||
body: |
|
body: This PR was automatically created to update i18n files based on changes in en-US locale.
|
||||||
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
|
|
||||||
|
|
||||||
**Changes included:**
|
|
||||||
- Updated translation files for all locales
|
|
||||||
- Regenerated TypeScript type definitions for type safety
|
|
||||||
branch: chore/automated-i18n-updates
|
branch: chore/automated-i18n-updates
|
||||||
|
|||||||
5
.github/workflows/web-tests.yml
vendored
5
.github/workflows/web-tests.yml
vendored
@ -47,11 +47,6 @@ jobs:
|
|||||||
working-directory: ./web
|
working-directory: ./web
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
- name: Check i18n types synchronization
|
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
|
||||||
working-directory: ./web
|
|
||||||
run: pnpm run check:i18n-types
|
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
working-directory: ./web
|
working-directory: ./web
|
||||||
|
|||||||
13
.gitignore
vendored
13
.gitignore
vendored
@ -123,12 +123,10 @@ venv.bak/
|
|||||||
# mkdocs documentation
|
# mkdocs documentation
|
||||||
/site
|
/site
|
||||||
|
|
||||||
# type checking
|
# mypy
|
||||||
.mypy_cache/
|
.mypy_cache/
|
||||||
.dmypy.json
|
.dmypy.json
|
||||||
dmypy.json
|
dmypy.json
|
||||||
pyrightconfig.json
|
|
||||||
!api/pyrightconfig.json
|
|
||||||
|
|
||||||
# Pyre type checker
|
# Pyre type checker
|
||||||
.pyre/
|
.pyre/
|
||||||
@ -197,8 +195,8 @@ sdks/python-client/dify_client.egg-info
|
|||||||
.vscode/*
|
.vscode/*
|
||||||
!.vscode/launch.json.template
|
!.vscode/launch.json.template
|
||||||
!.vscode/README.md
|
!.vscode/README.md
|
||||||
|
pyrightconfig.json
|
||||||
api/.vscode
|
api/.vscode
|
||||||
web/.vscode
|
|
||||||
# vscode Code History Extension
|
# vscode Code History Extension
|
||||||
.history
|
.history
|
||||||
|
|
||||||
@ -216,13 +214,6 @@ mise.toml
|
|||||||
# Next.js build output
|
# Next.js build output
|
||||||
.next/
|
.next/
|
||||||
|
|
||||||
# PWA generated files
|
|
||||||
web/public/sw.js
|
|
||||||
web/public/sw.js.map
|
|
||||||
web/public/workbox-*.js
|
|
||||||
web/public/workbox-*.js.map
|
|
||||||
web/public/fallback-*.js
|
|
||||||
|
|
||||||
# AI Assistant
|
# AI Assistant
|
||||||
.roo/
|
.roo/
|
||||||
api/.env.backup
|
api/.env.backup
|
||||||
|
|||||||
@ -32,7 +32,7 @@ uv run --project api pytest tests/integration_tests/ # Integration tests
|
|||||||
./dev/reformat # Run all formatters and linters
|
./dev/reformat # Run all formatters and linters
|
||||||
uv run --project api ruff check --fix ./ # Fix linting issues
|
uv run --project api ruff check --fix ./ # Fix linting issues
|
||||||
uv run --project api ruff format ./ # Format code
|
uv run --project api ruff format ./ # Format code
|
||||||
uv run --directory api basedpyright # Type checking
|
uv run --project api mypy . # Type checking
|
||||||
```
|
```
|
||||||
|
|
||||||
### Frontend (Web)
|
### Frontend (Web)
|
||||||
|
|||||||
60
Makefile
60
Makefile
@ -4,48 +4,6 @@ WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
|
|||||||
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
|
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
|
||||||
VERSION=latest
|
VERSION=latest
|
||||||
|
|
||||||
# Backend Development Environment Setup
|
|
||||||
.PHONY: dev-setup prepare-docker prepare-web prepare-api
|
|
||||||
|
|
||||||
# Default dev setup target
|
|
||||||
dev-setup: prepare-docker prepare-web prepare-api
|
|
||||||
@echo "✅ Backend development environment setup complete!"
|
|
||||||
|
|
||||||
# Step 1: Prepare Docker middleware
|
|
||||||
prepare-docker:
|
|
||||||
@echo "🐳 Setting up Docker middleware..."
|
|
||||||
@cp -n docker/middleware.env.example docker/middleware.env 2>/dev/null || echo "Docker middleware.env already exists"
|
|
||||||
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev up -d
|
|
||||||
@echo "✅ Docker middleware started"
|
|
||||||
|
|
||||||
# Step 2: Prepare web environment
|
|
||||||
prepare-web:
|
|
||||||
@echo "🌐 Setting up web environment..."
|
|
||||||
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
|
||||||
@cd web && pnpm install
|
|
||||||
@cd web && pnpm build
|
|
||||||
@echo "✅ Web environment prepared (not started)"
|
|
||||||
|
|
||||||
# Step 3: Prepare API environment
|
|
||||||
prepare-api:
|
|
||||||
@echo "🔧 Setting up API environment..."
|
|
||||||
@cp -n api/.env.example api/.env 2>/dev/null || echo "API .env already exists"
|
|
||||||
@cd api && uv sync --dev
|
|
||||||
@cd api && uv run flask db upgrade
|
|
||||||
@echo "✅ API environment prepared (not started)"
|
|
||||||
|
|
||||||
# Clean dev environment
|
|
||||||
dev-clean:
|
|
||||||
@echo "⚠️ Stopping Docker containers..."
|
|
||||||
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev down
|
|
||||||
@echo "🗑️ Removing volumes..."
|
|
||||||
@rm -rf docker/volumes/db
|
|
||||||
@rm -rf docker/volumes/redis
|
|
||||||
@rm -rf docker/volumes/plugin_daemon
|
|
||||||
@rm -rf docker/volumes/weaviate
|
|
||||||
@rm -rf api/storage
|
|
||||||
@echo "✅ Cleanup complete"
|
|
||||||
|
|
||||||
# Build Docker images
|
# Build Docker images
|
||||||
build-web:
|
build-web:
|
||||||
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
||||||
@ -81,21 +39,5 @@ build-push-web: build-web push-web
|
|||||||
build-push-all: build-all push-all
|
build-push-all: build-all push-all
|
||||||
@echo "All Docker images have been built and pushed."
|
@echo "All Docker images have been built and pushed."
|
||||||
|
|
||||||
# Help target
|
|
||||||
help:
|
|
||||||
@echo "Development Setup Targets:"
|
|
||||||
@echo " make dev-setup - Run all setup steps for backend dev environment"
|
|
||||||
@echo " make prepare-docker - Set up Docker middleware"
|
|
||||||
@echo " make prepare-web - Set up web environment"
|
|
||||||
@echo " make prepare-api - Set up API environment"
|
|
||||||
@echo " make dev-clean - Stop Docker middleware containers"
|
|
||||||
@echo ""
|
|
||||||
@echo "Docker Build Targets:"
|
|
||||||
@echo " make build-web - Build web Docker image"
|
|
||||||
@echo " make build-api - Build API Docker image"
|
|
||||||
@echo " make build-all - Build all Docker images"
|
|
||||||
@echo " make push-all - Push all Docker images"
|
|
||||||
@echo " make build-push-all - Build and push all Docker images"
|
|
||||||
|
|
||||||
# Phony targets
|
# Phony targets
|
||||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help
|
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all
|
||||||
|
|||||||
@ -75,7 +75,6 @@ DB_PASSWORD=difyai123456
|
|||||||
DB_HOST=localhost
|
DB_HOST=localhost
|
||||||
DB_PORT=5432
|
DB_PORT=5432
|
||||||
DB_DATABASE=dify
|
DB_DATABASE=dify
|
||||||
SQLALCHEMY_POOL_PRE_PING=true
|
|
||||||
|
|
||||||
# Storage configuration
|
# Storage configuration
|
||||||
# use for store upload files, private keys...
|
# use for store upload files, private keys...
|
||||||
@ -157,7 +156,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
|||||||
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||||
|
|
||||||
# Vector database configuration
|
# Vector database configuration
|
||||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `pinecone`.
|
||||||
VECTOR_STORE=weaviate
|
VECTOR_STORE=weaviate
|
||||||
# Prefix used to create collection name in vector database
|
# Prefix used to create collection name in vector database
|
||||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||||
@ -362,6 +361,17 @@ PROMPT_GENERATION_MAX_TOKENS=512
|
|||||||
CODE_GENERATION_MAX_TOKENS=1024
|
CODE_GENERATION_MAX_TOKENS=1024
|
||||||
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
|
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
|
||||||
|
|
||||||
|
|
||||||
|
# Pinecone configuration, only available when VECTOR_STORE is `pinecone`
|
||||||
|
PINECONE_API_KEY=your-pinecone-api-key
|
||||||
|
PINECONE_ENVIRONMENT=your-pinecone-environment
|
||||||
|
PINECONE_INDEX_NAME=dify-index
|
||||||
|
PINECONE_CLIENT_TIMEOUT=30
|
||||||
|
PINECONE_BATCH_SIZE=100
|
||||||
|
PINECONE_METRIC=cosine
|
||||||
|
PINECONE_PODS=1
|
||||||
|
PINECONE_POD_TYPE=s1
|
||||||
|
|
||||||
# Mail configuration, support: resend, smtp, sendgrid
|
# Mail configuration, support: resend, smtp, sendgrid
|
||||||
MAIL_TYPE=
|
MAIL_TYPE=
|
||||||
# If using SendGrid, use the 'from' field for authentication if necessary.
|
# If using SendGrid, use the 'from' field for authentication if necessary.
|
||||||
@ -569,7 +579,3 @@ QUEUE_MONITOR_INTERVAL=30
|
|||||||
# Swagger UI configuration
|
# Swagger UI configuration
|
||||||
SWAGGER_UI_ENABLED=true
|
SWAGGER_UI_ENABLED=true
|
||||||
SWAGGER_UI_PATH=/swagger-ui.html
|
SWAGGER_UI_PATH=/swagger-ui.html
|
||||||
|
|
||||||
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
|
|
||||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
|
||||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
|
||||||
|
|||||||
@ -45,7 +45,6 @@ select = [
|
|||||||
"G001", # don't use str format to logging messages
|
"G001", # don't use str format to logging messages
|
||||||
"G003", # don't use + in logging messages
|
"G003", # don't use + in logging messages
|
||||||
"G004", # don't use f-strings to format logging messages
|
"G004", # don't use f-strings to format logging messages
|
||||||
"UP042", # use StrEnum
|
|
||||||
]
|
]
|
||||||
|
|
||||||
ignore = [
|
ignore = [
|
||||||
|
|||||||
@ -108,5 +108,5 @@ uv run celery -A app.celery beat
|
|||||||
../dev/reformat # Run all formatters and linters
|
../dev/reformat # Run all formatters and linters
|
||||||
uv run ruff check --fix ./ # Fix linting issues
|
uv run ruff check --fix ./ # Fix linting issues
|
||||||
uv run ruff format ./ # Format code
|
uv run ruff format ./ # Format code
|
||||||
uv run basedpyright . # Type checking
|
uv run mypy . # Type checking
|
||||||
```
|
```
|
||||||
|
|||||||
@ -25,9 +25,6 @@ def create_flask_app_with_configs() -> DifyApp:
|
|||||||
# add an unique identifier to each request
|
# add an unique identifier to each request
|
||||||
RecyclableContextVar.increment_thread_recycles()
|
RecyclableContextVar.increment_thread_recycles()
|
||||||
|
|
||||||
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
|
||||||
_ = before_request
|
|
||||||
|
|
||||||
return dify_app
|
return dify_app
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
11
api/child_class.py
Normal file
11
api/child_class.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from tests.integration_tests.utils.parent_class import ParentClass
|
||||||
|
|
||||||
|
|
||||||
|
class ChildClass(ParentClass):
|
||||||
|
"""Test child class for module import helper tests"""
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
super().__init__(name)
|
||||||
|
|
||||||
|
def get_name(self):
|
||||||
|
return f"Child: {self.name}"
|
||||||
@ -212,9 +212,7 @@ def migrate_annotation_vector_database():
|
|||||||
if not dataset_collection_binding:
|
if not dataset_collection_binding:
|
||||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||||
continue
|
continue
|
||||||
annotations = db.session.scalars(
|
annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all()
|
||||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
|
||||||
).all()
|
|
||||||
dataset = Dataset(
|
dataset = Dataset(
|
||||||
id=app.id,
|
id=app.id,
|
||||||
tenant_id=app.tenant_id,
|
tenant_id=app.tenant_id,
|
||||||
@ -369,25 +367,29 @@ def migrate_knowledge_vector_database():
|
|||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
dataset_documents = db.session.scalars(
|
dataset_documents = (
|
||||||
select(DatasetDocument).where(
|
db.session.query(DatasetDocument)
|
||||||
|
.where(
|
||||||
DatasetDocument.dataset_id == dataset.id,
|
DatasetDocument.dataset_id == dataset.id,
|
||||||
DatasetDocument.indexing_status == "completed",
|
DatasetDocument.indexing_status == "completed",
|
||||||
DatasetDocument.enabled == True,
|
DatasetDocument.enabled == True,
|
||||||
DatasetDocument.archived == False,
|
DatasetDocument.archived == False,
|
||||||
)
|
)
|
||||||
).all()
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
segments_count = 0
|
segments_count = 0
|
||||||
for dataset_document in dataset_documents:
|
for dataset_document in dataset_documents:
|
||||||
segments = db.session.scalars(
|
segments = (
|
||||||
select(DocumentSegment).where(
|
db.session.query(DocumentSegment)
|
||||||
|
.where(
|
||||||
DocumentSegment.document_id == dataset_document.id,
|
DocumentSegment.document_id == dataset_document.id,
|
||||||
DocumentSegment.status == "completed",
|
DocumentSegment.status == "completed",
|
||||||
DocumentSegment.enabled == True,
|
DocumentSegment.enabled == True,
|
||||||
)
|
)
|
||||||
).all()
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
document = Document(
|
document = Document(
|
||||||
@ -509,7 +511,7 @@ def add_qdrant_index(field: str):
|
|||||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
from qdrant_client.http.models import PayloadSchemaType
|
from qdrant_client.http.models import PayloadSchemaType
|
||||||
|
|
||||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
|
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
|
||||||
|
|
||||||
for binding in bindings:
|
for binding in bindings:
|
||||||
if dify_config.QDRANT_URL is None:
|
if dify_config.QDRANT_URL is None:
|
||||||
@ -523,21 +525,7 @@ def add_qdrant_index(field: str):
|
|||||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
|
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
params = qdrant_config.to_qdrant_params()
|
client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
|
||||||
# Check the type before using
|
|
||||||
if isinstance(params, PathQdrantParams):
|
|
||||||
# PathQdrantParams case
|
|
||||||
client = qdrant_client.QdrantClient(path=params.path)
|
|
||||||
else:
|
|
||||||
# UrlQdrantParams case - params is UrlQdrantParams
|
|
||||||
client = qdrant_client.QdrantClient(
|
|
||||||
url=params.url,
|
|
||||||
api_key=params.api_key,
|
|
||||||
timeout=int(params.timeout),
|
|
||||||
verify=params.verify,
|
|
||||||
grpc_port=params.grpc_port,
|
|
||||||
prefer_grpc=params.prefer_grpc,
|
|
||||||
)
|
|
||||||
# create payload index
|
# create payload index
|
||||||
client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
|
client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
|
||||||
create_count += 1
|
create_count += 1
|
||||||
@ -583,7 +571,7 @@ def old_metadata_migration():
|
|||||||
for document in documents:
|
for document in documents:
|
||||||
if document.doc_metadata:
|
if document.doc_metadata:
|
||||||
doc_metadata = document.doc_metadata
|
doc_metadata = document.doc_metadata
|
||||||
for key in doc_metadata:
|
for key, value in doc_metadata.items():
|
||||||
for field in BuiltInField:
|
for field in BuiltInField:
|
||||||
if field.value == key:
|
if field.value == key:
|
||||||
break
|
break
|
||||||
|
|||||||
@ -796,11 +796,6 @@ class DataSetConfig(BaseSettings):
|
|||||||
default=30,
|
default=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
DSL_EXPORT_ENCRYPT_DATASET_ID: bool = Field(
|
|
||||||
description="Enable or disable dataset ID encryption when exporting DSL files",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceConfig(BaseSettings):
|
class WorkspaceConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from .vdb.opensearch_config import OpenSearchConfig
|
|||||||
from .vdb.oracle_config import OracleConfig
|
from .vdb.oracle_config import OracleConfig
|
||||||
from .vdb.pgvector_config import PGVectorConfig
|
from .vdb.pgvector_config import PGVectorConfig
|
||||||
from .vdb.pgvectors_config import PGVectoRSConfig
|
from .vdb.pgvectors_config import PGVectoRSConfig
|
||||||
|
from .vdb.pinecone_config import PineconeConfig
|
||||||
from .vdb.qdrant_config import QdrantConfig
|
from .vdb.qdrant_config import QdrantConfig
|
||||||
from .vdb.relyt_config import RelytConfig
|
from .vdb.relyt_config import RelytConfig
|
||||||
from .vdb.tablestore_config import TableStoreConfig
|
from .vdb.tablestore_config import TableStoreConfig
|
||||||
@ -300,7 +301,8 @@ class DatasetQueueMonitorConfig(BaseSettings):
|
|||||||
|
|
||||||
class MiddlewareConfig(
|
class MiddlewareConfig(
|
||||||
# place the configs in alphabet order
|
# place the configs in alphabet order
|
||||||
CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
|
CeleryConfig,
|
||||||
|
DatabaseConfig,
|
||||||
KeywordStoreConfig,
|
KeywordStoreConfig,
|
||||||
RedisConfig,
|
RedisConfig,
|
||||||
# configs of storage and storage providers
|
# configs of storage and storage providers
|
||||||
@ -330,6 +332,7 @@ class MiddlewareConfig(
|
|||||||
PGVectorConfig,
|
PGVectorConfig,
|
||||||
VastbaseVectorConfig,
|
VastbaseVectorConfig,
|
||||||
PGVectoRSConfig,
|
PGVectoRSConfig,
|
||||||
|
PineconeConfig,
|
||||||
QdrantConfig,
|
QdrantConfig,
|
||||||
RelytConfig,
|
RelytConfig,
|
||||||
TencentVectorDBConfig,
|
TencentVectorDBConfig,
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import BaseModel, Field
|
||||||
from pydantic_settings import BaseSettings
|
|
||||||
|
|
||||||
|
|
||||||
class ClickzettaConfig(BaseSettings):
|
class ClickzettaConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
Clickzetta Lakehouse vector database configuration
|
Clickzetta Lakehouse vector database configuration
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
from pydantic import Field
|
from pydantic import BaseModel, Field
|
||||||
from pydantic_settings import BaseSettings
|
|
||||||
|
|
||||||
|
|
||||||
class MatrixoneConfig(BaseSettings):
|
class MatrixoneConfig(BaseModel):
|
||||||
"""Matrixone vector database configuration."""
|
"""Matrixone vector database configuration."""
|
||||||
|
|
||||||
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")
|
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")
|
||||||
|
|||||||
41
api/configs/middleware/vdb/pinecone_config.py
Normal file
41
api/configs/middleware/vdb/pinecone_config.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field, PositiveInt
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class PineconeConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration settings for Pinecone vector database
|
||||||
|
"""
|
||||||
|
|
||||||
|
PINECONE_API_KEY: Optional[str] = Field(
|
||||||
|
description="API key for authenticating with Pinecone service",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
PINECONE_ENVIRONMENT: Optional[str] = Field(
|
||||||
|
description="Pinecone environment (e.g., 'us-west1-gcp', 'us-east-1-aws')",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
PINECONE_INDEX_NAME: Optional[str] = Field(
|
||||||
|
description="Default Pinecone index name",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
PINECONE_CLIENT_TIMEOUT: PositiveInt = Field(
|
||||||
|
description="Timeout in seconds for Pinecone client operations (default is 30 seconds)",
|
||||||
|
default=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
PINECONE_BATCH_SIZE: PositiveInt = Field(
|
||||||
|
description="Batch size for Pinecone operations (default is 100)",
|
||||||
|
default=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
PINECONE_METRIC: str = Field(
|
||||||
|
description="Distance metric for Pinecone index (cosine, euclidean, dotproduct)",
|
||||||
|
default="cosine",
|
||||||
|
)
|
||||||
|
|
||||||
@ -1,6 +1,6 @@
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from configs.packaging.pyproject import PyProjectTomlConfig
|
from configs.packaging.pyproject import PyProjectConfig, PyProjectTomlConfig
|
||||||
|
|
||||||
|
|
||||||
class PackagingInfo(PyProjectTomlConfig):
|
class PackagingInfo(PyProjectTomlConfig):
|
||||||
|
|||||||
@ -4,9 +4,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable, Mapping
|
from collections.abc import Mapping
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from .python_3x import http_request, makedirs_wrapper
|
from .python_3x import http_request, makedirs_wrapper
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@ -26,13 +25,13 @@ logger = logging.getLogger(__name__)
|
|||||||
class ApolloClient:
|
class ApolloClient:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config_url: str,
|
config_url,
|
||||||
app_id: str,
|
app_id,
|
||||||
cluster: str = "default",
|
cluster="default",
|
||||||
secret: str = "",
|
secret="",
|
||||||
start_hot_update: bool = True,
|
start_hot_update=True,
|
||||||
change_listener: Callable[[str, str, str, Any], None] | None = None,
|
change_listener=None,
|
||||||
_notification_map: dict[str, int] | None = None,
|
_notification_map=None,
|
||||||
):
|
):
|
||||||
# Core routing parameters
|
# Core routing parameters
|
||||||
self.config_url = config_url
|
self.config_url = config_url
|
||||||
@ -48,17 +47,17 @@ class ApolloClient:
|
|||||||
# Private control variables
|
# Private control variables
|
||||||
self._cycle_time = 5
|
self._cycle_time = 5
|
||||||
self._stopping = False
|
self._stopping = False
|
||||||
self._cache: dict[str, dict[str, Any]] = {}
|
self._cache = {}
|
||||||
self._no_key: dict[str, str] = {}
|
self._no_key = {}
|
||||||
self._hash: dict[str, str] = {}
|
self._hash = {}
|
||||||
self._pull_timeout = 75
|
self._pull_timeout = 75
|
||||||
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
|
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
|
||||||
self._long_poll_thread: threading.Thread | None = None
|
self._long_poll_thread = None
|
||||||
self._change_listener = change_listener # "add" "delete" "update"
|
self._change_listener = change_listener # "add" "delete" "update"
|
||||||
if _notification_map is None:
|
if _notification_map is None:
|
||||||
_notification_map = {"application": -1}
|
_notification_map = {"application": -1}
|
||||||
self._notification_map = _notification_map
|
self._notification_map = _notification_map
|
||||||
self.last_release_key: str | None = None
|
self.last_release_key = None
|
||||||
# Private startup method
|
# Private startup method
|
||||||
self._path_checker()
|
self._path_checker()
|
||||||
if start_hot_update:
|
if start_hot_update:
|
||||||
@ -69,7 +68,7 @@ class ApolloClient:
|
|||||||
heartbeat.daemon = True
|
heartbeat.daemon = True
|
||||||
heartbeat.start()
|
heartbeat.start()
|
||||||
|
|
||||||
def get_json_from_net(self, namespace: str = "application") -> dict[str, Any] | None:
|
def get_json_from_net(self, namespace="application"):
|
||||||
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
|
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
|
||||||
self.config_url, self.app_id, self.cluster, namespace, "", self.ip
|
self.config_url, self.app_id, self.cluster, namespace, "", self.ip
|
||||||
)
|
)
|
||||||
@ -89,7 +88,7 @@ class ApolloClient:
|
|||||||
logger.exception("an error occurred in get_json_from_net")
|
logger.exception("an error occurred in get_json_from_net")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_value(self, key: str, default_val: Any = None, namespace: str = "application") -> Any:
|
def get_value(self, key, default_val=None, namespace="application"):
|
||||||
try:
|
try:
|
||||||
# read memory configuration
|
# read memory configuration
|
||||||
namespace_cache = self._cache.get(namespace)
|
namespace_cache = self._cache.get(namespace)
|
||||||
@ -105,8 +104,7 @@ class ApolloClient:
|
|||||||
namespace_data = self.get_json_from_net(namespace)
|
namespace_data = self.get_json_from_net(namespace)
|
||||||
val = get_value_from_dict(namespace_data, key)
|
val = get_value_from_dict(namespace_data, key)
|
||||||
if val is not None:
|
if val is not None:
|
||||||
if namespace_data is not None:
|
self._update_cache_and_file(namespace_data, namespace)
|
||||||
self._update_cache_and_file(namespace_data, namespace)
|
|
||||||
return val
|
return val
|
||||||
|
|
||||||
# read the file configuration
|
# read the file configuration
|
||||||
@ -128,23 +126,23 @@ class ApolloClient:
|
|||||||
# to ensure the real-time correctness of the function call.
|
# to ensure the real-time correctness of the function call.
|
||||||
# If the user does not have the same default val twice
|
# If the user does not have the same default val twice
|
||||||
# and the default val is used here, there may be a problem.
|
# and the default val is used here, there may be a problem.
|
||||||
def _set_local_cache_none(self, namespace: str, key: str) -> None:
|
def _set_local_cache_none(self, namespace, key):
|
||||||
no_key = no_key_cache_key(namespace, key)
|
no_key = no_key_cache_key(namespace, key)
|
||||||
self._no_key[no_key] = key
|
self._no_key[no_key] = key
|
||||||
|
|
||||||
def _start_hot_update(self) -> None:
|
def _start_hot_update(self):
|
||||||
self._long_poll_thread = threading.Thread(target=self._listener)
|
self._long_poll_thread = threading.Thread(target=self._listener)
|
||||||
# When the asynchronous thread is started, the daemon thread will automatically exit
|
# When the asynchronous thread is started, the daemon thread will automatically exit
|
||||||
# when the main thread is launched.
|
# when the main thread is launched.
|
||||||
self._long_poll_thread.daemon = True
|
self._long_poll_thread.daemon = True
|
||||||
self._long_poll_thread.start()
|
self._long_poll_thread.start()
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self):
|
||||||
self._stopping = True
|
self._stopping = True
|
||||||
logger.info("Stopping listener...")
|
logger.info("Stopping listener...")
|
||||||
|
|
||||||
# Call the set callback function, and if it is abnormal, try it out
|
# Call the set callback function, and if it is abnormal, try it out
|
||||||
def _call_listener(self, namespace: str, old_kv: dict[str, Any] | None, new_kv: dict[str, Any] | None) -> None:
|
def _call_listener(self, namespace, old_kv, new_kv):
|
||||||
if self._change_listener is None:
|
if self._change_listener is None:
|
||||||
return
|
return
|
||||||
if old_kv is None:
|
if old_kv is None:
|
||||||
@ -170,12 +168,12 @@ class ApolloClient:
|
|||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.warning(str(e))
|
logger.warning(str(e))
|
||||||
|
|
||||||
def _path_checker(self) -> None:
|
def _path_checker(self):
|
||||||
if not os.path.isdir(self._cache_file_path):
|
if not os.path.isdir(self._cache_file_path):
|
||||||
makedirs_wrapper(self._cache_file_path)
|
makedirs_wrapper(self._cache_file_path)
|
||||||
|
|
||||||
# update the local cache and file cache
|
# update the local cache and file cache
|
||||||
def _update_cache_and_file(self, namespace_data: dict[str, Any], namespace: str = "application") -> None:
|
def _update_cache_and_file(self, namespace_data, namespace="application"):
|
||||||
# update the local cache
|
# update the local cache
|
||||||
self._cache[namespace] = namespace_data
|
self._cache[namespace] = namespace_data
|
||||||
# update the file cache
|
# update the file cache
|
||||||
@ -189,7 +187,7 @@ class ApolloClient:
|
|||||||
self._hash[namespace] = new_hash
|
self._hash[namespace] = new_hash
|
||||||
|
|
||||||
# get the configuration from the local file
|
# get the configuration from the local file
|
||||||
def _get_local_cache(self, namespace: str = "application") -> dict[str, Any]:
|
def _get_local_cache(self, namespace="application"):
|
||||||
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
|
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
|
||||||
if os.path.isfile(cache_file_path):
|
if os.path.isfile(cache_file_path):
|
||||||
with open(cache_file_path) as f:
|
with open(cache_file_path) as f:
|
||||||
@ -197,8 +195,8 @@ class ApolloClient:
|
|||||||
return result
|
return result
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _long_poll(self) -> None:
|
def _long_poll(self):
|
||||||
notifications: list[dict[str, Any]] = []
|
notifications = []
|
||||||
for key in self._cache:
|
for key in self._cache:
|
||||||
namespace_data = self._cache[key]
|
namespace_data = self._cache[key]
|
||||||
notification_id = -1
|
notification_id = -1
|
||||||
@ -238,7 +236,7 @@ class ApolloClient:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(str(e))
|
logger.warning(str(e))
|
||||||
|
|
||||||
def _get_net_and_set_local(self, namespace: str, n_id: int, call_change: bool = False) -> None:
|
def _get_net_and_set_local(self, namespace, n_id, call_change=False):
|
||||||
namespace_data = self.get_json_from_net(namespace)
|
namespace_data = self.get_json_from_net(namespace)
|
||||||
if not namespace_data:
|
if not namespace_data:
|
||||||
return
|
return
|
||||||
@ -250,7 +248,7 @@ class ApolloClient:
|
|||||||
new_kv = namespace_data.get(CONFIGURATIONS)
|
new_kv = namespace_data.get(CONFIGURATIONS)
|
||||||
self._call_listener(namespace, old_kv, new_kv)
|
self._call_listener(namespace, old_kv, new_kv)
|
||||||
|
|
||||||
def _listener(self) -> None:
|
def _listener(self):
|
||||||
logger.info("start long_poll")
|
logger.info("start long_poll")
|
||||||
while not self._stopping:
|
while not self._stopping:
|
||||||
self._long_poll()
|
self._long_poll()
|
||||||
@ -268,13 +266,13 @@ class ApolloClient:
|
|||||||
headers["Timestamp"] = time_unix_now
|
headers["Timestamp"] = time_unix_now
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
def _heart_beat(self) -> None:
|
def _heart_beat(self):
|
||||||
while not self._stopping:
|
while not self._stopping:
|
||||||
for namespace in self._notification_map:
|
for namespace in self._notification_map:
|
||||||
self._do_heart_beat(namespace)
|
self._do_heart_beat(namespace)
|
||||||
time.sleep(60 * 10) # 10 minutes
|
time.sleep(60 * 10) # 10 minutes
|
||||||
|
|
||||||
def _do_heart_beat(self, namespace: str) -> None:
|
def _do_heart_beat(self, namespace):
|
||||||
url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}"
|
url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}"
|
||||||
try:
|
try:
|
||||||
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
||||||
@ -294,7 +292,7 @@ class ApolloClient:
|
|||||||
logger.exception("an error occurred in _do_heart_beat")
|
logger.exception("an error occurred in _do_heart_beat")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_all_dicts(self, namespace: str) -> dict[str, Any] | None:
|
def get_all_dicts(self, namespace):
|
||||||
namespace_data = self._cache.get(namespace)
|
namespace_data = self._cache.get(namespace)
|
||||||
if namespace_data is None:
|
if namespace_data is None:
|
||||||
net_namespace_data = self.get_json_from_net(namespace)
|
net_namespace_data = self.get_json_from_net(namespace)
|
||||||
|
|||||||
@ -2,8 +2,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import ssl
|
import ssl
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import Any
|
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
from urllib.error import HTTPError
|
from urllib.error import HTTPError
|
||||||
|
|
||||||
@ -21,9 +19,9 @@ urllib.request.install_opener(opener)
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def http_request(url: str, timeout: int | float, headers: Mapping[str, str] = {}) -> tuple[int, str | None]:
|
def http_request(url, timeout, headers={}):
|
||||||
try:
|
try:
|
||||||
request = urllib.request.Request(url, headers=dict(headers))
|
request = urllib.request.Request(url, headers=headers)
|
||||||
res = urllib.request.urlopen(request, timeout=timeout)
|
res = urllib.request.urlopen(request, timeout=timeout)
|
||||||
body = res.read().decode("utf-8")
|
body = res.read().decode("utf-8")
|
||||||
return res.code, body
|
return res.code, body
|
||||||
@ -35,9 +33,9 @@ def http_request(url: str, timeout: int | float, headers: Mapping[str, str] = {}
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def url_encode(params: dict[str, Any]) -> str:
|
def url_encode(params):
|
||||||
return parse.urlencode(params)
|
return parse.urlencode(params)
|
||||||
|
|
||||||
|
|
||||||
def makedirs_wrapper(path: str) -> None:
|
def makedirs_wrapper(path):
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import socket
|
import socket
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from .python_3x import url_encode
|
from .python_3x import url_encode
|
||||||
|
|
||||||
@ -11,7 +10,7 @@ NAMESPACE_NAME = "namespaceName"
|
|||||||
|
|
||||||
|
|
||||||
# add timestamps uris and keys
|
# add timestamps uris and keys
|
||||||
def signature(timestamp: str, uri: str, secret: str) -> str:
|
def signature(timestamp, uri, secret):
|
||||||
import base64
|
import base64
|
||||||
import hmac
|
import hmac
|
||||||
|
|
||||||
@ -20,16 +19,16 @@ def signature(timestamp: str, uri: str, secret: str) -> str:
|
|||||||
return base64.b64encode(hmac_code).decode()
|
return base64.b64encode(hmac_code).decode()
|
||||||
|
|
||||||
|
|
||||||
def url_encode_wrapper(params: dict[str, Any]) -> str:
|
def url_encode_wrapper(params):
|
||||||
return url_encode(params)
|
return url_encode(params)
|
||||||
|
|
||||||
|
|
||||||
def no_key_cache_key(namespace: str, key: str) -> str:
|
def no_key_cache_key(namespace, key):
|
||||||
return f"{namespace}{len(namespace)}{key}"
|
return f"{namespace}{len(namespace)}{key}"
|
||||||
|
|
||||||
|
|
||||||
# Returns whether the obtained value is obtained, and None if it does not
|
# Returns whether the obtained value is obtained, and None if it does not
|
||||||
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None:
|
def get_value_from_dict(namespace_cache, key):
|
||||||
if namespace_cache:
|
if namespace_cache:
|
||||||
kv_data = namespace_cache.get(CONFIGURATIONS)
|
kv_data = namespace_cache.get(CONFIGURATIONS)
|
||||||
if kv_data is None:
|
if kv_data is None:
|
||||||
@ -39,7 +38,7 @@ def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def init_ip() -> str:
|
def init_ip():
|
||||||
ip = ""
|
ip = ""
|
||||||
s = None
|
s = None
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -11,5 +11,5 @@ class RemoteSettingsSource:
|
|||||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool):
|
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
||||||
return value
|
return value
|
||||||
|
|||||||
@ -11,16 +11,16 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
from configs.remote_settings_sources.base import RemoteSettingsSource
|
from configs.remote_settings_sources.base import RemoteSettingsSource
|
||||||
|
|
||||||
from .utils import parse_config
|
from .utils import _parse_config
|
||||||
|
|
||||||
|
|
||||||
class NacosSettingsSource(RemoteSettingsSource):
|
class NacosSettingsSource(RemoteSettingsSource):
|
||||||
def __init__(self, configs: Mapping[str, Any]):
|
def __init__(self, configs: Mapping[str, Any]):
|
||||||
self.configs = configs
|
self.configs = configs
|
||||||
self.remote_configs: dict[str, str] = {}
|
self.remote_configs: dict[str, Any] = {}
|
||||||
self.async_init()
|
self.async_init()
|
||||||
|
|
||||||
def async_init(self) -> None:
|
def async_init(self):
|
||||||
data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties")
|
data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties")
|
||||||
group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify")
|
group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify")
|
||||||
tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "")
|
tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "")
|
||||||
@ -29,19 +29,22 @@ class NacosSettingsSource(RemoteSettingsSource):
|
|||||||
try:
|
try:
|
||||||
content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params)
|
content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params)
|
||||||
self.remote_configs = self._parse_config(content)
|
self.remote_configs = self._parse_config(content)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.exception("[get-access-token] exception occurred")
|
logger.exception("[get-access-token] exception occurred")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _parse_config(self, content: str) -> dict[str, str]:
|
def _parse_config(self, content: str) -> dict:
|
||||||
if not content:
|
if not content:
|
||||||
return {}
|
return {}
|
||||||
try:
|
try:
|
||||||
return parse_config(content)
|
return _parse_config(self, content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to parse config: {e}")
|
raise RuntimeError(f"Failed to parse config: {e}")
|
||||||
|
|
||||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||||
|
if not isinstance(self.remote_configs, dict):
|
||||||
|
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
|
||||||
|
|
||||||
field_value = self.remote_configs.get(field_name)
|
field_value = self.remote_configs.get(field_name)
|
||||||
if field_value is None:
|
if field_value is None:
|
||||||
return None, field_name, False
|
return None, field_name, False
|
||||||
|
|||||||
@ -17,26 +17,20 @@ class NacosHttpClient:
|
|||||||
self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY")
|
self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY")
|
||||||
self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY")
|
self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY")
|
||||||
self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848")
|
self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848")
|
||||||
self.token: str | None = None
|
self.token = None
|
||||||
self.token_ttl = 18000
|
self.token_ttl = 18000
|
||||||
self.token_expire_time: float = 0
|
self.token_expire_time: float = 0
|
||||||
|
|
||||||
def http_request(
|
def http_request(self, url, method="GET", headers=None, params=None):
|
||||||
self, url: str, method: str = "GET", headers: dict[str, str] | None = None, params: dict[str, str] | None = None
|
|
||||||
) -> str:
|
|
||||||
if headers is None:
|
|
||||||
headers = {}
|
|
||||||
if params is None:
|
|
||||||
params = {}
|
|
||||||
try:
|
try:
|
||||||
self._inject_auth_info(headers, params)
|
self._inject_auth_info(headers, params)
|
||||||
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.text
|
return response.text
|
||||||
except requests.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
return f"Request to Nacos failed: {e}"
|
return f"Request to Nacos failed: {e}"
|
||||||
|
|
||||||
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
|
def _inject_auth_info(self, headers, params, module="config"):
|
||||||
headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"})
|
headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"})
|
||||||
|
|
||||||
if module == "login":
|
if module == "login":
|
||||||
@ -51,17 +45,16 @@ class NacosHttpClient:
|
|||||||
headers["timeStamp"] = ts
|
headers["timeStamp"] = ts
|
||||||
if self.username and self.password:
|
if self.username and self.password:
|
||||||
self.get_access_token(force_refresh=False)
|
self.get_access_token(force_refresh=False)
|
||||||
if self.token is not None:
|
params["accessToken"] = self.token
|
||||||
params["accessToken"] = self.token
|
|
||||||
|
|
||||||
def __do_sign(self, sign_str: str, sk: str) -> str:
|
def __do_sign(self, sign_str, sk):
|
||||||
return (
|
return (
|
||||||
base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest())
|
base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest())
|
||||||
.decode()
|
.decode()
|
||||||
.strip()
|
.strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_sign_str(self, group: str, tenant: str, ts: str) -> str:
|
def get_sign_str(self, group, tenant, ts):
|
||||||
sign_str = ""
|
sign_str = ""
|
||||||
if tenant:
|
if tenant:
|
||||||
sign_str = tenant + "+"
|
sign_str = tenant + "+"
|
||||||
@ -70,7 +63,7 @@ class NacosHttpClient:
|
|||||||
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
|
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
|
||||||
return sign_str
|
return sign_str
|
||||||
|
|
||||||
def get_access_token(self, force_refresh: bool = False) -> str | None:
|
def get_access_token(self, force_refresh=False):
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
if self.token and not force_refresh and self.token_expire_time > current_time:
|
if self.token and not force_refresh and self.token_expire_time > current_time:
|
||||||
return self.token
|
return self.token
|
||||||
@ -84,7 +77,6 @@ class NacosHttpClient:
|
|||||||
self.token = response_data.get("accessToken")
|
self.token = response_data.get("accessToken")
|
||||||
self.token_ttl = response_data.get("tokenTtl", 18000)
|
self.token_ttl = response_data.get("tokenTtl", 18000)
|
||||||
self.token_expire_time = current_time + self.token_ttl - 10
|
self.token_expire_time = current_time + self.token_ttl - 10
|
||||||
return self.token
|
except Exception as e:
|
||||||
except Exception:
|
|
||||||
logger.exception("[get-access-token] exception occur")
|
logger.exception("[get-access-token] exception occur")
|
||||||
raise
|
raise
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
def parse_config(content: str) -> dict[str, str]:
|
def _parse_config(self, content: str) -> dict[str, str]:
|
||||||
config: dict[str, str] = {}
|
config: dict[str, str] = {}
|
||||||
if not content:
|
if not content:
|
||||||
return config
|
return config
|
||||||
|
|||||||
@ -16,14 +16,14 @@ AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
|
|||||||
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||||
|
|
||||||
|
|
||||||
_doc_extensions: list[str]
|
|
||||||
if dify_config.ETL_TYPE == "Unstructured":
|
if dify_config.ETL_TYPE == "Unstructured":
|
||||||
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
|
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
|
||||||
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||||
if dify_config.UNSTRUCTURED_API_URL:
|
if dify_config.UNSTRUCTURED_API_URL:
|
||||||
_doc_extensions.append("ppt")
|
DOCUMENT_EXTENSIONS.append("ppt")
|
||||||
|
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||||
else:
|
else:
|
||||||
_doc_extensions = [
|
DOCUMENT_EXTENSIONS = [
|
||||||
"txt",
|
"txt",
|
||||||
"markdown",
|
"markdown",
|
||||||
"md",
|
"md",
|
||||||
@ -38,4 +38,4 @@ else:
|
|||||||
"vtt",
|
"vtt",
|
||||||
"properties",
|
"properties",
|
||||||
]
|
]
|
||||||
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions]
|
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||||
|
|||||||
@ -19,7 +19,6 @@ language_timezone_mapping = {
|
|||||||
"fa-IR": "Asia/Tehran",
|
"fa-IR": "Asia/Tehran",
|
||||||
"sl-SI": "Europe/Ljubljana",
|
"sl-SI": "Europe/Ljubljana",
|
||||||
"th-TH": "Asia/Bangkok",
|
"th-TH": "Asia/Bangkok",
|
||||||
"id-ID": "Asia/Jakarta",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
languages = list(language_timezone_mapping.keys())
|
languages = list(language_timezone_mapping.keys())
|
||||||
|
|||||||
@ -8,6 +8,7 @@ if TYPE_CHECKING:
|
|||||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
from flask_restx import Namespace
|
|
||||||
|
|
||||||
from libs.external_api import ExternalApi
|
from libs.external_api import ExternalApi
|
||||||
|
|
||||||
@ -27,16 +26,7 @@ from .files import FileApi, FilePreviewApi, FileSupportTypeApi
|
|||||||
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||||
|
|
||||||
bp = Blueprint("console", __name__, url_prefix="/console/api")
|
bp = Blueprint("console", __name__, url_prefix="/console/api")
|
||||||
|
api = ExternalApi(bp)
|
||||||
api = ExternalApi(
|
|
||||||
bp,
|
|
||||||
version="1.0",
|
|
||||||
title="Console API",
|
|
||||||
description="Console management APIs for app configuration, monitoring, and administration",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create namespace
|
|
||||||
console_ns = Namespace("console", description="Console management API operations", path="/")
|
|
||||||
|
|
||||||
# File
|
# File
|
||||||
api.add_resource(FileApi, "/files/upload")
|
api.add_resource(FileApi, "/files/upload")
|
||||||
@ -53,90 +43,56 @@ api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm"
|
|||||||
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
|
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
|
||||||
|
|
||||||
# Import other controllers
|
# Import other controllers
|
||||||
from . import (
|
from . import admin, apikey, extension, feature, ping, setup, version
|
||||||
admin, # pyright: ignore[reportUnusedImport]
|
|
||||||
apikey, # pyright: ignore[reportUnusedImport]
|
|
||||||
extension, # pyright: ignore[reportUnusedImport]
|
|
||||||
feature, # pyright: ignore[reportUnusedImport]
|
|
||||||
init_validate, # pyright: ignore[reportUnusedImport]
|
|
||||||
ping, # pyright: ignore[reportUnusedImport]
|
|
||||||
setup, # pyright: ignore[reportUnusedImport]
|
|
||||||
version, # pyright: ignore[reportUnusedImport]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Import app controllers
|
# Import app controllers
|
||||||
from .app import (
|
from .app import (
|
||||||
advanced_prompt_template, # pyright: ignore[reportUnusedImport]
|
advanced_prompt_template,
|
||||||
agent, # pyright: ignore[reportUnusedImport]
|
agent,
|
||||||
annotation, # pyright: ignore[reportUnusedImport]
|
annotation,
|
||||||
app, # pyright: ignore[reportUnusedImport]
|
app,
|
||||||
audio, # pyright: ignore[reportUnusedImport]
|
audio,
|
||||||
completion, # pyright: ignore[reportUnusedImport]
|
completion,
|
||||||
conversation, # pyright: ignore[reportUnusedImport]
|
conversation,
|
||||||
conversation_variables, # pyright: ignore[reportUnusedImport]
|
conversation_variables,
|
||||||
generator, # pyright: ignore[reportUnusedImport]
|
generator,
|
||||||
mcp_server, # pyright: ignore[reportUnusedImport]
|
mcp_server,
|
||||||
message, # pyright: ignore[reportUnusedImport]
|
message,
|
||||||
model_config, # pyright: ignore[reportUnusedImport]
|
model_config,
|
||||||
ops_trace, # pyright: ignore[reportUnusedImport]
|
ops_trace,
|
||||||
site, # pyright: ignore[reportUnusedImport]
|
site,
|
||||||
statistic, # pyright: ignore[reportUnusedImport]
|
statistic,
|
||||||
workflow, # pyright: ignore[reportUnusedImport]
|
workflow,
|
||||||
workflow_app_log, # pyright: ignore[reportUnusedImport]
|
workflow_app_log,
|
||||||
workflow_draft_variable, # pyright: ignore[reportUnusedImport]
|
workflow_draft_variable,
|
||||||
workflow_run, # pyright: ignore[reportUnusedImport]
|
workflow_run,
|
||||||
workflow_statistic, # pyright: ignore[reportUnusedImport]
|
workflow_statistic,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import auth controllers
|
# Import auth controllers
|
||||||
from .auth import (
|
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server
|
||||||
activate, # pyright: ignore[reportUnusedImport]
|
|
||||||
data_source_bearer_auth, # pyright: ignore[reportUnusedImport]
|
|
||||||
data_source_oauth, # pyright: ignore[reportUnusedImport]
|
|
||||||
forgot_password, # pyright: ignore[reportUnusedImport]
|
|
||||||
login, # pyright: ignore[reportUnusedImport]
|
|
||||||
oauth, # pyright: ignore[reportUnusedImport]
|
|
||||||
oauth_server, # pyright: ignore[reportUnusedImport]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Import billing controllers
|
# Import billing controllers
|
||||||
from .billing import billing, compliance # pyright: ignore[reportUnusedImport]
|
from .billing import billing, compliance
|
||||||
|
|
||||||
# Import datasets controllers
|
# Import datasets controllers
|
||||||
from .datasets import (
|
from .datasets import (
|
||||||
data_source, # pyright: ignore[reportUnusedImport]
|
data_source,
|
||||||
datasets, # pyright: ignore[reportUnusedImport]
|
datasets,
|
||||||
datasets_document, # pyright: ignore[reportUnusedImport]
|
datasets_document,
|
||||||
datasets_segments, # pyright: ignore[reportUnusedImport]
|
datasets_segments,
|
||||||
external, # pyright: ignore[reportUnusedImport]
|
external,
|
||||||
hit_testing, # pyright: ignore[reportUnusedImport]
|
hit_testing,
|
||||||
metadata, # pyright: ignore[reportUnusedImport]
|
metadata,
|
||||||
website, # pyright: ignore[reportUnusedImport]
|
website,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import explore controllers
|
# Import explore controllers
|
||||||
from .explore import (
|
from .explore import (
|
||||||
installed_app, # pyright: ignore[reportUnusedImport]
|
installed_app,
|
||||||
parameter, # pyright: ignore[reportUnusedImport]
|
parameter,
|
||||||
recommended_app, # pyright: ignore[reportUnusedImport]
|
recommended_app,
|
||||||
saved_message, # pyright: ignore[reportUnusedImport]
|
saved_message,
|
||||||
)
|
|
||||||
|
|
||||||
# Import tag controllers
|
|
||||||
from .tag import tags # pyright: ignore[reportUnusedImport]
|
|
||||||
|
|
||||||
# Import workspace controllers
|
|
||||||
from .workspace import (
|
|
||||||
account, # pyright: ignore[reportUnusedImport]
|
|
||||||
agent_providers, # pyright: ignore[reportUnusedImport]
|
|
||||||
endpoint, # pyright: ignore[reportUnusedImport]
|
|
||||||
load_balancing_config, # pyright: ignore[reportUnusedImport]
|
|
||||||
members, # pyright: ignore[reportUnusedImport]
|
|
||||||
model_providers, # pyright: ignore[reportUnusedImport]
|
|
||||||
models, # pyright: ignore[reportUnusedImport]
|
|
||||||
plugin, # pyright: ignore[reportUnusedImport]
|
|
||||||
tool_providers, # pyright: ignore[reportUnusedImport]
|
|
||||||
workspace, # pyright: ignore[reportUnusedImport]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Explore Audio
|
# Explore Audio
|
||||||
@ -210,4 +166,19 @@ api.add_resource(
|
|||||||
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
|
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
|
||||||
)
|
)
|
||||||
|
|
||||||
api.add_namespace(console_ns)
|
# Import tag controllers
|
||||||
|
from .tag import tags
|
||||||
|
|
||||||
|
# Import workspace controllers
|
||||||
|
from .workspace import (
|
||||||
|
account,
|
||||||
|
agent_providers,
|
||||||
|
endpoint,
|
||||||
|
load_balancing_config,
|
||||||
|
members,
|
||||||
|
model_providers,
|
||||||
|
models,
|
||||||
|
plugin,
|
||||||
|
tool_providers,
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
|||||||
@ -1,26 +1,22 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import ParamSpec, TypeVar
|
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import NotFound, Unauthorized
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import supported_language
|
from constants.languages import supported_language
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api
|
||||||
from controllers.console.wraps import only_edition_cloud
|
from controllers.console.wraps import only_edition_cloud
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App, InstalledApp, RecommendedApp
|
from models.model import App, InstalledApp, RecommendedApp
|
||||||
|
|
||||||
|
|
||||||
def admin_required(view: Callable[P, R]):
|
def admin_required(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
if not dify_config.ADMIN_API_KEY:
|
if not dify_config.ADMIN_API_KEY:
|
||||||
raise Unauthorized("API key is invalid.")
|
raise Unauthorized("API key is invalid.")
|
||||||
|
|
||||||
@ -45,28 +41,7 @@ def admin_required(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/admin/insert-explore-apps")
|
|
||||||
class InsertExploreAppListApi(Resource):
|
class InsertExploreAppListApi(Resource):
|
||||||
@api.doc("insert_explore_app")
|
|
||||||
@api.doc(description="Insert or update an app in the explore list")
|
|
||||||
@api.expect(
|
|
||||||
api.model(
|
|
||||||
"InsertExploreAppRequest",
|
|
||||||
{
|
|
||||||
"app_id": fields.String(required=True, description="Application ID"),
|
|
||||||
"desc": fields.String(description="App description"),
|
|
||||||
"copyright": fields.String(description="Copyright information"),
|
|
||||||
"privacy_policy": fields.String(description="Privacy policy"),
|
|
||||||
"custom_disclaimer": fields.String(description="Custom disclaimer"),
|
|
||||||
"language": fields.String(required=True, description="Language code"),
|
|
||||||
"category": fields.String(required=True, description="App category"),
|
|
||||||
"position": fields.Integer(required=True, description="Display position"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@api.response(200, "App updated successfully")
|
|
||||||
@api.response(201, "App inserted successfully")
|
|
||||||
@api.response(404, "App not found")
|
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
@admin_required
|
@admin_required
|
||||||
def post(self):
|
def post(self):
|
||||||
@ -136,12 +111,7 @@ class InsertExploreAppListApi(Resource):
|
|||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/admin/insert-explore-apps/<uuid:app_id>")
|
|
||||||
class InsertExploreAppApi(Resource):
|
class InsertExploreAppApi(Resource):
|
||||||
@api.doc("delete_explore_app")
|
|
||||||
@api.doc(description="Remove an app from the explore list")
|
|
||||||
@api.doc(params={"app_id": "Application ID to remove"})
|
|
||||||
@api.response(204, "App removed successfully")
|
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
@admin_required
|
@admin_required
|
||||||
def delete(self, app_id):
|
def delete(self, app_id):
|
||||||
@ -160,21 +130,21 @@ class InsertExploreAppApi(Resource):
|
|||||||
app.is_public = False
|
app.is_public = False
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
installed_apps = (
|
installed_apps = session.execute(
|
||||||
session.execute(
|
select(InstalledApp).where(
|
||||||
select(InstalledApp).where(
|
InstalledApp.app_id == recommended_app.app_id,
|
||||||
InstalledApp.app_id == recommended_app.app_id,
|
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
||||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
.scalars()
|
).all()
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
for installed_app in installed_apps:
|
for installed_app in installed_apps:
|
||||||
session.delete(installed_app)
|
db.session.delete(installed_app)
|
||||||
|
|
||||||
db.session.delete(recommended_app)
|
db.session.delete(recommended_app)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps")
|
||||||
|
api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/<uuid:app_id>")
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import flask_restx
|
import flask_restx
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restx import Resource, fields, marshal_with
|
from flask_restx import Resource, fields, marshal_with
|
||||||
from flask_restx._http import HTTPStatus
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
@ -14,7 +13,7 @@ from libs.login import login_required
|
|||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from models.model import ApiToken, App
|
from models.model import ApiToken, App
|
||||||
|
|
||||||
from . import api, console_ns
|
from . import api
|
||||||
from .wraps import account_initialization_required, setup_required
|
from .wraps import account_initialization_required, setup_required
|
||||||
|
|
||||||
api_key_fields = {
|
api_key_fields = {
|
||||||
@ -41,7 +40,7 @@ def _get_resource(resource_id, tenant_id, resource_model):
|
|||||||
).scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
|
|
||||||
if resource is None:
|
if resource is None:
|
||||||
flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.")
|
flask_restx.abort(404, message=f"{resource_model.__name__} not found.")
|
||||||
|
|
||||||
return resource
|
return resource
|
||||||
|
|
||||||
@ -50,7 +49,7 @@ class BaseApiKeyListResource(Resource):
|
|||||||
method_decorators = [account_initialization_required, login_required, setup_required]
|
method_decorators = [account_initialization_required, login_required, setup_required]
|
||||||
|
|
||||||
resource_type: str | None = None
|
resource_type: str | None = None
|
||||||
resource_model: Optional[type] = None
|
resource_model: Optional[Any] = None
|
||||||
resource_id_field: str | None = None
|
resource_id_field: str | None = None
|
||||||
token_prefix: str | None = None
|
token_prefix: str | None = None
|
||||||
max_keys = 10
|
max_keys = 10
|
||||||
@ -60,11 +59,11 @@ class BaseApiKeyListResource(Resource):
|
|||||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||||
keys = db.session.scalars(
|
keys = (
|
||||||
select(ApiToken).where(
|
db.session.query(ApiToken)
|
||||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||||
)
|
.all()
|
||||||
).all()
|
)
|
||||||
return {"items": keys}
|
return {"items": keys}
|
||||||
|
|
||||||
@marshal_with(api_key_fields)
|
@marshal_with(api_key_fields)
|
||||||
@ -83,12 +82,12 @@ class BaseApiKeyListResource(Resource):
|
|||||||
|
|
||||||
if current_key_count >= self.max_keys:
|
if current_key_count >= self.max_keys:
|
||||||
flask_restx.abort(
|
flask_restx.abort(
|
||||||
HTTPStatus.BAD_REQUEST,
|
400,
|
||||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||||
custom="max_keys_exceeded",
|
code="max_keys_exceeded",
|
||||||
)
|
)
|
||||||
|
|
||||||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||||
api_token = ApiToken()
|
api_token = ApiToken()
|
||||||
setattr(api_token, self.resource_id_field, resource_id)
|
setattr(api_token, self.resource_id_field, resource_id)
|
||||||
api_token.tenant_id = current_user.current_tenant_id
|
api_token.tenant_id = current_user.current_tenant_id
|
||||||
@ -103,7 +102,7 @@ class BaseApiKeyResource(Resource):
|
|||||||
method_decorators = [account_initialization_required, login_required, setup_required]
|
method_decorators = [account_initialization_required, login_required, setup_required]
|
||||||
|
|
||||||
resource_type: str | None = None
|
resource_type: str | None = None
|
||||||
resource_model: Optional[type] = None
|
resource_model: Optional[Any] = None
|
||||||
resource_id_field: str | None = None
|
resource_id_field: str | None = None
|
||||||
|
|
||||||
def delete(self, resource_id, api_key_id):
|
def delete(self, resource_id, api_key_id):
|
||||||
@ -127,7 +126,7 @@ class BaseApiKeyResource(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if key is None:
|
if key is None:
|
||||||
flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found")
|
flask_restx.abort(404, message="API key not found")
|
||||||
|
|
||||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -135,25 +134,7 @@ class BaseApiKeyResource(Resource):
|
|||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
|
|
||||||
class AppApiKeyListResource(BaseApiKeyListResource):
|
class AppApiKeyListResource(BaseApiKeyListResource):
|
||||||
@api.doc("get_app_api_keys")
|
|
||||||
@api.doc(description="Get all API keys for an app")
|
|
||||||
@api.doc(params={"resource_id": "App ID"})
|
|
||||||
@api.response(200, "Success", api_key_list)
|
|
||||||
def get(self, resource_id):
|
|
||||||
"""Get all API keys for an app"""
|
|
||||||
return super().get(resource_id)
|
|
||||||
|
|
||||||
@api.doc("create_app_api_key")
|
|
||||||
@api.doc(description="Create a new API key for an app")
|
|
||||||
@api.doc(params={"resource_id": "App ID"})
|
|
||||||
@api.response(201, "API key created successfully", api_key_fields)
|
|
||||||
@api.response(400, "Maximum keys exceeded")
|
|
||||||
def post(self, resource_id):
|
|
||||||
"""Create a new API key for an app"""
|
|
||||||
return super().post(resource_id)
|
|
||||||
|
|
||||||
def after_request(self, resp):
|
def after_request(self, resp):
|
||||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||||
@ -165,16 +146,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
|||||||
token_prefix = "app-"
|
token_prefix = "app-"
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
|
||||||
class AppApiKeyResource(BaseApiKeyResource):
|
class AppApiKeyResource(BaseApiKeyResource):
|
||||||
@api.doc("delete_app_api_key")
|
|
||||||
@api.doc(description="Delete an API key for an app")
|
|
||||||
@api.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
|
|
||||||
@api.response(204, "API key deleted successfully")
|
|
||||||
def delete(self, resource_id, api_key_id):
|
|
||||||
"""Delete an API key for an app"""
|
|
||||||
return super().delete(resource_id, api_key_id)
|
|
||||||
|
|
||||||
def after_request(self, resp):
|
def after_request(self, resp):
|
||||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||||
@ -185,25 +157,7 @@ class AppApiKeyResource(BaseApiKeyResource):
|
|||||||
resource_id_field = "app_id"
|
resource_id_field = "app_id"
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/datasets/<uuid:resource_id>/api-keys")
|
|
||||||
class DatasetApiKeyListResource(BaseApiKeyListResource):
|
class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||||
@api.doc("get_dataset_api_keys")
|
|
||||||
@api.doc(description="Get all API keys for a dataset")
|
|
||||||
@api.doc(params={"resource_id": "Dataset ID"})
|
|
||||||
@api.response(200, "Success", api_key_list)
|
|
||||||
def get(self, resource_id):
|
|
||||||
"""Get all API keys for a dataset"""
|
|
||||||
return super().get(resource_id)
|
|
||||||
|
|
||||||
@api.doc("create_dataset_api_key")
|
|
||||||
@api.doc(description="Create a new API key for a dataset")
|
|
||||||
@api.doc(params={"resource_id": "Dataset ID"})
|
|
||||||
@api.response(201, "API key created successfully", api_key_fields)
|
|
||||||
@api.response(400, "Maximum keys exceeded")
|
|
||||||
def post(self, resource_id):
|
|
||||||
"""Create a new API key for a dataset"""
|
|
||||||
return super().post(resource_id)
|
|
||||||
|
|
||||||
def after_request(self, resp):
|
def after_request(self, resp):
|
||||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||||
@ -215,16 +169,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
|||||||
token_prefix = "ds-"
|
token_prefix = "ds-"
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
|
||||||
class DatasetApiKeyResource(BaseApiKeyResource):
|
class DatasetApiKeyResource(BaseApiKeyResource):
|
||||||
@api.doc("delete_dataset_api_key")
|
|
||||||
@api.doc(description="Delete an API key for a dataset")
|
|
||||||
@api.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
|
|
||||||
@api.response(204, "API key deleted successfully")
|
|
||||||
def delete(self, resource_id, api_key_id):
|
|
||||||
"""Delete an API key for a dataset"""
|
|
||||||
return super().delete(resource_id, api_key_id)
|
|
||||||
|
|
||||||
def after_request(self, resp):
|
def after_request(self, resp):
|
||||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||||
@ -233,3 +178,9 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
|||||||
resource_type = "dataset"
|
resource_type = "dataset"
|
||||||
resource_model = Dataset
|
resource_model = Dataset
|
||||||
resource_id_field = "dataset_id"
|
resource_id_field = "dataset_id"
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(AppApiKeyListResource, "/apps/<uuid:resource_id>/api-keys")
|
||||||
|
api.add_resource(AppApiKeyResource, "/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||||
|
api.add_resource(DatasetApiKeyListResource, "/datasets/<uuid:resource_id>/api-keys")
|
||||||
|
api.add_resource(DatasetApiKeyResource, "/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||||
|
|||||||
@ -115,10 +115,6 @@ class AppListApi(Resource):
|
|||||||
raise BadRequest("mode is required")
|
raise BadRequest("mode is required")
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
if current_user.current_tenant_id is None:
|
|
||||||
raise ValueError("current_user.current_tenant_id cannot be None")
|
|
||||||
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
|
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
|
||||||
|
|
||||||
return app, 201
|
return app, 201
|
||||||
@ -165,26 +161,14 @@ class AppApi(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
# Construct ArgsDict from parsed arguments
|
app_model = app_service.update_app(app_model, args)
|
||||||
from services.app_service import AppService as AppServiceType
|
|
||||||
|
|
||||||
args_dict: AppServiceType.ArgsDict = {
|
|
||||||
"name": args["name"],
|
|
||||||
"description": args.get("description", ""),
|
|
||||||
"icon_type": args.get("icon_type", ""),
|
|
||||||
"icon": args.get("icon", ""),
|
|
||||||
"icon_background": args.get("icon_background", ""),
|
|
||||||
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
|
|
||||||
"max_active_requests": args.get("max_active_requests", 0),
|
|
||||||
}
|
|
||||||
app_model = app_service.update_app(app_model, args_dict)
|
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
@get_app_model
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@get_app_model
|
||||||
def delete(self, app_model):
|
def delete(self, app_model):
|
||||||
"""Delete app"""
|
"""Delete app"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
@ -240,10 +224,10 @@ class AppCopyApi(Resource):
|
|||||||
|
|
||||||
|
|
||||||
class AppExportApi(Resource):
|
class AppExportApi(Resource):
|
||||||
@get_app_model
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@get_app_model
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
"""Export app"""
|
"""Export app"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
@ -253,14 +237,9 @@ class AppExportApi(Resource):
|
|||||||
# Add include_secret params
|
# Add include_secret params
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||||
parser.add_argument("workflow_id", type=str, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return {
|
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
|
||||||
"data": AppDslService.export_dsl(
|
|
||||||
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class AppNameApi(Resource):
|
class AppNameApi(Resource):
|
||||||
@ -279,7 +258,7 @@ class AppNameApi(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app_name(app_model, args["name"])
|
app_model = app_service.update_app_name(app_model, args.get("name"))
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
@ -301,7 +280,7 @@ class AppIconApi(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
|
app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background"))
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
@ -322,7 +301,7 @@ class AppSiteStatus(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app_site_status(app_model, args["enable_site"])
|
app_model = app_service.update_app_site_status(app_model, args.get("enable_site"))
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
@ -343,7 +322,7 @@ class AppApiStatus(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app_api_status(app_model, args["enable_api"])
|
app_model = app_service.update_app_api_status(app_model, args.get("enable_api"))
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|||||||
@ -77,10 +77,10 @@ class ChatMessageAudioApi(Resource):
|
|||||||
|
|
||||||
|
|
||||||
class ChatMessageTextApi(Resource):
|
class ChatMessageTextApi(Resource):
|
||||||
@get_app_model
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@get_app_model
|
||||||
def post(self, app_model: App):
|
def post(self, app_model: App):
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -125,10 +125,10 @@ class ChatMessageTextApi(Resource):
|
|||||||
|
|
||||||
|
|
||||||
class TextModesApi(Resource):
|
class TextModesApi(Resource):
|
||||||
@get_app_model
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@get_app_model
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import flask_login
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
@ -28,8 +29,7 @@ from core.helper.trace_id_helper import get_external_trace_id
|
|||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.login import current_user, login_required
|
from libs.login import login_required
|
||||||
from models import Account
|
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.llm import InvokeRateLimitError
|
from services.errors.llm import InvokeRateLimitError
|
||||||
@ -56,11 +56,11 @@ class CompletionMessageApi(Resource):
|
|||||||
streaming = args["response_mode"] != "blocking"
|
streaming = args["response_mode"] != "blocking"
|
||||||
args["auto_generate_name"] = False
|
args["auto_generate_name"] = False
|
||||||
|
|
||||||
|
account = flask_login.current_user
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account or EndUser instance")
|
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
@ -92,9 +92,9 @@ class CompletionMessageStopApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
def post(self, app_model, task_id):
|
def post(self, app_model, task_id):
|
||||||
if not isinstance(current_user, Account):
|
account = flask_login.current_user
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
@ -105,12 +105,6 @@ class ChatMessageApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
parser.add_argument("query", type=str, required=True, location="json")
|
parser.add_argument("query", type=str, required=True, location="json")
|
||||||
@ -129,11 +123,11 @@ class ChatMessageApi(Resource):
|
|||||||
if external_trace_id:
|
if external_trace_id:
|
||||||
args["external_trace_id"] = external_trace_id
|
args["external_trace_id"] = external_trace_id
|
||||||
|
|
||||||
|
account = flask_login.current_user
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account or EndUser instance")
|
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
@ -167,9 +161,9 @@ class ChatMessageStopApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
def post(self, app_model, task_id):
|
def post(self, app_model, task_id):
|
||||||
if not isinstance(current_user, Account):
|
account = flask_login.current_user
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from fields.conversation_fields import (
|
|||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import DatetimeString
|
from libs.helper import DatetimeString
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models import Account, Conversation, EndUser, Message, MessageAnnotation
|
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
@ -117,15 +117,13 @@ class CompletionConversationDetailApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
def delete(self, app_model, conversation_id):
|
def delete(self, app_model, conversation_id):
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
ConversationService.delete(app_model, conversation_id, current_user)
|
ConversationService.delete(app_model, conversation_id, current_user)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
@ -284,8 +282,6 @@ class ChatConversationDetailApi(Resource):
|
|||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
ConversationService.delete(app_model, conversation_id, current_user)
|
ConversationService.delete(app_model, conversation_id, current_user)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|||||||
@ -207,7 +207,7 @@ class InstructionGenerationTemplateApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self) -> dict:
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("type", type=str, required=True, default=False, location="json")
|
parser.add_argument("type", type=str, required=True, default=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from flask_login import current_user
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
from sqlalchemy import exists, select
|
from sqlalchemy import exists, select
|
||||||
@ -26,8 +27,7 @@ from extensions.ext_database import db
|
|||||||
from fields.conversation_fields import annotation_fields, message_detail_fields
|
from fields.conversation_fields import annotation_fields, message_detail_fields
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
from libs.login import current_user, login_required
|
from libs.login import login_required
|
||||||
from models.account import Account
|
|
||||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||||
from services.annotation_service import AppAnnotationService
|
from services.annotation_service import AppAnnotationService
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
@ -118,14 +118,11 @@ class ChatMessageListApi(Resource):
|
|||||||
|
|
||||||
|
|
||||||
class MessageFeedbackApi(Resource):
|
class MessageFeedbackApi(Resource):
|
||||||
@get_app_model
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@get_app_model
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
if current_user is None:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
|
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
|
||||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||||
@ -170,9 +167,7 @@ class MessageAnnotationApi(Resource):
|
|||||||
@get_app_model
|
@get_app_model
|
||||||
@marshal_with(annotation_fields)
|
@marshal_with(annotation_fields)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
if not isinstance(current_user, Account):
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -187,10 +182,10 @@ class MessageAnnotationApi(Resource):
|
|||||||
|
|
||||||
|
|
||||||
class MessageAnnotationCountApi(Resource):
|
class MessageAnnotationCountApi(Resource):
|
||||||
@get_app_model
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@get_app_model
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
|
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
|
||||||
|
|
||||||
|
|||||||
@ -2,8 +2,8 @@ import json
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
from flask_login import current_user
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from werkzeug.exceptions import Forbidden
|
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
@ -13,8 +13,7 @@ from core.tools.tool_manager import ToolManager
|
|||||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||||
from events.app_event import app_model_config_was_updated
|
from events.app_event import app_model_config_was_updated
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import current_user, login_required
|
from libs.login import login_required
|
||||||
from models.account import Account
|
|
||||||
from models.model import AppMode, AppModelConfig
|
from models.model import AppMode, AppModelConfig
|
||||||
from services.app_model_config_service import AppModelConfigService
|
from services.app_model_config_service import AppModelConfigService
|
||||||
|
|
||||||
@ -26,13 +25,6 @@ class ModelConfigResource(Resource):
|
|||||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
"""Modify app model config"""
|
"""Modify app model config"""
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
|
|
||||||
# validate config
|
# validate config
|
||||||
model_configuration = AppModelConfigService.validate_configuration(
|
model_configuration = AppModelConfigService.validate_configuration(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from extensions.ext_database import db
|
|||||||
from fields.app_fields import app_site_fields
|
from fields.app_fields import app_site_fields
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models import Account, Site
|
from models import Site
|
||||||
|
|
||||||
|
|
||||||
def parse_app_site_args():
|
def parse_app_site_args():
|
||||||
@ -75,8 +75,6 @@ class AppSite(Resource):
|
|||||||
if value is not None:
|
if value is not None:
|
||||||
setattr(site, attr_name, value)
|
setattr(site, attr_name, value)
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
site.updated_by = current_user.id
|
site.updated_by = current_user.id
|
||||||
site.updated_at = naive_utc_now()
|
site.updated_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -101,8 +99,6 @@ class AppSiteAccessTokenReset(Resource):
|
|||||||
raise NotFound
|
raise NotFound
|
||||||
|
|
||||||
site.code = Site.generate_code(16)
|
site.code = Site.generate_code(16)
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
site.updated_by = current_user.id
|
site.updated_by = current_user.id
|
||||||
site.updated_at = naive_utc_now()
|
site.updated_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|||||||
@ -18,10 +18,10 @@ from models import AppMode, Message
|
|||||||
|
|
||||||
|
|
||||||
class DailyMessageStatistic(Resource):
|
class DailyMessageStatistic(Resource):
|
||||||
@get_app_model
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@get_app_model
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
@ -75,10 +75,10 @@ WHERE
|
|||||||
|
|
||||||
|
|
||||||
class DailyConversationStatistic(Resource):
|
class DailyConversationStatistic(Resource):
|
||||||
@get_app_model
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@get_app_model
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
@ -127,10 +127,10 @@ class DailyConversationStatistic(Resource):
|
|||||||
|
|
||||||
|
|
||||||
class DailyTerminalsStatistic(Resource):
|
class DailyTerminalsStatistic(Resource):
|
||||||
@get_app_model
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@get_app_model
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
@ -184,10 +184,10 @@ WHERE
|
|||||||
|
|
||||||
|
|
||||||
class DailyTokenCostStatistic(Resource):
|
class DailyTokenCostStatistic(Resource):
|
||||||
@get_app_model
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@get_app_model
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
@ -320,10 +320,10 @@ ORDER BY
|
|||||||
|
|
||||||
|
|
||||||
class UserSatisfactionRateStatistic(Resource):
|
class UserSatisfactionRateStatistic(Resource):
|
||||||
@get_app_model
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@get_app_model
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
@ -443,10 +443,10 @@ WHERE
|
|||||||
|
|
||||||
|
|
||||||
class TokensPerSecondStatistic(Resource):
|
class TokensPerSecondStatistic(Resource):
|
||||||
@get_app_model
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@get_app_model
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,11 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
|||||||
import services
|
import services
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
from controllers.console.app.error import (
|
||||||
|
ConversationCompletedError,
|
||||||
|
DraftWorkflowNotExist,
|
||||||
|
DraftWorkflowNotSync,
|
||||||
|
)
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||||
@ -69,7 +73,7 @@ class DraftWorkflowApi(Resource):
|
|||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
# fetch draft workflow by app_model
|
# fetch draft workflow by app_model
|
||||||
@ -92,7 +96,7 @@ class DraftWorkflowApi(Resource):
|
|||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
content_type = request.headers.get("Content-Type", "")
|
content_type = request.headers.get("Content-Type", "")
|
||||||
@ -170,7 +174,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
|||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
@ -220,7 +224,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -256,7 +260,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
|||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -293,7 +297,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -330,7 +334,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -367,7 +371,7 @@ class DraftWorkflowRunApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -406,7 +410,7 @@ class WorkflowTaskStopApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||||
@ -428,7 +432,7 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -476,7 +480,7 @@ class PublishedWorkflowApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
# fetch published workflow by app_model
|
# fetch published workflow by app_model
|
||||||
@ -497,7 +501,7 @@ class PublishedWorkflowApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -522,7 +526,7 @@ class PublishedWorkflowApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
app_model.workflow_id = workflow.id
|
app_model.workflow_id = workflow.id
|
||||||
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
|
db.session.commit()
|
||||||
|
|
||||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||||
|
|
||||||
@ -547,7 +551,7 @@ class DefaultBlockConfigsApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
# Get default block configs
|
# Get default block configs
|
||||||
@ -567,7 +571,7 @@ class DefaultBlockConfigApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -602,7 +606,7 @@ class ConvertToWorkflowApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
if request.data:
|
if request.data:
|
||||||
@ -651,7 +655,7 @@ class PublishedAllWorkflowApi(Resource):
|
|||||||
|
|
||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -702,7 +706,7 @@ class WorkflowByIdApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
# Check permission
|
# Check permission
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -758,7 +762,7 @@ class WorkflowByIdApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
# Check permission
|
# Check permission
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
|
|||||||
@ -27,9 +27,7 @@ class WorkflowAppLogApi(Resource):
|
|||||||
"""
|
"""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("keyword", type=str, location="args")
|
parser.add_argument("keyword", type=str, location="args")
|
||||||
parser.add_argument(
|
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import NoReturn
|
from typing import Any, NoReturn
|
||||||
|
|
||||||
from flask import Response
|
from flask import Response
|
||||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||||
@ -29,7 +29,7 @@ from services.workflow_service import WorkflowService
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _convert_values_to_json_serializable_object(value: Segment):
|
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||||
if isinstance(value, FileSegment):
|
if isinstance(value, FileSegment):
|
||||||
return value.value.model_dump()
|
return value.value.model_dump()
|
||||||
elif isinstance(value, ArrayFileSegment):
|
elif isinstance(value, ArrayFileSegment):
|
||||||
@ -40,7 +40,7 @@ def _convert_values_to_json_serializable_object(value: Segment):
|
|||||||
return value.value
|
return value.value
|
||||||
|
|
||||||
|
|
||||||
def _serialize_var_value(variable: WorkflowDraftVariable):
|
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||||
value = variable.get_value()
|
value = variable.get_value()
|
||||||
# create a copy of the value to avoid affecting the model cache.
|
# create a copy of the value to avoid affecting the model cache.
|
||||||
value = value.model_copy(deep=True)
|
value = value.model_copy(deep=True)
|
||||||
@ -137,7 +137,7 @@ def _api_prerequisite(f):
|
|||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -18,10 +18,10 @@ from models.model import AppMode
|
|||||||
|
|
||||||
|
|
||||||
class WorkflowDailyRunsStatistic(Resource):
|
class WorkflowDailyRunsStatistic(Resource):
|
||||||
@get_app_model
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@get_app_model
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
@ -80,10 +80,10 @@ WHERE
|
|||||||
|
|
||||||
|
|
||||||
class WorkflowDailyTerminalsStatistic(Resource):
|
class WorkflowDailyTerminalsStatistic(Resource):
|
||||||
@get_app_model
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@get_app_model
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
@ -142,10 +142,10 @@ WHERE
|
|||||||
|
|
||||||
|
|
||||||
class WorkflowDailyTokenCostStatistic(Resource):
|
class WorkflowDailyTokenCostStatistic(Resource):
|
||||||
@get_app_model
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@get_app_model
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional, ParamSpec, TypeVar, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from controllers.console.app.error import AppNotFoundError
|
from controllers.console.app.error import AppNotFoundError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -8,9 +8,6 @@ from libs.login import current_user
|
|||||||
from models import App, AppMode
|
from models import App, AppMode
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
|
||||||
def _load_app_model(app_id: str) -> Optional[App]:
|
def _load_app_model(app_id: str) -> Optional[App]:
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
@ -22,10 +19,10 @@ def _load_app_model(app_id: str) -> Optional[App]:
|
|||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
||||||
def get_app_model(view: Optional[Callable[P, R]] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||||
def decorator(view_func: Callable[P, R]):
|
def decorator(view_func):
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
def decorated_view(*args, **kwargs):
|
||||||
if not kwargs.get("app_id"):
|
if not kwargs.get("app_id"):
|
||||||
raise ValueError("missing app_id in path parameters")
|
raise ValueError("missing app_id in path parameters")
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
|
|
||||||
from constants.languages import supported_language
|
from constants.languages import supported_language
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api
|
||||||
from controllers.console.error import AlreadyActivateError
|
from controllers.console.error import AlreadyActivateError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
@ -10,36 +10,14 @@ from libs.helper import StrLen, email, extract_remote_ip, timezone
|
|||||||
from models.account import AccountStatus
|
from models.account import AccountStatus
|
||||||
from services.account_service import AccountService, RegisterService
|
from services.account_service import AccountService, RegisterService
|
||||||
|
|
||||||
active_check_parser = reqparse.RequestParser()
|
|
||||||
active_check_parser.add_argument(
|
|
||||||
"workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID"
|
|
||||||
)
|
|
||||||
active_check_parser.add_argument(
|
|
||||||
"email", type=email, required=False, nullable=True, location="args", help="Email address"
|
|
||||||
)
|
|
||||||
active_check_parser.add_argument(
|
|
||||||
"token", type=str, required=True, nullable=False, location="args", help="Activation token"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/activate/check")
|
|
||||||
class ActivateCheckApi(Resource):
|
class ActivateCheckApi(Resource):
|
||||||
@api.doc("check_activation_token")
|
|
||||||
@api.doc(description="Check if activation token is valid")
|
|
||||||
@api.expect(active_check_parser)
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Success",
|
|
||||||
api.model(
|
|
||||||
"ActivationCheckResponse",
|
|
||||||
{
|
|
||||||
"is_valid": fields.Boolean(description="Whether token is valid"),
|
|
||||||
"data": fields.Raw(description="Activation data if valid"),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
def get(self):
|
def get(self):
|
||||||
args = active_check_parser.parse_args()
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args")
|
||||||
|
parser.add_argument("email", type=email, required=False, nullable=True, location="args")
|
||||||
|
parser.add_argument("token", type=str, required=True, nullable=False, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
workspaceId = args["workspace_id"]
|
workspaceId = args["workspace_id"]
|
||||||
reg_email = args["email"]
|
reg_email = args["email"]
|
||||||
@ -60,36 +38,18 @@ class ActivateCheckApi(Resource):
|
|||||||
return {"is_valid": False}
|
return {"is_valid": False}
|
||||||
|
|
||||||
|
|
||||||
active_parser = reqparse.RequestParser()
|
|
||||||
active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
|
||||||
active_parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
|
||||||
active_parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
|
||||||
active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
|
||||||
active_parser.add_argument(
|
|
||||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
|
||||||
)
|
|
||||||
active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/activate")
|
|
||||||
class ActivateApi(Resource):
|
class ActivateApi(Resource):
|
||||||
@api.doc("activate_account")
|
|
||||||
@api.doc(description="Activate account with invitation token")
|
|
||||||
@api.expect(active_parser)
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Account activated successfully",
|
|
||||||
api.model(
|
|
||||||
"ActivationResponse",
|
|
||||||
{
|
|
||||||
"result": fields.String(description="Operation result"),
|
|
||||||
"data": fields.Raw(description="Login token data"),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@api.response(400, "Already activated or invalid token")
|
|
||||||
def post(self):
|
def post(self):
|
||||||
args = active_parser.parse_args()
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument(
|
||||||
|
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
|
parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
|
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
|
||||||
if invitation is None:
|
if invitation is None:
|
||||||
@ -110,3 +70,7 @@ class ActivateApi(Resource):
|
|||||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||||
|
|
||||||
return {"result": "success", "data": token_pair.model_dump()}
|
return {"result": "success", "data": token_pair.model_dump()}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(ActivateCheckApi, "/activate/check")
|
||||||
|
api.add_resource(ActivateApi, "/activate")
|
||||||
|
|||||||
@ -3,11 +3,11 @@ import logging
|
|||||||
import requests
|
import requests
|
||||||
from flask import current_app, redirect, request
|
from flask import current_app, redirect, request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from libs.oauth_data_source import NotionOAuth
|
from libs.oauth_data_source import NotionOAuth
|
||||||
|
|
||||||
@ -28,21 +28,7 @@ def get_oauth_providers():
|
|||||||
return OAUTH_PROVIDERS
|
return OAUTH_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/oauth/data-source/<string:provider>")
|
|
||||||
class OAuthDataSource(Resource):
|
class OAuthDataSource(Resource):
|
||||||
@api.doc("oauth_data_source")
|
|
||||||
@api.doc(description="Get OAuth authorization URL for data source provider")
|
|
||||||
@api.doc(params={"provider": "Data source provider name (notion)"})
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Authorization URL or internal setup success",
|
|
||||||
api.model(
|
|
||||||
"OAuthDataSourceResponse",
|
|
||||||
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@api.response(400, "Invalid provider")
|
|
||||||
@api.response(403, "Admin privileges required")
|
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
# The role of the current user in the table must be admin or owner
|
# The role of the current user in the table must be admin or owner
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
@ -63,19 +49,7 @@ class OAuthDataSource(Resource):
|
|||||||
return {"data": auth_url}, 200
|
return {"data": auth_url}, 200
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/oauth/data-source/callback/<string:provider>")
|
|
||||||
class OAuthDataSourceCallback(Resource):
|
class OAuthDataSourceCallback(Resource):
|
||||||
@api.doc("oauth_data_source_callback")
|
|
||||||
@api.doc(description="Handle OAuth callback from data source provider")
|
|
||||||
@api.doc(
|
|
||||||
params={
|
|
||||||
"provider": "Data source provider name (notion)",
|
|
||||||
"code": "Authorization code from OAuth provider",
|
|
||||||
"error": "Error message from OAuth provider",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
@api.response(302, "Redirect to console with result")
|
|
||||||
@api.response(400, "Invalid provider")
|
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||||
with current_app.app_context():
|
with current_app.app_context():
|
||||||
@ -94,19 +68,7 @@ class OAuthDataSourceCallback(Resource):
|
|||||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied")
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied")
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/oauth/data-source/binding/<string:provider>")
|
|
||||||
class OAuthDataSourceBinding(Resource):
|
class OAuthDataSourceBinding(Resource):
|
||||||
@api.doc("oauth_data_source_binding")
|
|
||||||
@api.doc(description="Bind OAuth data source with authorization code")
|
|
||||||
@api.doc(
|
|
||||||
params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"}
|
|
||||||
)
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Data source binding success",
|
|
||||||
api.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
|
|
||||||
)
|
|
||||||
@api.response(400, "Invalid provider or code")
|
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||||
with current_app.app_context():
|
with current_app.app_context():
|
||||||
@ -119,7 +81,7 @@ class OAuthDataSourceBinding(Resource):
|
|||||||
return {"error": "Invalid code"}, 400
|
return {"error": "Invalid code"}, 400
|
||||||
try:
|
try:
|
||||||
oauth_provider.get_access_token(code)
|
oauth_provider.get_access_token(code)
|
||||||
except requests.HTTPError as e:
|
except requests.exceptions.HTTPError as e:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||||
)
|
)
|
||||||
@ -128,17 +90,7 @@ class OAuthDataSourceBinding(Resource):
|
|||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
|
|
||||||
class OAuthDataSourceSync(Resource):
|
class OAuthDataSourceSync(Resource):
|
||||||
@api.doc("oauth_data_source_sync")
|
|
||||||
@api.doc(description="Sync data from OAuth data source")
|
|
||||||
@api.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"})
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Data source sync success",
|
|
||||||
api.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
|
|
||||||
)
|
|
||||||
@api.response(400, "Invalid provider or sync failed")
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -152,10 +104,16 @@ class OAuthDataSourceSync(Resource):
|
|||||||
return {"error": "Invalid provider"}, 400
|
return {"error": "Invalid provider"}, 400
|
||||||
try:
|
try:
|
||||||
oauth_provider.sync_data_source(binding_id)
|
oauth_provider.sync_data_source(binding_id)
|
||||||
except requests.HTTPError as e:
|
except requests.exceptions.HTTPError as e:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||||
)
|
)
|
||||||
return {"error": "OAuth data source process failed"}, 400
|
return {"error": "OAuth data source process failed"}, 400
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
|
||||||
|
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
|
||||||
|
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")
|
||||||
|
api.add_resource(OAuthDataSourceSync, "/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
|
||||||
|
|||||||
@ -2,12 +2,12 @@ import base64
|
|||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api
|
||||||
from controllers.console.auth.error import (
|
from controllers.console.auth.error import (
|
||||||
EmailCodeError,
|
EmailCodeError,
|
||||||
EmailPasswordResetLimitError,
|
EmailPasswordResetLimitError,
|
||||||
@ -28,32 +28,7 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces
|
|||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/forgot-password")
|
|
||||||
class ForgotPasswordSendEmailApi(Resource):
|
class ForgotPasswordSendEmailApi(Resource):
|
||||||
@api.doc("send_forgot_password_email")
|
|
||||||
@api.doc(description="Send password reset email")
|
|
||||||
@api.expect(
|
|
||||||
api.model(
|
|
||||||
"ForgotPasswordEmailRequest",
|
|
||||||
{
|
|
||||||
"email": fields.String(required=True, description="Email address"),
|
|
||||||
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Email sent successfully",
|
|
||||||
api.model(
|
|
||||||
"ForgotPasswordEmailResponse",
|
|
||||||
{
|
|
||||||
"result": fields.String(description="Operation result"),
|
|
||||||
"data": fields.String(description="Reset token"),
|
|
||||||
"code": fields.String(description="Error code if account not found"),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@api.response(400, "Invalid email or rate limit exceeded")
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
@ -86,33 +61,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||||||
return {"result": "success", "data": token}
|
return {"result": "success", "data": token}
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/forgot-password/validity")
|
|
||||||
class ForgotPasswordCheckApi(Resource):
|
class ForgotPasswordCheckApi(Resource):
|
||||||
@api.doc("check_forgot_password_code")
|
|
||||||
@api.doc(description="Verify password reset code")
|
|
||||||
@api.expect(
|
|
||||||
api.model(
|
|
||||||
"ForgotPasswordCheckRequest",
|
|
||||||
{
|
|
||||||
"email": fields.String(required=True, description="Email address"),
|
|
||||||
"code": fields.String(required=True, description="Verification code"),
|
|
||||||
"token": fields.String(required=True, description="Reset token"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Code verified successfully",
|
|
||||||
api.model(
|
|
||||||
"ForgotPasswordCheckResponse",
|
|
||||||
{
|
|
||||||
"is_valid": fields.Boolean(description="Whether code is valid"),
|
|
||||||
"email": fields.String(description="Email address"),
|
|
||||||
"token": fields.String(description="New reset token"),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@api.response(400, "Invalid code or token")
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
@ -151,26 +100,7 @@ class ForgotPasswordCheckApi(Resource):
|
|||||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/forgot-password/resets")
|
|
||||||
class ForgotPasswordResetApi(Resource):
|
class ForgotPasswordResetApi(Resource):
|
||||||
@api.doc("reset_password")
|
|
||||||
@api.doc(description="Reset password with verification token")
|
|
||||||
@api.expect(
|
|
||||||
api.model(
|
|
||||||
"ForgotPasswordResetRequest",
|
|
||||||
{
|
|
||||||
"token": fields.String(required=True, description="Verification token"),
|
|
||||||
"new_password": fields.String(required=True, description="New password"),
|
|
||||||
"password_confirm": fields.String(required=True, description="Password confirmation"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Password reset successfully",
|
|
||||||
api.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
|
|
||||||
)
|
|
||||||
@api.response(400, "Invalid token or password mismatch")
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
@ -242,3 +172,8 @@ class ForgotPasswordResetApi(Resource):
|
|||||||
pass
|
pass
|
||||||
except AccountRegisterError:
|
except AccountRegisterError:
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
|
||||||
|
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
|
||||||
|
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
|
||||||
|
|||||||
@ -130,7 +130,7 @@ class ResetPasswordSendEmailApi(Resource):
|
|||||||
language = "en-US"
|
language = "en-US"
|
||||||
try:
|
try:
|
||||||
account = AccountService.get_user_through_email(args["email"])
|
account = AccountService.get_user_through_email(args["email"])
|
||||||
except AccountRegisterError:
|
except AccountRegisterError as are:
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
|
|
||||||
if account is None:
|
if account is None:
|
||||||
@ -162,7 +162,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||||||
language = "en-US"
|
language = "en-US"
|
||||||
try:
|
try:
|
||||||
account = AccountService.get_user_through_email(args["email"])
|
account = AccountService.get_user_through_email(args["email"])
|
||||||
except AccountRegisterError:
|
except AccountRegisterError as are:
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
|
|
||||||
if account is None:
|
if account is None:
|
||||||
@ -200,7 +200,7 @@ class EmailCodeLoginApi(Resource):
|
|||||||
AccountService.revoke_email_code_login_token(args["token"])
|
AccountService.revoke_email_code_login_token(args["token"])
|
||||||
try:
|
try:
|
||||||
account = AccountService.get_user_through_email(user_email)
|
account = AccountService.get_user_through_email(user_email)
|
||||||
except AccountRegisterError:
|
except AccountRegisterError as are:
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
if account:
|
if account:
|
||||||
tenants = TenantService.get_join_tenants(account)
|
tenants = TenantService.get_join_tenants(account)
|
||||||
@ -223,7 +223,7 @@ class EmailCodeLoginApi(Resource):
|
|||||||
)
|
)
|
||||||
except WorkSpaceNotAllowedCreateError:
|
except WorkSpaceNotAllowedCreateError:
|
||||||
raise NotAllowedCreateWorkspace()
|
raise NotAllowedCreateWorkspace()
|
||||||
except AccountRegisterError:
|
except AccountRegisterError as are:
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
except WorkspacesLimitExceededError:
|
except WorkspacesLimitExceededError:
|
||||||
raise WorkspacesLimitExceeded()
|
raise WorkspacesLimitExceeded()
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from services.errors.account import AccountNotFoundError, AccountRegisterError
|
|||||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
|
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
from .. import api, console_ns
|
from .. import api
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -50,13 +50,7 @@ def get_oauth_providers():
|
|||||||
return OAUTH_PROVIDERS
|
return OAUTH_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/oauth/login/<provider>")
|
|
||||||
class OAuthLogin(Resource):
|
class OAuthLogin(Resource):
|
||||||
@api.doc("oauth_login")
|
|
||||||
@api.doc(description="Initiate OAuth login process")
|
|
||||||
@api.doc(params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"})
|
|
||||||
@api.response(302, "Redirect to OAuth authorization URL")
|
|
||||||
@api.response(400, "Invalid provider")
|
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
invite_token = request.args.get("invite_token") or None
|
invite_token = request.args.get("invite_token") or None
|
||||||
OAUTH_PROVIDERS = get_oauth_providers()
|
OAUTH_PROVIDERS = get_oauth_providers()
|
||||||
@ -69,19 +63,7 @@ class OAuthLogin(Resource):
|
|||||||
return redirect(auth_url)
|
return redirect(auth_url)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/oauth/authorize/<provider>")
|
|
||||||
class OAuthCallback(Resource):
|
class OAuthCallback(Resource):
|
||||||
@api.doc("oauth_callback")
|
|
||||||
@api.doc(description="Handle OAuth callback and complete login process")
|
|
||||||
@api.doc(
|
|
||||||
params={
|
|
||||||
"provider": "OAuth provider name (github/google)",
|
|
||||||
"code": "Authorization code from OAuth provider",
|
|
||||||
"state": "Optional state parameter (used for invite token)",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
@api.response(302, "Redirect to console with access token")
|
|
||||||
@api.response(400, "OAuth process failed")
|
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
OAUTH_PROVIDERS = get_oauth_providers()
|
OAUTH_PROVIDERS = get_oauth_providers()
|
||||||
with current_app.app_context():
|
with current_app.app_context():
|
||||||
@ -95,19 +77,16 @@ class OAuthCallback(Resource):
|
|||||||
if state:
|
if state:
|
||||||
invite_token = state
|
invite_token = state
|
||||||
|
|
||||||
if not code:
|
|
||||||
return {"error": "Authorization code is required"}, 400
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
token = oauth_provider.get_access_token(code)
|
token = oauth_provider.get_access_token(code)
|
||||||
user_info = oauth_provider.get_user_info(token)
|
user_info = oauth_provider.get_user_info(token)
|
||||||
except requests.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
error_text = e.response.text if e.response else str(e)
|
error_text = e.response.text if e.response else str(e)
|
||||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||||
return {"error": "OAuth process failed"}, 400
|
return {"error": "OAuth process failed"}, 400
|
||||||
|
|
||||||
if invite_token and RegisterService.is_valid_invite_token(invite_token):
|
if invite_token and RegisterService.is_valid_invite_token(invite_token):
|
||||||
invitation = RegisterService.get_invitation_by_token(token=invite_token)
|
invitation = RegisterService._get_invitation_by_token(token=invite_token)
|
||||||
if invitation:
|
if invitation:
|
||||||
invitation_email = invitation.get("email", None)
|
invitation_email = invitation.get("email", None)
|
||||||
if invitation_email != user_info.email:
|
if invitation_email != user_info.email:
|
||||||
@ -202,3 +181,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||||||
AccountService.link_account_integrate(provider, user_info.id, account)
|
AccountService.link_account_integrate(provider, user_info.id, account)
|
||||||
|
|
||||||
return account
|
return account
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(OAuthLogin, "/oauth/login/<provider>")
|
||||||
|
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
from typing import cast
|
||||||
|
|
||||||
import flask_login
|
import flask_login
|
||||||
from flask import jsonify, request
|
from flask import request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import BadRequest, NotFound
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
||||||
@ -16,14 +15,10 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
|
|||||||
|
|
||||||
from .. import api
|
from .. import api
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
def oauth_server_client_id_required(view):
|
||||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
parser.add_argument("client_id", type=str, required=True, location="json")
|
||||||
parsed_args = parser.parse_args()
|
parsed_args = parser.parse_args()
|
||||||
@ -35,53 +30,43 @@ def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderA
|
|||||||
if not oauth_provider_app:
|
if not oauth_provider_app:
|
||||||
raise NotFound("client_id is invalid")
|
raise NotFound("client_id is invalid")
|
||||||
|
|
||||||
return view(self, oauth_provider_app, *args, **kwargs)
|
kwargs["oauth_provider_app"] = oauth_provider_app
|
||||||
|
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
|
def oauth_server_access_token_required(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
if not isinstance(oauth_provider_app, OAuthProviderApp):
|
oauth_provider_app = kwargs.get("oauth_provider_app")
|
||||||
|
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||||
raise BadRequest("Invalid oauth_provider_app")
|
raise BadRequest("Invalid oauth_provider_app")
|
||||||
|
|
||||||
authorization_header = request.headers.get("Authorization")
|
authorization_header = request.headers.get("Authorization")
|
||||||
if not authorization_header:
|
if not authorization_header:
|
||||||
response = jsonify({"error": "Authorization header is required"})
|
raise BadRequest("Authorization header is required")
|
||||||
response.status_code = 401
|
|
||||||
response.headers["WWW-Authenticate"] = "Bearer"
|
|
||||||
return response
|
|
||||||
|
|
||||||
parts = authorization_header.strip().split(None, 1)
|
parts = authorization_header.strip().split(" ")
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
response = jsonify({"error": "Invalid Authorization header format"})
|
raise BadRequest("Invalid Authorization header format")
|
||||||
response.status_code = 401
|
|
||||||
response.headers["WWW-Authenticate"] = "Bearer"
|
|
||||||
return response
|
|
||||||
|
|
||||||
token_type = parts[0].strip()
|
token_type = parts[0].strip()
|
||||||
if token_type.lower() != "bearer":
|
if token_type.lower() != "bearer":
|
||||||
response = jsonify({"error": "token_type is invalid"})
|
raise BadRequest("token_type is invalid")
|
||||||
response.status_code = 401
|
|
||||||
response.headers["WWW-Authenticate"] = "Bearer"
|
|
||||||
return response
|
|
||||||
|
|
||||||
access_token = parts[1].strip()
|
access_token = parts[1].strip()
|
||||||
if not access_token:
|
if not access_token:
|
||||||
response = jsonify({"error": "access_token is required"})
|
raise BadRequest("access_token is required")
|
||||||
response.status_code = 401
|
|
||||||
response.headers["WWW-Authenticate"] = "Bearer"
|
|
||||||
return response
|
|
||||||
|
|
||||||
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
|
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
|
||||||
if not account:
|
if not account:
|
||||||
response = jsonify({"error": "access_token or client_id is invalid"})
|
raise BadRequest("access_token or client_id is invalid")
|
||||||
response.status_code = 401
|
|
||||||
response.headers["WWW-Authenticate"] = "Bearer"
|
|
||||||
return response
|
|
||||||
|
|
||||||
return view(self, oauth_provider_app, account, *args, **kwargs)
|
kwargs["account"] = account
|
||||||
|
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
|
from flask_login import current_user
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||||
from libs.login import current_user, login_required
|
from libs.login import login_required
|
||||||
from models.model import Account
|
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
|
|
||||||
@ -17,10 +17,9 @@ class Subscription(Resource):
|
|||||||
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
||||||
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
assert isinstance(current_user, Account)
|
|
||||||
|
|
||||||
BillingService.is_tenant_owner_or_admin(current_user)
|
BillingService.is_tenant_owner_or_admin(current_user)
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
return BillingService.get_subscription(
|
return BillingService.get_subscription(
|
||||||
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
|
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
|
||||||
)
|
)
|
||||||
@ -32,9 +31,7 @@ class Invoices(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
def get(self):
|
def get(self):
|
||||||
assert isinstance(current_user, Account)
|
|
||||||
BillingService.is_tenant_owner_or_admin(current_user)
|
BillingService.is_tenant_owner_or_admin(current_user)
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,6 @@ from werkzeug.exceptions import NotFound
|
|||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.indexing_runner import IndexingRunner
|
from core.indexing_runner import IndexingRunner
|
||||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -29,12 +28,14 @@ class DataSourceApi(Resource):
|
|||||||
@marshal_with(integrate_list_fields)
|
@marshal_with(integrate_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
# get workspace data source integrates
|
# get workspace data source integrates
|
||||||
data_source_integrates = db.session.scalars(
|
data_source_integrates = (
|
||||||
select(DataSourceOauthBinding).where(
|
db.session.query(DataSourceOauthBinding)
|
||||||
|
.where(
|
||||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||||
DataSourceOauthBinding.disabled == False,
|
DataSourceOauthBinding.disabled == False,
|
||||||
)
|
)
|
||||||
).all()
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
base_url = request.url_root.rstrip("/")
|
base_url = request.url_root.rstrip("/")
|
||||||
data_source_oauth_base_path = "/console/api/oauth/data-source"
|
data_source_oauth_base_path = "/console/api/oauth/data-source"
|
||||||
@ -213,7 +214,7 @@ class DataSourceNotionApi(Resource):
|
|||||||
workspace_id = notion_info["workspace_id"]
|
workspace_id = notion_info["workspace_id"]
|
||||||
for page in notion_info["pages"]:
|
for page in notion_info["pages"]:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.NOTION.value,
|
datasource_type="notion_import",
|
||||||
notion_info={
|
notion_info={
|
||||||
"notion_workspace_id": workspace_id,
|
"notion_workspace_id": workspace_id,
|
||||||
"notion_obj_id": page["page_id"],
|
"notion_obj_id": page["page_id"],
|
||||||
@ -247,7 +248,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
|
|||||||
documents = DocumentService.get_document_by_dataset_id(dataset_id_str)
|
documents = DocumentService.get_document_by_dataset_id(dataset_id_str)
|
||||||
for document in documents:
|
for document in documents:
|
||||||
document_indexing_sync_task.delay(dataset_id_str, document.id)
|
document_indexing_sync_task.delay(dataset_id_str, document.id)
|
||||||
return {"result": "success"}, 200
|
return 200
|
||||||
|
|
||||||
|
|
||||||
class DataSourceNotionDocumentSyncApi(Resource):
|
class DataSourceNotionDocumentSyncApi(Resource):
|
||||||
@ -265,7 +266,7 @@ class DataSourceNotionDocumentSyncApi(Resource):
|
|||||||
if document is None:
|
if document is None:
|
||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
document_indexing_sync_task.delay(dataset_id_str, document_id_str)
|
document_indexing_sync_task.delay(dataset_id_str, document_id_str)
|
||||||
return {"result": "success"}, 200
|
return 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>")
|
api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>")
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import flask_restx
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restx import Resource, marshal, marshal_with, reqparse
|
from flask_restx import Resource, marshal, marshal_with, reqparse
|
||||||
from sqlalchemy import select
|
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
@ -23,7 +22,6 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||||||
from core.plugin.entities.plugin import ModelProviderID
|
from core.plugin.entities.plugin import ModelProviderID
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -412,11 +410,11 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
extract_settings = []
|
extract_settings = []
|
||||||
if args["info_list"]["data_source_type"] == "upload_file":
|
if args["info_list"]["data_source_type"] == "upload_file":
|
||||||
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
||||||
file_details = db.session.scalars(
|
file_details = (
|
||||||
select(UploadFile).where(
|
db.session.query(UploadFile)
|
||||||
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)
|
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
|
||||||
)
|
.all()
|
||||||
).all()
|
)
|
||||||
|
|
||||||
if file_details is None:
|
if file_details is None:
|
||||||
raise NotFound("File not found.")
|
raise NotFound("File not found.")
|
||||||
@ -424,9 +422,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
if file_details:
|
if file_details:
|
||||||
for file_detail in file_details:
|
for file_detail in file_details:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.FILE.value,
|
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
|
||||||
upload_file=file_detail,
|
|
||||||
document_model=args["doc_form"],
|
|
||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
elif args["info_list"]["data_source_type"] == "notion_import":
|
elif args["info_list"]["data_source_type"] == "notion_import":
|
||||||
@ -435,7 +431,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
workspace_id = notion_info["workspace_id"]
|
workspace_id = notion_info["workspace_id"]
|
||||||
for page in notion_info["pages"]:
|
for page in notion_info["pages"]:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.NOTION.value,
|
datasource_type="notion_import",
|
||||||
notion_info={
|
notion_info={
|
||||||
"notion_workspace_id": workspace_id,
|
"notion_workspace_id": workspace_id,
|
||||||
"notion_obj_id": page["page_id"],
|
"notion_obj_id": page["page_id"],
|
||||||
@ -449,7 +445,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
website_info_list = args["info_list"]["website_info_list"]
|
website_info_list = args["info_list"]["website_info_list"]
|
||||||
for url in website_info_list["urls"]:
|
for url in website_info_list["urls"]:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.WEBSITE.value,
|
datasource_type="website_crawl",
|
||||||
website_info={
|
website_info={
|
||||||
"provider": website_info_list["provider"],
|
"provider": website_info_list["provider"],
|
||||||
"job_id": website_info_list["job_id"],
|
"job_id": website_info_list["job_id"],
|
||||||
@ -519,11 +515,11 @@ class DatasetIndexingStatusApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id):
|
def get(self, dataset_id):
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
documents = db.session.scalars(
|
documents = (
|
||||||
select(Document).where(
|
db.session.query(Document)
|
||||||
Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id
|
.where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
|
||||||
)
|
.all()
|
||||||
).all()
|
)
|
||||||
documents_status = []
|
documents_status = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
completed_segments = (
|
completed_segments = (
|
||||||
@ -570,11 +566,11 @@ class DatasetApiKeyApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_key_list)
|
@marshal_with(api_key_list)
|
||||||
def get(self):
|
def get(self):
|
||||||
keys = db.session.scalars(
|
keys = (
|
||||||
select(ApiToken).where(
|
db.session.query(ApiToken)
|
||||||
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id
|
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
|
||||||
)
|
.all()
|
||||||
).all()
|
)
|
||||||
return {"items": keys}
|
return {"items": keys}
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@ -664,6 +660,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||||||
| VectorType.BAIDU
|
| VectorType.BAIDU
|
||||||
| VectorType.VIKINGDB
|
| VectorType.VIKINGDB
|
||||||
| VectorType.UPSTASH
|
| VectorType.UPSTASH
|
||||||
|
| VectorType.PINECONE
|
||||||
):
|
):
|
||||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||||
case (
|
case (
|
||||||
@ -715,6 +712,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||||||
| VectorType.BAIDU
|
| VectorType.BAIDU
|
||||||
| VectorType.VIKINGDB
|
| VectorType.VIKINGDB
|
||||||
| VectorType.UPSTASH
|
| VectorType.UPSTASH
|
||||||
|
| VectorType.PINECONE
|
||||||
):
|
):
|
||||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||||
case (
|
case (
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from argparse import ArgumentTypeError
|
from argparse import ArgumentTypeError
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import Literal, cast
|
from typing import Literal, cast
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
@ -41,7 +40,6 @@ from core.model_manager import ModelManager
|
|||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.document_fields import (
|
from fields.document_fields import (
|
||||||
@ -80,7 +78,7 @@ class DocumentResource(Resource):
|
|||||||
|
|
||||||
return document
|
return document
|
||||||
|
|
||||||
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
|
def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
@ -356,6 +354,9 @@ class DatasetInitApi(Resource):
|
|||||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
|
if not current_user.is_dataset_editor:
|
||||||
|
raise Forbidden()
|
||||||
knowledge_config = KnowledgeConfig(**args)
|
knowledge_config = KnowledgeConfig(**args)
|
||||||
if knowledge_config.indexing_technique == "high_quality":
|
if knowledge_config.indexing_technique == "high_quality":
|
||||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||||
@ -427,7 +428,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||||||
raise NotFound("File not found.")
|
raise NotFound("File not found.")
|
||||||
|
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
|
datasource_type="upload_file", upload_file=file, document_model=document.doc_form
|
||||||
)
|
)
|
||||||
|
|
||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
@ -476,8 +477,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
data_source_info = document.data_source_info_dict
|
data_source_info = document.data_source_info_dict
|
||||||
|
|
||||||
if document.data_source_type == "upload_file":
|
if document.data_source_type == "upload_file":
|
||||||
if not data_source_info:
|
|
||||||
continue
|
|
||||||
file_id = data_source_info["upload_file_id"]
|
file_id = data_source_info["upload_file_id"]
|
||||||
file_detail = (
|
file_detail = (
|
||||||
db.session.query(UploadFile)
|
db.session.query(UploadFile)
|
||||||
@ -489,15 +488,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
raise NotFound("File not found.")
|
raise NotFound("File not found.")
|
||||||
|
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
|
datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form
|
||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
|
|
||||||
elif document.data_source_type == "notion_import":
|
elif document.data_source_type == "notion_import":
|
||||||
if not data_source_info:
|
|
||||||
continue
|
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.NOTION.value,
|
datasource_type="notion_import",
|
||||||
notion_info={
|
notion_info={
|
||||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||||
"notion_obj_id": data_source_info["notion_page_id"],
|
"notion_obj_id": data_source_info["notion_page_id"],
|
||||||
@ -508,10 +505,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
elif document.data_source_type == "website_crawl":
|
elif document.data_source_type == "website_crawl":
|
||||||
if not data_source_info:
|
|
||||||
continue
|
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.WEBSITE.value,
|
datasource_type="website_crawl",
|
||||||
website_info={
|
website_info={
|
||||||
"provider": data_source_info["provider"],
|
"provider": data_source_info["provider"],
|
||||||
"job_id": data_source_info["job_id"],
|
"job_id": data_source_info["job_id"],
|
||||||
|
|||||||
@ -113,7 +113,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
|||||||
MetadataService.enable_built_in_field(dataset)
|
MetadataService.enable_built_in_field(dataset)
|
||||||
elif action == "disable":
|
elif action == "disable":
|
||||||
MetadataService.disable_built_in_field(dataset)
|
MetadataService.disable_built_in_field(dataset)
|
||||||
return {"result": "success"}, 200
|
return 200
|
||||||
|
|
||||||
|
|
||||||
class DocumentMetadataEditApi(Resource):
|
class DocumentMetadataEditApi(Resource):
|
||||||
@ -135,7 +135,7 @@ class DocumentMetadataEditApi(Resource):
|
|||||||
|
|
||||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata")
|
api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata")
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from flask_login import current_user
|
||||||
from flask_restx import reqparse
|
from flask_restx import reqparse
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
@ -27,8 +28,6 @@ from extensions.ext_database import db
|
|||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.login import current_user
|
|
||||||
from models import Account
|
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.llm import InvokeRateLimitError
|
from services.errors.llm import InvokeRateLimitError
|
||||||
@ -58,8 +57,6 @@ class CompletionApi(InstalledAppResource):
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
||||||
)
|
)
|
||||||
@ -93,8 +90,6 @@ class CompletionStopApi(InstalledAppResource):
|
|||||||
if app_model.mode != "completion":
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
@ -122,8 +117,6 @@ class ChatApi(InstalledAppResource):
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||||
)
|
)
|
||||||
@ -160,8 +153,6 @@ class ChatStopApi(InstalledAppResource):
|
|||||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from flask_login import current_user
|
||||||
from flask_restx import marshal_with, reqparse
|
from flask_restx import marshal_with, reqparse
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -9,8 +10,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.login import current_user
|
|
||||||
from models import Account
|
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||||
@ -36,8 +35,6 @@ class ConversationListApi(InstalledAppResource):
|
|||||||
pinned = args["pinned"] == "true"
|
pinned = args["pinned"] == "true"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
return WebConversationService.pagination_by_last_id(
|
return WebConversationService.pagination_by_last_id(
|
||||||
session=session,
|
session=session,
|
||||||
@ -61,11 +58,10 @@ class ConversationApi(InstalledAppResource):
|
|||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
ConversationService.delete(app_model, conversation_id, current_user)
|
ConversationService.delete(app_model, conversation_id, current_user)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||||
|
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
@ -86,8 +82,6 @@ class ConversationRenameApi(InstalledAppResource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
return ConversationService.rename(
|
return ConversationService.rename(
|
||||||
app_model, conversation_id, current_user, args["name"], args["auto_generate"]
|
app_model, conversation_id, current_user, args["name"], args["auto_generate"]
|
||||||
)
|
)
|
||||||
@ -105,8 +99,6 @@ class ConversationPinApi(InstalledAppResource):
|
|||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
WebConversationService.pin(app_model, conversation_id, current_user)
|
WebConversationService.pin(app_model, conversation_id, current_user)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
@ -122,8 +114,6 @@ class ConversationUnPinApi(InstalledAppResource):
|
|||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|||||||
@ -2,8 +2,9 @@ import logging
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
from flask_login import current_user
|
||||||
from flask_restx import Resource, inputs, marshal_with, reqparse
|
from flask_restx import Resource, inputs, marshal_with, reqparse
|
||||||
from sqlalchemy import and_, select
|
from sqlalchemy import and_
|
||||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
@ -12,8 +13,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.installed_app_fields import installed_app_list_fields
|
from fields.installed_app_fields import installed_app_list_fields
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import current_user, login_required
|
from libs.login import login_required
|
||||||
from models import Account, App, InstalledApp, RecommendedApp
|
from models import App, InstalledApp, RecommendedApp
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
@ -28,23 +29,17 @@ class InstalledAppsListApi(Resource):
|
|||||||
@marshal_with(installed_app_list_fields)
|
@marshal_with(installed_app_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
app_id = request.args.get("app_id", default=None, type=str)
|
app_id = request.args.get("app_id", default=None, type=str)
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
current_tenant_id = current_user.current_tenant_id
|
current_tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
if app_id:
|
if app_id:
|
||||||
installed_apps = db.session.scalars(
|
installed_apps = (
|
||||||
select(InstalledApp).where(
|
db.session.query(InstalledApp)
|
||||||
and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)
|
.where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
|
||||||
)
|
.all()
|
||||||
).all()
|
)
|
||||||
else:
|
else:
|
||||||
installed_apps = db.session.scalars(
|
installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all()
|
||||||
select(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id)
|
|
||||||
).all()
|
|
||||||
|
|
||||||
if current_user.current_tenant is None:
|
|
||||||
raise ValueError("current_user.current_tenant must not be None")
|
|
||||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
||||||
installed_app_list: list[dict[str, Any]] = [
|
installed_app_list: list[dict[str, Any]] = [
|
||||||
{
|
{
|
||||||
@ -120,8 +115,6 @@ class InstalledAppsListApi(Resource):
|
|||||||
if recommended_app is None:
|
if recommended_app is None:
|
||||||
raise NotFound("App not found")
|
raise NotFound("App not found")
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
current_tenant_id = current_user.current_tenant_id
|
current_tenant_id = current_user.current_tenant_id
|
||||||
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
||||||
|
|
||||||
@ -161,8 +154,6 @@ class InstalledAppApi(InstalledAppResource):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def delete(self, installed_app):
|
def delete(self, installed_app):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
|
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
|
||||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from flask_login import current_user
|
||||||
from flask_restx import marshal_with, reqparse
|
from flask_restx import marshal_with, reqparse
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
@ -23,8 +24,6 @@ from core.model_runtime.errors.invoke import InvokeError
|
|||||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.login import current_user
|
|
||||||
from models import Account
|
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.app import MoreLikeThisDisabledError
|
from services.errors.app import MoreLikeThisDisabledError
|
||||||
@ -55,8 +54,6 @@ class MessageListApi(InstalledAppResource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
return MessageService.pagination_by_first_id(
|
return MessageService.pagination_by_first_id(
|
||||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||||
)
|
)
|
||||||
@ -78,8 +75,6 @@ class MessageFeedbackApi(InstalledAppResource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
MessageService.create_feedback(
|
MessageService.create_feedback(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
@ -110,8 +105,6 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
|||||||
streaming = args["response_mode"] == "streaming"
|
streaming = args["response_mode"] == "streaming"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
response = AppGenerateService.generate_more_like_this(
|
response = AppGenerateService.generate_more_like_this(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
@ -149,8 +142,6 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
|||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
questions = MessageService.get_suggested_questions_after_answer(
|
questions = MessageService.get_suggested_questions_after_answer(
|
||||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||||
)
|
)
|
||||||
|
|||||||
@ -43,8 +43,6 @@ class ExploreAppMetaApi(InstalledAppResource):
|
|||||||
def get(self, installed_app: InstalledApp):
|
def get(self, installed_app: InstalledApp):
|
||||||
"""Get app meta"""
|
"""Get app meta"""
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if not app_model:
|
|
||||||
raise ValueError("App not found")
|
|
||||||
return AppService().get_app_meta(app_model)
|
return AppService().get_app_meta(app_model)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
|
from flask_login import current_user
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||||
|
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from libs.helper import AppIconUrlField
|
from libs.helper import AppIconUrlField
|
||||||
from libs.login import current_user, login_required
|
from libs.login import login_required
|
||||||
from services.recommended_app_service import RecommendedAppService
|
from services.recommended_app_service import RecommendedAppService
|
||||||
|
|
||||||
app_fields = {
|
app_fields = {
|
||||||
@ -45,9 +46,8 @@ class RecommendedAppListApi(Resource):
|
|||||||
parser.add_argument("language", type=str, location="args")
|
parser.add_argument("language", type=str, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
language = args.get("language")
|
if args.get("language") and args.get("language") in languages:
|
||||||
if language and language in languages:
|
language_prefix = args.get("language")
|
||||||
language_prefix = language
|
|
||||||
elif current_user and current_user.interface_language:
|
elif current_user and current_user.interface_language:
|
||||||
language_prefix = current_user.interface_language
|
language_prefix = current_user.interface_language
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from flask_login import current_user
|
||||||
from flask_restx import fields, marshal_with, reqparse
|
from flask_restx import fields, marshal_with, reqparse
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
@ -7,8 +8,6 @@ from controllers.console.explore.error import NotCompletionAppError
|
|||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from fields.conversation_fields import message_file_fields
|
from fields.conversation_fields import message_file_fields
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
from libs.login import current_user
|
|
||||||
from models import Account
|
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
from services.saved_message_service import SavedMessageService
|
from services.saved_message_service import SavedMessageService
|
||||||
|
|
||||||
@ -43,8 +42,6 @@ class SavedMessageListApi(InstalledAppResource):
|
|||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
|
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
|
||||||
|
|
||||||
def post(self, installed_app):
|
def post(self, installed_app):
|
||||||
@ -57,8 +54,6 @@ class SavedMessageListApi(InstalledAppResource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
SavedMessageService.save(app_model, current_user, args["message_id"])
|
SavedMessageService.save(app_model, current_user, args["message_id"])
|
||||||
except MessageNotExistsError:
|
except MessageNotExistsError:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
@ -75,8 +70,6 @@ class SavedMessageApi(InstalledAppResource):
|
|||||||
if app_model.mode != "completion":
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
SavedMessageService.delete(app_model, current_user, message_id)
|
SavedMessageService.delete(app_model, current_user, message_id)
|
||||||
|
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|||||||
@ -35,8 +35,6 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
|||||||
Run workflow
|
Run workflow
|
||||||
"""
|
"""
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if not app_model:
|
|
||||||
raise NotWorkflowAppError()
|
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode != AppMode.WORKFLOW:
|
if app_mode != AppMode.WORKFLOW:
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
@ -75,8 +73,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
|||||||
Stop workflow task
|
Stop workflow task
|
||||||
"""
|
"""
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if not app_model:
|
|
||||||
raise NotWorkflowAppError()
|
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode != AppMode.WORKFLOW:
|
if app_mode != AppMode.WORKFLOW:
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
|
|||||||
@ -1,6 +1,4 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Concatenate, Optional, ParamSpec, TypeVar
|
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@ -15,15 +13,19 @@ from services.app_service import AppService
|
|||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
def installed_app_required(view=None):
|
||||||
def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
|
def decorator(view):
|
||||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
|
if not kwargs.get("installed_app_id"):
|
||||||
|
raise ValueError("missing installed_app_id in path parameters")
|
||||||
|
|
||||||
|
installed_app_id = kwargs.get("installed_app_id")
|
||||||
|
installed_app_id = str(installed_app_id)
|
||||||
|
|
||||||
|
del kwargs["installed_app_id"]
|
||||||
|
|
||||||
installed_app = (
|
installed_app = (
|
||||||
db.session.query(InstalledApp)
|
db.session.query(InstalledApp)
|
||||||
.where(
|
.where(
|
||||||
@ -50,10 +52,10 @@ def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P],
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
|
def user_allowed_to_access_app(view=None):
|
||||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
def decorator(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
def decorated(installed_app: InstalledApp, *args, **kwargs):
|
||||||
feature = FeatureService.get_system_features()
|
feature = FeatureService.get_system_features()
|
||||||
if feature.webapp_auth.enabled:
|
if feature.webapp_auth.enabled:
|
||||||
app_id = installed_app.app_id
|
app_id = installed_app.app_id
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
|
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from fields.api_based_extension_fields import api_based_extension_fields
|
from fields.api_based_extension_fields import api_based_extension_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
@ -11,21 +11,7 @@ from services.api_based_extension_service import APIBasedExtensionService
|
|||||||
from services.code_based_extension_service import CodeBasedExtensionService
|
from services.code_based_extension_service import CodeBasedExtensionService
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/code-based-extension")
|
|
||||||
class CodeBasedExtensionAPI(Resource):
|
class CodeBasedExtensionAPI(Resource):
|
||||||
@api.doc("get_code_based_extension")
|
|
||||||
@api.doc(description="Get code-based extension data by module name")
|
|
||||||
@api.expect(
|
|
||||||
api.parser().add_argument("module", type=str, required=True, location="args", help="Extension module name")
|
|
||||||
)
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Success",
|
|
||||||
api.model(
|
|
||||||
"CodeBasedExtensionResponse",
|
|
||||||
{"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -37,11 +23,7 @@ class CodeBasedExtensionAPI(Resource):
|
|||||||
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
|
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/api-based-extension")
|
|
||||||
class APIBasedExtensionAPI(Resource):
|
class APIBasedExtensionAPI(Resource):
|
||||||
@api.doc("get_api_based_extensions")
|
|
||||||
@api.doc(description="Get all API-based extensions for current tenant")
|
|
||||||
@api.response(200, "Success", fields.List(fields.Nested(api_based_extension_fields)))
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -50,19 +32,6 @@ class APIBasedExtensionAPI(Resource):
|
|||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||||
|
|
||||||
@api.doc("create_api_based_extension")
|
|
||||||
@api.doc(description="Create a new API-based extension")
|
|
||||||
@api.expect(
|
|
||||||
api.model(
|
|
||||||
"CreateAPIBasedExtensionRequest",
|
|
||||||
{
|
|
||||||
"name": fields.String(required=True, description="Extension name"),
|
|
||||||
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
|
|
||||||
"api_key": fields.String(required=True, description="API key for authentication"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@api.response(201, "Extension created successfully", api_based_extension_fields)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -84,12 +53,7 @@ class APIBasedExtensionAPI(Resource):
|
|||||||
return APIBasedExtensionService.save(extension_data)
|
return APIBasedExtensionService.save(extension_data)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/api-based-extension/<uuid:id>")
|
|
||||||
class APIBasedExtensionDetailAPI(Resource):
|
class APIBasedExtensionDetailAPI(Resource):
|
||||||
@api.doc("get_api_based_extension")
|
|
||||||
@api.doc(description="Get API-based extension by ID")
|
|
||||||
@api.doc(params={"id": "Extension ID"})
|
|
||||||
@api.response(200, "Success", api_based_extension_fields)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -100,20 +64,6 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||||||
|
|
||||||
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
@api.doc("update_api_based_extension")
|
|
||||||
@api.doc(description="Update API-based extension")
|
|
||||||
@api.doc(params={"id": "Extension ID"})
|
|
||||||
@api.expect(
|
|
||||||
api.model(
|
|
||||||
"UpdateAPIBasedExtensionRequest",
|
|
||||||
{
|
|
||||||
"name": fields.String(required=True, description="Extension name"),
|
|
||||||
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
|
|
||||||
"api_key": fields.String(required=True, description="API key for authentication"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@api.response(200, "Extension updated successfully", api_based_extension_fields)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -138,10 +88,6 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||||||
|
|
||||||
return APIBasedExtensionService.save(extension_data_from_db)
|
return APIBasedExtensionService.save(extension_data_from_db)
|
||||||
|
|
||||||
@api.doc("delete_api_based_extension")
|
|
||||||
@api.doc(description="Delete API-based extension")
|
|
||||||
@api.doc(params={"id": "Extension ID"})
|
|
||||||
@api.response(204, "Extension deleted successfully")
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -154,3 +100,9 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||||||
APIBasedExtensionService.delete(extension_data_from_db)
|
APIBasedExtensionService.delete(extension_data_from_db)
|
||||||
|
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(CodeBasedExtensionAPI, "/code-based-extension")
|
||||||
|
|
||||||
|
api.add_resource(APIBasedExtensionAPI, "/api-based-extension")
|
||||||
|
api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/<uuid:id>")
|
||||||
|
|||||||
@ -1,40 +1,26 @@
|
|||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource
|
||||||
|
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
from . import api, console_ns
|
from . import api
|
||||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/features")
|
|
||||||
class FeatureApi(Resource):
|
class FeatureApi(Resource):
|
||||||
@api.doc("get_tenant_features")
|
|
||||||
@api.doc(description="Get feature configuration for current tenant")
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Success",
|
|
||||||
api.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}),
|
|
||||||
)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_utm_record
|
@cloud_utm_record
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Get feature configuration for current tenant"""
|
|
||||||
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
|
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/system-features")
|
|
||||||
class SystemFeatureApi(Resource):
|
class SystemFeatureApi(Resource):
|
||||||
@api.doc("get_system_features")
|
|
||||||
@api.doc(description="Get system-wide feature configuration")
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Success",
|
|
||||||
api.model("SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}),
|
|
||||||
)
|
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Get system-wide feature configuration"""
|
|
||||||
return FeatureService.get_system_features().model_dump()
|
return FeatureService.get_system_features().model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(FeatureApi, "/features")
|
||||||
|
api.add_resource(SystemFeatureApi, "/system-features")
|
||||||
|
|||||||
@ -22,7 +22,6 @@ from controllers.console.wraps import (
|
|||||||
)
|
)
|
||||||
from fields.file_fields import file_fields, upload_config_fields
|
from fields.file_fields import file_fields, upload_config_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models import Account
|
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
PREVIEW_WORDS_LIMIT = 3000
|
PREVIEW_WORDS_LIMIT = 3000
|
||||||
@ -69,8 +68,6 @@ class FileApi(Resource):
|
|||||||
source = None
|
source = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
upload_file = FileService.upload_file(
|
upload_file = FileService.upload_file(
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
content=file.read(),
|
content=file.read(),
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from flask import session
|
from flask import session
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@ -11,47 +11,20 @@ from libs.helper import StrLen
|
|||||||
from models.model import DifySetup
|
from models.model import DifySetup
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
|
|
||||||
from . import api, console_ns
|
from . import api
|
||||||
from .error import AlreadySetupError, InitValidateFailedError
|
from .error import AlreadySetupError, InitValidateFailedError
|
||||||
from .wraps import only_edition_self_hosted
|
from .wraps import only_edition_self_hosted
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/init")
|
|
||||||
class InitValidateAPI(Resource):
|
class InitValidateAPI(Resource):
|
||||||
@api.doc("get_init_status")
|
|
||||||
@api.doc(description="Get initialization validation status")
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Success",
|
|
||||||
model=api.model(
|
|
||||||
"InitStatusResponse",
|
|
||||||
{"status": fields.String(description="Initialization status", enum=["finished", "not_started"])},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Get initialization validation status"""
|
|
||||||
init_status = get_init_validate_status()
|
init_status = get_init_validate_status()
|
||||||
if init_status:
|
if init_status:
|
||||||
return {"status": "finished"}
|
return {"status": "finished"}
|
||||||
return {"status": "not_started"}
|
return {"status": "not_started"}
|
||||||
|
|
||||||
@api.doc("validate_init_password")
|
|
||||||
@api.doc(description="Validate initialization password for self-hosted edition")
|
|
||||||
@api.expect(
|
|
||||||
api.model(
|
|
||||||
"InitValidateRequest",
|
|
||||||
{"password": fields.String(required=True, description="Initialization password", max_length=30)},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@api.response(
|
|
||||||
201,
|
|
||||||
"Success",
|
|
||||||
model=api.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
|
|
||||||
)
|
|
||||||
@api.response(400, "Already setup or validation failed")
|
|
||||||
@only_edition_self_hosted
|
@only_edition_self_hosted
|
||||||
def post(self):
|
def post(self):
|
||||||
"""Validate initialization password"""
|
|
||||||
# is tenant created
|
# is tenant created
|
||||||
tenant_count = TenantService.get_tenant_count()
|
tenant_count = TenantService.get_tenant_count()
|
||||||
if tenant_count > 0:
|
if tenant_count > 0:
|
||||||
@ -79,3 +52,6 @@ def get_init_validate_status():
|
|||||||
return db_session.execute(select(DifySetup)).scalar_one_or_none()
|
return db_session.execute(select(DifySetup)).scalar_one_or_none()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(InitValidateAPI, "/init")
|
||||||
|
|||||||
@ -1,17 +1,14 @@
|
|||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource
|
||||||
|
|
||||||
from . import api, console_ns
|
from controllers.console import api
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/ping")
|
|
||||||
class PingApi(Resource):
|
class PingApi(Resource):
|
||||||
@api.doc("health_check")
|
|
||||||
@api.doc(description="Health check endpoint for connection testing")
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Success",
|
|
||||||
api.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
|
|
||||||
)
|
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Health check endpoint for connection testing"""
|
"""
|
||||||
|
For connection health check
|
||||||
|
"""
|
||||||
return {"result": "pong"}
|
return {"result": "pong"}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(PingApi, "/ping")
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from libs.helper import StrLen, email, extract_remote_ip
|
from libs.helper import StrLen, email, extract_remote_ip
|
||||||
@ -7,56 +7,23 @@ from libs.password import valid_password
|
|||||||
from models.model import DifySetup, db
|
from models.model import DifySetup, db
|
||||||
from services.account_service import RegisterService, TenantService
|
from services.account_service import RegisterService, TenantService
|
||||||
|
|
||||||
from . import api, console_ns
|
from . import api
|
||||||
from .error import AlreadySetupError, NotInitValidateError
|
from .error import AlreadySetupError, NotInitValidateError
|
||||||
from .init_validate import get_init_validate_status
|
from .init_validate import get_init_validate_status
|
||||||
from .wraps import only_edition_self_hosted
|
from .wraps import only_edition_self_hosted
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/setup")
|
|
||||||
class SetupApi(Resource):
|
class SetupApi(Resource):
|
||||||
@api.doc("get_setup_status")
|
|
||||||
@api.doc(description="Get system setup status")
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Success",
|
|
||||||
api.model(
|
|
||||||
"SetupStatusResponse",
|
|
||||||
{
|
|
||||||
"step": fields.String(description="Setup step status", enum=["not_started", "finished"]),
|
|
||||||
"setup_at": fields.String(description="Setup completion time (ISO format)", required=False),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Get system setup status"""
|
|
||||||
if dify_config.EDITION == "SELF_HOSTED":
|
if dify_config.EDITION == "SELF_HOSTED":
|
||||||
setup_status = get_setup_status()
|
setup_status = get_setup_status()
|
||||||
# Check if setup_status is a DifySetup object rather than a bool
|
if setup_status:
|
||||||
if setup_status and not isinstance(setup_status, bool):
|
|
||||||
return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()}
|
return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()}
|
||||||
elif setup_status:
|
|
||||||
return {"step": "finished"}
|
|
||||||
return {"step": "not_started"}
|
return {"step": "not_started"}
|
||||||
return {"step": "finished"}
|
return {"step": "finished"}
|
||||||
|
|
||||||
@api.doc("setup_system")
|
|
||||||
@api.doc(description="Initialize system setup with admin account")
|
|
||||||
@api.expect(
|
|
||||||
api.model(
|
|
||||||
"SetupRequest",
|
|
||||||
{
|
|
||||||
"email": fields.String(required=True, description="Admin email address"),
|
|
||||||
"name": fields.String(required=True, description="Admin name (max 30 characters)"),
|
|
||||||
"password": fields.String(required=True, description="Admin password"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@api.response(201, "Success", api.model("SetupResponse", {"result": fields.String(description="Setup result")}))
|
|
||||||
@api.response(400, "Already setup or validation failed")
|
|
||||||
@only_edition_self_hosted
|
@only_edition_self_hosted
|
||||||
def post(self):
|
def post(self):
|
||||||
"""Initialize system setup with admin account"""
|
|
||||||
# is set up
|
# is set up
|
||||||
if get_setup_status():
|
if get_setup_status():
|
||||||
raise AlreadySetupError()
|
raise AlreadySetupError()
|
||||||
@ -88,3 +55,6 @@ def get_setup_status():
|
|||||||
return db.session.query(DifySetup).first()
|
return db.session.query(DifySetup).first()
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(SetupApi, "/setup")
|
||||||
|
|||||||
@ -111,7 +111,7 @@ class TagBindingCreateApi(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
TagService.save_tag_binding(args)
|
TagService.save_tag_binding(args)
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return 200
|
||||||
|
|
||||||
|
|
||||||
class TagBindingDeleteApi(Resource):
|
class TagBindingDeleteApi(Resource):
|
||||||
@ -132,7 +132,7 @@ class TagBindingDeleteApi(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
TagService.delete_tag_binding(args)
|
TagService.delete_tag_binding(args)
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(TagListApi, "/tags")
|
api.add_resource(TagListApi, "/tags")
|
||||||
|
|||||||
@ -2,41 +2,18 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
|
||||||
from . import api, console_ns
|
from . import api
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/version")
|
|
||||||
class VersionApi(Resource):
|
class VersionApi(Resource):
|
||||||
@api.doc("check_version_update")
|
|
||||||
@api.doc(description="Check for application version updates")
|
|
||||||
@api.expect(
|
|
||||||
api.parser().add_argument(
|
|
||||||
"current_version", type=str, required=True, location="args", help="Current application version"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Success",
|
|
||||||
api.model(
|
|
||||||
"VersionResponse",
|
|
||||||
{
|
|
||||||
"version": fields.String(description="Latest version number"),
|
|
||||||
"release_date": fields.String(description="Release date of latest version"),
|
|
||||||
"release_notes": fields.String(description="Release notes for latest version"),
|
|
||||||
"can_auto_update": fields.Boolean(description="Whether auto-update is supported"),
|
|
||||||
"features": fields.Raw(description="Feature flags and capabilities"),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Check for application version updates"""
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("current_version", type=str, required=True, location="args")
|
parser.add_argument("current_version", type=str, required=True, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -57,14 +34,14 @@ class VersionApi(Resource):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10))
|
response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10))
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.warning("Check update version error: %s.", str(error))
|
logger.warning("Check update version error: %s.", str(error))
|
||||||
result["version"] = args["current_version"]
|
result["version"] = args.get("current_version")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
content = json.loads(response.content)
|
content = json.loads(response.content)
|
||||||
if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"):
|
if _has_new_version(latest_version=content["version"], current_version=f"{args.get('current_version')}"):
|
||||||
result["version"] = content["version"]
|
result["version"] = content["version"]
|
||||||
result["release_date"] = content["releaseDate"]
|
result["release_date"] = content["releaseDate"]
|
||||||
result["release_notes"] = content["releaseNotes"]
|
result["release_notes"] = content["releaseNotes"]
|
||||||
@ -82,3 +59,6 @@ def _has_new_version(*, latest_version: str, current_version: str) -> bool:
|
|||||||
except version.InvalidVersion:
|
except version.InvalidVersion:
|
||||||
logger.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version)
|
logger.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(VersionApi, "/version")
|
||||||
|
|||||||
@ -1,6 +1,4 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import ParamSpec, TypeVar
|
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -9,17 +7,14 @@ from werkzeug.exceptions import Forbidden
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import TenantPluginPermission
|
from models.account import TenantPluginPermission
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
|
||||||
def plugin_permission_required(
|
def plugin_permission_required(
|
||||||
install_required: bool = False,
|
install_required: bool = False,
|
||||||
debug_required: bool = False,
|
debug_required: bool = False,
|
||||||
):
|
):
|
||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
user = current_user
|
user = current_user
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = user.current_tenant_id
|
||||||
|
|
||||||
|
|||||||
@ -49,8 +49,6 @@ class AccountInitApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
if account.status == "active":
|
if account.status == "active":
|
||||||
@ -104,8 +102,6 @@ class AccountProfileApi(Resource):
|
|||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
@ -115,8 +111,6 @@ class AccountNameApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
parser.add_argument("name", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -136,8 +130,6 @@ class AccountAvatarApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("avatar", type=str, required=True, location="json")
|
parser.add_argument("avatar", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -153,8 +145,6 @@ class AccountInterfaceLanguageApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("interface_language", type=supported_language, required=True, location="json")
|
parser.add_argument("interface_language", type=supported_language, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -170,8 +160,6 @@ class AccountInterfaceThemeApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")
|
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -187,8 +175,6 @@ class AccountTimezoneApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("timezone", type=str, required=True, location="json")
|
parser.add_argument("timezone", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -208,8 +194,6 @@ class AccountPasswordApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("password", type=str, required=False, location="json")
|
parser.add_argument("password", type=str, required=False, location="json")
|
||||||
parser.add_argument("new_password", type=str, required=True, location="json")
|
parser.add_argument("new_password", type=str, required=True, location="json")
|
||||||
@ -244,13 +228,9 @@ class AccountIntegrateApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(integrate_list_fields)
|
@marshal_with(integrate_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
account_integrates = db.session.scalars(
|
account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all()
|
||||||
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
|
|
||||||
).all()
|
|
||||||
|
|
||||||
base_url = request.url_root.rstrip("/")
|
base_url = request.url_root.rstrip("/")
|
||||||
oauth_base_path = "/console/api/oauth/login"
|
oauth_base_path = "/console/api/oauth/login"
|
||||||
@ -288,8 +268,6 @@ class AccountDeleteVerifyApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
token, code = AccountService.generate_account_deletion_verification_code(account)
|
token, code = AccountService.generate_account_deletion_verification_code(account)
|
||||||
@ -303,8 +281,6 @@ class AccountDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -345,8 +321,6 @@ class EducationVerifyApi(Resource):
|
|||||||
@cloud_edition_billing_enabled
|
@cloud_edition_billing_enabled
|
||||||
@marshal_with(verify_fields)
|
@marshal_with(verify_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
return BillingService.EducationIdentity.verify(account.id, account.email)
|
return BillingService.EducationIdentity.verify(account.id, account.email)
|
||||||
@ -366,8 +340,6 @@ class EducationApi(Resource):
|
|||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
@cloud_edition_billing_enabled
|
@cloud_edition_billing_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -385,8 +357,6 @@ class EducationApi(Resource):
|
|||||||
@cloud_edition_billing_enabled
|
@cloud_edition_billing_enabled
|
||||||
@marshal_with(status_fields)
|
@marshal_with(status_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
res = BillingService.EducationIdentity.status(account.id)
|
res = BillingService.EducationIdentity.status(account.id)
|
||||||
@ -451,8 +421,6 @@ class ChangeEmailSendEmailApi(Resource):
|
|||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
user_email = reset_data.get("email", "")
|
user_email = reset_data.get("email", "")
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if user_email != current_user.email:
|
if user_email != current_user.email:
|
||||||
raise InvalidEmailError()
|
raise InvalidEmailError()
|
||||||
else:
|
else:
|
||||||
@ -533,8 +501,6 @@ class ChangeEmailResetApi(Resource):
|
|||||||
AccountService.revoke_change_email_token(args["token"])
|
AccountService.revoke_change_email_token(args["token"])
|
||||||
|
|
||||||
old_email = reset_data.get("old_email", "")
|
old_email = reset_data.get("old_email", "")
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if current_user.email != old_email:
|
if current_user.email != old_email:
|
||||||
raise AccountNotFound()
|
raise AccountNotFound()
|
||||||
|
|
||||||
|
|||||||
@ -1,22 +1,14 @@
|
|||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.agent_service import AgentService
|
from services.agent_service import AgentService
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/agent-providers")
|
|
||||||
class AgentProviderListApi(Resource):
|
class AgentProviderListApi(Resource):
|
||||||
@api.doc("list_agent_providers")
|
|
||||||
@api.doc(description="Get list of available agent providers")
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Success",
|
|
||||||
fields.List(fields.Raw(description="Agent provider information")),
|
|
||||||
)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -29,16 +21,7 @@ class AgentProviderListApi(Resource):
|
|||||||
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
|
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
|
|
||||||
class AgentProviderApi(Resource):
|
class AgentProviderApi(Resource):
|
||||||
@api.doc("get_agent_provider")
|
|
||||||
@api.doc(description="Get specific agent provider details")
|
|
||||||
@api.doc(params={"provider_name": "Agent provider name"})
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Success",
|
|
||||||
fields.Raw(description="Agent provider details"),
|
|
||||||
)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -47,3 +30,7 @@ class AgentProviderApi(Resource):
|
|||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = user.current_tenant_id
|
||||||
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
|
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers")
|
||||||
|
api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/<path:provider_name>")
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||||
@ -10,26 +10,7 @@ from libs.login import login_required
|
|||||||
from services.plugin.endpoint_service import EndpointService
|
from services.plugin.endpoint_service import EndpointService
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/endpoints/create")
|
|
||||||
class EndpointCreateApi(Resource):
|
class EndpointCreateApi(Resource):
|
||||||
@api.doc("create_endpoint")
|
|
||||||
@api.doc(description="Create a new plugin endpoint")
|
|
||||||
@api.expect(
|
|
||||||
api.model(
|
|
||||||
"EndpointCreateRequest",
|
|
||||||
{
|
|
||||||
"plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"),
|
|
||||||
"settings": fields.Raw(required=True, description="Endpoint settings"),
|
|
||||||
"name": fields.String(required=True, description="Endpoint name"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Endpoint created successfully",
|
|
||||||
api.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
|
|
||||||
)
|
|
||||||
@api.response(403, "Admin privileges required")
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -62,20 +43,7 @@ class EndpointCreateApi(Resource):
|
|||||||
raise ValueError(e.description) from e
|
raise ValueError(e.description) from e
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/endpoints/list")
|
|
||||||
class EndpointListApi(Resource):
|
class EndpointListApi(Resource):
|
||||||
@api.doc("list_endpoints")
|
|
||||||
@api.doc(description="List plugin endpoints with pagination")
|
|
||||||
@api.expect(
|
|
||||||
api.parser()
|
|
||||||
.add_argument("page", type=int, required=True, location="args", help="Page number")
|
|
||||||
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
|
|
||||||
)
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Success",
|
|
||||||
api.model("EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}),
|
|
||||||
)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -102,23 +70,7 @@ class EndpointListApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/endpoints/list/plugin")
|
|
||||||
class EndpointListForSinglePluginApi(Resource):
|
class EndpointListForSinglePluginApi(Resource):
|
||||||
@api.doc("list_plugin_endpoints")
|
|
||||||
@api.doc(description="List endpoints for a specific plugin")
|
|
||||||
@api.expect(
|
|
||||||
api.parser()
|
|
||||||
.add_argument("page", type=int, required=True, location="args", help="Page number")
|
|
||||||
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
|
|
||||||
.add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID")
|
|
||||||
)
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Success",
|
|
||||||
api.model(
|
|
||||||
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -148,19 +100,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
|
||||||
class EndpointDeleteApi(Resource):
|
class EndpointDeleteApi(Resource):
|
||||||
@api.doc("delete_endpoint")
|
|
||||||
@api.doc(description="Delete a plugin endpoint")
|
|
||||||
@api.expect(
|
|
||||||
api.model("EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
|
|
||||||
)
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Endpoint deleted successfully",
|
|
||||||
api.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
|
|
||||||
)
|
|
||||||
@api.response(403, "Admin privileges required")
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -183,26 +123,7 @@ class EndpointDeleteApi(Resource):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/endpoints/update")
|
|
||||||
class EndpointUpdateApi(Resource):
|
class EndpointUpdateApi(Resource):
|
||||||
@api.doc("update_endpoint")
|
|
||||||
@api.doc(description="Update a plugin endpoint")
|
|
||||||
@api.expect(
|
|
||||||
api.model(
|
|
||||||
"EndpointUpdateRequest",
|
|
||||||
{
|
|
||||||
"endpoint_id": fields.String(required=True, description="Endpoint ID"),
|
|
||||||
"settings": fields.Raw(required=True, description="Updated settings"),
|
|
||||||
"name": fields.String(required=True, description="Updated name"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Endpoint updated successfully",
|
|
||||||
api.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
|
|
||||||
)
|
|
||||||
@api.response(403, "Admin privileges required")
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -233,19 +154,7 @@ class EndpointUpdateApi(Resource):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/endpoints/enable")
|
|
||||||
class EndpointEnableApi(Resource):
|
class EndpointEnableApi(Resource):
|
||||||
@api.doc("enable_endpoint")
|
|
||||||
@api.doc(description="Enable a plugin endpoint")
|
|
||||||
@api.expect(
|
|
||||||
api.model("EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
|
|
||||||
)
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Endpoint enabled successfully",
|
|
||||||
api.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
|
|
||||||
)
|
|
||||||
@api.response(403, "Admin privileges required")
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -268,19 +177,7 @@ class EndpointEnableApi(Resource):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/endpoints/disable")
|
|
||||||
class EndpointDisableApi(Resource):
|
class EndpointDisableApi(Resource):
|
||||||
@api.doc("disable_endpoint")
|
|
||||||
@api.doc(description="Disable a plugin endpoint")
|
|
||||||
@api.expect(
|
|
||||||
api.model("EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
|
|
||||||
)
|
|
||||||
@api.response(
|
|
||||||
200,
|
|
||||||
"Endpoint disabled successfully",
|
|
||||||
api.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
|
|
||||||
)
|
|
||||||
@api.response(403, "Admin privileges required")
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -301,3 +198,12 @@ class EndpointDisableApi(Resource):
|
|||||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create")
|
||||||
|
api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list")
|
||||||
|
api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin")
|
||||||
|
api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete")
|
||||||
|
api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update")
|
||||||
|
api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable")
|
||||||
|
api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable")
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from urllib import parse
|
from urllib import parse
|
||||||
|
|
||||||
from flask import abort, request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, abort, marshal_with, reqparse
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -41,10 +41,6 @@ class MemberListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_with_role_list_fields)
|
@marshal_with(account_with_role_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||||
return {"result": "success", "accounts": members}, 200
|
return {"result": "success", "accounts": members}, 200
|
||||||
|
|
||||||
@ -69,11 +65,7 @@ class MemberInviteEmailApi(Resource):
|
|||||||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
inviter = current_user
|
inviter = current_user
|
||||||
if not inviter.current_tenant:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
invitation_results = []
|
invitation_results = []
|
||||||
console_web_url = dify_config.CONSOLE_WEB_URL
|
console_web_url = dify_config.CONSOLE_WEB_URL
|
||||||
|
|
||||||
@ -84,8 +76,6 @@ class MemberInviteEmailApi(Resource):
|
|||||||
|
|
||||||
for invitee_email in invitee_emails:
|
for invitee_email in invitee_emails:
|
||||||
try:
|
try:
|
||||||
if not inviter.current_tenant:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
token = RegisterService.invite_new_member(
|
token = RegisterService.invite_new_member(
|
||||||
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
|
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
|
||||||
)
|
)
|
||||||
@ -107,7 +97,7 @@ class MemberInviteEmailApi(Resource):
|
|||||||
return {
|
return {
|
||||||
"result": "success",
|
"result": "success",
|
||||||
"invitation_results": invitation_results,
|
"invitation_results": invitation_results,
|
||||||
"tenant_id": str(inviter.current_tenant.id) if inviter.current_tenant else "",
|
"tenant_id": str(current_user.current_tenant.id),
|
||||||
}, 201
|
}, 201
|
||||||
|
|
||||||
|
|
||||||
@ -118,10 +108,6 @@ class MemberCancelInviteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, member_id):
|
def delete(self, member_id):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
member = db.session.query(Account).where(Account.id == str(member_id)).first()
|
member = db.session.query(Account).where(Account.id == str(member_id)).first()
|
||||||
if member is None:
|
if member is None:
|
||||||
abort(404)
|
abort(404)
|
||||||
@ -137,10 +123,7 @@ class MemberCancelInviteApi(Resource):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(str(e))
|
raise ValueError(str(e))
|
||||||
|
|
||||||
return {
|
return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200
|
||||||
"result": "success",
|
|
||||||
"tenant_id": str(current_user.current_tenant.id) if current_user.current_tenant else "",
|
|
||||||
}, 200
|
|
||||||
|
|
||||||
|
|
||||||
class MemberUpdateRoleApi(Resource):
|
class MemberUpdateRoleApi(Resource):
|
||||||
@ -158,10 +141,6 @@ class MemberUpdateRoleApi(Resource):
|
|||||||
if not TenantAccountRole.is_valid_role(new_role):
|
if not TenantAccountRole.is_valid_role(new_role):
|
||||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
member = db.session.get(Account, str(member_id))
|
member = db.session.get(Account, str(member_id))
|
||||||
if not member:
|
if not member:
|
||||||
abort(404)
|
abort(404)
|
||||||
@ -185,10 +164,6 @@ class DatasetOperatorMemberListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_with_role_list_fields)
|
@marshal_with(account_with_role_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||||
return {"result": "success", "accounts": members}, 200
|
return {"result": "success", "accounts": members}, 200
|
||||||
|
|
||||||
@ -209,10 +184,6 @@ class SendOwnerTransferEmailApi(Resource):
|
|||||||
raise EmailSendIpLimitError()
|
raise EmailSendIpLimitError()
|
||||||
|
|
||||||
# check if the current user is the owner of the workspace
|
# check if the current user is the owner of the workspace
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||||
raise NotOwnerError()
|
raise NotOwnerError()
|
||||||
|
|
||||||
@ -227,7 +198,7 @@ class SendOwnerTransferEmailApi(Resource):
|
|||||||
account=current_user,
|
account=current_user,
|
||||||
email=email,
|
email=email,
|
||||||
language=language,
|
language=language,
|
||||||
workspace_name=current_user.current_tenant.name if current_user.current_tenant else "",
|
workspace_name=current_user.current_tenant.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"result": "success", "data": token}
|
return {"result": "success", "data": token}
|
||||||
@ -244,10 +215,6 @@ class OwnerTransferCheckApi(Resource):
|
|||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# check if the current user is the owner of the workspace
|
# check if the current user is the owner of the workspace
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||||
raise NotOwnerError()
|
raise NotOwnerError()
|
||||||
|
|
||||||
@ -289,10 +256,6 @@ class OwnerTransfer(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# check if the current user is the owner of the workspace
|
# check if the current user is the owner of the workspace
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||||
raise NotOwnerError()
|
raise NotOwnerError()
|
||||||
|
|
||||||
@ -311,11 +274,9 @@ class OwnerTransfer(Resource):
|
|||||||
member = db.session.get(Account, str(member_id))
|
member = db.session.get(Account, str(member_id))
|
||||||
if not member:
|
if not member:
|
||||||
abort(404)
|
abort(404)
|
||||||
return # Never reached, but helps type checker
|
else:
|
||||||
|
member_account = member
|
||||||
if not current_user.current_tenant:
|
if not TenantService.is_member(member_account, current_user.current_tenant):
|
||||||
raise ValueError("No current tenant")
|
|
||||||
if not TenantService.is_member(member, current_user.current_tenant):
|
|
||||||
raise MemberNotInTenantError()
|
raise MemberNotInTenantError()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -325,13 +286,13 @@ class OwnerTransfer(Resource):
|
|||||||
AccountService.send_new_owner_transfer_notify_email(
|
AccountService.send_new_owner_transfer_notify_email(
|
||||||
account=member,
|
account=member,
|
||||||
email=member.email,
|
email=member.email,
|
||||||
workspace_name=current_user.current_tenant.name if current_user.current_tenant else "",
|
workspace_name=current_user.current_tenant.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
AccountService.send_old_owner_transfer_notify_email(
|
AccountService.send_old_owner_transfer_notify_email(
|
||||||
account=current_user,
|
account=current_user,
|
||||||
email=current_user.email,
|
email=current_user.email,
|
||||||
workspace_name=current_user.current_tenant.name if current_user.current_tenant else "",
|
workspace_name=current_user.current_tenant.name,
|
||||||
new_owner_email=member.email,
|
new_owner_email=member.email,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from libs.helper import StrLen, uuid_value
|
from libs.helper import StrLen, uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.account import Account
|
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
from services.model_provider_service import ModelProviderService
|
from services.model_provider_service import ModelProviderService
|
||||||
|
|
||||||
@ -22,10 +21,6 @@ class ModelProviderListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -50,10 +45,6 @@ class ModelProviderCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
# if credential_id is not provided, return current used credential
|
# if credential_id is not provided, return current used credential
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -71,20 +62,16 @@ class ModelProviderCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
try:
|
try:
|
||||||
model_provider_service.create_provider_credential(
|
model_provider_service.create_provider_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
@ -101,21 +88,17 @@ class ModelProviderCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def put(self, provider: str):
|
def put(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
try:
|
try:
|
||||||
model_provider_service.update_provider_credential(
|
model_provider_service.update_provider_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
@ -133,16 +116,12 @@ class ModelProviderCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, provider: str):
|
def delete(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
model_provider_service.remove_provider_credential(
|
model_provider_service.remove_provider_credential(
|
||||||
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
||||||
@ -156,16 +135,12 @@ class ModelProviderCredentialSwitchApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
service = ModelProviderService()
|
service = ModelProviderService()
|
||||||
service.switch_active_provider_credential(
|
service.switch_active_provider_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
@ -175,35 +150,15 @@ class ModelProviderCredentialSwitchApi(Resource):
|
|||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderCredentialCancelApi(Resource):
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
def post(self, provider: str):
|
|
||||||
if not current_user.is_admin_or_owner:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
service = ModelProviderService()
|
|
||||||
service.cancel_provider_credential(
|
|
||||||
tenant_id=current_user.current_tenant_id,
|
|
||||||
provider=provider,
|
|
||||||
)
|
|
||||||
return {"result": "success"}
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderValidateApi(Resource):
|
class ModelProviderValidateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
@ -250,13 +205,9 @@ class PreferredProviderTypeUpdateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -285,11 +236,7 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
|||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
if provider != "anthropic":
|
if provider != "anthropic":
|
||||||
raise ValueError(f"provider name {provider} is invalid")
|
raise ValueError(f"provider name {provider} is invalid")
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
BillingService.is_tenant_owner_or_admin(current_user)
|
BillingService.is_tenant_owner_or_admin(current_user)
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
data = BillingService.get_model_provider_payment_link(
|
data = BillingService.get_model_provider_payment_link(
|
||||||
provider_name=provider,
|
provider_name=provider,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
@ -305,9 +252,6 @@ api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-provider
|
|||||||
api.add_resource(
|
api.add_resource(
|
||||||
ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
|
ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
|
||||||
)
|
)
|
||||||
api.add_resource(
|
|
||||||
ModelProviderCredentialCancelApi, "/workspaces/current/model-providers/<path:provider>/credentials/cancel"
|
|
||||||
)
|
|
||||||
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||||
|
|
||||||
api.add_resource(
|
api.add_resource(
|
||||||
|
|||||||
@ -219,11 +219,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
|
|
||||||
model_load_balancing_service = ModelLoadBalancingService()
|
model_load_balancing_service = ModelLoadBalancingService()
|
||||||
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
|
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||||
provider=provider,
|
|
||||||
model=args["model"],
|
|
||||||
model_type=args["model_type"],
|
|
||||||
config_from=args.get("config_from", ""),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.get("config_from", "") == "predefined-model":
|
if args.get("config_from", "") == "predefined-model":
|
||||||
@ -267,7 +263,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
choices=[mt.value for mt in ModelType],
|
choices=[mt.value for mt in ModelType],
|
||||||
location="json",
|
location="json",
|
||||||
)
|
)
|
||||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -313,7 +309,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
)
|
)
|
||||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|||||||
@ -865,7 +865,6 @@ class ToolProviderMCPApi(Resource):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
|
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
|
||||||
)
|
)
|
||||||
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
user = current_user
|
user = current_user
|
||||||
if not is_valid_url(args["server_url"]):
|
if not is_valid_url(args["server_url"]):
|
||||||
@ -882,7 +881,6 @@ class ToolProviderMCPApi(Resource):
|
|||||||
server_identifier=args["server_identifier"],
|
server_identifier=args["server_identifier"],
|
||||||
timeout=args["timeout"],
|
timeout=args["timeout"],
|
||||||
sse_read_timeout=args["sse_read_timeout"],
|
sse_read_timeout=args["sse_read_timeout"],
|
||||||
headers=args["headers"],
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -900,7 +898,6 @@ class ToolProviderMCPApi(Resource):
|
|||||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
|
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
|
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if not is_valid_url(args["server_url"]):
|
if not is_valid_url(args["server_url"]):
|
||||||
if "[__HIDDEN__]" in args["server_url"]:
|
if "[__HIDDEN__]" in args["server_url"]:
|
||||||
@ -918,7 +915,6 @@ class ToolProviderMCPApi(Resource):
|
|||||||
server_identifier=args["server_identifier"],
|
server_identifier=args["server_identifier"],
|
||||||
timeout=args.get("timeout"),
|
timeout=args.get("timeout"),
|
||||||
sse_read_timeout=args.get("sse_read_timeout"),
|
sse_read_timeout=args.get("sse_read_timeout"),
|
||||||
headers=args.get("headers"),
|
|
||||||
)
|
)
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@ -955,9 +951,6 @@ class ToolMCPAuthApi(Resource):
|
|||||||
authed=False,
|
authed=False,
|
||||||
authorization_code=args["authorization_code"],
|
authorization_code=args["authorization_code"],
|
||||||
for_list=True,
|
for_list=True,
|
||||||
headers=provider.decrypted_headers,
|
|
||||||
timeout=provider.timeout,
|
|
||||||
sse_read_timeout=provider.sse_read_timeout,
|
|
||||||
):
|
):
|
||||||
MCPToolManageService.update_mcp_provider_credentials(
|
MCPToolManageService.update_mcp_provider_credentials(
|
||||||
mcp_provider=provider,
|
mcp_provider=provider,
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from controllers.console.wraps import (
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.account import Account, Tenant, TenantStatus
|
from models.account import Tenant, TenantStatus
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
@ -70,8 +70,6 @@ class TenantListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
tenants = TenantService.get_join_tenants(current_user)
|
tenants = TenantService.get_join_tenants(current_user)
|
||||||
tenant_dicts = []
|
tenant_dicts = []
|
||||||
|
|
||||||
@ -85,7 +83,7 @@ class TenantListApi(Resource):
|
|||||||
"status": tenant.status,
|
"status": tenant.status,
|
||||||
"created_at": tenant.created_at,
|
"created_at": tenant.created_at,
|
||||||
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
|
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
|
||||||
"current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False,
|
"current": tenant.id == current_user.current_tenant_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
tenant_dicts.append(tenant_dict)
|
tenant_dicts.append(tenant_dict)
|
||||||
@ -127,11 +125,7 @@ class TenantApi(Resource):
|
|||||||
if request.path == "/info":
|
if request.path == "/info":
|
||||||
logger.warning("Deprecated URL /info was used.")
|
logger.warning("Deprecated URL /info was used.")
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
tenant = current_user.current_tenant
|
tenant = current_user.current_tenant
|
||||||
if not tenant:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
|
|
||||||
if tenant.status == TenantStatus.ARCHIVE:
|
if tenant.status == TenantStatus.ARCHIVE:
|
||||||
tenants = TenantService.get_join_tenants(current_user)
|
tenants = TenantService.get_join_tenants(current_user)
|
||||||
@ -143,8 +137,6 @@ class TenantApi(Resource):
|
|||||||
else:
|
else:
|
||||||
raise Unauthorized("workspace is archived")
|
raise Unauthorized("workspace is archived")
|
||||||
|
|
||||||
if not tenant:
|
|
||||||
raise ValueError("No tenant available")
|
|
||||||
return WorkspaceService.get_tenant_info(tenant), 200
|
return WorkspaceService.get_tenant_info(tenant), 200
|
||||||
|
|
||||||
|
|
||||||
@ -153,8 +145,6 @@ class SwitchWorkspaceApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("tenant_id", type=str, required=True, location="json")
|
parser.add_argument("tenant_id", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -178,15 +168,11 @@ class CustomConfigWorkspaceApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("workspace_custom")
|
@cloud_edition_billing_resource_check("workspace_custom")
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("remove_webapp_brand", type=bool, location="json")
|
parser.add_argument("remove_webapp_brand", type=bool, location="json")
|
||||||
parser.add_argument("replace_webapp_logo", type=str, location="json")
|
parser.add_argument("replace_webapp_logo", type=str, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
|
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
|
||||||
|
|
||||||
custom_config_dict = {
|
custom_config_dict = {
|
||||||
@ -208,8 +194,6 @@ class WebappLogoWorkspaceApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("workspace_custom")
|
@cloud_edition_billing_resource_check("workspace_custom")
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
# check file
|
# check file
|
||||||
if "file" not in request.files:
|
if "file" not in request.files:
|
||||||
raise NoFileUploadedError()
|
raise NoFileUploadedError()
|
||||||
@ -248,14 +232,10 @@ class WorkspaceInfoApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
# Change workspace name
|
# Change workspace name
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
parser.add_argument("name", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
|
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
|
||||||
tenant.name = args["name"]
|
tenant.name = args["name"]
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|||||||
@ -2,9 +2,7 @@ import contextlib
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import ParamSpec, TypeVar
|
|
||||||
|
|
||||||
from flask import abort, request
|
from flask import abort, request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
@ -21,13 +19,10 @@ from services.operation_service import OperationService
|
|||||||
|
|
||||||
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
def account_initialization_required(view):
|
||||||
def account_initialization_required(view: Callable[P, R]):
|
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
# check account initialization
|
# check account initialization
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
@ -39,9 +34,9 @@ def account_initialization_required(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def only_edition_cloud(view: Callable[P, R]):
|
def only_edition_cloud(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
if dify_config.EDITION != "CLOUD":
|
if dify_config.EDITION != "CLOUD":
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
@ -50,9 +45,9 @@ def only_edition_cloud(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def only_edition_enterprise(view: Callable[P, R]):
|
def only_edition_enterprise(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
if not dify_config.ENTERPRISE_ENABLED:
|
if not dify_config.ENTERPRISE_ENABLED:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
@ -61,9 +56,9 @@ def only_edition_enterprise(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def only_edition_self_hosted(view: Callable[P, R]):
|
def only_edition_self_hosted(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
if dify_config.EDITION != "SELF_HOSTED":
|
if dify_config.EDITION != "SELF_HOSTED":
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
@ -72,9 +67,9 @@ def only_edition_self_hosted(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
def cloud_edition_billing_enabled(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
if not features.billing.enabled:
|
if not features.billing.enabled:
|
||||||
abort(403, "Billing feature is not enabled.")
|
abort(403, "Billing feature is not enabled.")
|
||||||
@ -84,9 +79,9 @@ def cloud_edition_billing_enabled(view: Callable[P, R]):
|
|||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_resource_check(resource: str):
|
def cloud_edition_billing_resource_check(resource: str):
|
||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
members = features.members
|
members = features.members
|
||||||
@ -125,9 +120,9 @@ def cloud_edition_billing_resource_check(resource: str):
|
|||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_knowledge_limit_check(resource: str):
|
def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
if resource == "add_segment":
|
if resource == "add_segment":
|
||||||
@ -147,9 +142,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
|||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_rate_limit_check(resource: str):
|
def cloud_edition_billing_rate_limit_check(resource: str):
|
||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
if resource == "knowledge":
|
if resource == "knowledge":
|
||||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
||||||
if knowledge_rate_limit.enabled:
|
if knowledge_rate_limit.enabled:
|
||||||
@ -181,9 +176,9 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||||||
return interceptor
|
return interceptor
|
||||||
|
|
||||||
|
|
||||||
def cloud_utm_record(view: Callable[P, R]):
|
def cloud_utm_record(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
|
|
||||||
@ -199,9 +194,9 @@ def cloud_utm_record(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def setup_required(view: Callable[P, R]):
|
def setup_required(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
# check setup
|
# check setup
|
||||||
if (
|
if (
|
||||||
dify_config.EDITION == "SELF_HOSTED"
|
dify_config.EDITION == "SELF_HOSTED"
|
||||||
@ -217,9 +212,9 @@ def setup_required(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def enterprise_license_required(view: Callable[P, R]):
|
def enterprise_license_required(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
settings = FeatureService.get_system_features()
|
settings = FeatureService.get_system_features()
|
||||||
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
|
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
|
||||||
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
|
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
|
||||||
@ -229,9 +224,9 @@ def enterprise_license_required(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def email_password_login_enabled(view: Callable[P, R]):
|
def email_password_login_enabled(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
features = FeatureService.get_system_features()
|
features = FeatureService.get_system_features()
|
||||||
if features.enable_email_password_login:
|
if features.enable_email_password_login:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
@ -242,9 +237,9 @@ def email_password_login_enabled(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def enable_change_email(view: Callable[P, R]):
|
def enable_change_email(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
features = FeatureService.get_system_features()
|
features = FeatureService.get_system_features()
|
||||||
if features.enable_change_email:
|
if features.enable_change_email:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
@ -255,9 +250,9 @@ def enable_change_email(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def is_allow_transfer_owner(view: Callable[P, R]):
|
def is_allow_transfer_owner(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
if features.is_allow_transfer_workspace:
|
if features.is_allow_transfer_workspace:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|||||||
@ -10,10 +10,11 @@ api = ExternalApi(
|
|||||||
version="1.0",
|
version="1.0",
|
||||||
title="Files API",
|
title="Files API",
|
||||||
description="API for file operations including upload and preview",
|
description="API for file operations including upload and preview",
|
||||||
|
doc="/docs", # Enable Swagger UI at /files/docs
|
||||||
)
|
)
|
||||||
|
|
||||||
files_ns = Namespace("files", description="File operations", path="/")
|
files_ns = Namespace("files", description="File operations", path="/")
|
||||||
|
|
||||||
from . import image_preview, tool_files, upload # pyright: ignore[reportUnusedImport]
|
from . import image_preview, tool_files, upload
|
||||||
|
|
||||||
api.add_namespace(files_ns)
|
api.add_namespace(files_ns)
|
||||||
|
|||||||
@ -10,13 +10,14 @@ api = ExternalApi(
|
|||||||
version="1.0",
|
version="1.0",
|
||||||
title="Inner API",
|
title="Inner API",
|
||||||
description="Internal APIs for enterprise features, billing, and plugin communication",
|
description="Internal APIs for enterprise features, billing, and plugin communication",
|
||||||
|
doc="/docs", # Enable Swagger UI at /inner/api/docs
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create namespace
|
# Create namespace
|
||||||
inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/")
|
inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/")
|
||||||
|
|
||||||
from . import mail as _mail # pyright: ignore[reportUnusedImport]
|
from . import mail
|
||||||
from .plugin import plugin as _plugin # pyright: ignore[reportUnusedImport]
|
from .plugin import plugin
|
||||||
from .workspace import workspace as _workspace # pyright: ignore[reportUnusedImport]
|
from .workspace import workspace
|
||||||
|
|
||||||
api.add_namespace(inner_api_ns)
|
api.add_namespace(inner_api_ns)
|
||||||
|
|||||||
@ -37,9 +37,9 @@ from models.model import EndUser
|
|||||||
|
|
||||||
@inner_api_ns.route("/invoke/llm")
|
@inner_api_ns.route("/invoke/llm")
|
||||||
class PluginInvokeLLMApi(Resource):
|
class PluginInvokeLLMApi(Resource):
|
||||||
@get_user_tenant
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
@plugin_data(payload_type=RequestInvokeLLM)
|
@plugin_data(payload_type=RequestInvokeLLM)
|
||||||
@inner_api_ns.doc("plugin_invoke_llm")
|
@inner_api_ns.doc("plugin_invoke_llm")
|
||||||
@inner_api_ns.doc(description="Invoke LLM models through plugin interface")
|
@inner_api_ns.doc(description="Invoke LLM models through plugin interface")
|
||||||
@ -60,9 +60,9 @@ class PluginInvokeLLMApi(Resource):
|
|||||||
|
|
||||||
@inner_api_ns.route("/invoke/llm/structured-output")
|
@inner_api_ns.route("/invoke/llm/structured-output")
|
||||||
class PluginInvokeLLMWithStructuredOutputApi(Resource):
|
class PluginInvokeLLMWithStructuredOutputApi(Resource):
|
||||||
@get_user_tenant
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
@plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput)
|
@plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput)
|
||||||
@inner_api_ns.doc("plugin_invoke_llm_structured")
|
@inner_api_ns.doc("plugin_invoke_llm_structured")
|
||||||
@inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface")
|
@inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface")
|
||||||
@ -85,9 +85,9 @@ class PluginInvokeLLMWithStructuredOutputApi(Resource):
|
|||||||
|
|
||||||
@inner_api_ns.route("/invoke/text-embedding")
|
@inner_api_ns.route("/invoke/text-embedding")
|
||||||
class PluginInvokeTextEmbeddingApi(Resource):
|
class PluginInvokeTextEmbeddingApi(Resource):
|
||||||
@get_user_tenant
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
@plugin_data(payload_type=RequestInvokeTextEmbedding)
|
@plugin_data(payload_type=RequestInvokeTextEmbedding)
|
||||||
@inner_api_ns.doc("plugin_invoke_text_embedding")
|
@inner_api_ns.doc("plugin_invoke_text_embedding")
|
||||||
@inner_api_ns.doc(description="Invoke text embedding models through plugin interface")
|
@inner_api_ns.doc(description="Invoke text embedding models through plugin interface")
|
||||||
@ -115,9 +115,9 @@ class PluginInvokeTextEmbeddingApi(Resource):
|
|||||||
|
|
||||||
@inner_api_ns.route("/invoke/rerank")
|
@inner_api_ns.route("/invoke/rerank")
|
||||||
class PluginInvokeRerankApi(Resource):
|
class PluginInvokeRerankApi(Resource):
|
||||||
@get_user_tenant
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
@plugin_data(payload_type=RequestInvokeRerank)
|
@plugin_data(payload_type=RequestInvokeRerank)
|
||||||
@inner_api_ns.doc("plugin_invoke_rerank")
|
@inner_api_ns.doc("plugin_invoke_rerank")
|
||||||
@inner_api_ns.doc(description="Invoke rerank models through plugin interface")
|
@inner_api_ns.doc(description="Invoke rerank models through plugin interface")
|
||||||
@ -141,9 +141,9 @@ class PluginInvokeRerankApi(Resource):
|
|||||||
|
|
||||||
@inner_api_ns.route("/invoke/tts")
|
@inner_api_ns.route("/invoke/tts")
|
||||||
class PluginInvokeTTSApi(Resource):
|
class PluginInvokeTTSApi(Resource):
|
||||||
@get_user_tenant
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
@plugin_data(payload_type=RequestInvokeTTS)
|
@plugin_data(payload_type=RequestInvokeTTS)
|
||||||
@inner_api_ns.doc("plugin_invoke_tts")
|
@inner_api_ns.doc("plugin_invoke_tts")
|
||||||
@inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface")
|
@inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface")
|
||||||
@ -168,9 +168,9 @@ class PluginInvokeTTSApi(Resource):
|
|||||||
|
|
||||||
@inner_api_ns.route("/invoke/speech2text")
|
@inner_api_ns.route("/invoke/speech2text")
|
||||||
class PluginInvokeSpeech2TextApi(Resource):
|
class PluginInvokeSpeech2TextApi(Resource):
|
||||||
@get_user_tenant
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
@plugin_data(payload_type=RequestInvokeSpeech2Text)
|
@plugin_data(payload_type=RequestInvokeSpeech2Text)
|
||||||
@inner_api_ns.doc("plugin_invoke_speech2text")
|
@inner_api_ns.doc("plugin_invoke_speech2text")
|
||||||
@inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface")
|
@inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface")
|
||||||
@ -194,9 +194,9 @@ class PluginInvokeSpeech2TextApi(Resource):
|
|||||||
|
|
||||||
@inner_api_ns.route("/invoke/moderation")
|
@inner_api_ns.route("/invoke/moderation")
|
||||||
class PluginInvokeModerationApi(Resource):
|
class PluginInvokeModerationApi(Resource):
|
||||||
@get_user_tenant
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
@plugin_data(payload_type=RequestInvokeModeration)
|
@plugin_data(payload_type=RequestInvokeModeration)
|
||||||
@inner_api_ns.doc("plugin_invoke_moderation")
|
@inner_api_ns.doc("plugin_invoke_moderation")
|
||||||
@inner_api_ns.doc(description="Invoke moderation models through plugin interface")
|
@inner_api_ns.doc(description="Invoke moderation models through plugin interface")
|
||||||
@ -220,9 +220,9 @@ class PluginInvokeModerationApi(Resource):
|
|||||||
|
|
||||||
@inner_api_ns.route("/invoke/tool")
|
@inner_api_ns.route("/invoke/tool")
|
||||||
class PluginInvokeToolApi(Resource):
|
class PluginInvokeToolApi(Resource):
|
||||||
@get_user_tenant
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
@plugin_data(payload_type=RequestInvokeTool)
|
@plugin_data(payload_type=RequestInvokeTool)
|
||||||
@inner_api_ns.doc("plugin_invoke_tool")
|
@inner_api_ns.doc("plugin_invoke_tool")
|
||||||
@inner_api_ns.doc(description="Invoke tools through plugin interface")
|
@inner_api_ns.doc(description="Invoke tools through plugin interface")
|
||||||
@ -252,9 +252,9 @@ class PluginInvokeToolApi(Resource):
|
|||||||
|
|
||||||
@inner_api_ns.route("/invoke/parameter-extractor")
|
@inner_api_ns.route("/invoke/parameter-extractor")
|
||||||
class PluginInvokeParameterExtractorNodeApi(Resource):
|
class PluginInvokeParameterExtractorNodeApi(Resource):
|
||||||
@get_user_tenant
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
|
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
|
||||||
@inner_api_ns.doc("plugin_invoke_parameter_extractor")
|
@inner_api_ns.doc("plugin_invoke_parameter_extractor")
|
||||||
@inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface")
|
@inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface")
|
||||||
@ -285,9 +285,9 @@ class PluginInvokeParameterExtractorNodeApi(Resource):
|
|||||||
|
|
||||||
@inner_api_ns.route("/invoke/question-classifier")
|
@inner_api_ns.route("/invoke/question-classifier")
|
||||||
class PluginInvokeQuestionClassifierNodeApi(Resource):
|
class PluginInvokeQuestionClassifierNodeApi(Resource):
|
||||||
@get_user_tenant
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
|
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
|
||||||
@inner_api_ns.doc("plugin_invoke_question_classifier")
|
@inner_api_ns.doc("plugin_invoke_question_classifier")
|
||||||
@inner_api_ns.doc(description="Invoke question classifier node through plugin interface")
|
@inner_api_ns.doc(description="Invoke question classifier node through plugin interface")
|
||||||
@ -318,9 +318,9 @@ class PluginInvokeQuestionClassifierNodeApi(Resource):
|
|||||||
|
|
||||||
@inner_api_ns.route("/invoke/app")
|
@inner_api_ns.route("/invoke/app")
|
||||||
class PluginInvokeAppApi(Resource):
|
class PluginInvokeAppApi(Resource):
|
||||||
@get_user_tenant
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
@plugin_data(payload_type=RequestInvokeApp)
|
@plugin_data(payload_type=RequestInvokeApp)
|
||||||
@inner_api_ns.doc("plugin_invoke_app")
|
@inner_api_ns.doc("plugin_invoke_app")
|
||||||
@inner_api_ns.doc(description="Invoke application through plugin interface")
|
@inner_api_ns.doc(description="Invoke application through plugin interface")
|
||||||
@ -348,9 +348,9 @@ class PluginInvokeAppApi(Resource):
|
|||||||
|
|
||||||
@inner_api_ns.route("/invoke/encrypt")
|
@inner_api_ns.route("/invoke/encrypt")
|
||||||
class PluginInvokeEncryptApi(Resource):
|
class PluginInvokeEncryptApi(Resource):
|
||||||
@get_user_tenant
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
@plugin_data(payload_type=RequestInvokeEncrypt)
|
@plugin_data(payload_type=RequestInvokeEncrypt)
|
||||||
@inner_api_ns.doc("plugin_invoke_encrypt")
|
@inner_api_ns.doc("plugin_invoke_encrypt")
|
||||||
@inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface")
|
@inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface")
|
||||||
@ -375,9 +375,9 @@ class PluginInvokeEncryptApi(Resource):
|
|||||||
|
|
||||||
@inner_api_ns.route("/invoke/summary")
|
@inner_api_ns.route("/invoke/summary")
|
||||||
class PluginInvokeSummaryApi(Resource):
|
class PluginInvokeSummaryApi(Resource):
|
||||||
@get_user_tenant
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
@plugin_data(payload_type=RequestInvokeSummary)
|
@plugin_data(payload_type=RequestInvokeSummary)
|
||||||
@inner_api_ns.doc("plugin_invoke_summary")
|
@inner_api_ns.doc("plugin_invoke_summary")
|
||||||
@inner_api_ns.doc(description="Invoke summary functionality through plugin interface")
|
@inner_api_ns.doc(description="Invoke summary functionality through plugin interface")
|
||||||
@ -405,9 +405,9 @@ class PluginInvokeSummaryApi(Resource):
|
|||||||
|
|
||||||
@inner_api_ns.route("/upload/file/request")
|
@inner_api_ns.route("/upload/file/request")
|
||||||
class PluginUploadFileRequestApi(Resource):
|
class PluginUploadFileRequestApi(Resource):
|
||||||
@get_user_tenant
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
@plugin_data(payload_type=RequestRequestUploadFile)
|
@plugin_data(payload_type=RequestRequestUploadFile)
|
||||||
@inner_api_ns.doc("plugin_upload_file_request")
|
@inner_api_ns.doc("plugin_upload_file_request")
|
||||||
@inner_api_ns.doc(description="Request signed URL for file upload through plugin interface")
|
@inner_api_ns.doc(description="Request signed URL for file upload through plugin interface")
|
||||||
@ -426,9 +426,9 @@ class PluginUploadFileRequestApi(Resource):
|
|||||||
|
|
||||||
@inner_api_ns.route("/fetch/app/info")
|
@inner_api_ns.route("/fetch/app/info")
|
||||||
class PluginFetchAppInfoApi(Resource):
|
class PluginFetchAppInfoApi(Resource):
|
||||||
@get_user_tenant
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
@plugin_data(payload_type=RequestFetchAppInfo)
|
@plugin_data(payload_type=RequestFetchAppInfo)
|
||||||
@inner_api_ns.doc("plugin_fetch_app_info")
|
@inner_api_ns.doc("plugin_fetch_app_info")
|
||||||
@inner_api_ns.doc(description="Fetch application information through plugin interface")
|
@inner_api_ns.doc(description="Fetch application information through plugin interface")
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional, ParamSpec, TypeVar, cast
|
from typing import Optional
|
||||||
|
|
||||||
from flask import current_app, request
|
from flask import current_app, request
|
||||||
from flask_login import user_logged_in
|
from flask_login import user_logged_in
|
||||||
@ -8,72 +8,65 @@ from flask_restx import reqparse
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.file.constants import DEFAULT_SERVICE_API_USER_ID
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import current_user
|
from libs.login import _get_user
|
||||||
from models.account import Tenant
|
from models.account import Account, Tenant
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
from services.account_service import AccountService
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
|
||||||
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
|
||||||
"""
|
|
||||||
Get current user
|
|
||||||
|
|
||||||
NOTE: user_id is not trusted, it could be maliciously set to any value.
|
|
||||||
As a result, it could only be considered as an end user id.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
if not user_id:
|
if not user_id:
|
||||||
user_id = DEFAULT_SERVICE_API_USER_ID
|
user_id = "DEFAULT-USER"
|
||||||
|
|
||||||
user_model = (
|
|
||||||
session.query(EndUser)
|
|
||||||
.where(
|
|
||||||
EndUser.session_id == user_id,
|
|
||||||
EndUser.tenant_id == tenant_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not user_model:
|
|
||||||
user_model = EndUser(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
type="service_api",
|
|
||||||
is_anonymous=user_id == DEFAULT_SERVICE_API_USER_ID,
|
|
||||||
session_id=user_id,
|
|
||||||
)
|
|
||||||
session.add(user_model)
|
|
||||||
session.commit()
|
|
||||||
session.refresh(user_model)
|
|
||||||
|
|
||||||
|
if user_id == "DEFAULT-USER":
|
||||||
|
user_model = session.query(EndUser).where(EndUser.session_id == "DEFAULT-USER").first()
|
||||||
|
if not user_model:
|
||||||
|
user_model = EndUser(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
type="service_api",
|
||||||
|
is_anonymous=True if user_id == "DEFAULT-USER" else False,
|
||||||
|
session_id=user_id,
|
||||||
|
)
|
||||||
|
session.add(user_model)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(user_model)
|
||||||
|
else:
|
||||||
|
user_model = AccountService.load_user(user_id)
|
||||||
|
if not user_model:
|
||||||
|
user_model = session.query(EndUser).where(EndUser.id == user_id).first()
|
||||||
|
if not user_model:
|
||||||
|
raise ValueError("user not found")
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError("user not found")
|
raise ValueError("user not found")
|
||||||
|
|
||||||
return user_model
|
return user_model
|
||||||
|
|
||||||
|
|
||||||
def get_user_tenant(view: Optional[Callable[P, R]] = None):
|
def get_user_tenant(view: Optional[Callable] = None):
|
||||||
def decorator(view_func: Callable[P, R]):
|
def decorator(view_func):
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
def decorated_view(*args, **kwargs):
|
||||||
# fetch json body
|
# fetch json body
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("tenant_id", type=str, required=True, location="json")
|
parser.add_argument("tenant_id", type=str, required=True, location="json")
|
||||||
parser.add_argument("user_id", type=str, required=True, location="json")
|
parser.add_argument("user_id", type=str, required=True, location="json")
|
||||||
|
|
||||||
p = parser.parse_args()
|
kwargs = parser.parse_args()
|
||||||
|
|
||||||
user_id = cast(str, p.get("user_id"))
|
user_id = kwargs.get("user_id")
|
||||||
tenant_id = cast(str, p.get("tenant_id"))
|
tenant_id = kwargs.get("tenant_id")
|
||||||
|
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
raise ValueError("tenant_id is required")
|
raise ValueError("tenant_id is required")
|
||||||
|
|
||||||
if not user_id:
|
if not user_id:
|
||||||
user_id = DEFAULT_SERVICE_API_USER_ID
|
user_id = "DEFAULT-USER"
|
||||||
|
|
||||||
|
del kwargs["tenant_id"]
|
||||||
|
del kwargs["user_id"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tenant_model = (
|
tenant_model = (
|
||||||
@ -95,7 +88,7 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None):
|
|||||||
kwargs["user_model"] = user
|
kwargs["user_model"] = user
|
||||||
|
|
||||||
current_app.login_manager._update_request_context_with_user(user) # type: ignore
|
current_app.login_manager._update_request_context_with_user(user) # type: ignore
|
||||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
|
||||||
|
|
||||||
return view_func(*args, **kwargs)
|
return view_func(*args, **kwargs)
|
||||||
|
|
||||||
@ -107,9 +100,9 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None):
|
|||||||
return decorator(view)
|
return decorator(view)
|
||||||
|
|
||||||
|
|
||||||
def plugin_data(view: Optional[Callable[P, R]] = None, *, payload_type: type[BaseModel]):
|
def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
|
||||||
def decorator(view_func: Callable[P, R]):
|
def decorator(view_func):
|
||||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
def decorated_view(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@ -46,9 +46,9 @@ def enterprise_inner_api_only(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def enterprise_inner_api_user_auth(view: Callable[P, R]):
|
def enterprise_inner_api_user_auth(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args, **kwargs):
|
||||||
if not dify_config.INNER_API:
|
if not dify_config.INNER_API:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -10,10 +10,11 @@ api = ExternalApi(
|
|||||||
version="1.0",
|
version="1.0",
|
||||||
title="MCP API",
|
title="MCP API",
|
||||||
description="API for Model Context Protocol operations",
|
description="API for Model Context Protocol operations",
|
||||||
|
doc="/docs", # Enable Swagger UI at /mcp/docs
|
||||||
)
|
)
|
||||||
|
|
||||||
mcp_ns = Namespace("mcp", description="MCP operations", path="/")
|
mcp_ns = Namespace("mcp", description="MCP operations", path="/")
|
||||||
|
|
||||||
from . import mcp # pyright: ignore[reportUnusedImport]
|
from . import mcp
|
||||||
|
|
||||||
api.add_namespace(mcp_ns)
|
api.add_namespace(mcp_ns)
|
||||||
|
|||||||
@ -99,7 +99,7 @@ class MCPAppApi(Resource):
|
|||||||
|
|
||||||
return mcp_server, app
|
return mcp_server, app
|
||||||
|
|
||||||
def _validate_server_status(self, mcp_server: AppMCPServer):
|
def _validate_server_status(self, mcp_server: AppMCPServer) -> None:
|
||||||
"""Validate MCP server status"""
|
"""Validate MCP server status"""
|
||||||
if mcp_server.status != AppMCPServerStatus.ACTIVE:
|
if mcp_server.status != AppMCPServerStatus.ACTIVE:
|
||||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")
|
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")
|
||||||
|
|||||||
@ -10,31 +10,14 @@ api = ExternalApi(
|
|||||||
version="1.0",
|
version="1.0",
|
||||||
title="Service API",
|
title="Service API",
|
||||||
description="API for application services",
|
description="API for application services",
|
||||||
|
doc="/docs", # Enable Swagger UI at /v1/docs
|
||||||
)
|
)
|
||||||
|
|
||||||
service_api_ns = Namespace("service_api", description="Service operations", path="/")
|
service_api_ns = Namespace("service_api", description="Service operations", path="/")
|
||||||
|
|
||||||
from . import index # pyright: ignore[reportUnusedImport]
|
from . import index
|
||||||
from .app import (
|
from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow
|
||||||
annotation, # pyright: ignore[reportUnusedImport]
|
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
|
||||||
app, # pyright: ignore[reportUnusedImport]
|
from .workspace import models
|
||||||
audio, # pyright: ignore[reportUnusedImport]
|
|
||||||
completion, # pyright: ignore[reportUnusedImport]
|
|
||||||
conversation, # pyright: ignore[reportUnusedImport]
|
|
||||||
file, # pyright: ignore[reportUnusedImport]
|
|
||||||
file_preview, # pyright: ignore[reportUnusedImport]
|
|
||||||
message, # pyright: ignore[reportUnusedImport]
|
|
||||||
site, # pyright: ignore[reportUnusedImport]
|
|
||||||
workflow, # pyright: ignore[reportUnusedImport]
|
|
||||||
)
|
|
||||||
from .dataset import (
|
|
||||||
dataset, # pyright: ignore[reportUnusedImport]
|
|
||||||
document, # pyright: ignore[reportUnusedImport]
|
|
||||||
hit_testing, # pyright: ignore[reportUnusedImport]
|
|
||||||
metadata, # pyright: ignore[reportUnusedImport]
|
|
||||||
segment, # pyright: ignore[reportUnusedImport]
|
|
||||||
upload_file, # pyright: ignore[reportUnusedImport]
|
|
||||||
)
|
|
||||||
from .workspace import models # pyright: ignore[reportUnusedImport]
|
|
||||||
|
|
||||||
api.add_namespace(service_api_ns)
|
api.add_namespace(service_api_ns)
|
||||||
|
|||||||
@ -165,7 +165,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||||||
def put(self, app_model: App, annotation_id):
|
def put(self, app_model: App, annotation_id):
|
||||||
"""Update an existing annotation."""
|
"""Update an existing annotation."""
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
annotation_id = str(annotation_id)
|
annotation_id = str(annotation_id)
|
||||||
@ -189,7 +189,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||||||
"""Delete an annotation."""
|
"""Delete an annotation."""
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
|
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
annotation_id = str(annotation_id)
|
annotation_id = str(annotation_id)
|
||||||
|
|||||||
@ -55,7 +55,7 @@ class AudioApi(Resource):
|
|||||||
file = request.files["file"]
|
file = request.files["file"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.id)
|
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from flask_restx._http import HTTPStatus
|
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import BadRequest, NotFound
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
@ -122,7 +121,7 @@ class ConversationDetailApi(Resource):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||||
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT)
|
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=204)
|
||||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||||
"""Delete a specific conversation."""
|
"""Delete a specific conversation."""
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
|
|||||||
@ -59,7 +59,7 @@ class FilePreviewApi(Resource):
|
|||||||
args = file_preview_parser.parse_args()
|
args = file_preview_parser.parse_args()
|
||||||
|
|
||||||
# Validate file ownership and get file objects
|
# Validate file ownership and get file objects
|
||||||
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
message_file, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||||
|
|
||||||
# Get file content generator
|
# Get file content generator
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -559,7 +559,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||||||
def post(self, _, dataset_id):
|
def post(self, _, dataset_id):
|
||||||
"""Add a knowledge type tag."""
|
"""Add a knowledge type tag."""
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
args = tag_create_parser.parse_args()
|
args = tag_create_parser.parse_args()
|
||||||
@ -583,7 +583,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||||||
@validate_dataset_token
|
@validate_dataset_token
|
||||||
def patch(self, _, dataset_id):
|
def patch(self, _, dataset_id):
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
args = tag_update_parser.parse_args()
|
args = tag_update_parser.parse_args()
|
||||||
@ -610,7 +610,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||||||
def delete(self, _, dataset_id):
|
def delete(self, _, dataset_id):
|
||||||
"""Delete a knowledge type tag."""
|
"""Delete a knowledge type tag."""
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
args = tag_delete_parser.parse_args()
|
args = tag_delete_parser.parse_args()
|
||||||
TagService.delete_tag(args.get("tag_id"))
|
TagService.delete_tag(args.get("tag_id"))
|
||||||
@ -634,7 +634,7 @@ class DatasetTagBindingApi(DatasetApiResource):
|
|||||||
def post(self, _, dataset_id):
|
def post(self, _, dataset_id):
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
args = tag_binding_parser.parse_args()
|
args = tag_binding_parser.parse_args()
|
||||||
@ -660,7 +660,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
|||||||
def post(self, _, dataset_id):
|
def post(self, _, dataset_id):
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
args = tag_unbinding_parser.parse_args()
|
args = tag_unbinding_parser.parse_args()
|
||||||
|
|||||||
@ -30,7 +30,6 @@ from extensions.ext_database import db
|
|||||||
from fields.document_fields import document_fields, document_status_fields
|
from fields.document_fields import document_fields, document_status_fields
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models.dataset import Dataset, Document, DocumentSegment
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
from models.model import EndUser
|
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
@ -299,9 +298,6 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||||||
if not file.filename:
|
if not file.filename:
|
||||||
raise FilenameNotExistsError
|
raise FilenameNotExistsError
|
||||||
|
|
||||||
if not isinstance(current_user, EndUser):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
|
|
||||||
upload_file = FileService.upload_file(
|
upload_file = FileService.upload_file(
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
content=file.read(),
|
content=file.read(),
|
||||||
@ -391,8 +387,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||||||
raise FilenameNotExistsError
|
raise FilenameNotExistsError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, EndUser):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
upload_file = FileService.upload_file(
|
upload_file = FileService.upload_file(
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
content=file.read(),
|
content=file.read(),
|
||||||
@ -416,7 +410,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||||||
DocumentService.document_create_args_validate(knowledge_config)
|
DocumentService.document_create_args_validate(knowledge_config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
knowledge_config=knowledge_config,
|
knowledge_config=knowledge_config,
|
||||||
account=dataset.created_by_account,
|
account=dataset.created_by_account,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user