mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 03:59:30 +08:00
Compare commits
10 Commits
docker-env
...
pinecone
| Author | SHA1 | Date | |
|---|---|---|---|
| 594906c1ff | |||
| 80f8245f2e | |||
| a12b437c16 | |||
| 12de554313 | |||
| 1f36c0c1c5 | |||
| 8b9297563c | |||
| 1cbe9eedb6 | |||
| 90fc5a1f12 | |||
| 41dfdf1ac0 | |||
| dd7de74aa6 |
15
.github/workflows/api-tests.yml
vendored
15
.github/workflows/api-tests.yml
vendored
@ -42,7 +42,11 @@ jobs:
|
||||
- name: Run Unit tests
|
||||
run: |
|
||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run ty check
|
||||
run: |
|
||||
cd api
|
||||
uv add --dev ty
|
||||
uv run ty check || true
|
||||
- name: Run pyrefly check
|
||||
run: |
|
||||
cd api
|
||||
@ -62,6 +66,15 @@ jobs:
|
||||
- name: Run dify config tests
|
||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: MyPy Cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: api/.mypy_cache
|
||||
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/uv.lock') }}
|
||||
|
||||
- name: Run MyPy Checks
|
||||
run: dev/mypy-check
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
|
||||
12
.github/workflows/style.yml
vendored
12
.github/workflows/style.yml
vendored
@ -44,14 +44,6 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run Basedpyright Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: dev/basedpyright-check
|
||||
|
||||
- name: Run Mypy Type Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
@ -97,9 +89,7 @@ jobs:
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: |
|
||||
pnpm run lint
|
||||
pnpm run eslint
|
||||
run: pnpm run lint
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
|
||||
@ -67,22 +67,12 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
||||
|
||||
- name: Generate i18n type definitions
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run gen:i18n-types
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: Update i18n files and type definitions based on en-US changes
|
||||
title: 'chore: translate i18n files and update type definitions'
|
||||
body: |
|
||||
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
|
||||
|
||||
**Changes included:**
|
||||
- Updated translation files for all locales
|
||||
- Regenerated TypeScript type definitions for type safety
|
||||
commit-message: Update i18n files based on en-US changes
|
||||
title: 'chore: translate i18n files'
|
||||
body: This PR was automatically created to update i18n files based on changes in en-US locale.
|
||||
branch: chore/automated-i18n-updates
|
||||
|
||||
5
.github/workflows/web-tests.yml
vendored
5
.github/workflows/web-tests.yml
vendored
@ -47,11 +47,6 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Check i18n types synchronization
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run check:i18n-types
|
||||
|
||||
- name: Run tests
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
|
||||
13
.gitignore
vendored
13
.gitignore
vendored
@ -123,12 +123,10 @@ venv.bak/
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# type checking
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
pyrightconfig.json
|
||||
!api/pyrightconfig.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
@ -197,8 +195,8 @@ sdks/python-client/dify_client.egg-info
|
||||
.vscode/*
|
||||
!.vscode/launch.json.template
|
||||
!.vscode/README.md
|
||||
pyrightconfig.json
|
||||
api/.vscode
|
||||
web/.vscode
|
||||
# vscode Code History Extension
|
||||
.history
|
||||
|
||||
@ -216,13 +214,6 @@ mise.toml
|
||||
# Next.js build output
|
||||
.next/
|
||||
|
||||
# PWA generated files
|
||||
web/public/sw.js
|
||||
web/public/sw.js.map
|
||||
web/public/workbox-*.js
|
||||
web/public/workbox-*.js.map
|
||||
web/public/fallback-*.js
|
||||
|
||||
# AI Assistant
|
||||
.roo/
|
||||
api/.env.backup
|
||||
|
||||
@ -32,7 +32,7 @@ uv run --project api pytest tests/integration_tests/ # Integration tests
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run --project api ruff check --fix ./ # Fix linting issues
|
||||
uv run --project api ruff format ./ # Format code
|
||||
uv run --directory api basedpyright # Type checking
|
||||
uv run --project api mypy . # Type checking
|
||||
```
|
||||
|
||||
### Frontend (Web)
|
||||
|
||||
60
Makefile
60
Makefile
@ -4,48 +4,6 @@ WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
|
||||
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
|
||||
VERSION=latest
|
||||
|
||||
# Backend Development Environment Setup
|
||||
.PHONY: dev-setup prepare-docker prepare-web prepare-api
|
||||
|
||||
# Default dev setup target
|
||||
dev-setup: prepare-docker prepare-web prepare-api
|
||||
@echo "✅ Backend development environment setup complete!"
|
||||
|
||||
# Step 1: Prepare Docker middleware
|
||||
prepare-docker:
|
||||
@echo "🐳 Setting up Docker middleware..."
|
||||
@cp -n docker/middleware.env.example docker/middleware.env 2>/dev/null || echo "Docker middleware.env already exists"
|
||||
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev up -d
|
||||
@echo "✅ Docker middleware started"
|
||||
|
||||
# Step 2: Prepare web environment
|
||||
prepare-web:
|
||||
@echo "🌐 Setting up web environment..."
|
||||
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
||||
@cd web && pnpm install
|
||||
@cd web && pnpm build
|
||||
@echo "✅ Web environment prepared (not started)"
|
||||
|
||||
# Step 3: Prepare API environment
|
||||
prepare-api:
|
||||
@echo "🔧 Setting up API environment..."
|
||||
@cp -n api/.env.example api/.env 2>/dev/null || echo "API .env already exists"
|
||||
@cd api && uv sync --dev
|
||||
@cd api && uv run flask db upgrade
|
||||
@echo "✅ API environment prepared (not started)"
|
||||
|
||||
# Clean dev environment
|
||||
dev-clean:
|
||||
@echo "⚠️ Stopping Docker containers..."
|
||||
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev down
|
||||
@echo "🗑️ Removing volumes..."
|
||||
@rm -rf docker/volumes/db
|
||||
@rm -rf docker/volumes/redis
|
||||
@rm -rf docker/volumes/plugin_daemon
|
||||
@rm -rf docker/volumes/weaviate
|
||||
@rm -rf api/storage
|
||||
@echo "✅ Cleanup complete"
|
||||
|
||||
# Build Docker images
|
||||
build-web:
|
||||
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
||||
@ -81,21 +39,5 @@ build-push-web: build-web push-web
|
||||
build-push-all: build-all push-all
|
||||
@echo "All Docker images have been built and pushed."
|
||||
|
||||
# Help target
|
||||
help:
|
||||
@echo "Development Setup Targets:"
|
||||
@echo " make dev-setup - Run all setup steps for backend dev environment"
|
||||
@echo " make prepare-docker - Set up Docker middleware"
|
||||
@echo " make prepare-web - Set up web environment"
|
||||
@echo " make prepare-api - Set up API environment"
|
||||
@echo " make dev-clean - Stop Docker middleware containers"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
@echo " make build-web - Build web Docker image"
|
||||
@echo " make build-api - Build API Docker image"
|
||||
@echo " make build-all - Build all Docker images"
|
||||
@echo " make push-all - Push all Docker images"
|
||||
@echo " make build-push-all - Build and push all Docker images"
|
||||
|
||||
# Phony targets
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all
|
||||
|
||||
@ -75,7 +75,6 @@ DB_PASSWORD=difyai123456
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_DATABASE=dify
|
||||
SQLALCHEMY_POOL_PRE_PING=true
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
@ -157,7 +156,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
|
||||
# Vector database configuration
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `pinecone`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
@ -362,6 +361,17 @@ PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
|
||||
|
||||
|
||||
# Pinecone configuration, only available when VECTOR_STORE is `pinecone`
|
||||
PINECONE_API_KEY=your-pinecone-api-key
|
||||
PINECONE_ENVIRONMENT=your-pinecone-environment
|
||||
PINECONE_INDEX_NAME=dify-index
|
||||
PINECONE_CLIENT_TIMEOUT=30
|
||||
PINECONE_BATCH_SIZE=100
|
||||
PINECONE_METRIC=cosine
|
||||
PINECONE_PODS=1
|
||||
PINECONE_POD_TYPE=s1
|
||||
|
||||
# Mail configuration, support: resend, smtp, sendgrid
|
||||
MAIL_TYPE=
|
||||
# If using SendGrid, use the 'from' field for authentication if necessary.
|
||||
|
||||
@ -108,5 +108,5 @@ uv run celery -A app.celery beat
|
||||
../dev/reformat # Run all formatters and linters
|
||||
uv run ruff check --fix ./ # Fix linting issues
|
||||
uv run ruff format ./ # Format code
|
||||
uv run basedpyright . # Type checking
|
||||
uv run mypy . # Type checking
|
||||
```
|
||||
|
||||
@ -25,9 +25,6 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
# add an unique identifier to each request
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
||||
_ = before_request
|
||||
|
||||
return dify_app
|
||||
|
||||
|
||||
|
||||
11
api/child_class.py
Normal file
11
api/child_class.py
Normal file
@ -0,0 +1,11 @@
|
||||
from tests.integration_tests.utils.parent_class import ParentClass
|
||||
|
||||
|
||||
class ChildClass(ParentClass):
|
||||
"""Test child class for module import helper tests"""
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
|
||||
def get_name(self):
|
||||
return f"Child: {self.name}"
|
||||
@ -571,7 +571,7 @@ def old_metadata_migration():
|
||||
for document in documents:
|
||||
if document.doc_metadata:
|
||||
doc_metadata = document.doc_metadata
|
||||
for key in doc_metadata:
|
||||
for key, value in doc_metadata.items():
|
||||
for field in BuiltInField:
|
||||
if field.value == key:
|
||||
break
|
||||
|
||||
@ -35,6 +35,7 @@ from .vdb.opensearch_config import OpenSearchConfig
|
||||
from .vdb.oracle_config import OracleConfig
|
||||
from .vdb.pgvector_config import PGVectorConfig
|
||||
from .vdb.pgvectors_config import PGVectoRSConfig
|
||||
from .vdb.pinecone_config import PineconeConfig
|
||||
from .vdb.qdrant_config import QdrantConfig
|
||||
from .vdb.relyt_config import RelytConfig
|
||||
from .vdb.tablestore_config import TableStoreConfig
|
||||
@ -300,7 +301,8 @@ class DatasetQueueMonitorConfig(BaseSettings):
|
||||
|
||||
class MiddlewareConfig(
|
||||
# place the configs in alphabet order
|
||||
CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
|
||||
CeleryConfig,
|
||||
DatabaseConfig,
|
||||
KeywordStoreConfig,
|
||||
RedisConfig,
|
||||
# configs of storage and storage providers
|
||||
@ -330,6 +332,7 @@ class MiddlewareConfig(
|
||||
PGVectorConfig,
|
||||
VastbaseVectorConfig,
|
||||
PGVectoRSConfig,
|
||||
PineconeConfig,
|
||||
QdrantConfig,
|
||||
RelytConfig,
|
||||
TencentVectorDBConfig,
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ClickzettaConfig(BaseSettings):
|
||||
class ClickzettaConfig(BaseModel):
|
||||
"""
|
||||
Clickzetta Lakehouse vector database configuration
|
||||
"""
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MatrixoneConfig(BaseSettings):
|
||||
class MatrixoneConfig(BaseModel):
|
||||
"""Matrixone vector database configuration."""
|
||||
|
||||
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")
|
||||
|
||||
41
api/configs/middleware/vdb/pinecone_config.py
Normal file
41
api/configs/middleware/vdb/pinecone_config.py
Normal file
@ -0,0 +1,41 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class PineconeConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Pinecone vector database
|
||||
"""
|
||||
|
||||
PINECONE_API_KEY: Optional[str] = Field(
|
||||
description="API key for authenticating with Pinecone service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PINECONE_ENVIRONMENT: Optional[str] = Field(
|
||||
description="Pinecone environment (e.g., 'us-west1-gcp', 'us-east-1-aws')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PINECONE_INDEX_NAME: Optional[str] = Field(
|
||||
description="Default Pinecone index name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PINECONE_CLIENT_TIMEOUT: PositiveInt = Field(
|
||||
description="Timeout in seconds for Pinecone client operations (default is 30 seconds)",
|
||||
default=30,
|
||||
)
|
||||
|
||||
PINECONE_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Batch size for Pinecone operations (default is 100)",
|
||||
default=100,
|
||||
)
|
||||
|
||||
PINECONE_METRIC: str = Field(
|
||||
description="Distance metric for Pinecone index (cosine, euclidean, dotproduct)",
|
||||
default="cosine",
|
||||
)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from pydantic import Field
|
||||
|
||||
from configs.packaging.pyproject import PyProjectTomlConfig
|
||||
from configs.packaging.pyproject import PyProjectConfig, PyProjectTomlConfig
|
||||
|
||||
|
||||
class PackagingInfo(PyProjectTomlConfig):
|
||||
|
||||
@ -4,9 +4,8 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable, Mapping
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .python_3x import http_request, makedirs_wrapper
|
||||
from .utils import (
|
||||
@ -26,13 +25,13 @@ logger = logging.getLogger(__name__)
|
||||
class ApolloClient:
|
||||
def __init__(
|
||||
self,
|
||||
config_url: str,
|
||||
app_id: str,
|
||||
cluster: str = "default",
|
||||
secret: str = "",
|
||||
start_hot_update: bool = True,
|
||||
change_listener: Callable[[str, str, str, Any], None] | None = None,
|
||||
_notification_map: dict[str, int] | None = None,
|
||||
config_url,
|
||||
app_id,
|
||||
cluster="default",
|
||||
secret="",
|
||||
start_hot_update=True,
|
||||
change_listener=None,
|
||||
_notification_map=None,
|
||||
):
|
||||
# Core routing parameters
|
||||
self.config_url = config_url
|
||||
@ -48,17 +47,17 @@ class ApolloClient:
|
||||
# Private control variables
|
||||
self._cycle_time = 5
|
||||
self._stopping = False
|
||||
self._cache: dict[str, dict[str, Any]] = {}
|
||||
self._no_key: dict[str, str] = {}
|
||||
self._hash: dict[str, str] = {}
|
||||
self._cache = {}
|
||||
self._no_key = {}
|
||||
self._hash = {}
|
||||
self._pull_timeout = 75
|
||||
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
|
||||
self._long_poll_thread: threading.Thread | None = None
|
||||
self._long_poll_thread = None
|
||||
self._change_listener = change_listener # "add" "delete" "update"
|
||||
if _notification_map is None:
|
||||
_notification_map = {"application": -1}
|
||||
self._notification_map = _notification_map
|
||||
self.last_release_key: str | None = None
|
||||
self.last_release_key = None
|
||||
# Private startup method
|
||||
self._path_checker()
|
||||
if start_hot_update:
|
||||
@ -69,7 +68,7 @@ class ApolloClient:
|
||||
heartbeat.daemon = True
|
||||
heartbeat.start()
|
||||
|
||||
def get_json_from_net(self, namespace: str = "application") -> dict[str, Any] | None:
|
||||
def get_json_from_net(self, namespace="application"):
|
||||
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
|
||||
self.config_url, self.app_id, self.cluster, namespace, "", self.ip
|
||||
)
|
||||
@ -89,7 +88,7 @@ class ApolloClient:
|
||||
logger.exception("an error occurred in get_json_from_net")
|
||||
return None
|
||||
|
||||
def get_value(self, key: str, default_val: Any = None, namespace: str = "application") -> Any:
|
||||
def get_value(self, key, default_val=None, namespace="application"):
|
||||
try:
|
||||
# read memory configuration
|
||||
namespace_cache = self._cache.get(namespace)
|
||||
@ -105,8 +104,7 @@ class ApolloClient:
|
||||
namespace_data = self.get_json_from_net(namespace)
|
||||
val = get_value_from_dict(namespace_data, key)
|
||||
if val is not None:
|
||||
if namespace_data is not None:
|
||||
self._update_cache_and_file(namespace_data, namespace)
|
||||
self._update_cache_and_file(namespace_data, namespace)
|
||||
return val
|
||||
|
||||
# read the file configuration
|
||||
@ -128,23 +126,23 @@ class ApolloClient:
|
||||
# to ensure the real-time correctness of the function call.
|
||||
# If the user does not have the same default val twice
|
||||
# and the default val is used here, there may be a problem.
|
||||
def _set_local_cache_none(self, namespace: str, key: str) -> None:
|
||||
def _set_local_cache_none(self, namespace, key):
|
||||
no_key = no_key_cache_key(namespace, key)
|
||||
self._no_key[no_key] = key
|
||||
|
||||
def _start_hot_update(self) -> None:
|
||||
def _start_hot_update(self):
|
||||
self._long_poll_thread = threading.Thread(target=self._listener)
|
||||
# When the asynchronous thread is started, the daemon thread will automatically exit
|
||||
# when the main thread is launched.
|
||||
self._long_poll_thread.daemon = True
|
||||
self._long_poll_thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
def stop(self):
|
||||
self._stopping = True
|
||||
logger.info("Stopping listener...")
|
||||
|
||||
# Call the set callback function, and if it is abnormal, try it out
|
||||
def _call_listener(self, namespace: str, old_kv: dict[str, Any] | None, new_kv: dict[str, Any] | None) -> None:
|
||||
def _call_listener(self, namespace, old_kv, new_kv):
|
||||
if self._change_listener is None:
|
||||
return
|
||||
if old_kv is None:
|
||||
@ -170,12 +168,12 @@ class ApolloClient:
|
||||
except BaseException as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
def _path_checker(self) -> None:
|
||||
def _path_checker(self):
|
||||
if not os.path.isdir(self._cache_file_path):
|
||||
makedirs_wrapper(self._cache_file_path)
|
||||
|
||||
# update the local cache and file cache
|
||||
def _update_cache_and_file(self, namespace_data: dict[str, Any], namespace: str = "application") -> None:
|
||||
def _update_cache_and_file(self, namespace_data, namespace="application"):
|
||||
# update the local cache
|
||||
self._cache[namespace] = namespace_data
|
||||
# update the file cache
|
||||
@ -189,7 +187,7 @@ class ApolloClient:
|
||||
self._hash[namespace] = new_hash
|
||||
|
||||
# get the configuration from the local file
|
||||
def _get_local_cache(self, namespace: str = "application") -> dict[str, Any]:
|
||||
def _get_local_cache(self, namespace="application"):
|
||||
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
|
||||
if os.path.isfile(cache_file_path):
|
||||
with open(cache_file_path) as f:
|
||||
@ -197,8 +195,8 @@ class ApolloClient:
|
||||
return result
|
||||
return {}
|
||||
|
||||
def _long_poll(self) -> None:
|
||||
notifications: list[dict[str, Any]] = []
|
||||
def _long_poll(self):
|
||||
notifications = []
|
||||
for key in self._cache:
|
||||
namespace_data = self._cache[key]
|
||||
notification_id = -1
|
||||
@ -238,7 +236,7 @@ class ApolloClient:
|
||||
except Exception as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
def _get_net_and_set_local(self, namespace: str, n_id: int, call_change: bool = False) -> None:
|
||||
def _get_net_and_set_local(self, namespace, n_id, call_change=False):
|
||||
namespace_data = self.get_json_from_net(namespace)
|
||||
if not namespace_data:
|
||||
return
|
||||
@ -250,7 +248,7 @@ class ApolloClient:
|
||||
new_kv = namespace_data.get(CONFIGURATIONS)
|
||||
self._call_listener(namespace, old_kv, new_kv)
|
||||
|
||||
def _listener(self) -> None:
|
||||
def _listener(self):
|
||||
logger.info("start long_poll")
|
||||
while not self._stopping:
|
||||
self._long_poll()
|
||||
@ -268,13 +266,13 @@ class ApolloClient:
|
||||
headers["Timestamp"] = time_unix_now
|
||||
return headers
|
||||
|
||||
def _heart_beat(self) -> None:
|
||||
def _heart_beat(self):
|
||||
while not self._stopping:
|
||||
for namespace in self._notification_map:
|
||||
self._do_heart_beat(namespace)
|
||||
time.sleep(60 * 10) # 10 minutes
|
||||
|
||||
def _do_heart_beat(self, namespace: str) -> None:
|
||||
def _do_heart_beat(self, namespace):
|
||||
url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}"
|
||||
try:
|
||||
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
||||
@ -294,7 +292,7 @@ class ApolloClient:
|
||||
logger.exception("an error occurred in _do_heart_beat")
|
||||
return None
|
||||
|
||||
def get_all_dicts(self, namespace: str) -> dict[str, Any] | None:
|
||||
def get_all_dicts(self, namespace):
|
||||
namespace_data = self._cache.get(namespace)
|
||||
if namespace_data is None:
|
||||
net_namespace_data = self.get_json_from_net(namespace)
|
||||
|
||||
@ -2,8 +2,6 @@ import logging
|
||||
import os
|
||||
import ssl
|
||||
import urllib.request
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from urllib import parse
|
||||
from urllib.error import HTTPError
|
||||
|
||||
@ -21,9 +19,9 @@ urllib.request.install_opener(opener)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def http_request(url: str, timeout: int | float, headers: Mapping[str, str] = {}) -> tuple[int, str | None]:
|
||||
def http_request(url, timeout, headers={}):
|
||||
try:
|
||||
request = urllib.request.Request(url, headers=dict(headers))
|
||||
request = urllib.request.Request(url, headers=headers)
|
||||
res = urllib.request.urlopen(request, timeout=timeout)
|
||||
body = res.read().decode("utf-8")
|
||||
return res.code, body
|
||||
@ -35,9 +33,9 @@ def http_request(url: str, timeout: int | float, headers: Mapping[str, str] = {}
|
||||
raise e
|
||||
|
||||
|
||||
def url_encode(params: dict[str, Any]) -> str:
|
||||
def url_encode(params):
|
||||
return parse.urlencode(params)
|
||||
|
||||
|
||||
def makedirs_wrapper(path: str) -> None:
|
||||
def makedirs_wrapper(path):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import hashlib
|
||||
import socket
|
||||
from typing import Any
|
||||
|
||||
from .python_3x import url_encode
|
||||
|
||||
@ -11,7 +10,7 @@ NAMESPACE_NAME = "namespaceName"
|
||||
|
||||
|
||||
# add timestamps uris and keys
|
||||
def signature(timestamp: str, uri: str, secret: str) -> str:
|
||||
def signature(timestamp, uri, secret):
|
||||
import base64
|
||||
import hmac
|
||||
|
||||
@ -20,16 +19,16 @@ def signature(timestamp: str, uri: str, secret: str) -> str:
|
||||
return base64.b64encode(hmac_code).decode()
|
||||
|
||||
|
||||
def url_encode_wrapper(params: dict[str, Any]) -> str:
|
||||
def url_encode_wrapper(params):
|
||||
return url_encode(params)
|
||||
|
||||
|
||||
def no_key_cache_key(namespace: str, key: str) -> str:
|
||||
def no_key_cache_key(namespace, key):
|
||||
return f"{namespace}{len(namespace)}{key}"
|
||||
|
||||
|
||||
# Returns whether the obtained value is obtained, and None if it does not
|
||||
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None:
|
||||
def get_value_from_dict(namespace_cache, key):
|
||||
if namespace_cache:
|
||||
kv_data = namespace_cache.get(CONFIGURATIONS)
|
||||
if kv_data is None:
|
||||
@ -39,7 +38,7 @@ def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any
|
||||
return None
|
||||
|
||||
|
||||
def init_ip() -> str:
|
||||
def init_ip():
|
||||
ip = ""
|
||||
s = None
|
||||
try:
|
||||
|
||||
@ -11,5 +11,5 @@ class RemoteSettingsSource:
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool):
|
||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
||||
return value
|
||||
|
||||
@ -11,16 +11,16 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
from configs.remote_settings_sources.base import RemoteSettingsSource
|
||||
|
||||
from .utils import parse_config
|
||||
from .utils import _parse_config
|
||||
|
||||
|
||||
class NacosSettingsSource(RemoteSettingsSource):
|
||||
def __init__(self, configs: Mapping[str, Any]):
|
||||
self.configs = configs
|
||||
self.remote_configs: dict[str, str] = {}
|
||||
self.remote_configs: dict[str, Any] = {}
|
||||
self.async_init()
|
||||
|
||||
def async_init(self) -> None:
|
||||
def async_init(self):
|
||||
data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties")
|
||||
group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify")
|
||||
tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "")
|
||||
@ -29,19 +29,22 @@ class NacosSettingsSource(RemoteSettingsSource):
|
||||
try:
|
||||
content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params)
|
||||
self.remote_configs = self._parse_config(content)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.exception("[get-access-token] exception occurred")
|
||||
raise
|
||||
|
||||
def _parse_config(self, content: str) -> dict[str, str]:
|
||||
def _parse_config(self, content: str) -> dict:
|
||||
if not content:
|
||||
return {}
|
||||
try:
|
||||
return parse_config(content)
|
||||
return _parse_config(self, content)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to parse config: {e}")
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
if not isinstance(self.remote_configs, dict):
|
||||
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
|
||||
|
||||
field_value = self.remote_configs.get(field_name)
|
||||
if field_value is None:
|
||||
return None, field_name, False
|
||||
|
||||
@ -17,26 +17,20 @@ class NacosHttpClient:
|
||||
self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY")
|
||||
self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY")
|
||||
self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848")
|
||||
self.token: str | None = None
|
||||
self.token = None
|
||||
self.token_ttl = 18000
|
||||
self.token_expire_time: float = 0
|
||||
|
||||
def http_request(
|
||||
self, url: str, method: str = "GET", headers: dict[str, str] | None = None, params: dict[str, str] | None = None
|
||||
) -> str:
|
||||
if headers is None:
|
||||
headers = {}
|
||||
if params is None:
|
||||
params = {}
|
||||
def http_request(self, url, method="GET", headers=None, params=None):
|
||||
try:
|
||||
self._inject_auth_info(headers, params)
|
||||
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.RequestException as e:
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"Request to Nacos failed: {e}"
|
||||
|
||||
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
|
||||
def _inject_auth_info(self, headers, params, module="config"):
|
||||
headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"})
|
||||
|
||||
if module == "login":
|
||||
@ -51,17 +45,16 @@ class NacosHttpClient:
|
||||
headers["timeStamp"] = ts
|
||||
if self.username and self.password:
|
||||
self.get_access_token(force_refresh=False)
|
||||
if self.token is not None:
|
||||
params["accessToken"] = self.token
|
||||
params["accessToken"] = self.token
|
||||
|
||||
def __do_sign(self, sign_str: str, sk: str) -> str:
|
||||
def __do_sign(self, sign_str, sk):
|
||||
return (
|
||||
base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest())
|
||||
.decode()
|
||||
.strip()
|
||||
)
|
||||
|
||||
def get_sign_str(self, group: str, tenant: str, ts: str) -> str:
|
||||
def get_sign_str(self, group, tenant, ts):
|
||||
sign_str = ""
|
||||
if tenant:
|
||||
sign_str = tenant + "+"
|
||||
@ -70,7 +63,7 @@ class NacosHttpClient:
|
||||
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
|
||||
return sign_str
|
||||
|
||||
def get_access_token(self, force_refresh: bool = False) -> str | None:
|
||||
def get_access_token(self, force_refresh=False):
|
||||
current_time = time.time()
|
||||
if self.token and not force_refresh and self.token_expire_time > current_time:
|
||||
return self.token
|
||||
@ -84,7 +77,6 @@ class NacosHttpClient:
|
||||
self.token = response_data.get("accessToken")
|
||||
self.token_ttl = response_data.get("tokenTtl", 18000)
|
||||
self.token_expire_time = current_time + self.token_ttl - 10
|
||||
return self.token
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.exception("[get-access-token] exception occur")
|
||||
raise
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
def parse_config(content: str) -> dict[str, str]:
|
||||
def _parse_config(self, content: str) -> dict[str, str]:
|
||||
config: dict[str, str] = {}
|
||||
if not content:
|
||||
return config
|
||||
|
||||
@ -19,7 +19,6 @@ language_timezone_mapping = {
|
||||
"fa-IR": "Asia/Tehran",
|
||||
"sl-SI": "Europe/Ljubljana",
|
||||
"th-TH": "Asia/Bangkok",
|
||||
"id-ID": "Asia/Jakarta",
|
||||
}
|
||||
|
||||
languages = list(language_timezone_mapping.keys())
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
@ -8,8 +6,6 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
@ -18,9 +14,9 @@ from extensions.ext_database import db
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
|
||||
def admin_required(view: Callable[P, R]):
|
||||
def admin_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
@ -134,19 +130,15 @@ class InsertExploreAppApi(Resource):
|
||||
app.is_public = False
|
||||
|
||||
with Session(db.engine) as session:
|
||||
installed_apps = (
|
||||
session.execute(
|
||||
select(InstalledApp).where(
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
||||
)
|
||||
installed_apps = session.execute(
|
||||
select(InstalledApp).where(
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
for installed_app in installed_apps:
|
||||
session.delete(installed_app)
|
||||
for installed_app in installed_apps:
|
||||
db.session.delete(installed_app)
|
||||
|
||||
db.session.delete(recommended_app)
|
||||
db.session.commit()
|
||||
|
||||
@ -84,10 +84,10 @@ class BaseApiKeyListResource(Resource):
|
||||
flask_restx.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
custom="max_keys_exceeded",
|
||||
code="max_keys_exceeded",
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
api_token = ApiToken()
|
||||
setattr(api_token, self.resource_id_field, resource_id)
|
||||
api_token.tenant_id = current_user.current_tenant_id
|
||||
|
||||
@ -237,14 +237,9 @@ class AppExportApi(Resource):
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||
parser.add_argument("workflow_id", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
return {
|
||||
"data": AppDslService.export_dsl(
|
||||
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
|
||||
)
|
||||
}
|
||||
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
|
||||
|
||||
|
||||
class AppNameApi(Resource):
|
||||
|
||||
@ -117,7 +117,7 @@ class CompletionConversationDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def delete(self, app_model, conversation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -207,7 +207,7 @@ class InstructionGenerationTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
def post(self) -> dict:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -526,7 +526,7 @@ class PublishedWorkflowApi(Resource):
|
||||
)
|
||||
|
||||
app_model.workflow_id = workflow.id
|
||||
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
|
||||
db.session.commit()
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
|
||||
@ -27,9 +27,7 @@ class WorkflowAppLogApi(Resource):
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument(
|
||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||
)
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import NoReturn
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
@ -29,7 +29,7 @@ from services.workflow_service import WorkflowService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment):
|
||||
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
@ -40,7 +40,7 @@ def _convert_values_to_json_serializable_object(value: Segment):
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||
value = variable.get_value()
|
||||
# create a copy of the value to avoid affecting the model cache.
|
||||
value = value.model_copy(deep=True)
|
||||
|
||||
@ -81,7 +81,7 @@ class OAuthDataSourceBinding(Resource):
|
||||
return {"error": "Invalid code"}, 400
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.HTTPError as e:
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
@ -104,7 +104,7 @@ class OAuthDataSourceSync(Resource):
|
||||
return {"error": "Invalid provider"}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except requests.HTTPError as e:
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
|
||||
@ -130,7 +130,7 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
except AccountRegisterError:
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
@ -162,7 +162,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
except AccountRegisterError:
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
@ -200,7 +200,7 @@ class EmailCodeLoginApi(Resource):
|
||||
AccountService.revoke_email_code_login_token(args["token"])
|
||||
try:
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
except AccountRegisterError:
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@ -223,7 +223,7 @@ class EmailCodeLoginApi(Resource):
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
except AccountRegisterError:
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
|
||||
@ -80,7 +80,7 @@ class OAuthCallback(Resource):
|
||||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.RequestException as e:
|
||||
except requests.exceptions.RequestException as e:
|
||||
error_text = e.response.text if e.response else str(e)
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
||||
from typing import cast
|
||||
|
||||
import flask_login
|
||||
from flask import jsonify, request
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
@ -16,14 +15,10 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
|
||||
|
||||
from .. import api
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||
def oauth_server_client_id_required(view):
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
@ -35,53 +30,43 @@ def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderA
|
||||
if not oauth_provider_app:
|
||||
raise NotFound("client_id is invalid")
|
||||
|
||||
return view(self, oauth_provider_app, *args, **kwargs)
|
||||
kwargs["oauth_provider_app"] = oauth_provider_app
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
|
||||
def oauth_server_access_token_required(view):
|
||||
@wraps(view)
|
||||
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
|
||||
if not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
def decorated(*args, **kwargs):
|
||||
oauth_provider_app = kwargs.get("oauth_provider_app")
|
||||
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
raise BadRequest("Invalid oauth_provider_app")
|
||||
|
||||
authorization_header = request.headers.get("Authorization")
|
||||
if not authorization_header:
|
||||
response = jsonify({"error": "Authorization header is required"})
|
||||
response.status_code = 401
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
raise BadRequest("Authorization header is required")
|
||||
|
||||
parts = authorization_header.strip().split(None, 1)
|
||||
parts = authorization_header.strip().split(" ")
|
||||
if len(parts) != 2:
|
||||
response = jsonify({"error": "Invalid Authorization header format"})
|
||||
response.status_code = 401
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
raise BadRequest("Invalid Authorization header format")
|
||||
|
||||
token_type = parts[0].strip()
|
||||
if token_type.lower() != "bearer":
|
||||
response = jsonify({"error": "token_type is invalid"})
|
||||
response.status_code = 401
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
raise BadRequest("token_type is invalid")
|
||||
|
||||
access_token = parts[1].strip()
|
||||
if not access_token:
|
||||
response = jsonify({"error": "access_token is required"})
|
||||
response.status_code = 401
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
raise BadRequest("access_token is required")
|
||||
|
||||
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
|
||||
if not account:
|
||||
response = jsonify({"error": "access_token or client_id is invalid"})
|
||||
response.status_code = 401
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
raise BadRequest("access_token or client_id is invalid")
|
||||
|
||||
return view(self, oauth_provider_app, account, *args, **kwargs)
|
||||
kwargs["account"] = account
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import Account
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@ -17,10 +17,9 @@ class Subscription(Resource):
|
||||
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
||||
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
args = parser.parse_args()
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
return BillingService.get_subscription(
|
||||
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
|
||||
)
|
||||
@ -32,9 +31,7 @@ class Invoices(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
assert current_user.current_tenant_id is not None
|
||||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
||||
|
||||
|
||||
|
||||
@ -10,7 +10,6 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
@ -215,7 +214,7 @@ class DataSourceNotionApi(Resource):
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
|
||||
@ -22,7 +22,6 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
@ -423,9 +422,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value,
|
||||
upload_file=file_detail,
|
||||
document_model=args["doc_form"],
|
||||
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args["info_list"]["data_source_type"] == "notion_import":
|
||||
@ -434,7 +431,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
@ -448,7 +445,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
website_info_list = args["info_list"]["website_info_list"]
|
||||
for url in website_info_list["urls"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
datasource_type="website_crawl",
|
||||
website_info={
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
@ -663,6 +660,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
| VectorType.PINECONE
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
@ -714,6 +712,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
| VectorType.PINECONE
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
|
||||
@ -40,7 +40,6 @@ from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import (
|
||||
@ -355,6 +354,9 @@ class DatasetInitApi(Resource):
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
@ -426,7 +428,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
|
||||
datasource_type="upload_file", upload_file=file, document_model=document.doc_form
|
||||
)
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
@ -475,8 +477,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
data_source_info = document.data_source_info_dict
|
||||
|
||||
if document.data_source_type == "upload_file":
|
||||
if not data_source_info:
|
||||
continue
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
@ -488,15 +488,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
|
||||
datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
elif document.data_source_type == "notion_import":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
@ -507,10 +505,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif document.data_source_type == "website_crawl":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
datasource_type="website_crawl",
|
||||
website_info={
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
|
||||
@ -61,6 +61,7 @@ class ConversationApi(InstalledAppResource):
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@ -43,8 +43,6 @@ class ExploreAppMetaApi(InstalledAppResource):
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Get app meta"""
|
||||
app_model = installed_app.app
|
||||
if not app_model:
|
||||
raise ValueError("App not found")
|
||||
return AppService().get_app_meta(app_model)
|
||||
|
||||
|
||||
|
||||
@ -35,8 +35,6 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
Run workflow
|
||||
"""
|
||||
app_model = installed_app.app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
@ -75,8 +73,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||
Stop workflow task
|
||||
"""
|
||||
app_model = installed_app.app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, Optional, ParamSpec, TypeVar
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
@ -15,15 +13,19 @@ from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
|
||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
def installed_app_required(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
if not kwargs.get("installed_app_id"):
|
||||
raise ValueError("missing installed_app_id in path parameters")
|
||||
|
||||
installed_app_id = kwargs.get("installed_app_id")
|
||||
installed_app_id = str(installed_app_id)
|
||||
|
||||
del kwargs["installed_app_id"]
|
||||
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
.where(
|
||||
@ -50,10 +52,10 @@ def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P],
|
||||
return decorator
|
||||
|
||||
|
||||
def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
|
||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
def user_allowed_to_access_app(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(installed_app: InstalledApp, *args, **kwargs):
|
||||
feature = FeatureService.get_system_features()
|
||||
if feature.webapp_auth.enabled:
|
||||
app_id = installed_app.app_id
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy.orm import Session
|
||||
@ -9,17 +7,14 @@ from werkzeug.exceptions import Forbidden
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def plugin_permission_required(
|
||||
install_required: bool = False,
|
||||
debug_required: bool = False,
|
||||
):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
|
||||
@ -67,7 +67,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -94,7 +94,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
@ -219,11 +219,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
config_from=args.get("config_from", ""),
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
)
|
||||
|
||||
if args.get("config_from", "") == "predefined-model":
|
||||
@ -267,7 +263,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -313,7 +309,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
)
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
@ -2,9 +2,7 @@ import contextlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, request
|
||||
from flask_login import current_user
|
||||
@ -21,13 +19,10 @@ from services.operation_service import OperationService
|
||||
|
||||
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def account_initialization_required(view: Callable[P, R]):
|
||||
def account_initialization_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
# check account initialization
|
||||
account = current_user
|
||||
|
||||
@ -39,9 +34,9 @@ def account_initialization_required(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_cloud(view: Callable[P, R]):
|
||||
def only_edition_cloud(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
if dify_config.EDITION != "CLOUD":
|
||||
abort(404)
|
||||
|
||||
@ -50,9 +45,9 @@ def only_edition_cloud(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_enterprise(view: Callable[P, R]):
|
||||
def only_edition_enterprise(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
abort(404)
|
||||
|
||||
@ -61,9 +56,9 @@ def only_edition_enterprise(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_self_hosted(view: Callable[P, R]):
|
||||
def only_edition_self_hosted(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
abort(404)
|
||||
|
||||
@ -72,9 +67,9 @@ def only_edition_self_hosted(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||
def cloud_edition_billing_enabled(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if not features.billing.enabled:
|
||||
abort(403, "Billing feature is not enabled.")
|
||||
@ -84,9 +79,9 @@ def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
@ -125,9 +120,9 @@ def cloud_edition_billing_resource_check(resource: str):
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
@ -147,9 +142,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check(resource: str):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
if resource == "knowledge":
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
@ -181,9 +176,9 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_utm_record(view: Callable[P, R]):
|
||||
def cloud_utm_record(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
with contextlib.suppress(Exception):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
@ -199,9 +194,9 @@ def cloud_utm_record(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def setup_required(view: Callable[P, R]):
|
||||
def setup_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
# check setup
|
||||
if (
|
||||
dify_config.EDITION == "SELF_HOSTED"
|
||||
@ -217,9 +212,9 @@ def setup_required(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_license_required(view: Callable[P, R]):
|
||||
def enterprise_license_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
settings = FeatureService.get_system_features()
|
||||
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
|
||||
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
|
||||
@ -229,9 +224,9 @@ def enterprise_license_required(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def email_password_login_enabled(view: Callable[P, R]):
|
||||
def email_password_login_enabled(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if features.enable_email_password_login:
|
||||
return view(*args, **kwargs)
|
||||
@ -242,9 +237,9 @@ def email_password_login_enabled(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def enable_change_email(view: Callable[P, R]):
|
||||
def enable_change_email(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if features.enable_change_email:
|
||||
return view(*args, **kwargs)
|
||||
@ -255,9 +250,9 @@ def enable_change_email(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
def is_allow_transfer_owner(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if features.is_allow_transfer_workspace:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -99,7 +99,7 @@ class MCPAppApi(Resource):
|
||||
|
||||
return mcp_server, app
|
||||
|
||||
def _validate_server_status(self, mcp_server: AppMCPServer):
|
||||
def _validate_server_status(self, mcp_server: AppMCPServer) -> None:
|
||||
"""Validate MCP server status"""
|
||||
if mcp_server.status != AppMCPServerStatus.ACTIVE:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")
|
||||
|
||||
@ -55,7 +55,7 @@ class AudioApi(Resource):
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.id)
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
|
||||
@ -59,7 +59,7 @@ class FilePreviewApi(Resource):
|
||||
args = file_preview_parser.parse_args()
|
||||
|
||||
# Validate file ownership and get file objects
|
||||
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
message_file, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
|
||||
# Get file content generator
|
||||
try:
|
||||
|
||||
@ -410,7 +410,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
|
||||
@ -440,7 +440,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate segment belongs to the specified document
|
||||
if str(segment.document_id) != str(document_id):
|
||||
if segment.document_id != document_id:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
# check child chunk
|
||||
@ -451,7 +451,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate child chunk belongs to the specified segment
|
||||
if str(child_chunk.segment_id) != str(segment.id):
|
||||
if child_chunk.segment_id != segment.id:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
try:
|
||||
@ -500,7 +500,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate segment belongs to the specified document
|
||||
if str(segment.document_id) != str(document_id):
|
||||
if segment.document_id != document_id:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# get child chunk
|
||||
@ -511,7 +511,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate child chunk belongs to the specified segment
|
||||
if str(child_chunk.segment_id) != str(segment.id):
|
||||
if child_chunk.segment_id != segment.id:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate args
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import StrEnum, auto
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Optional, ParamSpec, TypeVar
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
@ -22,18 +22,15 @@ from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class WhereisUserArg(StrEnum):
|
||||
class WhereisUserArg(Enum):
|
||||
"""
|
||||
Enum for whereis_user_arg.
|
||||
"""
|
||||
|
||||
QUERY = auto()
|
||||
JSON = auto()
|
||||
FORM = auto()
|
||||
QUERY = "query"
|
||||
JSON = "json"
|
||||
FORM = "form"
|
||||
|
||||
|
||||
class FetchUserArg(BaseModel):
|
||||
@ -63,6 +60,27 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
||||
if tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("The workspace's status is archived.")
|
||||
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == api_token.tenant_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.role.in_(["owner"]))
|
||||
.where(Tenant.status == TenantStatus.NORMAL)
|
||||
.one_or_none()
|
||||
) # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.query(Account).where(Account.id == ta.account_id).first()
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
else:
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
|
||||
kwargs["app_model"] = app_model
|
||||
|
||||
if fetch_user_arg:
|
||||
@ -100,8 +118,8 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def interceptor(view):
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
|
||||
@ -130,9 +148,9 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
if features.billing.enabled:
|
||||
@ -152,9 +170,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
|
||||
if resource == "knowledge":
|
||||
@ -273,28 +291,27 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
|
||||
if not user_id:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
end_user = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == "service_api",
|
||||
)
|
||||
.first()
|
||||
end_user = (
|
||||
db.session.query(EndUser)
|
||||
.where(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == "service_api",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == "DEFAULT-USER",
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
session.commit()
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == "DEFAULT-USER",
|
||||
session_id=user_id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
|
||||
@ -73,6 +73,8 @@ class ConversationApi(WebApiResource):
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
WebConversationService.unpin(app_model, conversation_id, end_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
from datetime import UTC, datetime
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
|
||||
@ -16,9 +14,6 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def validate_jwt_token(view=None):
|
||||
def decorator(view):
|
||||
@ -54,19 +49,18 @@ def decode_jwt_token():
|
||||
decoded = PassportService().verify(tk)
|
||||
app_code = decoded.get("app_code")
|
||||
app_id = decoded.get("app_id")
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app_model = session.scalar(select(App).where(App.id == app_id))
|
||||
site = session.scalar(select(Site).where(Site.code == app_code))
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
if not app_code or not site:
|
||||
raise BadRequest("Site URL is no longer valid.")
|
||||
if app_model.enable_site is False:
|
||||
raise BadRequest("Site is disabled.")
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
end_user = session.scalar(select(EndUser).where(EndUser.id == end_user_id))
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
app_model = db.session.scalar(select(App).where(App.id == app_id))
|
||||
site = db.session.scalar(select(Site).where(Site.code == app_code))
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
if not app_code or not site:
|
||||
raise BadRequest("Site URL is no longer valid.")
|
||||
if app_model.enable_site is False:
|
||||
raise BadRequest("Site is disabled.")
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
|
||||
# for enterprise webapp auth
|
||||
app_web_auth_enabled = False
|
||||
|
||||
@ -62,7 +62,7 @@ class BaseAgentRunner(AppRunner):
|
||||
model_instance: ModelInstance,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
):
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
@ -334,8 +334,7 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id)
|
||||
agent_thought = db.session.scalar(stmt)
|
||||
agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first()
|
||||
if not agent_thought:
|
||||
raise ValueError("agent thought not found")
|
||||
|
||||
@ -493,8 +492,7 @@ class BaseAgentRunner(AppRunner):
|
||||
return result
|
||||
|
||||
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
||||
stmt = select(MessageFile).where(MessageFile.message_id == message.id)
|
||||
files = db.session.scalars(stmt).all()
|
||||
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
|
||||
if not files:
|
||||
return UserPromptMessage(content=message.query)
|
||||
if message.app_model_config:
|
||||
|
||||
@ -338,7 +338,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
return instruction
|
||||
|
||||
def _init_react_state(self, query):
|
||||
def _init_react_state(self, query) -> None:
|
||||
"""
|
||||
init agent scratchpad
|
||||
"""
|
||||
|
||||
@ -41,7 +41,7 @@ class AgentScratchpadUnit(BaseModel):
|
||||
action_name: str
|
||||
action_input: Union[dict, str]
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Convert to dictionary.
|
||||
"""
|
||||
|
||||
@ -158,7 +158,7 @@ class DatasetConfigManager:
|
||||
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
|
||||
|
||||
@classmethod
|
||||
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
|
||||
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict:
|
||||
"""
|
||||
Extract dataset config for legacy compatibility
|
||||
|
||||
|
||||
@ -105,7 +105,7 @@ class ModelConfigManager:
|
||||
return dict(config), ["model"]
|
||||
|
||||
@classmethod
|
||||
def validate_model_completion_params(cls, cp: dict):
|
||||
def validate_model_completion_params(cls, cp: dict) -> dict:
|
||||
# model.completion_params
|
||||
if not isinstance(cp, dict):
|
||||
raise ValueError("model.completion_params must be of object type")
|
||||
|
||||
@ -122,7 +122,7 @@ class PromptTemplateConfigManager:
|
||||
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
|
||||
|
||||
@classmethod
|
||||
def validate_post_prompt_and_set_defaults(cls, config: dict):
|
||||
def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict:
|
||||
"""
|
||||
Validate post_prompt and set defaults for prompt feature
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ class MoreLikeThisConfigManager:
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
try:
|
||||
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
|
||||
except ValidationError:
|
||||
except ValidationError as e:
|
||||
raise ValueError(
|
||||
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
|
||||
)
|
||||
|
||||
@ -41,7 +41,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False):
|
||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||
"""
|
||||
Validate for advanced chat app model config
|
||||
|
||||
|
||||
@ -450,12 +450,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
db.session.refresh(workflow)
|
||||
db.session.refresh(message)
|
||||
# db.session.refresh(user)
|
||||
db.session.close()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_advanced_chat_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
@ -481,7 +475,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
message_id: str,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
|
||||
@ -54,7 +54,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
app: App,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=variable_loader,
|
||||
@ -68,13 +68,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
self.system_user_id = system_user_id
|
||||
self._app = app
|
||||
|
||||
def run(self):
|
||||
def run(self) -> None:
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
|
||||
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
@ -142,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
environment_variables=self._workflow.environment_variables,
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
conversation_variables=conversation_variables,
|
||||
conversation_variables=cast(list[VariableUnion], conversation_variables),
|
||||
)
|
||||
|
||||
# init graph
|
||||
@ -221,7 +219,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
return False
|
||||
|
||||
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy):
|
||||
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
|
||||
"""
|
||||
Direct output
|
||||
"""
|
||||
|
||||
@ -118,7 +118,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
|
||||
|
||||
@ -73,6 +73,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Conversation, EndUser, Message, MessageFile
|
||||
@ -101,7 +102,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
):
|
||||
) -> None:
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
@ -289,7 +290,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
session.rollback()
|
||||
raise
|
||||
|
||||
def _ensure_workflow_initialized(self):
|
||||
def _ensure_workflow_initialized(self) -> None:
|
||||
"""Fluent validation for workflow state."""
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
@ -310,8 +311,13 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
|
||||
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
def _handle_workflow_started_event(
|
||||
self, event: QueueWorkflowStartedEvent, *, graph_runtime_state: Optional[GraphRuntimeState] = None, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow started events."""
|
||||
# Override graph runtime state - this is a side effect but necessary
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
@ -332,14 +338,15 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
"""Handle node retry events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id, event=event
|
||||
)
|
||||
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
with self._database_session() as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id, event=event
|
||||
)
|
||||
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_retry_resp:
|
||||
yield node_retry_resp
|
||||
@ -373,12 +380,13 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
|
||||
)
|
||||
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
|
||||
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
with self._database_session() as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
|
||||
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
|
||||
@ -888,7 +896,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None):
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||
message = self._get_message(session=session)
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
@ -931,6 +939,10 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self._task_state.metadata.usage = usage
|
||||
else:
|
||||
self._task_state.metadata.usage = LLMUsage.empty_usage()
|
||||
message_was_created.send(
|
||||
message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
)
|
||||
|
||||
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
||||
"""
|
||||
|
||||
@ -86,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]):
|
||||
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict:
|
||||
"""
|
||||
Validate for agent chat app model config
|
||||
|
||||
|
||||
@ -222,7 +222,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from core.agent.entities import AgentEntity
|
||||
@ -35,7 +33,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Run assistant application
|
||||
:param application_generate_entity: application generate entity
|
||||
@ -46,8 +44,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(AgentChatAppConfig, app_config)
|
||||
app_stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(app_stmt)
|
||||
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
@ -184,12 +182,11 @@ class AgentChatAppRunner(AppRunner):
|
||||
|
||||
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
|
||||
conversation_result = db.session.scalar(conversation_stmt)
|
||||
|
||||
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
|
||||
if conversation_result is None:
|
||||
raise ValueError("Conversation not found")
|
||||
msg_stmt = select(Message).where(Message.id == message.id)
|
||||
message_result = db.session.scalar(msg_stmt)
|
||||
message_result = db.session.query(Message).where(Message.id == message.id).first()
|
||||
if message_result is None:
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
|
||||
@ -94,7 +94,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
def _error_to_stream_response(cls, e: Exception):
|
||||
def _error_to_stream_response(cls, e: Exception) -> dict:
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
|
||||
@ -157,7 +157,7 @@ class BaseAppGenerator:
|
||||
|
||||
return value
|
||||
|
||||
def _sanitize_value(self, value: Any):
|
||||
def _sanitize_value(self, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return value.replace("\x00", "")
|
||||
return value
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from enum import IntEnum, auto
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
@ -19,13 +19,13 @@ from core.app.entities.queue_entities import (
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class PublishFrom(IntEnum):
|
||||
APPLICATION_MANAGER = auto()
|
||||
TASK_PIPELINE = auto()
|
||||
class PublishFrom(Enum):
|
||||
APPLICATION_MANAGER = 1
|
||||
TASK_PIPELINE = 2
|
||||
|
||||
|
||||
class AppQueueManager:
|
||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom):
|
||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None:
|
||||
if not user_id:
|
||||
raise ValueError("user is required")
|
||||
|
||||
@ -73,14 +73,14 @@ class AppQueueManager:
|
||||
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
|
||||
last_ping_time = elapsed_time // 10
|
||||
|
||||
def stop_listen(self):
|
||||
def stop_listen(self) -> None:
|
||||
"""
|
||||
Stop listen to queue
|
||||
:return:
|
||||
"""
|
||||
self._q.put(None)
|
||||
|
||||
def publish_error(self, e, pub_from: PublishFrom):
|
||||
def publish_error(self, e, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish error
|
||||
:param e: error
|
||||
@ -89,7 +89,7 @@ class AppQueueManager:
|
||||
"""
|
||||
self.publish(QueueErrorEvent(error=e), pub_from)
|
||||
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
@ -100,7 +100,7 @@ class AppQueueManager:
|
||||
self._publish(event, pub_from)
|
||||
|
||||
@abstractmethod
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
@ -110,7 +110,7 @@ class AppQueueManager:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str):
|
||||
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
|
||||
"""
|
||||
Set task stop flag
|
||||
:return:
|
||||
@ -159,7 +159,7 @@ class AppQueueManager:
|
||||
def _check_for_sqlalchemy_models(self, data: Any):
|
||||
# from entity to dict or list
|
||||
if isinstance(data, dict):
|
||||
for value in data.values():
|
||||
for key, value in data.items():
|
||||
self._check_for_sqlalchemy_models(value)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
|
||||
@ -162,7 +162,7 @@ class AppRunner:
|
||||
text: str,
|
||||
stream: bool,
|
||||
usage: Optional[LLMUsage] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param queue_manager: application queue manager
|
||||
@ -204,7 +204,7 @@ class AppRunner:
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
agent: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
@ -220,7 +220,9 @@ class AppRunner:
|
||||
else:
|
||||
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
|
||||
|
||||
def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool):
|
||||
def _handle_invoke_result_direct(
|
||||
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
|
||||
) -> None:
|
||||
"""
|
||||
Handle invoke result direct
|
||||
:param invoke_result: invoke result
|
||||
@ -237,7 +239,7 @@ class AppRunner:
|
||||
|
||||
def _handle_invoke_result_stream(
|
||||
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
|
||||
@ -81,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict):
|
||||
def config_validate(cls, tenant_id: str, config: dict) -> dict:
|
||||
"""
|
||||
Validate for chat app model config
|
||||
|
||||
|
||||
@ -211,7 +211,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfig
|
||||
@ -33,7 +31,7 @@ class ChatAppRunner(AppRunner):
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
@ -44,8 +42,8 @@ class ChatAppRunner(AppRunner):
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(ChatAppConfig, app_config)
|
||||
stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(stmt)
|
||||
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
|
||||
@ -62,7 +62,7 @@ class WorkflowResponseConverter:
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
user: Union[Account, EndUser],
|
||||
):
|
||||
) -> None:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._user = user
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict):
|
||||
def config_validate(cls, tenant_id: str, config: dict) -> dict:
|
||||
"""
|
||||
Validate for completion app model config
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, copy_current_request_context, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
@ -192,7 +191,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
application_generate_entity: CompletionAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message_id: str,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
@ -249,22 +248,22 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
stmt = select(Message).where(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.where(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
message = db.session.scalar(stmt)
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
current_app_model_config = app_model.app_model_config
|
||||
if not current_app_model_config:
|
||||
raise MoreLikeThisDisabledError()
|
||||
|
||||
more_like_this = current_app_model_config.more_like_this_dict
|
||||
|
||||
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfig
|
||||
@ -27,7 +25,7 @@ class CompletionAppRunner(AppRunner):
|
||||
|
||||
def run(
|
||||
self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
@ -37,8 +35,8 @@ class CompletionAppRunner(AppRunner):
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(CompletionAppConfig, app_config)
|
||||
stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(stmt)
|
||||
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = CompletionAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
|
||||
@ -3,9 +3,6 @@ import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -86,10 +83,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
|
||||
if conversation:
|
||||
stmt = select(AppModelConfig).where(
|
||||
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig)
|
||||
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
app_model_config = db.session.scalar(stmt)
|
||||
|
||||
if not app_model_config:
|
||||
raise AppModelConfigBrokenError()
|
||||
@ -255,8 +253,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
:param conversation_id: conversation id
|
||||
:return: conversation
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
conversation = session.scalar(select(Conversation).where(Conversation.id == conversation_id))
|
||||
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError("Conversation not exists")
|
||||
@ -269,8 +266,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
:param message_id: message id
|
||||
:return: message
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = session.scalar(select(Message).where(Message.id == message_id))
|
||||
message = db.session.query(Message).where(Message.id == message_id).first()
|
||||
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
@ -14,14 +14,14 @@ from core.app.entities.queue_entities import (
|
||||
class MessageBasedAppQueueManager(AppQueueManager):
|
||||
def __init__(
|
||||
self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(task_id, user_id, invoke_from)
|
||||
|
||||
self._conversation_id = str(conversation_id)
|
||||
self._app_mode = app_mode
|
||||
self._message_id = str(message_id)
|
||||
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
|
||||
@ -35,7 +35,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False):
|
||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||
"""
|
||||
Validate for workflow app model config
|
||||
|
||||
|
||||
@ -435,7 +435,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
|
||||
@ -14,12 +14,12 @@ from core.app.entities.queue_entities import (
|
||||
|
||||
|
||||
class WorkflowAppQueueManager(AppQueueManager):
|
||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str):
|
||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
|
||||
super().__init__(task_id, user_id, invoke_from)
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
|
||||
@ -34,7 +34,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=variable_loader,
|
||||
@ -45,7 +45,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
|
||||
def run(self):
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
"""
|
||||
|
||||
@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return dict(blocking_response.to_dict())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
|
||||
@ -92,7 +92,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
):
|
||||
) -> None:
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
@ -263,7 +263,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
session.rollback()
|
||||
raise
|
||||
|
||||
def _ensure_workflow_initialized(self):
|
||||
def _ensure_workflow_initialized(self) -> None:
|
||||
"""Fluent validation for workflow state."""
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
@ -300,15 +300,16 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
"""Handle node retry events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
with self._database_session() as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
@ -744,7 +745,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution):
|
||||
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
|
||||
invoke_from = self._application_generate_entity.invoke_from
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||
|
||||
@ -74,7 +74,7 @@ class WorkflowBasedAppRunner:
|
||||
queue_manager: AppQueueManager,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
app_id: str,
|
||||
):
|
||||
) -> None:
|
||||
self._queue_manager = queue_manager
|
||||
self._variable_loader = variable_loader
|
||||
self._app_id = app_id
|
||||
@ -292,7 +292,7 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Handle event
|
||||
:param workflow_entry: workflow entry
|
||||
@ -694,5 +694,5 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
)
|
||||
|
||||
def _publish_event(self, event: AppQueueEvent):
|
||||
def _publish_event(self, event: AppQueueEvent) -> None:
|
||||
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
@ -11,7 +11,7 @@ from core.file import File, FileUploadConfig
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
class InvokeFrom(Enum):
|
||||
"""
|
||||
Invoke From.
|
||||
"""
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
@ -27,8 +25,9 @@ class AnnotationReplyFeature:
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id)
|
||||
annotation_setting = db.session.scalar(stmt)
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
|
||||
)
|
||||
|
||||
if not annotation_setting:
|
||||
return None
|
||||
|
||||
@ -96,11 +96,7 @@ class RateLimit:
|
||||
if isinstance(generator, Mapping):
|
||||
return generator
|
||||
else:
|
||||
return RateLimitGenerator(
|
||||
rate_limit=self,
|
||||
generator=generator, # ty: ignore [invalid-argument-type]
|
||||
request_id=request_id,
|
||||
)
|
||||
return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id)
|
||||
|
||||
|
||||
class RateLimitGenerator:
|
||||
|
||||
@ -35,7 +35,7 @@ class BasedGenerateTaskPipeline:
|
||||
application_generate_entity: AppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
):
|
||||
) -> None:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self.queue_manager = queue_manager
|
||||
self._start_at = time.perf_counter()
|
||||
@ -50,7 +50,7 @@ class BasedGenerateTaskPipeline:
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
elif isinstance(e, InvokeError | ValueError):
|
||||
err = e # ty: ignore [invalid-assignment]
|
||||
err = e
|
||||
else:
|
||||
description = getattr(e, "description", None)
|
||||
err = Exception(description if description is not None else str(e))
|
||||
|
||||
@ -80,7 +80,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
stream: bool,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
@ -362,7 +362,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None):
|
||||
def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
@ -412,7 +412,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
)
|
||||
|
||||
def _handle_stop(self, event: QueueStopEvent):
|
||||
def _handle_stop(self, event: QueueStopEvent) -> None:
|
||||
"""
|
||||
Handle stop.
|
||||
:return:
|
||||
@ -472,10 +472,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
:param event: agent thought event
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
|
||||
if agent_thought:
|
||||
return AgentThoughtStreamResponse(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user