mirror of
https://github.com/langgenius/dify.git
synced 2026-01-21 04:25:23 +08:00
Compare commits
57 Commits
fix/model-
...
bug1
| Author | SHA1 | Date | |
|---|---|---|---|
| 41dfdf1ac0 | |||
| dd7de74aa6 | |||
| f11131f8b5 | |||
| 2e6e414a9e | |||
| c45d676477 | |||
| b8d8dddd5a | |||
| c45c22b1b2 | |||
| 3d57a9ccdc | |||
| cb04c21141 | |||
| f70272f638 | |||
| b4b71ded47 | |||
| 24e2b72b71 | |||
| 529791ce62 | |||
| b66945b9b8 | |||
| f3c5d77ad5 | |||
| e5e42bc483 | |||
| bdfbfa391f | |||
| 72acd9b483 | |||
| 9f528d23d4 | |||
| d937cc491d | |||
| 863f3aeb27 | |||
| 0fe078d25e | |||
| d9420c7224 | |||
| 9ff6baaf52 | |||
| 574d00bb13 | |||
| 8d60e5c342 | |||
| d9eb1a73af | |||
| 1a34ff8a67 | |||
| 14e7ba4818 | |||
| 52e9bcbfdb | |||
| 20ae3eae54 | |||
| 0fb145e667 | |||
| bcac43c812 | |||
| 929d9e0b3f | |||
| d5e560a987 | |||
| e4383d6167 | |||
| f32e176d6a | |||
| 3d5a4df9d0 | |||
| e47bfd2ca3 | |||
| f8f768873e | |||
| d043e1a05a | |||
| 837c0ddacc | |||
| 7c340695d6 | |||
| e87d4fbf69 | |||
| 39064197da | |||
| c4496e6cf2 | |||
| 27d09d1783 | |||
| a174ee419e | |||
| 79e6138ce2 | |||
| 5a64f69456 | |||
| 5c01dd97e8 | |||
| ecf74d91e2 | |||
| 62892ed8d7 | |||
| 7b399cc5e5 | |||
| fab5740778 | |||
| 30f2d756a7 | |||
| 0d745c64d8 |
19
.claude/settings.json.example
Normal file
19
.claude/settings.json.example
Normal file
@ -0,0 +1,19 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [],
|
||||
"deny": []
|
||||
},
|
||||
"env": {
|
||||
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
},
|
||||
"enabledMcpjsonServers": [
|
||||
"context7",
|
||||
"sequential-thinking",
|
||||
"github",
|
||||
"fetch",
|
||||
"playwright",
|
||||
"ide"
|
||||
],
|
||||
"enableAllProjectMcpServers": true
|
||||
}
|
||||
8
.github/workflows/api-tests.yml
vendored
8
.github/workflows/api-tests.yml
vendored
@ -1,13 +1,7 @@
|
||||
name: Run Pytest
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- api/**
|
||||
- docker/**
|
||||
- .github/workflows/api-tests.yml
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: api-tests-${{ github.head_ref || github.run_id }}
|
||||
|
||||
8
.github/workflows/autofix.yml
vendored
8
.github/workflows/autofix.yml
vendored
@ -1,10 +1,7 @@
|
||||
name: autofix.ci
|
||||
on:
|
||||
workflow_call:
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
branches: ["main"]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
@ -18,7 +15,7 @@ jobs:
|
||||
# Use uv to ensure we have the same ruff version in CI and locally.
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
python-version: "3.12"
|
||||
- run: |
|
||||
cd api
|
||||
uv sync --dev
|
||||
@ -29,6 +26,7 @@ jobs:
|
||||
- name: ast-grep
|
||||
run: |
|
||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
- name: mdformat
|
||||
run: |
|
||||
uvx mdformat .
|
||||
|
||||
14
.github/workflows/db-migration-test.yml
vendored
14
.github/workflows/db-migration-test.yml
vendored
@ -1,13 +1,7 @@
|
||||
name: DB Migration Test
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- plugins/beta
|
||||
paths:
|
||||
- api/migrations/**
|
||||
- .github/workflows/db-migration-test.yml
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: db-migration-test-${{ github.ref }}
|
||||
@ -33,6 +27,12 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api
|
||||
- name: Ensure Offline migration are supported
|
||||
run: |
|
||||
# upgrade
|
||||
uv run --directory api flask db upgrade 'base:head' --sql
|
||||
# downgrade
|
||||
uv run --directory api flask db downgrade 'head:base' --sql
|
||||
|
||||
- name: Prepare middleware env
|
||||
run: |
|
||||
|
||||
78
.github/workflows/main-ci.yml
vendored
Normal file
78
.github/workflows/main-ci.yml
vendored
Normal file
@ -0,0 +1,78 @@
|
||||
name: Main CI Pipeline
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
push:
|
||||
branches: ["main"]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
checks: write
|
||||
statuses: write
|
||||
|
||||
concurrency:
|
||||
group: main-ci-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
# Check which paths were changed to determine which tests to run
|
||||
check-changes:
|
||||
name: Check Changed Files
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
api-changed: ${{ steps.changes.outputs.api }}
|
||||
web-changed: ${{ steps.changes.outputs.web }}
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dorny/paths-filter@v3
|
||||
id: changes
|
||||
with:
|
||||
filters: |
|
||||
api:
|
||||
- 'api/**'
|
||||
- 'docker/**'
|
||||
- '.github/workflows/api-tests.yml'
|
||||
web:
|
||||
- 'web/**'
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'docker/**'
|
||||
- '.github/workflows/vdb-tests.yml'
|
||||
- 'api/uv.lock'
|
||||
- 'api/pyproject.toml'
|
||||
migration:
|
||||
- 'api/migrations/**'
|
||||
- '.github/workflows/db-migration-test.yml'
|
||||
|
||||
# Run tests in parallel
|
||||
api-tests:
|
||||
name: API Tests
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.api-changed == 'true'
|
||||
uses: ./.github/workflows/api-tests.yml
|
||||
|
||||
web-tests:
|
||||
name: Web Tests
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.web-changed == 'true'
|
||||
uses: ./.github/workflows/web-tests.yml
|
||||
|
||||
style-check:
|
||||
name: Style Check
|
||||
uses: ./.github/workflows/style.yml
|
||||
|
||||
vdb-tests:
|
||||
name: VDB Tests
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.vdb-changed == 'true'
|
||||
uses: ./.github/workflows/vdb-tests.yml
|
||||
|
||||
db-migration-test:
|
||||
name: DB Migration Test
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.migration-changed == 'true'
|
||||
uses: ./.github/workflows/db-migration-test.yml
|
||||
15
.github/workflows/style.yml
vendored
15
.github/workflows/style.yml
vendored
@ -1,9 +1,7 @@
|
||||
name: Style check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: style-${{ github.head_ref || github.run_id }}
|
||||
@ -46,21 +44,10 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Ruff check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: |
|
||||
uv run --directory api ruff --version
|
||||
uv run --directory api ruff check ./
|
||||
uv run --directory api ruff format --check ./
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
||||
- name: Lint hints
|
||||
if: failure()
|
||||
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
|
||||
|
||||
web-style:
|
||||
name: Web Style
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
10
.github/workflows/vdb-tests.yml
vendored
10
.github/workflows/vdb-tests.yml
vendored
@ -1,15 +1,7 @@
|
||||
name: Run VDB Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- api/core/rag/datasource/**
|
||||
- docker/**
|
||||
- .github/workflows/vdb-tests.yml
|
||||
- api/uv.lock
|
||||
- api/pyproject.toml
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-${{ github.head_ref || github.run_id }}
|
||||
|
||||
6
.github/workflows/web-tests.yml
vendored
6
.github/workflows/web-tests.yml
vendored
@ -1,11 +1,7 @@
|
||||
name: Web Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- web/**
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: web-tests-${{ github.head_ref || github.run_id }}
|
||||
|
||||
34
.mcp.json
Normal file
34
.mcp.json
Normal file
@ -0,0 +1,34 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"context7": {
|
||||
"type": "http",
|
||||
"url": "https://mcp.context7.com/mcp"
|
||||
},
|
||||
"sequential-thinking": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
|
||||
"env": {}
|
||||
},
|
||||
"github": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-github"],
|
||||
"env": {
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
|
||||
}
|
||||
},
|
||||
"fetch": {
|
||||
"type": "stdio",
|
||||
"command": "uvx",
|
||||
"args": ["mcp-server-fetch"],
|
||||
"env": {}
|
||||
},
|
||||
"playwright": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@playwright/mcp@latest"],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -86,3 +86,4 @@ pnpm test # Run Jest tests
|
||||
## Project-Specific Conventions
|
||||
|
||||
- All async tasks use Celery with Redis as broker
|
||||
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.
|
||||
|
||||
@ -70,7 +70,7 @@ from .app import (
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server
|
||||
|
||||
# Import billing controllers
|
||||
from .billing import billing, compliance
|
||||
|
||||
@ -95,18 +95,22 @@ class ChatMessageListApi(Resource):
|
||||
.all()
|
||||
)
|
||||
|
||||
# Initialize has_more based on whether we have a full page
|
||||
if len(history_messages) == args["limit"]:
|
||||
current_page_first_message = history_messages[-1]
|
||||
|
||||
has_more = db.session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id,
|
||||
# Check if there are more messages before the current page
|
||||
has_more = db.session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
# If we don't have a full page, there are no more messages
|
||||
has_more = False
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
@ -126,7 +130,7 @@ class MessageFeedbackApi(Resource):
|
||||
|
||||
message_id = str(args["message_id"])
|
||||
|
||||
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
187
api/controllers/console/auth/oauth_server.py
Normal file
187
api/controllers/console/auth/oauth_server.py
Normal file
@ -0,0 +1,187 @@
|
||||
from functools import wraps
|
||||
from typing import cast
|
||||
|
||||
import flask_login
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||
|
||||
from .. import api
|
||||
|
||||
|
||||
def oauth_server_client_id_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
client_id = parsed_args.get("client_id")
|
||||
if not client_id:
|
||||
raise BadRequest("client_id is required")
|
||||
|
||||
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
|
||||
if not oauth_provider_app:
|
||||
raise NotFound("client_id is invalid")
|
||||
|
||||
kwargs["oauth_provider_app"] = oauth_provider_app
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def oauth_server_access_token_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
oauth_provider_app = kwargs.get("oauth_provider_app")
|
||||
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
raise BadRequest("Invalid oauth_provider_app")
|
||||
|
||||
authorization_header = request.headers.get("Authorization")
|
||||
if not authorization_header:
|
||||
raise BadRequest("Authorization header is required")
|
||||
|
||||
parts = authorization_header.strip().split(" ")
|
||||
if len(parts) != 2:
|
||||
raise BadRequest("Invalid Authorization header format")
|
||||
|
||||
token_type = parts[0].strip()
|
||||
if token_type.lower() != "bearer":
|
||||
raise BadRequest("token_type is invalid")
|
||||
|
||||
access_token = parts[1].strip()
|
||||
if not access_token:
|
||||
raise BadRequest("access_token is required")
|
||||
|
||||
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
|
||||
if not account:
|
||||
raise BadRequest("access_token or client_id is invalid")
|
||||
|
||||
kwargs["account"] = account
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
class OAuthServerAppApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("redirect_uri", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
redirect_uri = parsed_args.get("redirect_uri")
|
||||
|
||||
# check if redirect_uri is valid
|
||||
if redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"app_icon": oauth_provider_app.app_icon,
|
||||
"app_label": oauth_provider_app.app_label,
|
||||
"scope": oauth_provider_app.scope,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthServerUserAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
account = cast(Account, flask_login.current_user)
|
||||
user_account_id = account.id
|
||||
|
||||
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"code": code,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthServerUserTokenApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("grant_type", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=False, location="json")
|
||||
parser.add_argument("client_secret", type=str, required=False, location="json")
|
||||
parser.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||
parser.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
|
||||
try:
|
||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
||||
except ValueError:
|
||||
raise BadRequest("invalid grant_type")
|
||||
|
||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not parsed_args["code"]:
|
||||
raise BadRequest("code is required")
|
||||
|
||||
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
|
||||
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
||||
if not parsed_args["refresh_token"]:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthServerUserAccountApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
@oauth_server_access_token_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp, account: Account):
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"name": account.name,
|
||||
"email": account.email,
|
||||
"avatar": account.avatar,
|
||||
"interface_language": account.interface_language,
|
||||
"timezone": account.timezone,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
api.add_resource(OAuthServerAppApi, "/oauth/provider")
|
||||
api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize")
|
||||
api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token")
|
||||
api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account")
|
||||
@ -1,8 +1,12 @@
|
||||
from base64 import b64encode
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
@ -10,9 +14,9 @@ from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def billing_inner_api_only(view):
|
||||
def billing_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
@ -26,9 +30,9 @@ def billing_inner_api_only(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_inner_api_only(view):
|
||||
def enterprise_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
@ -78,9 +82,9 @@ def enterprise_inner_api_user_auth(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def plugin_inner_api_only(view):
|
||||
def plugin_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.PLUGIN_DAEMON_KEY:
|
||||
abort(404)
|
||||
|
||||
|
||||
@ -1,18 +1,27 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, reqparse
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import mcp_ns
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.mcp import types
|
||||
from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
|
||||
from core.mcp.types import ClientNotification, ClientRequest
|
||||
from core.mcp.utils import create_mcp_error_response
|
||||
from core.mcp import types as mcp_types
|
||||
from core.mcp.server.streamable_http import handle_mcp_request
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMCPServer, AppMode
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
|
||||
|
||||
class MCPRequestError(Exception):
|
||||
"""Custom exception for MCP request processing errors"""
|
||||
|
||||
def __init__(self, error_code: int, message: str):
|
||||
self.error_code = error_code
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def int_or_str(value):
|
||||
@ -63,77 +72,173 @@ class MCPAppApi(Resource):
|
||||
Raises:
|
||||
ValidationError: Invalid request format or parameters
|
||||
"""
|
||||
# Parse and validate all arguments
|
||||
args = mcp_request_parser.parse_args()
|
||||
|
||||
request_id: Optional[Union[int, str]] = args.get("id")
|
||||
mcp_request = self._parse_mcp_request(args)
|
||||
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
|
||||
if not server:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get MCP server and app
|
||||
mcp_server, app = self._get_mcp_server_and_app(server_code, session)
|
||||
self._validate_server_status(mcp_server)
|
||||
|
||||
if server.status != AppMCPServerStatus.ACTIVE:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
|
||||
)
|
||||
# Get user input form
|
||||
user_input_form = self._get_user_input_form(app)
|
||||
|
||||
app = db.session.query(App).where(App.id == server.app_id).first()
|
||||
# Handle notification vs request differently
|
||||
return self._process_mcp_message(mcp_request, request_id, app, mcp_server, user_input_form, session)
|
||||
|
||||
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
|
||||
"""Get and validate MCP server and app in one query session"""
|
||||
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
|
||||
if not mcp_server:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
|
||||
|
||||
app = session.query(App).where(App.id == mcp_server.app_id).first()
|
||||
if not app:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")
|
||||
)
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
|
||||
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
||||
)
|
||||
return mcp_server, app
|
||||
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
def _validate_server_status(self, mcp_server: AppMCPServer) -> None:
|
||||
"""Validate MCP server status"""
|
||||
if mcp_server.status != AppMCPServerStatus.ACTIVE:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")
|
||||
|
||||
def _process_mcp_message(
|
||||
self,
|
||||
mcp_request: mcp_types.ClientRequest | mcp_types.ClientNotification,
|
||||
request_id: Optional[Union[int, str]],
|
||||
app: App,
|
||||
mcp_server: AppMCPServer,
|
||||
user_input_form: list[VariableEntity],
|
||||
session: Session,
|
||||
) -> Response:
|
||||
"""Process MCP message (notification or request)"""
|
||||
if isinstance(mcp_request, mcp_types.ClientNotification):
|
||||
return self._handle_notification(mcp_request)
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
||||
)
|
||||
return self._handle_request(mcp_request, request_id, app, mcp_server, user_input_form, session)
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
converted_user_input_form: list[VariableEntity] = []
|
||||
try:
|
||||
for item in user_input_form:
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
converted_user_input_form.append(
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
)
|
||||
)
|
||||
except ValidationError as e:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
)
|
||||
def _handle_notification(self, mcp_request: mcp_types.ClientNotification) -> Response:
|
||||
"""Handle MCP notification"""
|
||||
# For notifications, only support init notification
|
||||
if mcp_request.root.method != "notifications/initialized":
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Invalid notification method")
|
||||
# Return HTTP 202 Accepted for notifications (no response body)
|
||||
return Response("", status=202, content_type="application/json")
|
||||
|
||||
def _handle_request(
|
||||
self,
|
||||
mcp_request: mcp_types.ClientRequest,
|
||||
request_id: Optional[Union[int, str]],
|
||||
app: App,
|
||||
mcp_server: AppMCPServer,
|
||||
user_input_form: list[VariableEntity],
|
||||
session: Session,
|
||||
) -> Response:
|
||||
"""Handle MCP request"""
|
||||
if request_id is None:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Request ID is required")
|
||||
|
||||
result = self._handle_mcp_request(app, mcp_server, mcp_request, user_input_form, session, request_id)
|
||||
if result is None:
|
||||
# This shouldn't happen for requests, but handle gracefully
|
||||
raise MCPRequestError(mcp_types.INTERNAL_ERROR, "No response generated for request")
|
||||
|
||||
return helper.compact_generate_response(result.model_dump(by_alias=True, mode="json", exclude_none=True))
|
||||
|
||||
def _get_user_input_form(self, app: App) -> list[VariableEntity]:
|
||||
"""Get and convert user input form"""
|
||||
# Get raw user input form based on app mode
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if not app.workflow:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
|
||||
raw_user_input_form = app.workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
if not app.app_model_config:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
|
||||
features_dict = app.app_model_config.to_dict()
|
||||
raw_user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
# Convert to VariableEntity objects
|
||||
try:
|
||||
request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
|
||||
return self._convert_user_input_form(raw_user_input_form)
|
||||
except ValidationError as e:
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
|
||||
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
|
||||
"""Convert raw user input form to VariableEntity objects"""
|
||||
return [self._create_variable_entity(item) for item in raw_form]
|
||||
|
||||
def _create_variable_entity(self, item: dict) -> VariableEntity:
|
||||
"""Create a single VariableEntity from raw form item"""
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
|
||||
return VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
)
|
||||
|
||||
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
"""Parse and validate MCP request"""
|
||||
try:
|
||||
return mcp_types.ClientRequest.model_validate(args)
|
||||
except ValidationError:
|
||||
try:
|
||||
notification = ClientNotification.model_validate(args)
|
||||
request = notification
|
||||
return mcp_types.ClientNotification.model_validate(args)
|
||||
except ValidationError as e:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||
)
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||
|
||||
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
|
||||
response = mcp_server_handler.handle()
|
||||
return helper.compact_generate_response(response)
|
||||
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
|
||||
"""Get end user from existing session - optimized query"""
|
||||
return (
|
||||
session.query(EndUser)
|
||||
.where(EndUser.tenant_id == tenant_id)
|
||||
.where(EndUser.session_id == mcp_server_id)
|
||||
.where(EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
|
||||
def _create_end_user(
|
||||
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
|
||||
) -> EndUser:
|
||||
"""Create end user in existing session"""
|
||||
end_user = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type="mcp",
|
||||
name=client_name,
|
||||
session_id=mcp_server_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
session.flush() # Use flush instead of commit to keep transaction open
|
||||
session.refresh(end_user)
|
||||
return end_user
|
||||
|
||||
def _handle_mcp_request(
|
||||
self,
|
||||
app: App,
|
||||
mcp_server: AppMCPServer,
|
||||
mcp_request: mcp_types.ClientRequest,
|
||||
user_input_form: list[VariableEntity],
|
||||
session: Session,
|
||||
request_id: Union[int, str],
|
||||
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
|
||||
"""Handle MCP request and return response"""
|
||||
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)
|
||||
|
||||
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
|
||||
client_info = mcp_request.root.params.clientInfo
|
||||
client_name = f"{client_info.name}@{client_info.version}"
|
||||
# Commit the session before creating end user to avoid transaction conflicts
|
||||
session.commit()
|
||||
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
|
||||
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
|
||||
|
||||
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)
|
||||
|
||||
@ -318,10 +318,6 @@ class DatasetApi(DatasetApiResource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
if data.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in # type: ignore
|
||||
from flask_login import user_logged_in
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, update
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
from jwt import InvalidTokenError # type: ignore
|
||||
from jwt import InvalidTokenError
|
||||
|
||||
import services
|
||||
from controllers.console.auth.error import (
|
||||
|
||||
@ -1,3 +1,16 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
|
||||
class MoreLikeThisConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class AppConfigModel(BaseModel):
|
||||
more_like_this: MoreLikeThisConfig = Field(default_factory=MoreLikeThisConfig)
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class MoreLikeThisConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
@ -6,31 +19,14 @@ class MoreLikeThisConfigManager:
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
more_like_this = False
|
||||
more_like_this_dict = config.get("more_like_this")
|
||||
if more_like_this_dict:
|
||||
if more_like_this_dict.get("enabled"):
|
||||
more_like_this = True
|
||||
|
||||
return more_like_this
|
||||
validated_config, _ = cls.validate_and_set_defaults(config)
|
||||
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for more like this feature
|
||||
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("more_like_this"):
|
||||
config["more_like_this"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["more_like_this"], dict):
|
||||
raise ValueError("more_like_this must be of dict type")
|
||||
|
||||
if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]:
|
||||
config["more_like_this"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["more_like_this"]["enabled"], bool):
|
||||
raise ValueError("enabled in more_like_this must be of boolean type")
|
||||
|
||||
return config, ["more_like_this"]
|
||||
try:
|
||||
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
|
||||
except ValidationError as e:
|
||||
raise ValueError(
|
||||
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
|
||||
)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
@ -143,6 +144,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
self._workflow_response_converter = WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
)
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
@ -373,7 +375,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle node succeeded events."""
|
||||
# Record files if it's an answer node or end node
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END, NodeType.LLM]:
|
||||
self._recorded_files.extend(
|
||||
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
|
||||
)
|
||||
@ -896,7 +898,14 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||
message = self._get_message(session=session)
|
||||
message.answer = self._task_state.answer
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
answer_text = self._task_state.answer
|
||||
if self._recorded_files:
|
||||
# Remove markdown image links since we're storing files separately
|
||||
answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip()
|
||||
|
||||
message.answer = answer_text
|
||||
message.updated_at = naive_utc_now()
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, final
|
||||
|
||||
@ -14,6 +13,7 @@ from core.workflow.repositories.draft_variable_repository import (
|
||||
NoopDraftVariableSaver,
|
||||
)
|
||||
from factories import file_factory
|
||||
from libs.orjson import orjson_dumps
|
||||
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -174,7 +174,7 @@ class BaseAppGenerator:
|
||||
def gen():
|
||||
for message in generator:
|
||||
if isinstance(message, Mapping | dict):
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
yield f"data: {orjson_dumps(message)}\n\n"
|
||||
else:
|
||||
yield f"event: {message}\n\n"
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
@ -53,9 +52,7 @@ from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
|
||||
@ -64,8 +61,10 @@ class WorkflowResponseConverter:
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
user: Union[Account, EndUser],
|
||||
) -> None:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._user = user
|
||||
|
||||
def workflow_start_to_stream_response(
|
||||
self,
|
||||
@ -92,27 +91,21 @@ class WorkflowResponseConverter:
|
||||
workflow_execution: WorkflowExecution,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
created_by = None
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
|
||||
assert workflow_run is not None
|
||||
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
|
||||
stmt = select(Account).where(Account.id == workflow_run.created_by)
|
||||
account = session.scalar(stmt)
|
||||
if account:
|
||||
created_by = {
|
||||
"id": account.id,
|
||||
"name": account.name,
|
||||
"email": account.email,
|
||||
}
|
||||
elif workflow_run.created_by_role == CreatorUserRole.END_USER:
|
||||
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
|
||||
end_user = session.scalar(stmt)
|
||||
if end_user:
|
||||
created_by = {
|
||||
"id": end_user.id,
|
||||
"user": end_user.session_id,
|
||||
}
|
||||
|
||||
user = self._user
|
||||
if isinstance(user, Account):
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
}
|
||||
elif isinstance(user, EndUser):
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"user": user.session_id,
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
|
||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||
|
||||
# Handle the case where finished_at is None by using current time as default
|
||||
finished_at_timestamp = (
|
||||
|
||||
@ -131,6 +131,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
self._workflow_response_converter = WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
)
|
||||
|
||||
self._application_generate_entity = application_generate_entity
|
||||
|
||||
@ -118,7 +118,7 @@ class QueueIterationNextEvent(AppQueueEvent):
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
"""iteration run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
duration: Optional[float] = None
|
||||
@ -201,7 +201,7 @@ class QueueLoopNextEvent(AppQueueEvent):
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
"""iteration run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current loop
|
||||
duration: Optional[float] = None
|
||||
@ -382,7 +382,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
"""iteration run in parallel mode run id"""
|
||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
class TaskPipilineError(ValueError):
|
||||
class TaskPipelineError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class RecordNotFoundError(TaskPipilineError):
|
||||
class RecordNotFoundError(TaskPipelineError):
|
||||
def __init__(self, record_name: str, record_id: str):
|
||||
super().__init__(f"{record_name} with id {record_id} not found")
|
||||
|
||||
|
||||
@ -88,6 +88,7 @@ def to_prompt_message_content(
|
||||
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
|
||||
"format": f.extension.removeprefix("."),
|
||||
"mime_type": f.mime_type,
|
||||
"filename": f.filename or "",
|
||||
}
|
||||
if f.type == FileType.IMAGE:
|
||||
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
@ -4,224 +4,259 @@ from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web.passport import generate_session_id
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
from core.mcp import types
|
||||
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
|
||||
from core.mcp.utils import create_mcp_error_response
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from core.mcp import types as mcp_types
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPServerStreamableHTTPRequestHandler:
|
||||
def handle_mcp_request(
|
||||
app: App,
|
||||
request: mcp_types.ClientRequest,
|
||||
user_input_form: list[VariableEntity],
|
||||
mcp_server: AppMCPServer,
|
||||
end_user: EndUser | None = None,
|
||||
request_id: int | str = 1,
|
||||
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError:
|
||||
"""
|
||||
Apply to MCP HTTP streamable server with stateless http
|
||||
Handle MCP request and return JSON-RPC response
|
||||
|
||||
Args:
|
||||
app: The Dify app instance
|
||||
request: The JSON-RPC request message
|
||||
user_input_form: List of variable entities for the app
|
||||
mcp_server: The MCP server configuration
|
||||
end_user: Optional end user
|
||||
request_id: The request ID
|
||||
|
||||
Returns:
|
||||
JSON-RPC response or error
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
|
||||
):
|
||||
self.app = app
|
||||
self.request = request
|
||||
mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first()
|
||||
if not mcp_server:
|
||||
raise ValueError("MCP server not found")
|
||||
self.mcp_server: AppMCPServer = mcp_server
|
||||
self.end_user = self.retrieve_end_user()
|
||||
self.user_input_form = user_input_form
|
||||
request_type = type(request.root)
|
||||
|
||||
@property
|
||||
def request_type(self):
|
||||
return type(self.request.root)
|
||||
def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse:
|
||||
"""Create success response with business result data"""
|
||||
return mcp_types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
result=result_data.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
@property
|
||||
def parameter_schema(self):
|
||||
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
|
||||
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
}
|
||||
def create_error_response(code: int, message: str) -> mcp_types.JSONRPCError:
|
||||
"""Create error response with error code and message"""
|
||||
from core.mcp.types import ErrorData
|
||||
|
||||
error_data = ErrorData(code=code, message=message)
|
||||
return mcp_types.JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
error=error_data,
|
||||
)
|
||||
|
||||
# Request handler mapping using functional approach
|
||||
request_handlers = {
|
||||
mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description),
|
||||
mcp_types.ListToolsRequest: lambda: handle_list_tools(
|
||||
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
|
||||
),
|
||||
mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user),
|
||||
mcp_types.PingRequest: lambda: handle_ping(),
|
||||
}
|
||||
|
||||
try:
|
||||
# Dispatch request to appropriate handler
|
||||
handler = request_handlers.get(request_type)
|
||||
if handler:
|
||||
return create_success_response(handler())
|
||||
else:
|
||||
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Invalid params")
|
||||
return create_error_response(mcp_types.INVALID_PARAMS, str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Internal server error")
|
||||
return create_error_response(mcp_types.INTERNAL_ERROR, "Internal server error: " + str(e))
|
||||
|
||||
|
||||
def handle_ping() -> mcp_types.EmptyResult:
|
||||
"""Handle ping request"""
|
||||
return mcp_types.EmptyResult()
|
||||
|
||||
|
||||
def handle_initialize(description: str) -> mcp_types.InitializeResult:
|
||||
"""Handle initialize request"""
|
||||
capabilities = mcp_types.ServerCapabilities(
|
||||
tools=mcp_types.ToolsCapability(listChanged=False),
|
||||
)
|
||||
|
||||
return mcp_types.InitializeResult(
|
||||
protocolVersion=mcp_types.SERVER_LATEST_PROTOCOL_VERSION,
|
||||
capabilities=capabilities,
|
||||
serverInfo=mcp_types.Implementation(name="Dify", version=dify_config.project.version),
|
||||
instructions=description,
|
||||
)
|
||||
|
||||
|
||||
def handle_list_tools(
|
||||
app_name: str,
|
||||
app_mode: str,
|
||||
user_input_form: list[VariableEntity],
|
||||
description: str,
|
||||
parameters_dict: dict[str, str],
|
||||
) -> mcp_types.ListToolsResult:
|
||||
"""Handle list tools request"""
|
||||
parameter_schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)
|
||||
|
||||
return mcp_types.ListToolsResult(
|
||||
tools=[
|
||||
mcp_types.Tool(
|
||||
name=app_name,
|
||||
description=description,
|
||||
inputSchema=parameter_schema,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def handle_call_tool(
|
||||
app: App,
|
||||
request: mcp_types.ClientRequest,
|
||||
user_input_form: list[VariableEntity],
|
||||
end_user: EndUser | None,
|
||||
) -> mcp_types.CallToolResult:
|
||||
"""Handle call tool request"""
|
||||
request_obj = cast(mcp_types.CallToolRequest, request.root)
|
||||
args = prepare_tool_arguments(app, request_obj.params.arguments or {})
|
||||
|
||||
if not end_user:
|
||||
raise ValueError("End user not found")
|
||||
|
||||
response = AppGenerateService.generate(
|
||||
app,
|
||||
end_user,
|
||||
args,
|
||||
InvokeFrom.SERVICE_API,
|
||||
streaming=app.mode == AppMode.AGENT_CHAT.value,
|
||||
)
|
||||
|
||||
answer = extract_answer_from_response(app, response)
|
||||
return mcp_types.CallToolResult(content=[mcp_types.TextContent(text=answer, type="text")])
|
||||
|
||||
|
||||
def build_parameter_schema(
|
||||
app_mode: str,
|
||||
user_input_form: list[VariableEntity],
|
||||
parameters_dict: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
"""Build parameter schema for the tool"""
|
||||
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||
|
||||
if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "User Input/Question content"},
|
||||
**parameters,
|
||||
},
|
||||
"required": ["query", *required],
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
}
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "User Input/Question content"},
|
||||
**parameters,
|
||||
},
|
||||
"required": ["query", *required],
|
||||
}
|
||||
|
||||
@property
|
||||
def capabilities(self):
|
||||
return types.ServerCapabilities(
|
||||
tools=types.ToolsCapability(listChanged=False),
|
||||
)
|
||||
|
||||
def response(self, response: types.Result | str):
|
||||
if isinstance(response, str):
|
||||
sse_content = f"event: ping\ndata: {response}\n\n".encode()
|
||||
yield sse_content
|
||||
return
|
||||
json_response = types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=(self.request.root.model_extra or {}).get("id", 1),
|
||||
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
json_data = json.dumps(jsonable_encoder(json_response))
|
||||
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Prepare arguments based on app mode"""
|
||||
if app.mode == AppMode.WORKFLOW.value:
|
||||
return {"inputs": arguments}
|
||||
elif app.mode == AppMode.COMPLETION.value:
|
||||
return {"query": "", "inputs": arguments}
|
||||
else:
|
||||
# Chat modes - create a copy to avoid modifying original dict
|
||||
args_copy = arguments.copy()
|
||||
query = args_copy.pop("query", "")
|
||||
return {"query": query, "inputs": args_copy}
|
||||
|
||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
||||
|
||||
yield sse_content
|
||||
def extract_answer_from_response(app: App, response: Any) -> str:
|
||||
"""Extract answer from app generate response"""
|
||||
answer = ""
|
||||
|
||||
def error_response(self, code: int, message: str, data=None):
|
||||
request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
|
||||
return create_mcp_error_response(request_id, code, message, data)
|
||||
if isinstance(response, RateLimitGenerator):
|
||||
answer = process_streaming_response(response)
|
||||
elif isinstance(response, Mapping):
|
||||
answer = process_mapping_response(app, response)
|
||||
else:
|
||||
logger.warning("Unexpected response type: %s", type(response))
|
||||
|
||||
def handle(self):
|
||||
handle_map = {
|
||||
types.InitializeRequest: self.initialize,
|
||||
types.ListToolsRequest: self.list_tools,
|
||||
types.CallToolRequest: self.invoke_tool,
|
||||
types.InitializedNotification: self.handle_notification,
|
||||
types.PingRequest: self.handle_ping,
|
||||
}
|
||||
try:
|
||||
if self.request_type in handle_map:
|
||||
return self.response(handle_map[self.request_type]())
|
||||
else:
|
||||
return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
|
||||
except ValueError as e:
|
||||
logger.exception("Invalid params")
|
||||
return self.error_response(INVALID_PARAMS, str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Internal server error")
|
||||
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
|
||||
return answer
|
||||
|
||||
def handle_notification(self):
|
||||
return "ping"
|
||||
|
||||
def handle_ping(self):
|
||||
return types.EmptyResult()
|
||||
|
||||
def initialize(self):
|
||||
request = cast(types.InitializeRequest, self.request.root)
|
||||
client_info = request.params.clientInfo
|
||||
client_name = f"{client_info.name}@{client_info.version}"
|
||||
if not self.end_user:
|
||||
end_user = EndUser(
|
||||
tenant_id=self.app.tenant_id,
|
||||
app_id=self.app.id,
|
||||
type="mcp",
|
||||
name=client_name,
|
||||
session_id=generate_session_id(),
|
||||
external_user_id=self.mcp_server.id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
return types.InitializeResult(
|
||||
protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION,
|
||||
capabilities=self.capabilities,
|
||||
serverInfo=types.Implementation(name="Dify", version=dify_config.project.version),
|
||||
instructions=self.mcp_server.description,
|
||||
)
|
||||
|
||||
def list_tools(self):
|
||||
if not self.end_user:
|
||||
raise ValueError("User not found")
|
||||
return types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name=self.app.name,
|
||||
description=self.mcp_server.description,
|
||||
inputSchema=self.parameter_schema,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def invoke_tool(self):
|
||||
if not self.end_user:
|
||||
raise ValueError("User not found")
|
||||
request = cast(types.CallToolRequest, self.request.root)
|
||||
args = request.params.arguments or {}
|
||||
if self.app.mode in {AppMode.WORKFLOW.value}:
|
||||
args = {"inputs": args}
|
||||
elif self.app.mode in {AppMode.COMPLETION.value}:
|
||||
args = {"query": "", "inputs": args}
|
||||
else:
|
||||
args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
|
||||
response = AppGenerateService.generate(
|
||||
self.app,
|
||||
self.end_user,
|
||||
args,
|
||||
InvokeFrom.SERVICE_API,
|
||||
streaming=self.app.mode == AppMode.AGENT_CHAT.value,
|
||||
)
|
||||
answer = ""
|
||||
if isinstance(response, RateLimitGenerator):
|
||||
for item in response.generator:
|
||||
data = item
|
||||
if isinstance(data, str) and data.startswith("data: "):
|
||||
try:
|
||||
json_str = data[6:].strip()
|
||||
parsed_data = json.loads(json_str)
|
||||
if parsed_data.get("event") == "agent_thought":
|
||||
answer += parsed_data.get("thought", "")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(response, Mapping):
|
||||
if self.app.mode in {
|
||||
AppMode.ADVANCED_CHAT.value,
|
||||
AppMode.COMPLETION.value,
|
||||
AppMode.CHAT.value,
|
||||
AppMode.AGENT_CHAT.value,
|
||||
}:
|
||||
answer = response["answer"]
|
||||
elif self.app.mode in {AppMode.WORKFLOW.value}:
|
||||
answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
# Not support image yet
|
||||
return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
|
||||
|
||||
def retrieve_end_user(self):
|
||||
return (
|
||||
db.session.query(EndUser)
|
||||
.where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
|
||||
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
|
||||
parameters: dict[str, dict[str, Any]] = {}
|
||||
required = []
|
||||
for item in user_input_form:
|
||||
parameters[item.variable] = {}
|
||||
if item.type in (
|
||||
VariableEntityType.FILE,
|
||||
VariableEntityType.FILE_LIST,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
):
|
||||
continue
|
||||
if item.required:
|
||||
required.append(item.variable)
|
||||
# if the workflow republished, the parameters not changed
|
||||
# we should not raise error here
|
||||
def process_streaming_response(response: RateLimitGenerator) -> str:
|
||||
"""Process streaming response for agent chat mode"""
|
||||
answer = ""
|
||||
for item in response.generator:
|
||||
if isinstance(item, str) and item.startswith("data: "):
|
||||
try:
|
||||
description = self.mcp_server.parameters_dict[item.variable]
|
||||
except KeyError:
|
||||
description = ""
|
||||
parameters[item.variable]["description"] = description
|
||||
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
parameters[item.variable]["type"] = "string"
|
||||
elif item.type == VariableEntityType.SELECT:
|
||||
parameters[item.variable]["type"] = "string"
|
||||
parameters[item.variable]["enum"] = item.options
|
||||
elif item.type == VariableEntityType.NUMBER:
|
||||
parameters[item.variable]["type"] = "float"
|
||||
return parameters, required
|
||||
json_str = item[6:].strip()
|
||||
parsed_data = json.loads(json_str)
|
||||
if parsed_data.get("event") == "agent_thought":
|
||||
answer += parsed_data.get("thought", "")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return answer
|
||||
|
||||
|
||||
def process_mapping_response(app: App, response: Mapping) -> str:
|
||||
"""Process mapping response based on app mode"""
|
||||
if app.mode in {
|
||||
AppMode.ADVANCED_CHAT.value,
|
||||
AppMode.COMPLETION.value,
|
||||
AppMode.CHAT.value,
|
||||
AppMode.AGENT_CHAT.value,
|
||||
}:
|
||||
return response.get("answer", "")
|
||||
elif app.mode == AppMode.WORKFLOW.value:
|
||||
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError("Invalid app mode: " + str(app.mode))
|
||||
|
||||
|
||||
def convert_input_form_to_parameters(
|
||||
user_input_form: list[VariableEntity],
|
||||
parameters_dict: dict[str, str],
|
||||
) -> tuple[dict[str, dict[str, Any]], list[str]]:
|
||||
"""Convert user input form to parameter schema"""
|
||||
parameters: dict[str, dict[str, Any]] = {}
|
||||
required = []
|
||||
|
||||
for item in user_input_form:
|
||||
if item.type in (
|
||||
VariableEntityType.FILE,
|
||||
VariableEntityType.FILE_LIST,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
):
|
||||
continue
|
||||
parameters[item.variable] = {}
|
||||
if item.required:
|
||||
required.append(item.variable)
|
||||
# if the workflow republished, the parameters not changed
|
||||
# we should not raise error here
|
||||
description = parameters_dict.get(item.variable, "")
|
||||
parameters[item.variable]["description"] = description
|
||||
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
parameters[item.variable]["type"] = "string"
|
||||
elif item.type == VariableEntityType.SELECT:
|
||||
parameters[item.variable]["type"] = "string"
|
||||
parameters[item.variable]["enum"] = item.options
|
||||
elif item.type == VariableEntityType.NUMBER:
|
||||
parameters[item.variable]["type"] = "float"
|
||||
return parameters, required
|
||||
|
||||
@ -138,5 +138,5 @@ def create_mcp_error_response(
|
||||
error=error_data,
|
||||
)
|
||||
json_data = json.dumps(jsonable_encoder(json_response))
|
||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
||||
sse_content = json_data.encode()
|
||||
yield sse_content
|
||||
|
||||
@ -31,6 +31,65 @@ class TokenBufferMemory:
|
||||
self.conversation = conversation
|
||||
self.model_instance = model_instance
|
||||
|
||||
def _build_prompt_message_with_files(
|
||||
self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool
|
||||
) -> PromptMessage:
|
||||
"""
|
||||
Build prompt message with files.
|
||||
:param message_files: list of MessageFile objects
|
||||
:param text_content: text content of the message
|
||||
:param message: Message object
|
||||
:param app_record: app record
|
||||
:param is_user_message: whether this is a user message
|
||||
:return: PromptMessage
|
||||
"""
|
||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow_run = db.session.scalar(select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id))
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
else:
|
||||
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
||||
|
||||
detail = ImagePromptMessageContent.DETAIL.HIGH
|
||||
if file_extra_config and app_record:
|
||||
# Build files directly without filtering by belongs_to
|
||||
file_objs = [
|
||||
file_factory.build_from_message_file(
|
||||
message_file=message_file, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||
)
|
||||
for message_file in message_files
|
||||
]
|
||||
if file_extra_config.image_config and file_extra_config.image_config.detail:
|
||||
detail = file_extra_config.image_config.detail
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
if not file_objs:
|
||||
if is_user_message:
|
||||
return UserPromptMessage(content=text_content)
|
||||
else:
|
||||
return AssistantPromptMessage(content=text_content)
|
||||
else:
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in file_objs:
|
||||
prompt_message = file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=detail,
|
||||
)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=text_content))
|
||||
|
||||
if is_user_message:
|
||||
return UserPromptMessage(content=prompt_message_contents)
|
||||
else:
|
||||
return AssistantPromptMessage(content=prompt_message_contents)
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
@ -67,52 +126,46 @@ class TokenBufferMemory:
|
||||
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
|
||||
if files:
|
||||
file_extra_config = None
|
||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow_run = db.session.scalar(
|
||||
select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id)
|
||||
)
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
else:
|
||||
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
||||
|
||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
if file_extra_config and app_record:
|
||||
file_objs = file_factory.build_from_message_files(
|
||||
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||
)
|
||||
if file_extra_config.image_config and file_extra_config.image_config.detail:
|
||||
detail = file_extra_config.image_config.detail
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
if not file_objs:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
else:
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in file_objs:
|
||||
prompt_message = file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=detail,
|
||||
)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
# Process user message with files
|
||||
user_files = (
|
||||
db.session.query(MessageFile)
|
||||
.where(
|
||||
MessageFile.message_id == message.id,
|
||||
(MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if user_files:
|
||||
user_prompt_message = self._build_prompt_message_with_files(
|
||||
message_files=user_files,
|
||||
text_content=message.query,
|
||||
message=message,
|
||||
app_record=app_record,
|
||||
is_user_message=True,
|
||||
)
|
||||
prompt_messages.append(user_prompt_message)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
|
||||
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
||||
# Process assistant message with files
|
||||
assistant_files = (
|
||||
db.session.query(MessageFile)
|
||||
.where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
|
||||
.all()
|
||||
)
|
||||
|
||||
if assistant_files:
|
||||
assistant_prompt_message = self._build_prompt_message_with_files(
|
||||
message_files=assistant_files,
|
||||
text_content=message.answer,
|
||||
message=message,
|
||||
app_record=app_record,
|
||||
is_user_message=False,
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message)
|
||||
else:
|
||||
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
||||
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
@ -87,6 +87,7 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
|
||||
url: str = Field(default="", description="the url of multi-modal file")
|
||||
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
|
||||
filename: str = Field(default="", description="the filename of multi-modal file")
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
|
||||
@ -43,7 +43,7 @@ class GPT2Tokenizer:
|
||||
except Exception:
|
||||
from os.path import abspath, dirname, join
|
||||
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
|
||||
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||
|
||||
@ -330,7 +330,7 @@ class OpsTraceManager:
|
||||
except KeyError:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
else:
|
||||
if tracing_provider is not None:
|
||||
if tracing_provider is None:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
|
||||
app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first()
|
||||
|
||||
@ -375,16 +375,16 @@ Here is the extra instruction you need to follow:
|
||||
|
||||
# merge lines into messages with max tokens
|
||||
messages: list[str] = []
|
||||
for i in new_lines: # type: ignore
|
||||
for line in new_lines:
|
||||
if len(messages) == 0:
|
||||
messages.append(i) # type: ignore
|
||||
messages.append(line)
|
||||
else:
|
||||
if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
|
||||
messages[-1] += i # type: ignore
|
||||
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
|
||||
messages.append(i) # type: ignore
|
||||
if len(messages[-1]) + len(line) < max_tokens * 0.5:
|
||||
messages[-1] += line
|
||||
if get_prompt_tokens(messages[-1] + line) > max_tokens * 0.7:
|
||||
messages.append(line)
|
||||
else:
|
||||
messages[-1] += i # type: ignore
|
||||
messages[-1] += line
|
||||
|
||||
summaries = []
|
||||
for i in range(len(messages)):
|
||||
|
||||
@ -24,7 +24,7 @@ default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
@ -256,7 +256,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
if match.score >= score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
@ -293,7 +293,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
if match.score >= score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
|
||||
@ -3,8 +3,8 @@ import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.models.document import Document
|
||||
@ -229,7 +229,7 @@ class AnalyticdbVectorBySql:
|
||||
documents = []
|
||||
for record in cur:
|
||||
id, vector, score, page_content, metadata = record
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
|
||||
@ -157,7 +157,7 @@ class BaiduVector(BaseVector):
|
||||
if meta is not None:
|
||||
meta = json.loads(meta)
|
||||
score = row.get("score", 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
|
||||
docs.append(doc)
|
||||
|
||||
@ -120,7 +120,7 @@ class ChromaVector(BaseVector):
|
||||
distance = distances[index]
|
||||
metadata = dict(metadatas[index])
|
||||
score = 1 - distance
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=documents[index],
|
||||
|
||||
@ -304,7 +304,7 @@ class CouchbaseVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 2)
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
try:
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
||||
search_iter = self._scope.search(
|
||||
|
||||
@ -216,7 +216,7 @@ class ElasticSearchVector(BaseVector):
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
@ -127,7 +127,7 @@ class HuaweiCloudVector(BaseVector):
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
@ -275,7 +275,7 @@ class LindormVectorStore(BaseVector):
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
@ -3,8 +3,8 @@ import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
@ -194,7 +194,7 @@ class OpenGauss(BaseVector):
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ class OpenSearchConfig(BaseModel):
|
||||
return values
|
||||
|
||||
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
|
||||
import boto3 # type: ignore
|
||||
import boto3
|
||||
|
||||
return Urllib3AWSV4SignerAuth(
|
||||
credentials=boto3.Session().get_credentials(),
|
||||
@ -211,7 +211,7 @@ class OpenSearchVector(BaseVector):
|
||||
|
||||
metadata["score"] = hit["_score"]
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if hit["_score"] > score_threshold:
|
||||
if hit["_score"] >= score_threshold:
|
||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
|
||||
@ -261,7 +261,7 @@ class OracleVector(BaseVector):
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
conn.close()
|
||||
return docs
|
||||
|
||||
@ -202,7 +202,7 @@ class PGVectoRS(BaseVector):
|
||||
score = 1 - dis
|
||||
metadata["score"] = score
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
doc = Document(page_content=record.text, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@ -6,8 +6,8 @@ from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.errors
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
@ -195,7 +195,7 @@ class PGVector(BaseVector):
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
||||
@ -3,8 +3,8 @@ import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
@ -170,7 +170,7 @@ class VastbaseVector(BaseVector):
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
||||
@ -369,7 +369,7 @@ class QdrantVector(BaseVector):
|
||||
continue
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
|
||||
@ -233,7 +233,7 @@ class RelytVector(BaseVector):
|
||||
docs = []
|
||||
for document, score in results:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if 1 - score > score_threshold:
|
||||
if 1 - score >= score_threshold:
|
||||
docs.append(document)
|
||||
return docs
|
||||
|
||||
|
||||
@ -300,7 +300,7 @@ class TableStoreVector(BaseVector):
|
||||
)
|
||||
documents = []
|
||||
for search_hit in search_response.search_hits:
|
||||
if search_hit.score > score_threshold:
|
||||
if search_hit.score >= score_threshold:
|
||||
ots_column_map = {}
|
||||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
@ -291,7 +291,7 @@ class TencentVector(BaseVector):
|
||||
score = 1 - result.get("score", 0.0)
|
||||
else:
|
||||
score = result.get("score", 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=result.get(self.field_text), metadata=meta)
|
||||
docs.append(doc)
|
||||
|
||||
@ -351,7 +351,7 @@ class TidbOnQdrantVector(BaseVector):
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
|
||||
@ -110,7 +110,7 @@ class UpstashVector(BaseVector):
|
||||
score = record.score
|
||||
if metadata is not None and text is not None:
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
||||
@ -192,7 +192,7 @@ class VikingDBVector(BaseVector):
|
||||
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
|
||||
if metadata is not None:
|
||||
metadata = json.loads(metadata)
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
@ -220,7 +220,7 @@ class WeaviateVector(BaseVector):
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
# check score threshold
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
@ -4,7 +4,7 @@ import os
|
||||
from typing import Optional, cast
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import load_workbook # type: ignore
|
||||
from openpyxl import load_workbook
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from bs4 import BeautifulSoup # type: ignore
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
@ -3,7 +3,7 @@ import contextlib
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from bs4 import BeautifulSoup # type: ignore
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
@ -123,7 +123,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata["score"] = result.score
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@ -162,7 +162,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata["score"] = result.score
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@ -158,7 +158,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata["score"] = result.score
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@ -65,7 +65,7 @@ default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
@ -647,7 +647,7 @@ class DatasetRetrieval:
|
||||
retrieval_method=retrieval_model["search_method"],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 2,
|
||||
top_k=retrieval_model.get("top_k") or 4,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
@ -743,7 +743,7 @@ class DatasetRetrieval:
|
||||
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||
tenant_id=tenant_id,
|
||||
top_k=retrieve_config.top_k or 2,
|
||||
top_k=retrieve_config.top_k or 4,
|
||||
score_threshold=retrieve_config.score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
|
||||
@ -144,7 +144,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
|
||||
"""Text splitter that uses HuggingFace tokenizer to count length."""
|
||||
try:
|
||||
from transformers import PreTrainedTokenizerBase # type: ignore
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||
raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase")
|
||||
|
||||
@ -181,7 +181,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
retrieval_method="keyword_search",
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 2,
|
||||
top_k=retrieval_model.get("top_k") or 4,
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
@ -192,7 +192,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
retrieval_method=retrieval_model["search_method"],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 2,
|
||||
top_k=retrieval_model.get("top_k") or 4,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
|
||||
@ -13,7 +13,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
|
||||
name: str = "dataset"
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
tenant_id: str
|
||||
top_k: int = 2
|
||||
top_k: int = 4
|
||||
score_threshold: Optional[float] = None
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
return_resource: bool
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
from yaml import YAMLError, safe_load # type: ignore
|
||||
from yaml import YAMLError, safe_load
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
|
||||
@ -166,7 +166,7 @@ class BaseIterationEvent(GraphEngineEvent):
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
"""iteration run in parallel mode run id"""
|
||||
|
||||
|
||||
class IterationRunStartedEvent(BaseIterationEvent):
|
||||
|
||||
@ -149,9 +149,6 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
return []
|
||||
|
||||
stream_output_value_selector = event.from_variable_selector
|
||||
if not stream_output_value_selector:
|
||||
return []
|
||||
|
||||
stream_out_answer_node_ids = []
|
||||
for answer_node_id, route_position in self.route_position.items():
|
||||
if answer_node_id not in self.rest_node_ids:
|
||||
|
||||
@ -515,14 +515,14 @@ def _extract_text_from_excel(file_content: bytes) -> str:
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
# Combine multi-line text in each cell into a single line
|
||||
df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore
|
||||
df = df.map(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x)
|
||||
|
||||
# Combine multi-line text in column names into a single line
|
||||
df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns])
|
||||
|
||||
# Manually construct the Markdown table
|
||||
markdown_table += _construct_markdown_table(df) + "\n\n"
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
continue
|
||||
return markdown_table
|
||||
except Exception as e:
|
||||
|
||||
@ -78,7 +78,7 @@ default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from dify_app import DifyApp
|
||||
def init_app(app: DifyApp):
|
||||
# register blueprint routers
|
||||
|
||||
from flask_cors import CORS # type: ignore
|
||||
from flask_cors import CORS
|
||||
|
||||
from controllers.console import bp as console_app_bp
|
||||
from controllers.files import bp as files_bp
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import Union
|
||||
|
||||
import flask
|
||||
from celery.signals import worker_init
|
||||
from flask_login import user_loaded_from_request, user_logged_in # type: ignore
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
import boto3 # type: ignore
|
||||
from botocore.client import Config # type: ignore
|
||||
from botocore.exceptions import ClientError # type: ignore
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
@ -41,8 +41,14 @@ def build_from_message_file(
|
||||
"url": message_file.url,
|
||||
"id": message_file.id,
|
||||
"type": message_file.type,
|
||||
"upload_file_id": message_file.upload_file_id,
|
||||
}
|
||||
|
||||
# Set the correct ID field based on transfer method
|
||||
if message_file.transfer_method == FileTransferMethod.TOOL_FILE.value:
|
||||
mapping["tool_file_id"] = message_file.upload_file_id
|
||||
else:
|
||||
mapping["upload_file_id"] = message_file.upload_file_id
|
||||
|
||||
return build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
@ -318,6 +324,11 @@ def _is_file_valid_with_config(
|
||||
file_transfer_method: FileTransferMethod,
|
||||
config: FileUploadConfig,
|
||||
) -> bool:
|
||||
# FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model)
|
||||
# These are internally generated and should bypass user upload restrictions
|
||||
if file_transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
return True
|
||||
|
||||
if (
|
||||
config.allowed_file_types
|
||||
and input_file_type not in config.allowed_file_types
|
||||
|
||||
11
api/libs/orjson.py
Normal file
11
api/libs/orjson.py
Normal file
@ -0,0 +1,11 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import orjson
|
||||
|
||||
|
||||
def orjson_dumps(
|
||||
obj: Any,
|
||||
encoding: str = "utf-8",
|
||||
option: Optional[int] = None,
|
||||
) -> str:
|
||||
return orjson.dumps(obj, option=option).decode(encoding)
|
||||
@ -5,7 +5,7 @@ Revises: 8bcc02c9bd07
|
||||
Create Date: 2025-08-09 15:53:54.341341
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from alembic import op, context
|
||||
from libs.uuid_utils import uuidv7
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
@ -43,7 +43,15 @@ def upgrade():
|
||||
with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True))
|
||||
|
||||
migrate_existing_providers_data()
|
||||
if not context.is_offline_mode():
|
||||
migrate_existing_providers_data()
|
||||
else:
|
||||
op.execute(
|
||||
'-- [IMPORTANT] Data migration skipped!!!\n'
|
||||
"-- You should manually run data migration function `migrate_existing_providers_data`\n"
|
||||
f"-- inside file {__file__}\n"
|
||||
"-- Please review the migration script carefully!"
|
||||
)
|
||||
|
||||
# Remove encrypted_config column from providers table after migration
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
@ -119,7 +127,16 @@ def downgrade():
|
||||
batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True))
|
||||
|
||||
# Migrate data back from provider_credentials to providers
|
||||
migrate_data_back_to_providers()
|
||||
|
||||
if not context.is_offline_mode():
|
||||
migrate_data_back_to_providers()
|
||||
else:
|
||||
op.execute(
|
||||
'-- [IMPORTANT] Data migration skipped!!!\n'
|
||||
"-- You should manually run data migration function `migrate_data_back_to_providers`\n"
|
||||
f"-- inside file {__file__}\n"
|
||||
"-- Please review the migration script carefully!"
|
||||
)
|
||||
|
||||
# Remove credential_id columns
|
||||
with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
|
||||
|
||||
@ -6,7 +6,7 @@ Create Date: 2025-08-13 16:05:42.657730
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from alembic import op, context
|
||||
from libs.uuid_utils import uuidv7
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
@ -48,8 +48,16 @@ def upgrade():
|
||||
with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('credential_source_type', sa.String(length=40), nullable=True))
|
||||
|
||||
# Migrate existing provider_models data
|
||||
migrate_existing_provider_models_data()
|
||||
if not context.is_offline_mode():
|
||||
# Migrate existing provider_models data
|
||||
migrate_existing_provider_models_data()
|
||||
else:
|
||||
op.execute(
|
||||
'-- [IMPORTANT] Data migration skipped!!!\n'
|
||||
"-- You should manually run data migration function `migrate_existing_provider_models_data`\n"
|
||||
f"-- inside file {__file__}\n"
|
||||
"-- Please review the migration script carefully!"
|
||||
)
|
||||
|
||||
# Remove encrypted_config column from provider_models table after migration
|
||||
with op.batch_alter_table('provider_models', schema=None) as batch_op:
|
||||
@ -132,8 +140,16 @@ def downgrade():
|
||||
with op.batch_alter_table('provider_models', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True))
|
||||
|
||||
# Migrate data back from provider_model_credentials to provider_models
|
||||
migrate_data_back_to_provider_models()
|
||||
if not context.is_offline_mode():
|
||||
# Migrate data back from provider_model_credentials to provider_models
|
||||
migrate_data_back_to_provider_models()
|
||||
else:
|
||||
op.execute(
|
||||
'-- [IMPORTANT] Data migration skipped!!!\n'
|
||||
"-- You should manually run data migration function `migrate_data_back_to_provider_models`\n"
|
||||
f"-- inside file {__file__}\n"
|
||||
"-- Please review the migration script carefully!"
|
||||
)
|
||||
|
||||
with op.batch_alter_table('provider_models', schema=None) as batch_op:
|
||||
batch_op.drop_column('credential_id')
|
||||
|
||||
@ -0,0 +1,45 @@
|
||||
"""empty message
|
||||
|
||||
Revision ID: 8d289573e1da
|
||||
Revises: 0e154742a5fa
|
||||
Create Date: 2025-08-20 17:47:17.015695
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '8d289573e1da'
|
||||
down_revision = '0e154742a5fa'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('oauth_provider_apps',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
sa.Column('app_icon', sa.String(length=255), nullable=False),
|
||||
sa.Column('app_label', sa.JSON(), server_default='{}', nullable=False),
|
||||
sa.Column('client_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('client_secret', sa.String(length=255), nullable=False),
|
||||
sa.Column('redirect_uris', sa.JSON(), server_default='[]', nullable=False),
|
||||
sa.Column('scope', sa.String(length=255), server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='oauth_provider_app_pkey')
|
||||
)
|
||||
with op.batch_alter_table('oauth_provider_apps', schema=None) as batch_op:
|
||||
batch_op.create_index('oauth_provider_app_client_id_idx', ['client_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('oauth_provider_apps', schema=None) as batch_op:
|
||||
batch_op.drop_index('oauth_provider_app_client_id_idx')
|
||||
|
||||
op.drop_table('oauth_provider_apps')
|
||||
# ### end Alembic commands ###
|
||||
@ -1,12 +1,12 @@
|
||||
import enum
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_login import UserMixin # type: ignore
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.orm import Mapped, mapped_column, reconstructor
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor
|
||||
|
||||
from models.base import Base
|
||||
|
||||
@ -118,10 +118,24 @@ class Account(UserMixin, Base):
|
||||
|
||||
@current_tenant.setter
|
||||
def current_tenant(self, tenant: "Tenant"):
|
||||
ta = db.session.scalar(select(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).limit(1))
|
||||
if ta:
|
||||
self.role = TenantAccountRole(ta.role)
|
||||
self._current_tenant = tenant
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
tenant_join_query = select(TenantAccountJoin).where(
|
||||
TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == self.id
|
||||
)
|
||||
tenant_join = session.scalar(tenant_join_query)
|
||||
tenant_query = select(Tenant).where(Tenant.id == tenant.id)
|
||||
# TODO: A workaround to reload the tenant with `expire_on_commit=False`, allowing
|
||||
# access to it after the session has been closed.
|
||||
# This prevents `DetachedInstanceError` when accessing the tenant outside
|
||||
# the session's lifecycle.
|
||||
# (The `tenant` argument is typically loaded by `db.session` without the
|
||||
# `expire_on_commit=False` flag, meaning its lifetime is tied to the web
|
||||
# request's lifecycle.)
|
||||
tenant_reloaded = session.scalars(tenant_query).one()
|
||||
|
||||
if tenant_join:
|
||||
self.role = TenantAccountRole(tenant_join.role)
|
||||
self._current_tenant = tenant_reloaded
|
||||
return
|
||||
self._current_tenant = None
|
||||
|
||||
@ -130,23 +144,19 @@ class Account(UserMixin, Base):
|
||||
return self._current_tenant.id if self._current_tenant else None
|
||||
|
||||
def set_tenant_id(self, tenant_id: str):
|
||||
tenant_account_join = cast(
|
||||
tuple[Tenant, TenantAccountJoin],
|
||||
(
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == tenant_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.account_id == self.id)
|
||||
.one_or_none()
|
||||
),
|
||||
query = (
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == tenant_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.account_id == self.id)
|
||||
)
|
||||
|
||||
if not tenant_account_join:
|
||||
return
|
||||
|
||||
tenant, join = tenant_account_join
|
||||
self.role = TenantAccountRole(join.role)
|
||||
self._current_tenant = tenant
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
tenant_account_join = session.execute(query).first()
|
||||
if not tenant_account_join:
|
||||
return
|
||||
tenant, join = tenant_account_join
|
||||
self.role = TenantAccountRole(join.role)
|
||||
self._current_tenant = tenant
|
||||
|
||||
@property
|
||||
def current_role(self):
|
||||
|
||||
@ -522,33 +522,6 @@ class AppModelConfig(Base):
|
||||
self.file_upload = json.dumps(model_config.get("file_upload")) if model_config.get("file_upload") else None
|
||||
return self
|
||||
|
||||
def copy(self):
|
||||
new_app_model_config = AppModelConfig(
|
||||
id=self.id,
|
||||
app_id=self.app_id,
|
||||
opening_statement=self.opening_statement,
|
||||
suggested_questions=self.suggested_questions,
|
||||
suggested_questions_after_answer=self.suggested_questions_after_answer,
|
||||
speech_to_text=self.speech_to_text,
|
||||
text_to_speech=self.text_to_speech,
|
||||
more_like_this=self.more_like_this,
|
||||
sensitive_word_avoidance=self.sensitive_word_avoidance,
|
||||
external_data_tools=self.external_data_tools,
|
||||
model=self.model,
|
||||
user_input_form=self.user_input_form,
|
||||
dataset_query_variable=self.dataset_query_variable,
|
||||
pre_prompt=self.pre_prompt,
|
||||
agent_mode=self.agent_mode,
|
||||
retriever_resource=self.retriever_resource,
|
||||
prompt_type=self.prompt_type,
|
||||
chat_prompt_config=self.chat_prompt_config,
|
||||
completion_prompt_config=self.completion_prompt_config,
|
||||
dataset_configs=self.dataset_configs,
|
||||
file_upload=self.file_upload,
|
||||
)
|
||||
|
||||
return new_app_model_config
|
||||
|
||||
|
||||
class RecommendedApp(Base):
|
||||
__tablename__ = "recommended_apps"
|
||||
@ -607,6 +580,32 @@ class InstalledApp(Base):
|
||||
return tenant
|
||||
|
||||
|
||||
class OAuthProviderApp(Base):
|
||||
"""
|
||||
Globally shared OAuth provider app information.
|
||||
Only for Dify Cloud.
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_provider_apps"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="oauth_provider_app_pkey"),
|
||||
sa.Index("oauth_provider_app_client_id_idx", "client_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
app_icon = mapped_column(String(255), nullable=False)
|
||||
app_label = mapped_column(sa.JSON, nullable=False, server_default="{}")
|
||||
client_id = mapped_column(String(255), nullable=False)
|
||||
client_secret = mapped_column(String(255), nullable=False)
|
||||
redirect_uris = mapped_column(sa.JSON, nullable=False, server_default="[]")
|
||||
scope = mapped_column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"),
|
||||
)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
||||
|
||||
class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
__table_args__ = (
|
||||
|
||||
@ -67,7 +67,7 @@ dependencies = [
|
||||
"pydantic~=2.11.4",
|
||||
"pydantic-extra-types~=2.10.3",
|
||||
"pydantic-settings~=2.9.1",
|
||||
"pyjwt~=2.8.0",
|
||||
"pyjwt~=2.10.1",
|
||||
"pypdfium2==4.30.0",
|
||||
"python-docx~=1.1.0",
|
||||
"python-dotenv==1.0.1",
|
||||
@ -179,7 +179,7 @@ storage = [
|
||||
"google-cloud-storage==2.16.0",
|
||||
"opendal~=0.45.16",
|
||||
"oss2==2.18.5",
|
||||
"supabase~=2.8.1",
|
||||
"supabase~=2.18.1",
|
||||
"tos~=2.7.1",
|
||||
]
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ def check_upgradable_plugin_task():
|
||||
|
||||
strategies = (
|
||||
db.session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.filter(
|
||||
.where(
|
||||
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day >= now_seconds_of_day,
|
||||
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day
|
||||
< now_seconds_of_day + AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL,
|
||||
|
||||
@ -93,7 +93,7 @@ def _delete_batch_with_retry(workflow_run_ids: list[str], attempt_count: int) ->
|
||||
with db.session.begin_nested():
|
||||
message_data = (
|
||||
db.session.query(Message.id, Message.conversation_id)
|
||||
.filter(Message.workflow_run_id.in_(workflow_run_ids))
|
||||
.where(Message.workflow_run_id.in_(workflow_run_ids))
|
||||
.all()
|
||||
)
|
||||
message_id_list = [msg.id for msg in message_data]
|
||||
|
||||
@ -282,7 +282,7 @@ class AppAnnotationService:
|
||||
annotations_to_delete = (
|
||||
db.session.query(MessageAnnotation, AppAnnotationSetting)
|
||||
.outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id)
|
||||
.filter(MessageAnnotation.id.in_(annotation_ids))
|
||||
.where(MessageAnnotation.id.in_(annotation_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
@ -493,7 +493,7 @@ class AppAnnotationService:
|
||||
def clear_all_annotations(cls, app_id: str) -> dict:
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ class AppGenerateService:
|
||||
cls.system_rate_limiter.increment_rate_limit(app_model.tenant_id)
|
||||
|
||||
# app level rate limiter
|
||||
max_active_request = AppGenerateService._get_max_active_requests(app_model)
|
||||
max_active_request = cls._get_max_active_requests(app_model)
|
||||
rate_limit = RateLimit(app_model.id, max_active_request)
|
||||
request_id = RateLimit.gen_request_key()
|
||||
try:
|
||||
|
||||
@ -62,7 +62,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
# Query records related to expired messages
|
||||
records = (
|
||||
session.query(model)
|
||||
.filter(
|
||||
.where(
|
||||
model.message_id.in_(batch_message_ids), # type: ignore
|
||||
)
|
||||
.all()
|
||||
@ -101,7 +101,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
except Exception:
|
||||
logger.exception("Failed to save %s records", table_name)
|
||||
|
||||
session.query(model).filter(
|
||||
session.query(model).where(
|
||||
model.id.in_(record_ids), # type: ignore
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
@ -295,7 +295,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
workflow_app_logs = (
|
||||
session.query(WorkflowAppLog)
|
||||
.filter(
|
||||
.where(
|
||||
WorkflowAppLog.tenant_id == tenant_id,
|
||||
WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
@ -321,9 +321,9 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs]
|
||||
|
||||
# delete workflow app logs
|
||||
session.query(WorkflowAppLog).filter(
|
||||
WorkflowAppLog.id.in_(workflow_app_log_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
|
||||
click.echo(
|
||||
|
||||
@ -1149,7 +1149,7 @@ class DocumentService:
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
@ -1612,7 +1612,7 @@ class DocumentService:
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
reranking_enable=False,
|
||||
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
|
||||
top_k=2,
|
||||
top_k=4,
|
||||
score_threshold_enabled=False,
|
||||
)
|
||||
# save dataset
|
||||
@ -2346,7 +2346,7 @@ class SegmentService:
|
||||
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
|
||||
segments = (
|
||||
db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count)
|
||||
.filter(
|
||||
.where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
|
||||
@ -18,7 +18,7 @@ default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
@ -66,7 +66,7 @@ class HitTestingService:
|
||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k", 2),
|
||||
top_k=retrieval_model.get("top_k", 4),
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
|
||||
94
api/services/oauth_server.py
Normal file
94
api/services/oauth_server.py
Normal file
@ -0,0 +1,94 @@
|
||||
import enum
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
class OAuthGrantType(enum.StrEnum):
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
REFRESH_TOKEN = "refresh_token"
|
||||
|
||||
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY = "oauth_provider:{client_id}:authorization_code:{code}"
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:access_token:{token}"
|
||||
OAUTH_ACCESS_TOKEN_EXPIRES_IN = 60 * 60 * 12 # 12 hours
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:refresh_token:{token}"
|
||||
OAUTH_REFRESH_TOKEN_EXPIRES_IN = 60 * 60 * 24 * 30 # 30 days
|
||||
|
||||
|
||||
class OAuthServerService:
|
||||
@staticmethod
|
||||
def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
|
||||
query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
return session.execute(query).scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
def sign_oauth_authorization_code(client_id: str, user_account_id: str) -> str:
|
||||
code = str(uuid.uuid4())
|
||||
redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
|
||||
redis_client.set(redis_key, user_account_id, ex=60 * 10) # 10 minutes
|
||||
return code
|
||||
|
||||
@staticmethod
|
||||
def sign_oauth_access_token(
|
||||
grant_type: OAuthGrantType,
|
||||
code: str = "",
|
||||
client_id: str = "",
|
||||
refresh_token: str = "",
|
||||
) -> tuple[str, str]:
|
||||
match grant_type:
|
||||
case OAuthGrantType.AUTHORIZATION_CODE:
|
||||
redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
|
||||
user_account_id = redis_client.get(redis_key)
|
||||
if not user_account_id:
|
||||
raise BadRequest("invalid code")
|
||||
|
||||
# delete code
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
|
||||
refresh_token = OAuthServerService._sign_oauth_refresh_token(client_id, user_account_id)
|
||||
return access_token, refresh_token
|
||||
case OAuthGrantType.REFRESH_TOKEN:
|
||||
redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=refresh_token)
|
||||
user_account_id = redis_client.get(redis_key)
|
||||
if not user_account_id:
|
||||
raise BadRequest("invalid refresh token")
|
||||
|
||||
access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
|
||||
return access_token, refresh_token
|
||||
|
||||
@staticmethod
|
||||
def _sign_oauth_access_token(client_id: str, user_account_id: str) -> str:
|
||||
token = str(uuid.uuid4())
|
||||
redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
|
||||
redis_client.set(redis_key, user_account_id, ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN)
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def _sign_oauth_refresh_token(client_id: str, user_account_id: str) -> str:
|
||||
token = str(uuid.uuid4())
|
||||
redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
|
||||
redis_client.set(redis_key, user_account_id, ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN)
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def validate_oauth_access_token(client_id: str, token: str) -> Account | None:
|
||||
redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
|
||||
user_account_id = redis_client.get(redis_key)
|
||||
if not user_account_id:
|
||||
return None
|
||||
|
||||
user_id_str = user_account_id.decode("utf-8")
|
||||
|
||||
return AccountService.load_user(user_id_str)
|
||||
@ -10,7 +10,7 @@ class PluginAutoUpgradeService:
|
||||
with Session(db.engine) as session:
|
||||
return (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -26,7 +26,7 @@ class PluginAutoUpgradeService:
|
||||
with Session(db.engine) as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not exist_strategy:
|
||||
@ -54,7 +54,7 @@ class PluginAutoUpgradeService:
|
||||
with Session(db.engine) as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not exist_strategy:
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from celery import shared_task
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import time
|
||||
|
||||
import psycopg2 # type: ignore
|
||||
import psycopg2
|
||||
|
||||
from core.rag.datasource.vdb.opengauss.opengauss import OpenGauss, OpenGaussConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
|
||||
@ -674,7 +674,7 @@ class TestAnnotationService:
|
||||
|
||||
history = (
|
||||
db.session.query(AppAnnotationHitHistory)
|
||||
.filter(
|
||||
.where(
|
||||
AppAnnotationHitHistory.annotation_id == annotation.id, AppAnnotationHitHistory.message_id == message_id
|
||||
)
|
||||
.first()
|
||||
|
||||
@ -166,7 +166,7 @@ class TestAppDslService:
|
||||
assert result.imported_dsl_version == ""
|
||||
|
||||
# Verify no app was created in database
|
||||
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
|
||||
apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count()
|
||||
assert apps_count == 1 # Only the original test app
|
||||
|
||||
def test_import_app_missing_yaml_url(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
@ -191,7 +191,7 @@ class TestAppDslService:
|
||||
assert result.imported_dsl_version == ""
|
||||
|
||||
# Verify no app was created in database
|
||||
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
|
||||
apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count()
|
||||
assert apps_count == 1 # Only the original test app
|
||||
|
||||
def test_import_app_invalid_import_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
@ -215,7 +215,7 @@ class TestAppDslService:
|
||||
)
|
||||
|
||||
# Verify no app was created in database
|
||||
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
|
||||
apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count()
|
||||
assert apps_count == 1 # Only the original test app
|
||||
|
||||
def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
|
||||
@ -0,0 +1,529 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
|
||||
class TestWorkspaceService:
|
||||
"""Integration tests for WorkspaceService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.workspace_service.FeatureService") as mock_feature_service,
|
||||
patch("services.workspace_service.TenantService") as mock_tenant_service,
|
||||
patch("services.workspace_service.dify_config") as mock_dify_config,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_feature_service.get_features.return_value.can_replace_logo = True
|
||||
mock_tenant_service.has_roles.return_value = True
|
||||
mock_dify_config.FILES_URL = "https://example.com/files"
|
||||
|
||||
yield {
|
||||
"feature_service": mock_feature_service,
|
||||
"tenant_service": mock_tenant_service,
|
||||
"dify_config": mock_dify_config,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant) - Created account and tenant instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
plan="basic",
|
||||
custom_config='{"replace_webapp_logo": true, "remove_webapp_brand": false}',
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join with owner role
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER.value,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def test_get_tenant_info_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of tenant information with all features enabled.
|
||||
|
||||
This test verifies:
|
||||
- Proper tenant info retrieval with all required fields
|
||||
- Correct role assignment from TenantAccountJoin
|
||||
- Custom config handling when features are enabled
|
||||
- Logo replacement functionality for privileged users
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup mocks for feature service
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
|
||||
|
||||
# Mock current_user for flask_login
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
# Act: Execute the method under test
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result["id"] == tenant.id
|
||||
assert result["name"] == tenant.name
|
||||
assert result["plan"] == tenant.plan
|
||||
assert result["status"] == tenant.status
|
||||
assert result["role"] == TenantAccountRole.OWNER.value
|
||||
assert result["created_at"] == tenant.created_at
|
||||
assert result["trial_end_reason"] is None
|
||||
|
||||
# Verify custom config is included for privileged users
|
||||
assert "custom_config" in result
|
||||
assert result["custom_config"]["remove_webapp_brand"] is False
|
||||
assert "replace_webapp_logo" in result["custom_config"]
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_without_custom_config(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval when custom config features are disabled.
|
||||
|
||||
This test verifies:
|
||||
- Tenant info retrieval without custom config when features are disabled
|
||||
- Proper handling of disabled logo replacement functionality
|
||||
- Role assignment still works correctly
|
||||
- Basic tenant information is complete
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup mocks to disable custom config features
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = False
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
# Mock current_user for flask_login
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
# Act: Execute the method under test
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result["id"] == tenant.id
|
||||
assert result["name"] == tenant.name
|
||||
assert result["plan"] == tenant.plan
|
||||
assert result["status"] == tenant.status
|
||||
assert result["role"] == TenantAccountRole.OWNER.value
|
||||
assert result["created_at"] == tenant.created_at
|
||||
assert result["trial_end_reason"] is None
|
||||
|
||||
# Verify custom config is not included when features are disabled
|
||||
assert "custom_config" not in result
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_normal_user_role(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval for normal user role without privileged features.
|
||||
|
||||
This test verifies:
|
||||
- Tenant info retrieval for non-privileged users
|
||||
- Role assignment for normal users
|
||||
- Custom config is not accessible for normal users
|
||||
- Proper handling of different user roles
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Update the join to have normal role
|
||||
from extensions.ext_database import db
|
||||
|
||||
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
join.role = TenantAccountRole.NORMAL.value
|
||||
db.session.commit()
|
||||
|
||||
# Setup mocks for feature service
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
# Mock current_user for flask_login
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
# Act: Execute the method under test
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result["id"] == tenant.id
|
||||
assert result["name"] == tenant.name
|
||||
assert result["plan"] == tenant.plan
|
||||
assert result["status"] == tenant.status
|
||||
assert result["role"] == TenantAccountRole.NORMAL.value
|
||||
assert result["created_at"] == tenant.created_at
|
||||
assert result["trial_end_reason"] is None
|
||||
|
||||
# Verify custom config is not included for normal users
|
||||
assert "custom_config" not in result
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_admin_role_and_logo_replacement(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval for admin role with logo replacement enabled.
|
||||
|
||||
This test verifies:
|
||||
- Admin role can access custom config features
|
||||
- Logo replacement functionality works for admin users
|
||||
- Proper URL construction for logo replacement
|
||||
- Custom config handling for admin role
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Update the join to have admin role
|
||||
from extensions.ext_database import db
|
||||
|
||||
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
join.role = TenantAccountRole.ADMIN.value
|
||||
db.session.commit()
|
||||
|
||||
# Setup mocks for feature service and tenant service
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
|
||||
mock_external_service_dependencies["dify_config"].FILES_URL = "https://cdn.example.com"
|
||||
|
||||
# Mock current_user for flask_login
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
# Act: Execute the method under test
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result["role"] == TenantAccountRole.ADMIN.value
|
||||
|
||||
# Verify custom config is included for admin users
|
||||
assert "custom_config" in result
|
||||
assert result["custom_config"]["remove_webapp_brand"] is False
|
||||
assert "replace_webapp_logo" in result["custom_config"]
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_tenant_none(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test tenant info retrieval when tenant parameter is None.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of None tenant parameter
|
||||
- Method returns None for invalid input
|
||||
- No exceptions are raised for None input
|
||||
- Graceful degradation for invalid data
|
||||
"""
|
||||
# Arrange: No test data needed for this test
|
||||
|
||||
# Act: Execute the method under test with None tenant
|
||||
result = WorkspaceService.get_tenant_info(None)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is None
|
||||
|
||||
def test_get_tenant_info_with_custom_config_variations(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval with various custom config configurations.
|
||||
|
||||
This test verifies:
|
||||
- Different custom config combinations work correctly
|
||||
- Logo replacement URL construction with various configs
|
||||
- Brand removal functionality
|
||||
- Edge cases in custom config handling
|
||||
"""
|
||||
# Arrange: Create test data with different custom configs
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Test different custom config combinations
|
||||
test_configs = [
|
||||
# Case 1: Both logo and brand removal enabled
|
||||
{"replace_webapp_logo": True, "remove_webapp_brand": True},
|
||||
# Case 2: Only logo replacement enabled
|
||||
{"replace_webapp_logo": True, "remove_webapp_brand": False},
|
||||
# Case 3: Only brand removal enabled
|
||||
{"replace_webapp_logo": False, "remove_webapp_brand": True},
|
||||
# Case 4: Neither enabled
|
||||
{"replace_webapp_logo": False, "remove_webapp_brand": False},
|
||||
]
|
||||
|
||||
for config in test_configs:
|
||||
# Update tenant custom config
|
||||
import json
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
tenant.custom_config = json.dumps(config)
|
||||
db.session.commit()
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
|
||||
mock_external_service_dependencies["dify_config"].FILES_URL = "https://files.example.com"
|
||||
|
||||
# Mock current_user for flask_login
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
# Act: Execute the method under test
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "custom_config" in result
|
||||
|
||||
if config["replace_webapp_logo"]:
|
||||
assert "replace_webapp_logo" in result["custom_config"]
|
||||
if config["replace_webapp_logo"]:
|
||||
expected_url = f"https://files.example.com/files/workspaces/{tenant.id}/webapp-logo"
|
||||
assert result["custom_config"]["replace_webapp_logo"] == expected_url
|
||||
else:
|
||||
assert result["custom_config"]["replace_webapp_logo"] is None
|
||||
|
||||
assert result["custom_config"]["remove_webapp_brand"] == config["remove_webapp_brand"]
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_editor_role_and_limited_permissions(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval for editor role with limited permissions.
|
||||
|
||||
This test verifies:
|
||||
- Editor role has limited access to custom config features
|
||||
- Proper role-based permission checking
|
||||
- Custom config handling for different role levels
|
||||
- Role hierarchy and permission boundaries
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Update the join to have editor role
|
||||
from extensions.ext_database import db
|
||||
|
||||
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
join.role = TenantAccountRole.EDITOR.value
|
||||
db.session.commit()
|
||||
|
||||
# Setup mocks for feature service and tenant service
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
# Editor role should not have admin/owner permissions
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
mock_external_service_dependencies["dify_config"].FILES_URL = "https://cdn.example.com"
|
||||
|
||||
# Mock current_user for flask_login
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
# Act: Execute the method under test
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result["role"] == TenantAccountRole.EDITOR.value
|
||||
|
||||
# Verify custom config is not included for editor users without admin privileges
|
||||
assert "custom_config" not in result
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_dataset_operator_role(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval for dataset operator role.
|
||||
|
||||
This test verifies:
|
||||
- Dataset operator role handling
|
||||
- Role assignment for specialized roles
|
||||
- Permission boundaries for dataset operators
|
||||
- Custom config access for dataset operators
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Update the join to have dataset operator role
|
||||
from extensions.ext_database import db
|
||||
|
||||
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
join.role = TenantAccountRole.DATASET_OPERATOR.value
|
||||
db.session.commit()
|
||||
|
||||
# Setup mocks for feature service and tenant service
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
# Dataset operator should not have admin/owner permissions
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
mock_external_service_dependencies["dify_config"].FILES_URL = "https://cdn.example.com"
|
||||
|
||||
# Mock current_user for flask_login
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
# Act: Execute the method under test
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result["role"] == TenantAccountRole.DATASET_OPERATOR.value
|
||||
|
||||
# Verify custom config is not included for dataset operators without admin privileges
|
||||
assert "custom_config" not in result
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_complex_custom_config_scenarios(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval with complex custom config scenarios.
|
||||
|
||||
This test verifies:
|
||||
- Complex custom config combinations
|
||||
- Edge cases in custom config handling
|
||||
- URL construction with various configs
|
||||
- Error handling for malformed configs
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Test complex custom config scenarios
|
||||
test_configs = [
|
||||
# Case 1: Empty custom config
|
||||
{},
|
||||
# Case 2: Custom config with only logo replacement
|
||||
{"replace_webapp_logo": True},
|
||||
# Case 3: Custom config with only brand removal
|
||||
{"remove_webapp_brand": True},
|
||||
# Case 4: Custom config with additional fields
|
||||
{
|
||||
"replace_webapp_logo": True,
|
||||
"remove_webapp_brand": False,
|
||||
"custom_field": "custom_value",
|
||||
"nested_config": {"key": "value"},
|
||||
},
|
||||
# Case 5: Custom config with null values
|
||||
{"replace_webapp_logo": None, "remove_webapp_brand": None},
|
||||
]
|
||||
|
||||
for config in test_configs:
|
||||
# Update tenant custom config
|
||||
import json
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
tenant.custom_config = json.dumps(config)
|
||||
db.session.commit()
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
|
||||
mock_external_service_dependencies["dify_config"].FILES_URL = "https://files.example.com"
|
||||
|
||||
# Mock current_user for flask_login
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
# Act: Execute the method under test
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "custom_config" in result
|
||||
|
||||
# Verify logo replacement handling
|
||||
if config.get("replace_webapp_logo"):
|
||||
assert "replace_webapp_logo" in result["custom_config"]
|
||||
expected_url = f"https://files.example.com/files/workspaces/{tenant.id}/webapp-logo"
|
||||
assert result["custom_config"]["replace_webapp_logo"] == expected_url
|
||||
else:
|
||||
assert result["custom_config"]["replace_webapp_logo"] is None
|
||||
|
||||
# Verify brand removal handling
|
||||
if "remove_webapp_brand" in config:
|
||||
assert result["custom_config"]["remove_webapp_brand"] == config["remove_webapp_brand"]
|
||||
else:
|
||||
assert result["custom_config"]["remove_webapp_brand"] is False
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
@ -0,0 +1,550 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from models.account import Account, Tenant
|
||||
from models.tools import ApiToolProvider
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
|
||||
|
||||
class TestApiToolManageService:
|
||||
"""Integration tests for ApiToolManageService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.tools.api_tools_manage_service.ToolLabelManager") as mock_tool_label_manager,
|
||||
patch("services.tools.api_tools_manage_service.create_tool_provider_encrypter") as mock_encrypter,
|
||||
patch("services.tools.api_tools_manage_service.ApiToolProviderController") as mock_provider_controller,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_tool_label_manager.update_tool_labels.return_value = None
|
||||
mock_encrypter.return_value = (mock_encrypter, None)
|
||||
mock_encrypter.encrypt.return_value = {"encrypted": "credentials"}
|
||||
mock_provider_controller.from_db.return_value = mock_provider_controller
|
||||
mock_provider_controller.load_bundled_tools.return_value = None
|
||||
|
||||
yield {
|
||||
"tool_label_manager": mock_tool_label_manager,
|
||||
"encrypter": mock_encrypter,
|
||||
"provider_controller": mock_provider_controller,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant) - Created account and tenant instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER.value,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_openapi_schema(self):
|
||||
"""Helper method to create a test OpenAPI schema."""
|
||||
return """
|
||||
{
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": "Test API",
|
||||
"version": "1.0.0",
|
||||
"description": "Test API for testing purposes"
|
||||
},
|
||||
"servers": [
|
||||
{
|
||||
"url": "https://api.example.com",
|
||||
"description": "Production server"
|
||||
}
|
||||
],
|
||||
"paths": {
|
||||
"/test": {
|
||||
"get": {
|
||||
"operationId": "testOperation",
|
||||
"summary": "Test operation",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def test_parser_api_schema_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful parsing of API schema.
|
||||
|
||||
This test verifies:
|
||||
- Proper schema parsing with valid OpenAPI schema
|
||||
- Correct credentials schema generation
|
||||
- Proper warning handling
|
||||
- Return value structure
|
||||
"""
|
||||
# Arrange: Create test schema
|
||||
schema = self._create_test_openapi_schema()
|
||||
|
||||
# Act: Parse the schema
|
||||
result = ApiToolManageService.parser_api_schema(schema)
|
||||
|
||||
# Assert: Verify the result structure
|
||||
assert result is not None
|
||||
assert "schema_type" in result
|
||||
assert "parameters_schema" in result
|
||||
assert "credentials_schema" in result
|
||||
assert "warning" in result
|
||||
|
||||
# Verify credentials schema structure
|
||||
credentials_schema = result["credentials_schema"]
|
||||
assert len(credentials_schema) == 3
|
||||
|
||||
# Check auth_type field
|
||||
auth_type_field = next(field for field in credentials_schema if field["name"] == "auth_type")
|
||||
assert auth_type_field["required"] is True
|
||||
assert auth_type_field["default"] == "none"
|
||||
assert len(auth_type_field["options"]) == 2
|
||||
|
||||
# Check api_key_header field
|
||||
api_key_header_field = next(field for field in credentials_schema if field["name"] == "api_key_header")
|
||||
assert api_key_header_field["required"] is False
|
||||
assert api_key_header_field["default"] == "api_key"
|
||||
|
||||
# Check api_key_value field
|
||||
api_key_value_field = next(field for field in credentials_schema if field["name"] == "api_key_value")
|
||||
assert api_key_value_field["required"] is False
|
||||
assert api_key_value_field["default"] == ""
|
||||
|
||||
def test_parser_api_schema_invalid_schema(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test parsing of invalid API schema.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for invalid schemas
|
||||
- Correct exception type and message
|
||||
- Error propagation from underlying parser
|
||||
"""
|
||||
# Arrange: Create invalid schema
|
||||
invalid_schema = "invalid json schema"
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiToolManageService.parser_api_schema(invalid_schema)
|
||||
|
||||
assert "invalid schema" in str(exc_info.value)
|
||||
|
||||
def test_parser_api_schema_malformed_json(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test parsing of malformed JSON schema.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for malformed JSON
|
||||
- Correct exception type and message
|
||||
- Error propagation from JSON parsing
|
||||
"""
|
||||
# Arrange: Create malformed JSON schema
|
||||
malformed_schema = '{"openapi": "3.0.0", "info": {"title": "Test", "version": "1.0.0"}, "paths": {}}'
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiToolManageService.parser_api_schema(malformed_schema)
|
||||
|
||||
assert "invalid schema" in str(exc_info.value)
|
||||
|
||||
def test_convert_schema_to_tool_bundles_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of schema to tool bundles.
|
||||
|
||||
This test verifies:
|
||||
- Proper schema conversion with valid OpenAPI schema
|
||||
- Correct tool bundles generation
|
||||
- Proper schema type detection
|
||||
- Return value structure
|
||||
"""
|
||||
# Arrange: Create test schema
|
||||
schema = self._create_test_openapi_schema()
|
||||
|
||||
# Act: Convert schema to tool bundles
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema)
|
||||
|
||||
# Assert: Verify the result structure
|
||||
assert tool_bundles is not None
|
||||
assert isinstance(tool_bundles, list)
|
||||
assert len(tool_bundles) > 0
|
||||
assert schema_type is not None
|
||||
assert isinstance(schema_type, str)
|
||||
|
||||
# Verify tool bundle structure
|
||||
tool_bundle = tool_bundles[0]
|
||||
assert hasattr(tool_bundle, "operation_id")
|
||||
assert tool_bundle.operation_id == "testOperation"
|
||||
|
||||
def test_convert_schema_to_tool_bundles_with_extra_info(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of schema to tool bundles with extra info.
|
||||
|
||||
This test verifies:
|
||||
- Proper schema conversion with extra info parameter
|
||||
- Correct tool bundles generation
|
||||
- Extra info handling
|
||||
- Return value structure
|
||||
"""
|
||||
# Arrange: Create test schema and extra info
|
||||
schema = self._create_test_openapi_schema()
|
||||
extra_info = {"description": "Custom description", "version": "2.0.0"}
|
||||
|
||||
# Act: Convert schema to tool bundles with extra info
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
# Assert: Verify the result structure
|
||||
assert tool_bundles is not None
|
||||
assert isinstance(tool_bundles, list)
|
||||
assert len(tool_bundles) > 0
|
||||
assert schema_type is not None
|
||||
assert isinstance(schema_type, str)
|
||||
|
||||
def test_convert_schema_to_tool_bundles_invalid_schema(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test conversion of invalid schema to tool bundles.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for invalid schemas
|
||||
- Correct exception type and message
|
||||
- Error propagation from underlying parser
|
||||
"""
|
||||
# Arrange: Create invalid schema
|
||||
invalid_schema = "invalid schema content"
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiToolManageService.convert_schema_to_tool_bundles(invalid_schema)
|
||||
|
||||
assert "invalid schema" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful creation of API tool provider.
|
||||
|
||||
This test verifies:
|
||||
- Proper provider creation with valid parameters
|
||||
- Correct database state after creation
|
||||
- Proper relationship establishment
|
||||
- External service integration
|
||||
- Return value correctness
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""}
|
||||
schema_type = "openapi"
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
labels = ["test", "api"]
|
||||
|
||||
# Act: Create API tool provider
|
||||
result = ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon=icon,
|
||||
credentials=credentials,
|
||||
schema_type=schema_type,
|
||||
schema=schema,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
# Assert: Verify the result
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert provider is not None
|
||||
assert provider.name == provider_name
|
||||
assert provider.tenant_id == tenant.id
|
||||
assert provider.user_id == account.id
|
||||
assert provider.schema_type_str == schema_type
|
||||
assert provider.privacy_policy == privacy_policy
|
||||
assert provider.custom_disclaimer == custom_disclaimer
|
||||
|
||||
# Verify mock interactions
|
||||
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once()
|
||||
mock_external_service_dependencies["encrypter"].assert_called_once()
|
||||
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
|
||||
mock_external_service_dependencies["provider_controller"].load_bundled_tools.assert_called_once()
|
||||
|
||||
def test_create_api_tool_provider_duplicate_name(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test creation of API tool provider with duplicate name.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for duplicate provider names
|
||||
- Correct exception type and message
|
||||
- Database constraint enforcement
|
||||
"""
|
||||
# Arrange: Create test data and existing provider
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {"auth_type": "none"}
|
||||
schema_type = "openapi"
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
labels = ["test"]
|
||||
|
||||
# Create first provider
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon=icon,
|
||||
credentials=credentials,
|
||||
schema_type=schema_type,
|
||||
schema=schema,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
# Act & Assert: Try to create duplicate provider
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon=icon,
|
||||
credentials=credentials,
|
||||
schema_type=schema_type,
|
||||
schema=schema,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
assert f"provider {provider_name} already exists" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_invalid_schema_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test creation of API tool provider with invalid schema type.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for invalid schema types
|
||||
- Correct exception type and message
|
||||
- Schema type validation
|
||||
"""
|
||||
# Arrange: Create test data with invalid schema type
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {"auth_type": "none"}
|
||||
schema_type = "invalid_type"
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
labels = ["test"]
|
||||
|
||||
# Act & Assert: Try to create provider with invalid schema type
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon=icon,
|
||||
credentials=credentials,
|
||||
schema_type=schema_type,
|
||||
schema=schema,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
assert "invalid schema type" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_missing_auth_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test creation of API tool provider with missing auth type.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing auth type
|
||||
- Correct exception type and message
|
||||
- Credentials validation
|
||||
"""
|
||||
# Arrange: Create test data with missing auth type
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {} # Missing auth_type
|
||||
schema_type = "openapi"
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
labels = ["test"]
|
||||
|
||||
# Act & Assert: Try to create provider with missing auth type
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon=icon,
|
||||
credentials=credentials,
|
||||
schema_type=schema_type,
|
||||
schema=schema,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
assert "auth_type is required" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_with_api_key_auth(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful creation of API tool provider with API key authentication.
|
||||
|
||||
This test verifies:
|
||||
- Proper provider creation with API key auth
|
||||
- Correct credentials handling
|
||||
- Proper authentication type processing
|
||||
"""
|
||||
# Arrange: Create test data with API key auth
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔑"}
|
||||
credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()}
|
||||
schema_type = "openapi"
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
labels = ["api_key", "secure"]
|
||||
|
||||
# Act: Create API tool provider
|
||||
result = ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon=icon,
|
||||
credentials=credentials,
|
||||
schema_type=schema_type,
|
||||
schema=schema,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
# Assert: Verify the result
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert provider is not None
|
||||
assert provider.name == provider_name
|
||||
assert provider.tenant_id == tenant.id
|
||||
assert provider.user_id == account.id
|
||||
assert provider.schema_type_str == schema_type
|
||||
|
||||
# Verify mock interactions
|
||||
mock_external_service_dependencies["encrypter"].assert_called_once()
|
||||
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user