Compare commits

..

57 Commits

Author SHA1 Message Date
41dfdf1ac0 fix:score threshold 2025-09-01 16:34:17 +08:00
dd7de74aa6 修复top-k硬编码回退问题 2025-09-01 14:27:43 +08:00
f11131f8b5 fix: basepath did not read from the environment variable (#24870) 2025-09-01 13:50:33 +08:00
2e6e414a9e the conversion OAuthGrantType(parsed_args["grant_type"]) can raise ValueError for invalid values which is not caught and will produce a 500 (#24854) 2025-09-01 10:05:54 +08:00
c45d676477 remove duplicated authorization header handling and bearer should be case-insensitive (#24852) 2025-09-01 10:05:19 +08:00
b8d8dddd5a example of decorator typing (#24857)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-01 10:04:24 +08:00
c45c22b1b2 fix translation of all oauth.ts (#24855) 2025-09-01 10:04:05 +08:00
3d57a9ccdc Fix never hit (!code || code.length === 0) (#24860) 2025-09-01 09:45:07 +08:00
cb04c21141 model_config = ConfigDict(extra='allow') (#24859)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-08-31 23:21:41 +08:00
f70272f638 refactor: replace clsx with classnames (#24776) 2025-08-31 17:08:29 +08:00
b4b71ded47 chore: remove unused i18n keys (#24803) 2025-08-31 17:07:15 +08:00
24e2b72b71 Update ast-grep pattern for session.query (#24828)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-08-31 17:03:51 +08:00
529791ce62 fix: Variable Aggregator cannot select conversation variables (#24793) 2025-08-31 17:03:36 +08:00
b66945b9b8 feat: add test containers based tests for api tool manage service (#24821) 2025-08-31 17:02:08 +08:00
f3c5d77ad5 chore: remove duplicate Python style checks handled by autofix CI (#24833) 2025-08-31 17:01:19 +08:00
e5e42bc483 fix: XSS vulnerability in block-input and support-var-input components (#24835) 2025-08-31 17:01:10 +08:00
bdfbfa391f Feature add test containers mcp tools manage service (#24840) 2025-08-31 17:01:01 +08:00
72acd9b483 Remove redundant from_variable_selector null-check (#24842) 2025-08-31 17:00:13 +08:00
9f528d23d4 poc of validate config (#24837)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-08-31 02:41:52 +08:00
d937cc491d chore[docker]: Fix Redis health check error but display healthy (#24778) 2025-08-30 06:19:43 -07:00
863f3aeb27 Fix: rm invalid errorMessage on e.toString() (#24805) 2025-08-30 06:18:51 -07:00
0fe078d25e fix: workflow_finish_to_stream_response assert exception with celery … (#24674)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-08-30 00:59:21 +08:00
d9420c7224 refactor: reorganize the CI pipeline (#24817)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-08-30 00:12:25 +08:00
9ff6baaf52 refactor: remove duplicate pull_request triggers from workflow files (#24814) 2025-08-29 23:09:26 +08:00
574d00bb13 fix: add missing statuses permission to main CI workflow (#24809) 2025-08-29 22:33:13 +08:00
8d60e5c342 chore(api): fix Alembic offline migration compatibility (#24795)
This PR fixes Alembic offline mode (`--sql` flag) by ensuring data migration functions only execute in online mode. When running in offline mode, these functions now skip data operations and output informational comments to the generated SQL.
2025-08-29 19:13:24 +08:00
d9eb1a73af fix(api): fix DetachedInstanceError for Account.current_tenant_id (#24789)
The `Account._current_tenant` object is loaded by a database session (typically `db.session`) whose lifetime 
is not aligned with the Account model instance. This misalignment causes a `DetachedInstanceError` to be raised
when accessing attributes of `Account._current_tenant` after the original session has been closed.

To resolve this issue, we now reload the tenant object with `expire_on_commit=False`, ensuring the tenant remains
accessible even after the session is closed.
2025-08-29 19:12:02 +08:00
1a34ff8a67 fix: change the mcp server strucutre to support github copilot (#24788) 2025-08-29 18:00:58 +08:00
14e7ba4818 chore: change the oauth_provider_apps table to uuidV7 (#24792) 2025-08-29 17:54:14 +08:00
52e9bcbfdb fix(web): improve floating UI positioning when scrolling (#24595) (#24782) 2025-08-29 16:49:13 +08:00
20ae3eae54 feat: add filename support to multi-modal prompt messages (#24777) 2025-08-29 16:22:26 +08:00
0fb145e667 refactor: Promote basepath to environment variable (#24445)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-08-29 15:39:37 +08:00
bcac43c812 fix(web): fix error notify when tagInput component is not required (#… (#24774) 2025-08-29 15:30:40 +08:00
929d9e0b3f feat(api): maintain assistant content parts and file handling in advanced chat (#24663)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-08-29 15:19:55 +08:00
d5e560a987 chore: translate i18n files (#24770)
Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>
2025-08-29 14:34:35 +08:00
e4383d6167 Chore: remove dupliacte logic in DatasetApi.get() (#24769)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
2025-08-29 14:25:36 +08:00
f32e176d6a feat: oauth provider (#24206)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: yessenia <yessenia.contact@gmail.com>
2025-08-29 14:10:51 +08:00
3d5a4df9d0 chore: use orjson in streaming event JSON serialisation for performance improvement (#24763) 2025-08-29 14:06:07 +08:00
e47bfd2ca3 feat: orchestrate CI workflows to prevent duplicate runs when autofix makes changes (#24758) 2025-08-29 13:23:08 +08:00
f8f768873e fix: inconsistent text color for settings button in webapp cards (#24754) 2025-08-29 12:10:27 +08:00
d043e1a05a feat: add test containers based tests for workspace service (#24752) 2025-08-29 12:10:13 +08:00
837c0ddacc Chore: remove dead func AppModelConfig.copy() with wrong logic (#24747) 2025-08-29 11:38:24 +08:00
7c340695d6 fix: unclosing tag (#24733) 2025-08-28 23:59:04 +08:00
e87d4fbf69 chore: translate i18n files (#24727)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-08-28 23:40:10 +08:00
39064197da chore: cleanup unnecessary mypy suppressions on imports (#24712) 2025-08-28 23:17:25 +08:00
c4496e6cf2 chore: use DataFrame.map instead of deprecated DataFrame.applymap (#24726) 2025-08-28 21:13:47 +08:00
27d09d1783 feat: Add support for slash commands, optimize command selector logic. (#24723) 2025-08-28 21:13:18 +08:00
a174ee419e chore: fix some api desc (#24715)
Co-authored-by: zhuqingchao <zhuqingchao@xiaomi.com>
2025-08-28 20:47:12 +08:00
79e6138ce2 chore: simplify the workflow details logic (#24714) 2025-08-28 18:17:48 +08:00
5a64f69456 fix: Default value for input variable is null when starting new conversations on the web app (#24709) 2025-08-28 17:48:04 +08:00
5c01dd97e8 clean typos words. (#24667)
Signed-off-by: zhanluxianshen <zhanluxianshen@163.com>
2025-08-28 15:23:59 +08:00
ecf74d91e2 fix: has_more logic in ChatMessageListApi to ensure correct on behavior when no more messages are available. (#24661) 2025-08-28 15:05:52 +08:00
62892ed8d7 refactor: relocate China npm registry config to base image (#24678) 2025-08-28 14:43:34 +08:00
7b399cc5e5 feat: add MCP configuration for Claude Code optimization (#24679) 2025-08-28 14:38:36 +08:00
fab5740778 fix: can not choose file type var in aggreggator node (#24689) 2025-08-28 14:28:46 +08:00
30f2d756a7 fix_trace_config (#24669)
Co-authored-by: renming <renming@renmingdeMacBook-Air.local>
2025-08-28 13:54:49 +08:00
0d745c64d8 chore: bump supabase and pyjwt versions and added tests (#24681) 2025-08-28 13:45:56 +08:00
316 changed files with 8276 additions and 3836 deletions

View File

@ -0,0 +1,19 @@
{
"permissions": {
"allow": [],
"deny": []
},
"env": {
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
},
"enabledMcpjsonServers": [
"context7",
"sequential-thinking",
"github",
"fetch",
"playwright",
"ide"
],
"enableAllProjectMcpServers": true
}

View File

@ -1,13 +1,7 @@
name: Run Pytest
on:
pull_request:
branches:
- main
paths:
- api/**
- docker/**
- .github/workflows/api-tests.yml
workflow_call:
concurrency:
group: api-tests-${{ github.head_ref || github.run_id }}

View File

@ -1,10 +1,7 @@
name: autofix.ci
on:
workflow_call:
pull_request:
branches: [ "main" ]
push:
branches: [ "main" ]
branches: ["main"]
permissions:
contents: read
@ -18,7 +15,7 @@ jobs:
# Use uv to ensure we have the same ruff version in CI and locally.
- uses: astral-sh/setup-uv@v6
with:
python-version: "3.12"
python-version: "3.12"
- run: |
cd api
uv sync --dev
@ -29,6 +26,7 @@ jobs:
- name: ast-grep
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
- name: mdformat
run: |
uvx mdformat .

View File

@ -1,13 +1,7 @@
name: DB Migration Test
on:
pull_request:
branches:
- main
- plugins/beta
paths:
- api/migrations/**
- .github/workflows/db-migration-test.yml
workflow_call:
concurrency:
group: db-migration-test-${{ github.ref }}
@ -33,6 +27,12 @@ jobs:
- name: Install dependencies
run: uv sync --project api
- name: Ensure Offline migration are supported
run: |
# upgrade
uv run --directory api flask db upgrade 'base:head' --sql
# downgrade
uv run --directory api flask db downgrade 'head:base' --sql
- name: Prepare middleware env
run: |

78
.github/workflows/main-ci.yml vendored Normal file
View File

@ -0,0 +1,78 @@
name: Main CI Pipeline
on:
pull_request:
branches: ["main"]
push:
branches: ["main"]
permissions:
contents: write
pull-requests: write
checks: write
statuses: write
concurrency:
group: main-ci-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
# Check which paths were changed to determine which tests to run
check-changes:
name: Check Changed Files
runs-on: ubuntu-latest
outputs:
api-changed: ${{ steps.changes.outputs.api }}
web-changed: ${{ steps.changes.outputs.web }}
vdb-changed: ${{ steps.changes.outputs.vdb }}
migration-changed: ${{ steps.changes.outputs.migration }}
steps:
- uses: actions/checkout@v4
- uses: dorny/paths-filter@v3
id: changes
with:
filters: |
api:
- 'api/**'
- 'docker/**'
- '.github/workflows/api-tests.yml'
web:
- 'web/**'
vdb:
- 'api/core/rag/datasource/**'
- 'docker/**'
- '.github/workflows/vdb-tests.yml'
- 'api/uv.lock'
- 'api/pyproject.toml'
migration:
- 'api/migrations/**'
- '.github/workflows/db-migration-test.yml'
# Run tests in parallel
api-tests:
name: API Tests
needs: check-changes
if: needs.check-changes.outputs.api-changed == 'true'
uses: ./.github/workflows/api-tests.yml
web-tests:
name: Web Tests
needs: check-changes
if: needs.check-changes.outputs.web-changed == 'true'
uses: ./.github/workflows/web-tests.yml
style-check:
name: Style Check
uses: ./.github/workflows/style.yml
vdb-tests:
name: VDB Tests
needs: check-changes
if: needs.check-changes.outputs.vdb-changed == 'true'
uses: ./.github/workflows/vdb-tests.yml
db-migration-test:
name: DB Migration Test
needs: check-changes
if: needs.check-changes.outputs.migration-changed == 'true'
uses: ./.github/workflows/db-migration-test.yml

View File

@ -1,9 +1,7 @@
name: Style check
on:
pull_request:
branches:
- main
workflow_call:
concurrency:
group: style-${{ github.head_ref || github.run_id }}
@ -46,21 +44,10 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: uv sync --project api --dev
- name: Ruff check
if: steps.changed-files.outputs.any_changed == 'true'
run: |
uv run --directory api ruff --version
uv run --directory api ruff check ./
uv run --directory api ruff format --check ./
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
- name: Lint hints
if: failure()
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
web-style:
name: Web Style
runs-on: ubuntu-latest

View File

@ -1,15 +1,7 @@
name: Run VDB Tests
on:
pull_request:
branches:
- main
paths:
- api/core/rag/datasource/**
- docker/**
- .github/workflows/vdb-tests.yml
- api/uv.lock
- api/pyproject.toml
workflow_call:
concurrency:
group: vdb-tests-${{ github.head_ref || github.run_id }}

View File

@ -1,11 +1,7 @@
name: Web Tests
on:
pull_request:
branches:
- main
paths:
- web/**
workflow_call:
concurrency:
group: web-tests-${{ github.head_ref || github.run_id }}

34
.mcp.json Normal file
View File

@ -0,0 +1,34 @@
{
"mcpServers": {
"context7": {
"type": "http",
"url": "https://mcp.context7.com/mcp"
},
"sequential-thinking": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"env": {}
},
"github": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-github"],
"env": {
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
}
},
"fetch": {
"type": "stdio",
"command": "uvx",
"args": ["mcp-server-fetch"],
"env": {}
},
"playwright": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@playwright/mcp@latest"],
"env": {}
}
}
}

View File

@ -86,3 +86,4 @@ pnpm test # Run Jest tests
## Project-Specific Conventions
- All async tasks use Celery with Redis as broker
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.

View File

@ -70,7 +70,7 @@ from .app import (
)
# Import auth controllers
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server
# Import billing controllers
from .billing import billing, compliance

View File

@ -95,18 +95,22 @@ class ChatMessageListApi(Resource):
.all()
)
# Initialize has_more based on whether we have a full page
if len(history_messages) == args["limit"]:
current_page_first_message = history_messages[-1]
has_more = db.session.scalar(
select(
exists().where(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
# Check if there are more messages before the current page
has_more = db.session.scalar(
select(
exists().where(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
)
)
)
)
else:
# If we don't have a full page, there are no more messages
has_more = False
history_messages = list(reversed(history_messages))
@ -126,7 +130,7 @@ class MessageFeedbackApi(Resource):
message_id = str(args["message_id"])
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")

View File

@ -0,0 +1,187 @@
from functools import wraps
from typing import cast
import flask_login
from flask import request
from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from models.account import Account
from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
from .. import api
def oauth_server_client_id_required(view):
@wraps(view)
def decorated(*args, **kwargs):
parser = reqparse.RequestParser()
parser.add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args()
client_id = parsed_args.get("client_id")
if not client_id:
raise BadRequest("client_id is required")
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
if not oauth_provider_app:
raise NotFound("client_id is invalid")
kwargs["oauth_provider_app"] = oauth_provider_app
return view(*args, **kwargs)
return decorated
def oauth_server_access_token_required(view):
@wraps(view)
def decorated(*args, **kwargs):
oauth_provider_app = kwargs.get("oauth_provider_app")
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")
authorization_header = request.headers.get("Authorization")
if not authorization_header:
raise BadRequest("Authorization header is required")
parts = authorization_header.strip().split(" ")
if len(parts) != 2:
raise BadRequest("Invalid Authorization header format")
token_type = parts[0].strip()
if token_type.lower() != "bearer":
raise BadRequest("token_type is invalid")
access_token = parts[1].strip()
if not access_token:
raise BadRequest("access_token is required")
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
if not account:
raise BadRequest("access_token or client_id is invalid")
kwargs["account"] = account
return view(*args, **kwargs)
return decorated
class OAuthServerAppApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser()
parser.add_argument("redirect_uri", type=str, required=True, location="json")
parsed_args = parser.parse_args()
redirect_uri = parsed_args.get("redirect_uri")
# check if redirect_uri is valid
if redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
return jsonable_encoder(
{
"app_icon": oauth_provider_app.app_icon,
"app_label": oauth_provider_app.app_label,
"scope": oauth_provider_app.scope,
}
)
class OAuthServerUserAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
account = cast(Account, flask_login.current_user)
user_account_id = account.id
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
return jsonable_encoder(
{
"code": code,
}
)
class OAuthServerUserTokenApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser()
parser.add_argument("grant_type", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=False, location="json")
parser.add_argument("client_secret", type=str, required=False, location="json")
parser.add_argument("redirect_uri", type=str, required=False, location="json")
parser.add_argument("refresh_token", type=str, required=False, location="json")
parsed_args = parser.parse_args()
try:
grant_type = OAuthGrantType(parsed_args["grant_type"])
except ValueError:
raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not parsed_args["code"]:
raise BadRequest("code is required")
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not parsed_args["refresh_token"]:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
class OAuthServerUserAccountApi(Resource):
@setup_required
@oauth_server_client_id_required
@oauth_server_access_token_required
def post(self, oauth_provider_app: OAuthProviderApp, account: Account):
return jsonable_encoder(
{
"name": account.name,
"email": account.email,
"avatar": account.avatar,
"interface_language": account.interface_language,
"timezone": account.timezone,
}
)
api.add_resource(OAuthServerAppApi, "/oauth/provider")
api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize")
api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token")
api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account")

View File

@ -1,8 +1,12 @@
from base64 import b64encode
from collections.abc import Callable
from functools import wraps
from hashlib import sha1
from hmac import new as hmac_new
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
from flask import abort, request
from configs import dify_config
@ -10,9 +14,9 @@ from extensions.ext_database import db
from models.model import EndUser
def billing_inner_api_only(view):
def billing_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
abort(404)
@ -26,9 +30,9 @@ def billing_inner_api_only(view):
return decorated
def enterprise_inner_api_only(view):
def enterprise_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
abort(404)
@ -78,9 +82,9 @@ def enterprise_inner_api_user_auth(view):
return decorated
def plugin_inner_api_only(view):
def plugin_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.PLUGIN_DAEMON_KEY:
abort(404)

View File

@ -1,18 +1,27 @@
from typing import Optional, Union
from flask import Response
from flask_restx import Resource, reqparse
from pydantic import ValidationError
from sqlalchemy.orm import Session
from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import mcp_ns
from core.app.app_config.entities import VariableEntity
from core.mcp import types
from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
from core.mcp.types import ClientNotification, ClientRequest
from core.mcp.utils import create_mcp_error_response
from core.mcp import types as mcp_types
from core.mcp.server.streamable_http import handle_mcp_request
from extensions.ext_database import db
from libs import helper
from models.model import App, AppMCPServer, AppMode
from models.model import App, AppMCPServer, AppMode, EndUser
class MCPRequestError(Exception):
"""Custom exception for MCP request processing errors"""
def __init__(self, error_code: int, message: str):
self.error_code = error_code
self.message = message
super().__init__(message)
def int_or_str(value):
@ -63,77 +72,173 @@ class MCPAppApi(Resource):
Raises:
ValidationError: Invalid request format or parameters
"""
# Parse and validate all arguments
args = mcp_request_parser.parse_args()
request_id: Optional[Union[int, str]] = args.get("id")
mcp_request = self._parse_mcp_request(args)
server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not server:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
)
with Session(db.engine, expire_on_commit=False) as session:
# Get MCP server and app
mcp_server, app = self._get_mcp_server_and_app(server_code, session)
self._validate_server_status(mcp_server)
if server.status != AppMCPServerStatus.ACTIVE:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
)
# Get user input form
user_input_form = self._get_user_input_form(app)
app = db.session.query(App).where(App.id == server.app_id).first()
# Handle notification vs request differently
return self._process_mcp_message(mcp_request, request_id, app, mcp_server, user_input_form, session)
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
"""Get and validate MCP server and app in one query session"""
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not mcp_server:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
app = session.query(App).where(App.id == mcp_server.app_id).first()
if not app:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")
)
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app.workflow
if workflow is None:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
)
return mcp_server, app
user_input_form = workflow.user_input_form(to_old_structure=True)
def _validate_server_status(self, mcp_server: AppMCPServer) -> None:
"""Validate MCP server status"""
if mcp_server.status != AppMCPServerStatus.ACTIVE:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")
def _process_mcp_message(
self,
mcp_request: mcp_types.ClientRequest | mcp_types.ClientNotification,
request_id: Optional[Union[int, str]],
app: App,
mcp_server: AppMCPServer,
user_input_form: list[VariableEntity],
session: Session,
) -> Response:
"""Process MCP message (notification or request)"""
if isinstance(mcp_request, mcp_types.ClientNotification):
return self._handle_notification(mcp_request)
else:
app_model_config = app.app_model_config
if app_model_config is None:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
)
return self._handle_request(mcp_request, request_id, app, mcp_server, user_input_form, session)
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
converted_user_input_form: list[VariableEntity] = []
try:
for item in user_input_form:
variable_type = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type]
converted_user_input_form.append(
VariableEntity(
type=variable_type,
variable=variable.get("variable"),
description=variable.get("description") or "",
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options") or [],
)
)
except ValidationError as e:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
)
def _handle_notification(self, mcp_request: mcp_types.ClientNotification) -> Response:
"""Handle MCP notification"""
# For notifications, only support init notification
if mcp_request.root.method != "notifications/initialized":
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Invalid notification method")
# Return HTTP 202 Accepted for notifications (no response body)
return Response("", status=202, content_type="application/json")
def _handle_request(
self,
mcp_request: mcp_types.ClientRequest,
request_id: Optional[Union[int, str]],
app: App,
mcp_server: AppMCPServer,
user_input_form: list[VariableEntity],
session: Session,
) -> Response:
"""Handle MCP request"""
if request_id is None:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Request ID is required")
result = self._handle_mcp_request(app, mcp_server, mcp_request, user_input_form, session, request_id)
if result is None:
# This shouldn't happen for requests, but handle gracefully
raise MCPRequestError(mcp_types.INTERNAL_ERROR, "No response generated for request")
return helper.compact_generate_response(result.model_dump(by_alias=True, mode="json", exclude_none=True))
def _get_user_input_form(self, app: App) -> list[VariableEntity]:
"""Get and convert user input form"""
# Get raw user input form based on app mode
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if not app.workflow:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
raw_user_input_form = app.workflow.user_input_form(to_old_structure=True)
else:
if not app.app_model_config:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
features_dict = app.app_model_config.to_dict()
raw_user_input_form = features_dict.get("user_input_form", [])
# Convert to VariableEntity objects
try:
request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
return self._convert_user_input_form(raw_user_input_form)
except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
"""Convert raw user input form to VariableEntity objects"""
return [self._create_variable_entity(item) for item in raw_form]
def _create_variable_entity(self, item: dict) -> VariableEntity:
"""Create a single VariableEntity from raw form item"""
variable_type = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type]
return VariableEntity(
type=variable_type,
variable=variable.get("variable"),
description=variable.get("description") or "",
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options") or [],
)
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
"""Parse and validate MCP request"""
try:
return mcp_types.ClientRequest.model_validate(args)
except ValidationError:
try:
notification = ClientNotification.model_validate(args)
request = notification
return mcp_types.ClientNotification.model_validate(args)
except ValidationError as e:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
)
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
response = mcp_server_handler.handle()
return helper.compact_generate_response(response)
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
"""Get end user from existing session - optimized query"""
return (
session.query(EndUser)
.where(EndUser.tenant_id == tenant_id)
.where(EndUser.session_id == mcp_server_id)
.where(EndUser.type == "mcp")
.first()
)
def _create_end_user(
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
) -> EndUser:
"""Create end user in existing session"""
end_user = EndUser(
tenant_id=tenant_id,
app_id=app_id,
type="mcp",
name=client_name,
session_id=mcp_server_id,
)
session.add(end_user)
session.flush() # Use flush instead of commit to keep transaction open
session.refresh(end_user)
return end_user
def _handle_mcp_request(
self,
app: App,
mcp_server: AppMCPServer,
mcp_request: mcp_types.ClientRequest,
user_input_form: list[VariableEntity],
session: Session,
request_id: Union[int, str],
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
"""Handle MCP request and return response"""
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
client_info = mcp_request.root.params.clientInfo
client_name = f"{client_info.name}@{client_info.version}"
# Commit the session before creating end user to avoid transaction conflicts
session.commit()
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)

View File

@ -318,10 +318,6 @@ class DatasetApi(DatasetApiResource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
# check embedding setting
provider_manager = ProviderManager()
assert isinstance(current_user, Account)

View File

@ -1,6 +1,6 @@
from typing import Literal
from flask_login import current_user # type: ignore
from flask_login import current_user
from flask_restx import marshal, reqparse
from werkzeug.exceptions import NotFound

View File

@ -6,7 +6,7 @@ from functools import wraps
from typing import Optional
from flask import current_app, request
from flask_login import user_logged_in # type: ignore
from flask_login import user_logged_in
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select, update

View File

@ -1,5 +1,5 @@
from flask_restx import Resource, reqparse
from jwt import InvalidTokenError # type: ignore
from jwt import InvalidTokenError
import services
from controllers.console.auth.error import (

View File

@ -1,3 +1,16 @@
from pydantic import BaseModel, ConfigDict, Field, ValidationError
class MoreLikeThisConfig(BaseModel):
enabled: bool = False
model_config = ConfigDict(extra="allow")
class AppConfigModel(BaseModel):
more_like_this: MoreLikeThisConfig = Field(default_factory=MoreLikeThisConfig)
model_config = ConfigDict(extra="allow")
class MoreLikeThisConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
@ -6,31 +19,14 @@ class MoreLikeThisConfigManager:
:param config: model config args
"""
more_like_this = False
more_like_this_dict = config.get("more_like_this")
if more_like_this_dict:
if more_like_this_dict.get("enabled"):
more_like_this = True
return more_like_this
validated_config, _ = cls.validate_and_set_defaults(config)
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for more like this feature
:param config: app model config args
"""
if not config.get("more_like_this"):
config["more_like_this"] = {"enabled": False}
if not isinstance(config["more_like_this"], dict):
raise ValueError("more_like_this must be of dict type")
if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]:
config["more_like_this"]["enabled"] = False
if not isinstance(config["more_like_this"]["enabled"], bool):
raise ValueError("enabled in more_like_this must be of boolean type")
return config, ["more_like_this"]
try:
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
except ValidationError as e:
raise ValueError(
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
)

View File

@ -1,4 +1,5 @@
import logging
import re
import time
from collections.abc import Callable, Generator, Mapping
from contextlib import contextmanager
@ -143,6 +144,7 @@ class AdvancedChatAppGenerateTaskPipeline:
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
)
self._task_state = WorkflowTaskState()
@ -373,7 +375,7 @@ class AdvancedChatAppGenerateTaskPipeline:
) -> Generator[StreamResponse, None, None]:
"""Handle node succeeded events."""
# Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END]:
if event.node_type in [NodeType.ANSWER, NodeType.END, NodeType.LLM]:
self._recorded_files.extend(
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
)
@ -896,7 +898,14 @@ class AdvancedChatAppGenerateTaskPipeline:
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
message = self._get_message(session=session)
message.answer = self._task_state.answer
# If there are assistant files, remove markdown image links from answer
answer_text = self._task_state.answer
if self._recorded_files:
# Remove markdown image links since we're storing files separately
answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip()
message.answer = answer_text
message.updated_at = naive_utc_now()
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = self._task_state.metadata.model_dump_json()

View File

@ -1,4 +1,3 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union, final
@ -14,6 +13,7 @@ from core.workflow.repositories.draft_variable_repository import (
NoopDraftVariableSaver,
)
from factories import file_factory
from libs.orjson import orjson_dumps
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
if TYPE_CHECKING:
@ -174,7 +174,7 @@ class BaseAppGenerator:
def gen():
for message in generator:
if isinstance(message, Mapping | dict):
yield f"data: {json.dumps(message)}\n\n"
yield f"data: {orjson_dumps(message)}\n\n"
else:
yield f"event: {message}\n\n"

View File

@ -3,7 +3,6 @@ from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
@ -53,9 +52,7 @@ from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.datetime_utils import naive_utc_now
from models import (
Account,
CreatorUserRole,
EndUser,
WorkflowRun,
)
@ -64,8 +61,10 @@ class WorkflowResponseConverter:
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
user: Union[Account, EndUser],
) -> None:
self._application_generate_entity = application_generate_entity
self._user = user
def workflow_start_to_stream_response(
self,
@ -92,27 +91,21 @@ class WorkflowResponseConverter:
workflow_execution: WorkflowExecution,
) -> WorkflowFinishStreamResponse:
created_by = None
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
assert workflow_run is not None
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
stmt = select(Account).where(Account.id == workflow_run.created_by)
account = session.scalar(stmt)
if account:
created_by = {
"id": account.id,
"name": account.name,
"email": account.email,
}
elif workflow_run.created_by_role == CreatorUserRole.END_USER:
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
end_user = session.scalar(stmt)
if end_user:
created_by = {
"id": end_user.id,
"user": end_user.session_id,
}
user = self._user
if isinstance(user, Account):
created_by = {
"id": user.id,
"name": user.name,
"email": user.email,
}
elif isinstance(user, EndUser):
created_by = {
"id": user.id,
"user": user.session_id,
}
else:
raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
raise NotImplementedError(f"User type not supported: {type(user)}")
# Handle the case where finished_at is None by using current time as default
finished_at_timestamp = (

View File

@ -131,6 +131,7 @@ class WorkflowAppGenerateTaskPipeline:
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
)
self._application_generate_entity = application_generate_entity

View File

@ -118,7 +118,7 @@ class QueueIterationNextEvent(AppQueueEvent):
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
"""iteration run in parallel mode run id"""
node_run_index: int
output: Optional[Any] = None # output for the current iteration
duration: Optional[float] = None
@ -201,7 +201,7 @@ class QueueLoopNextEvent(AppQueueEvent):
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
"""iteration run in parallel mode run id"""
node_run_index: int
output: Optional[Any] = None # output for the current loop
duration: Optional[float] = None
@ -382,7 +382,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
"""loop id if node is in loop"""
start_at: datetime
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
"""iteration run in parallel mode run id"""
agent_strategy: Optional[AgentNodeStrategyInit] = None

View File

@ -1,8 +1,8 @@
class TaskPipilineError(ValueError):
class TaskPipelineError(ValueError):
pass
class RecordNotFoundError(TaskPipilineError):
class RecordNotFoundError(TaskPipelineError):
def __init__(self, record_name: str, record_id: str):
super().__init__(f"{record_name} with id {record_id} not found")

View File

@ -88,6 +88,7 @@ def to_prompt_message_content(
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
"format": f.extension.removeprefix("."),
"mime_type": f.mime_type,
"filename": f.filename or "",
}
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW

View File

@ -4,224 +4,259 @@ from collections.abc import Mapping
from typing import Any, cast
from configs import dify_config
from controllers.web.passport import generate_session_id
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.mcp import types
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
from core.mcp.utils import create_mcp_error_response
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from core.mcp import types as mcp_types
from models.model import App, AppMCPServer, AppMode, EndUser
from services.app_generate_service import AppGenerateService
logger = logging.getLogger(__name__)
class MCPServerStreamableHTTPRequestHandler:
def handle_mcp_request(
app: App,
request: mcp_types.ClientRequest,
user_input_form: list[VariableEntity],
mcp_server: AppMCPServer,
end_user: EndUser | None = None,
request_id: int | str = 1,
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError:
"""
Apply to MCP HTTP streamable server with stateless http
Handle MCP request and return JSON-RPC response
Args:
app: The Dify app instance
request: The JSON-RPC request message
user_input_form: List of variable entities for the app
mcp_server: The MCP server configuration
end_user: Optional end user
request_id: The request ID
Returns:
JSON-RPC response or error
"""
def __init__(
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
):
self.app = app
self.request = request
mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first()
if not mcp_server:
raise ValueError("MCP server not found")
self.mcp_server: AppMCPServer = mcp_server
self.end_user = self.retrieve_end_user()
self.user_input_form = user_input_form
request_type = type(request.root)
@property
def request_type(self):
return type(self.request.root)
def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse:
"""Create success response with business result data"""
return mcp_types.JSONRPCResponse(
jsonrpc="2.0",
id=request_id,
result=result_data.model_dump(by_alias=True, mode="json", exclude_none=True),
)
@property
def parameter_schema(self):
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
return {
"type": "object",
"properties": parameters,
"required": required,
}
def create_error_response(code: int, message: str) -> mcp_types.JSONRPCError:
"""Create error response with error code and message"""
from core.mcp.types import ErrorData
error_data = ErrorData(code=code, message=message)
return mcp_types.JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=error_data,
)
# Request handler mapping using functional approach
request_handlers = {
mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description),
mcp_types.ListToolsRequest: lambda: handle_list_tools(
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
),
mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user),
mcp_types.PingRequest: lambda: handle_ping(),
}
try:
# Dispatch request to appropriate handler
handler = request_handlers.get(request_type)
if handler:
return create_success_response(handler())
else:
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")
except ValueError as e:
logger.exception("Invalid params")
return create_error_response(mcp_types.INVALID_PARAMS, str(e))
except Exception as e:
logger.exception("Internal server error")
return create_error_response(mcp_types.INTERNAL_ERROR, "Internal server error: " + str(e))
def handle_ping() -> mcp_types.EmptyResult:
"""Handle ping request"""
return mcp_types.EmptyResult()
def handle_initialize(description: str) -> mcp_types.InitializeResult:
"""Handle initialize request"""
capabilities = mcp_types.ServerCapabilities(
tools=mcp_types.ToolsCapability(listChanged=False),
)
return mcp_types.InitializeResult(
protocolVersion=mcp_types.SERVER_LATEST_PROTOCOL_VERSION,
capabilities=capabilities,
serverInfo=mcp_types.Implementation(name="Dify", version=dify_config.project.version),
instructions=description,
)
def handle_list_tools(
app_name: str,
app_mode: str,
user_input_form: list[VariableEntity],
description: str,
parameters_dict: dict[str, str],
) -> mcp_types.ListToolsResult:
"""Handle list tools request"""
parameter_schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)
return mcp_types.ListToolsResult(
tools=[
mcp_types.Tool(
name=app_name,
description=description,
inputSchema=parameter_schema,
)
],
)
def handle_call_tool(
app: App,
request: mcp_types.ClientRequest,
user_input_form: list[VariableEntity],
end_user: EndUser | None,
) -> mcp_types.CallToolResult:
"""Handle call tool request"""
request_obj = cast(mcp_types.CallToolRequest, request.root)
args = prepare_tool_arguments(app, request_obj.params.arguments or {})
if not end_user:
raise ValueError("End user not found")
response = AppGenerateService.generate(
app,
end_user,
args,
InvokeFrom.SERVICE_API,
streaming=app.mode == AppMode.AGENT_CHAT.value,
)
answer = extract_answer_from_response(app, response)
return mcp_types.CallToolResult(content=[mcp_types.TextContent(text=answer, type="text")])
def build_parameter_schema(
app_mode: str,
user_input_form: list[VariableEntity],
parameters_dict: dict[str, str],
) -> dict[str, Any]:
"""Build parameter schema for the tool"""
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
return {
"type": "object",
"properties": {
"query": {"type": "string", "description": "User Input/Question content"},
**parameters,
},
"required": ["query", *required],
"properties": parameters,
"required": required,
}
return {
"type": "object",
"properties": {
"query": {"type": "string", "description": "User Input/Question content"},
**parameters,
},
"required": ["query", *required],
}
@property
def capabilities(self):
return types.ServerCapabilities(
tools=types.ToolsCapability(listChanged=False),
)
def response(self, response: types.Result | str):
if isinstance(response, str):
sse_content = f"event: ping\ndata: {response}\n\n".encode()
yield sse_content
return
json_response = types.JSONRPCResponse(
jsonrpc="2.0",
id=(self.request.root.model_extra or {}).get("id", 1),
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
)
json_data = json.dumps(jsonable_encoder(json_response))
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
"""Prepare arguments based on app mode"""
if app.mode == AppMode.WORKFLOW.value:
return {"inputs": arguments}
elif app.mode == AppMode.COMPLETION.value:
return {"query": "", "inputs": arguments}
else:
# Chat modes - create a copy to avoid modifying original dict
args_copy = arguments.copy()
query = args_copy.pop("query", "")
return {"query": query, "inputs": args_copy}
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
yield sse_content
def extract_answer_from_response(app: App, response: Any) -> str:
"""Extract answer from app generate response"""
answer = ""
def error_response(self, code: int, message: str, data=None):
request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
return create_mcp_error_response(request_id, code, message, data)
if isinstance(response, RateLimitGenerator):
answer = process_streaming_response(response)
elif isinstance(response, Mapping):
answer = process_mapping_response(app, response)
else:
logger.warning("Unexpected response type: %s", type(response))
def handle(self):
handle_map = {
types.InitializeRequest: self.initialize,
types.ListToolsRequest: self.list_tools,
types.CallToolRequest: self.invoke_tool,
types.InitializedNotification: self.handle_notification,
types.PingRequest: self.handle_ping,
}
try:
if self.request_type in handle_map:
return self.response(handle_map[self.request_type]())
else:
return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
except ValueError as e:
logger.exception("Invalid params")
return self.error_response(INVALID_PARAMS, str(e))
except Exception as e:
logger.exception("Internal server error")
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
return answer
def handle_notification(self):
return "ping"
def handle_ping(self):
return types.EmptyResult()
def initialize(self):
request = cast(types.InitializeRequest, self.request.root)
client_info = request.params.clientInfo
client_name = f"{client_info.name}@{client_info.version}"
if not self.end_user:
end_user = EndUser(
tenant_id=self.app.tenant_id,
app_id=self.app.id,
type="mcp",
name=client_name,
session_id=generate_session_id(),
external_user_id=self.mcp_server.id,
)
db.session.add(end_user)
db.session.commit()
return types.InitializeResult(
protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION,
capabilities=self.capabilities,
serverInfo=types.Implementation(name="Dify", version=dify_config.project.version),
instructions=self.mcp_server.description,
)
def list_tools(self):
if not self.end_user:
raise ValueError("User not found")
return types.ListToolsResult(
tools=[
types.Tool(
name=self.app.name,
description=self.mcp_server.description,
inputSchema=self.parameter_schema,
)
],
)
def invoke_tool(self):
if not self.end_user:
raise ValueError("User not found")
request = cast(types.CallToolRequest, self.request.root)
args = request.params.arguments or {}
if self.app.mode in {AppMode.WORKFLOW.value}:
args = {"inputs": args}
elif self.app.mode in {AppMode.COMPLETION.value}:
args = {"query": "", "inputs": args}
else:
args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
response = AppGenerateService.generate(
self.app,
self.end_user,
args,
InvokeFrom.SERVICE_API,
streaming=self.app.mode == AppMode.AGENT_CHAT.value,
)
answer = ""
if isinstance(response, RateLimitGenerator):
for item in response.generator:
data = item
if isinstance(data, str) and data.startswith("data: "):
try:
json_str = data[6:].strip()
parsed_data = json.loads(json_str)
if parsed_data.get("event") == "agent_thought":
answer += parsed_data.get("thought", "")
except json.JSONDecodeError:
continue
if isinstance(response, Mapping):
if self.app.mode in {
AppMode.ADVANCED_CHAT.value,
AppMode.COMPLETION.value,
AppMode.CHAT.value,
AppMode.AGENT_CHAT.value,
}:
answer = response["answer"]
elif self.app.mode in {AppMode.WORKFLOW.value}:
answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode")
# Not support image yet
return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
def retrieve_end_user(self):
return (
db.session.query(EndUser)
.where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first()
)
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
parameters: dict[str, dict[str, Any]] = {}
required = []
for item in user_input_form:
parameters[item.variable] = {}
if item.type in (
VariableEntityType.FILE,
VariableEntityType.FILE_LIST,
VariableEntityType.EXTERNAL_DATA_TOOL,
):
continue
if item.required:
required.append(item.variable)
# if the workflow republished, the parameters not changed
# we should not raise error here
def process_streaming_response(response: RateLimitGenerator) -> str:
"""Process streaming response for agent chat mode"""
answer = ""
for item in response.generator:
if isinstance(item, str) and item.startswith("data: "):
try:
description = self.mcp_server.parameters_dict[item.variable]
except KeyError:
description = ""
parameters[item.variable]["description"] = description
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
parameters[item.variable]["type"] = "string"
elif item.type == VariableEntityType.SELECT:
parameters[item.variable]["type"] = "string"
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "float"
return parameters, required
json_str = item[6:].strip()
parsed_data = json.loads(json_str)
if parsed_data.get("event") == "agent_thought":
answer += parsed_data.get("thought", "")
except json.JSONDecodeError:
continue
return answer
def process_mapping_response(app: App, response: Mapping) -> str:
"""Process mapping response based on app mode"""
if app.mode in {
AppMode.ADVANCED_CHAT.value,
AppMode.COMPLETION.value,
AppMode.CHAT.value,
AppMode.AGENT_CHAT.value,
}:
return response.get("answer", "")
elif app.mode == AppMode.WORKFLOW.value:
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode: " + str(app.mode))
def convert_input_form_to_parameters(
user_input_form: list[VariableEntity],
parameters_dict: dict[str, str],
) -> tuple[dict[str, dict[str, Any]], list[str]]:
"""Convert user input form to parameter schema"""
parameters: dict[str, dict[str, Any]] = {}
required = []
for item in user_input_form:
if item.type in (
VariableEntityType.FILE,
VariableEntityType.FILE_LIST,
VariableEntityType.EXTERNAL_DATA_TOOL,
):
continue
parameters[item.variable] = {}
if item.required:
required.append(item.variable)
# if the workflow republished, the parameters not changed
# we should not raise error here
description = parameters_dict.get(item.variable, "")
parameters[item.variable]["description"] = description
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
parameters[item.variable]["type"] = "string"
elif item.type == VariableEntityType.SELECT:
parameters[item.variable]["type"] = "string"
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "float"
return parameters, required

View File

@ -138,5 +138,5 @@ def create_mcp_error_response(
error=error_data,
)
json_data = json.dumps(jsonable_encoder(json_response))
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
sse_content = json_data.encode()
yield sse_content

View File

@ -31,6 +31,65 @@ class TokenBufferMemory:
self.conversation = conversation
self.model_instance = model_instance
def _build_prompt_message_with_files(
self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool
) -> PromptMessage:
"""
Build prompt message with files.
:param message_files: list of MessageFile objects
:param text_content: text content of the message
:param message: Message object
:param app_record: app record
:param is_user_message: whether this is a user message
:return: PromptMessage
"""
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_run = db.session.scalar(select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id))
if not workflow_run:
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
if not workflow:
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
else:
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
detail = ImagePromptMessageContent.DETAIL.HIGH
if file_extra_config and app_record:
# Build files directly without filtering by belongs_to
file_objs = [
file_factory.build_from_message_file(
message_file=message_file, tenant_id=app_record.tenant_id, config=file_extra_config
)
for message_file in message_files
]
if file_extra_config.image_config and file_extra_config.image_config.detail:
detail = file_extra_config.image_config.detail
else:
file_objs = []
if not file_objs:
if is_user_message:
return UserPromptMessage(content=text_content)
else:
return AssistantPromptMessage(content=text_content)
else:
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in file_objs:
prompt_message = file_manager.to_prompt_message_content(
file,
image_detail_config=detail,
)
prompt_message_contents.append(prompt_message)
prompt_message_contents.append(TextPromptMessageContent(data=text_content))
if is_user_message:
return UserPromptMessage(content=prompt_message_contents)
else:
return AssistantPromptMessage(content=prompt_message_contents)
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> Sequence[PromptMessage]:
@ -67,52 +126,46 @@ class TokenBufferMemory:
prompt_messages: list[PromptMessage] = []
for message in messages:
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
if files:
file_extra_config = None
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_run = db.session.scalar(
select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id)
)
if not workflow_run:
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
if not workflow:
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
else:
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
detail = ImagePromptMessageContent.DETAIL.LOW
if file_extra_config and app_record:
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
)
if file_extra_config.image_config and file_extra_config.image_config.detail:
detail = file_extra_config.image_config.detail
else:
file_objs = []
if not file_objs:
prompt_messages.append(UserPromptMessage(content=message.query))
else:
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in file_objs:
prompt_message = file_manager.to_prompt_message_content(
file,
image_detail_config=detail,
)
prompt_message_contents.append(prompt_message)
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
# Process user message with files
user_files = (
db.session.query(MessageFile)
.where(
MessageFile.message_id == message.id,
(MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)),
)
.all()
)
if user_files:
user_prompt_message = self._build_prompt_message_with_files(
message_files=user_files,
text_content=message.query,
message=message,
app_record=app_record,
is_user_message=True,
)
prompt_messages.append(user_prompt_message)
else:
prompt_messages.append(UserPromptMessage(content=message.query))
prompt_messages.append(AssistantPromptMessage(content=message.answer))
# Process assistant message with files
assistant_files = (
db.session.query(MessageFile)
.where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
.all()
)
if assistant_files:
assistant_prompt_message = self._build_prompt_message_with_files(
message_files=assistant_files,
text_content=message.answer,
message=message,
app_record=app_record,
is_user_message=False,
)
prompt_messages.append(assistant_prompt_message)
else:
prompt_messages.append(AssistantPromptMessage(content=message.answer))
if not prompt_messages:
return []

View File

@ -87,6 +87,7 @@ class MultiModalPromptMessageContent(PromptMessageContent):
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
url: str = Field(default="", description="the url of multi-modal file")
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
filename: str = Field(default="", description="the filename of multi-modal file")
@property
def data(self):

View File

@ -43,7 +43,7 @@ class GPT2Tokenizer:
except Exception:
from os.path import abspath, dirname, join
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")

View File

@ -330,7 +330,7 @@ class OpsTraceManager:
except KeyError:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
else:
if tracing_provider is not None:
if tracing_provider is None:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first()

View File

@ -375,16 +375,16 @@ Here is the extra instruction you need to follow:
# merge lines into messages with max tokens
messages: list[str] = []
for i in new_lines: # type: ignore
for line in new_lines:
if len(messages) == 0:
messages.append(i) # type: ignore
messages.append(line)
else:
if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
messages[-1] += i # type: ignore
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
messages.append(i) # type: ignore
if len(messages[-1]) + len(line) < max_tokens * 0.5:
messages[-1] += line
if get_prompt_tokens(messages[-1] + line) > max_tokens * 0.7:
messages.append(line)
else:
messages[-1] += i # type: ignore
messages[-1] += line
summaries = []
for i in range(len(messages)):

View File

@ -24,7 +24,7 @@ default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}

View File

@ -256,7 +256,7 @@ class AnalyticdbVectorOpenAPI:
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
if match.score >= score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
@ -293,7 +293,7 @@ class AnalyticdbVectorOpenAPI:
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
if match.score >= score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(

View File

@ -3,8 +3,8 @@ import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from core.rag.models.document import Document
@ -229,7 +229,7 @@ class AnalyticdbVectorBySql:
documents = []
for record in cur:
id, vector, score, page_content, metadata = record
if score > score_threshold:
if score >= score_threshold:
metadata["score"] = score
doc = Document(
page_content=page_content,

View File

@ -157,7 +157,7 @@ class BaiduVector(BaseVector):
if meta is not None:
meta = json.loads(meta)
score = row.get("score", 0.0)
if score > score_threshold:
if score >= score_threshold:
meta["score"] = score
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
docs.append(doc)

View File

@ -120,7 +120,7 @@ class ChromaVector(BaseVector):
distance = distances[index]
metadata = dict(metadatas[index])
score = 1 - distance
if score > score_threshold:
if score >= score_threshold:
metadata["score"] = score
doc = Document(
page_content=documents[index],

View File

@ -304,7 +304,7 @@ class CouchbaseVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 2)
top_k = kwargs.get("top_k", 4)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
search_iter = self._scope.search(

View File

@ -216,7 +216,7 @@ class ElasticSearchVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)

View File

@ -127,7 +127,7 @@ class HuaweiCloudVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)

View File

@ -275,7 +275,7 @@ class LindormVectorStore(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
if score > score_threshold:
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)

View File

@ -3,8 +3,8 @@ import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from configs import dify_config
@ -194,7 +194,7 @@ class OpenGauss(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs

View File

@ -48,7 +48,7 @@ class OpenSearchConfig(BaseModel):
return values
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
import boto3 # type: ignore
import boto3
return Urllib3AWSV4SignerAuth(
credentials=boto3.Session().get_credentials(),
@ -211,7 +211,7 @@ class OpenSearchVector(BaseVector):
metadata["score"] = hit["_score"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if hit["_score"] > score_threshold:
if hit["_score"] >= score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)

View File

@ -261,7 +261,7 @@ class OracleVector(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
conn.close()
return docs

View File

@ -202,7 +202,7 @@ class PGVectoRS(BaseVector):
score = 1 - dis
metadata["score"] = score
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
if score >= score_threshold:
doc = Document(page_content=record.text, metadata=metadata)
docs.append(doc)
return docs

View File

@ -6,8 +6,8 @@ from contextlib import contextmanager
from typing import Any
import psycopg2.errors
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from configs import dify_config
@ -195,7 +195,7 @@ class PGVector(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs

View File

@ -3,8 +3,8 @@ import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from configs import dify_config
@ -170,7 +170,7 @@ class VastbaseVector(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs

View File

@ -369,7 +369,7 @@ class QdrantVector(BaseVector):
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
if result.score > score_threshold:
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),

View File

@ -233,7 +233,7 @@ class RelytVector(BaseVector):
docs = []
for document, score in results:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if 1 - score > score_threshold:
if 1 - score >= score_threshold:
docs.append(document)
return docs

View File

@ -300,7 +300,7 @@ class TableStoreVector(BaseVector):
)
documents = []
for search_hit in search_response.search_hits:
if search_hit.score > score_threshold:
if search_hit.score >= score_threshold:
ots_column_map = {}
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]

View File

@ -291,7 +291,7 @@ class TencentVector(BaseVector):
score = 1 - result.get("score", 0.0)
else:
score = result.get("score", 0.0)
if score > score_threshold:
if score >= score_threshold:
meta["score"] = score
doc = Document(page_content=result.get(self.field_text), metadata=meta)
docs.append(doc)

View File

@ -351,7 +351,7 @@ class TidbOnQdrantVector(BaseVector):
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold") or 0.0
if result.score > score_threshold:
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),

View File

@ -110,7 +110,7 @@ class UpstashVector(BaseVector):
score = record.score
if metadata is not None and text is not None:
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs

View File

@ -192,7 +192,7 @@ class VikingDBVector(BaseVector):
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
if metadata is not None:
metadata = json.loads(metadata)
if result.score > score_threshold:
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)

View File

@ -220,7 +220,7 @@ class WeaviateVector(BaseVector):
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# check score threshold
if score > score_threshold:
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)

View File

@ -4,7 +4,7 @@ import os
from typing import Optional, cast
import pandas as pd
from openpyxl import load_workbook # type: ignore
from openpyxl import load_workbook
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document

View File

@ -1,6 +1,6 @@
"""Abstract interface for document loader implementations."""
from bs4 import BeautifulSoup # type: ignore
from bs4 import BeautifulSoup
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document

View File

@ -3,7 +3,7 @@ import contextlib
import logging
from typing import Optional
from bs4 import BeautifulSoup # type: ignore
from bs4 import BeautifulSoup
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document

View File

@ -123,7 +123,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
for result in results:
metadata = result.metadata
metadata["score"] = result.score
if result.score > score_threshold:
if result.score >= score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs

View File

@ -162,7 +162,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
for result in results:
metadata = result.metadata
metadata["score"] = result.score
if result.score > score_threshold:
if result.score >= score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs

View File

@ -158,7 +158,7 @@ class QAIndexProcessor(BaseIndexProcessor):
for result in results:
metadata = result.metadata
metadata["score"] = result.score
if result.score > score_threshold:
if result.score >= score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs

View File

@ -65,7 +65,7 @@ default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}
@ -647,7 +647,7 @@ class DatasetRetrieval:
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
top_k=retrieval_model.get("top_k") or 4,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
@ -743,7 +743,7 @@ class DatasetRetrieval:
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,
top_k=retrieve_config.top_k or 2,
top_k=retrieve_config.top_k or 4,
score_threshold=retrieve_config.score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,

View File

@ -144,7 +144,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
"""Text splitter that uses HuggingFace tokenizer to count length."""
try:
from transformers import PreTrainedTokenizerBase # type: ignore
from transformers import PreTrainedTokenizerBase
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase")

View File

@ -181,7 +181,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
retrieval_method="keyword_search",
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
top_k=retrieval_model.get("top_k") or 4,
)
if documents:
all_documents.extend(documents)
@ -192,7 +192,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
top_k=retrieval_model.get("top_k") or 4,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,

View File

@ -13,7 +13,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
name: str = "dataset"
description: str = "use this to retrieve a dataset. "
tenant_id: str
top_k: int = 2
top_k: int = 4
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool

View File

@ -6,7 +6,7 @@ from typing import Optional
from flask import request
from requests import get
from yaml import YAMLError, safe_load # type: ignore
from yaml import YAMLError, safe_load
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle

View File

@ -166,7 +166,7 @@ class BaseIterationEvent(GraphEngineEvent):
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
"""iteration run in parallel mode run id"""
class IterationRunStartedEvent(BaseIterationEvent):

View File

@ -149,9 +149,6 @@ class AnswerStreamProcessor(StreamProcessor):
return []
stream_output_value_selector = event.from_variable_selector
if not stream_output_value_selector:
return []
stream_out_answer_node_ids = []
for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:

View File

@ -515,14 +515,14 @@ def _extract_text_from_excel(file_content: bytes) -> str:
df.dropna(how="all", inplace=True)
# Combine multi-line text in each cell into a single line
df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore
df = df.map(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x)
# Combine multi-line text in column names into a single line
df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns])
# Manually construct the Markdown table
markdown_table += _construct_markdown_table(df) + "\n\n"
except Exception as e:
except Exception:
continue
return markdown_table
except Exception as e:

View File

@ -78,7 +78,7 @@ default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}

View File

@ -5,7 +5,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
# register blueprint routers
from flask_cors import CORS # type: ignore
from flask_cors import CORS
from controllers.console import bp as console_app_bp
from controllers.files import bp as files_bp

View File

@ -9,7 +9,7 @@ from typing import Union
import flask
from celery.signals import worker_init
from flask_login import user_loaded_from_request, user_logged_in # type: ignore
from flask_login import user_loaded_from_request, user_logged_in
from configs import dify_config
from dify_app import DifyApp

View File

@ -1,9 +1,9 @@
import logging
from collections.abc import Generator
import boto3 # type: ignore
from botocore.client import Config # type: ignore
from botocore.exceptions import ClientError # type: ignore
import boto3
from botocore.client import Config
from botocore.exceptions import ClientError
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -41,8 +41,14 @@ def build_from_message_file(
"url": message_file.url,
"id": message_file.id,
"type": message_file.type,
"upload_file_id": message_file.upload_file_id,
}
# Set the correct ID field based on transfer method
if message_file.transfer_method == FileTransferMethod.TOOL_FILE.value:
mapping["tool_file_id"] = message_file.upload_file_id
else:
mapping["upload_file_id"] = message_file.upload_file_id
return build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
@ -318,6 +324,11 @@ def _is_file_valid_with_config(
file_transfer_method: FileTransferMethod,
config: FileUploadConfig,
) -> bool:
# FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model)
# These are internally generated and should bypass user upload restrictions
if file_transfer_method == FileTransferMethod.TOOL_FILE:
return True
if (
config.allowed_file_types
and input_file_type not in config.allowed_file_types

11
api/libs/orjson.py Normal file
View File

@ -0,0 +1,11 @@
from typing import Any, Optional
import orjson
def orjson_dumps(
obj: Any,
encoding: str = "utf-8",
option: Optional[int] = None,
) -> str:
return orjson.dumps(obj, option=option).decode(encoding)

View File

@ -5,7 +5,7 @@ Revises: 8bcc02c9bd07
Create Date: 2025-08-09 15:53:54.341341
"""
from alembic import op
from alembic import op, context
from libs.uuid_utils import uuidv7
import models as models
import sqlalchemy as sa
@ -43,7 +43,15 @@ def upgrade():
with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True))
migrate_existing_providers_data()
if not context.is_offline_mode():
migrate_existing_providers_data()
else:
op.execute(
'-- [IMPORTANT] Data migration skipped!!!\n'
"-- You should manually run data migration function `migrate_existing_providers_data`\n"
f"-- inside file {__file__}\n"
"-- Please review the migration script carefully!"
)
# Remove encrypted_config column from providers table after migration
with op.batch_alter_table('providers', schema=None) as batch_op:
@ -119,7 +127,16 @@ def downgrade():
batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True))
# Migrate data back from provider_credentials to providers
migrate_data_back_to_providers()
if not context.is_offline_mode():
migrate_data_back_to_providers()
else:
op.execute(
'-- [IMPORTANT] Data migration skipped!!!\n'
"-- You should manually run data migration function `migrate_data_back_to_providers`\n"
f"-- inside file {__file__}\n"
"-- Please review the migration script carefully!"
)
# Remove credential_id columns
with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:

View File

@ -6,7 +6,7 @@ Create Date: 2025-08-13 16:05:42.657730
"""
from alembic import op
from alembic import op, context
from libs.uuid_utils import uuidv7
import models as models
import sqlalchemy as sa
@ -48,8 +48,16 @@ def upgrade():
with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('credential_source_type', sa.String(length=40), nullable=True))
# Migrate existing provider_models data
migrate_existing_provider_models_data()
if not context.is_offline_mode():
# Migrate existing provider_models data
migrate_existing_provider_models_data()
else:
op.execute(
'-- [IMPORTANT] Data migration skipped!!!\n'
"-- You should manually run data migration function `migrate_existing_provider_models_data`\n"
f"-- inside file {__file__}\n"
"-- Please review the migration script carefully!"
)
# Remove encrypted_config column from provider_models table after migration
with op.batch_alter_table('provider_models', schema=None) as batch_op:
@ -132,8 +140,16 @@ def downgrade():
with op.batch_alter_table('provider_models', schema=None) as batch_op:
batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True))
# Migrate data back from provider_model_credentials to provider_models
migrate_data_back_to_provider_models()
if not context.is_offline_mode():
# Migrate data back from provider_model_credentials to provider_models
migrate_data_back_to_provider_models()
else:
op.execute(
'-- [IMPORTANT] Data migration skipped!!!\n'
"-- You should manually run data migration function `migrate_data_back_to_provider_models`\n"
f"-- inside file {__file__}\n"
"-- Please review the migration script carefully!"
)
with op.batch_alter_table('provider_models', schema=None) as batch_op:
batch_op.drop_column('credential_id')

View File

@ -0,0 +1,45 @@
"""empty message
Revision ID: 8d289573e1da
Revises: 0e154742a5fa
Create Date: 2025-08-20 17:47:17.015695
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '8d289573e1da'
down_revision = '0e154742a5fa'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('oauth_provider_apps',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
sa.Column('app_icon', sa.String(length=255), nullable=False),
sa.Column('app_label', sa.JSON(), server_default='{}', nullable=False),
sa.Column('client_id', sa.String(length=255), nullable=False),
sa.Column('client_secret', sa.String(length=255), nullable=False),
sa.Column('redirect_uris', sa.JSON(), server_default='[]', nullable=False),
sa.Column('scope', sa.String(length=255), server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='oauth_provider_app_pkey')
)
with op.batch_alter_table('oauth_provider_apps', schema=None) as batch_op:
batch_op.create_index('oauth_provider_app_client_id_idx', ['client_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('oauth_provider_apps', schema=None) as batch_op:
batch_op.drop_index('oauth_provider_app_client_id_idx')
op.drop_table('oauth_provider_apps')
# ### end Alembic commands ###

View File

@ -1,12 +1,12 @@
import enum
import json
from datetime import datetime
from typing import Optional, cast
from typing import Optional
import sqlalchemy as sa
from flask_login import UserMixin # type: ignore
from flask_login import UserMixin
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, mapped_column, reconstructor
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor
from models.base import Base
@ -118,10 +118,24 @@ class Account(UserMixin, Base):
@current_tenant.setter
def current_tenant(self, tenant: "Tenant"):
ta = db.session.scalar(select(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).limit(1))
if ta:
self.role = TenantAccountRole(ta.role)
self._current_tenant = tenant
with Session(db.engine, expire_on_commit=False) as session:
tenant_join_query = select(TenantAccountJoin).where(
TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == self.id
)
tenant_join = session.scalar(tenant_join_query)
tenant_query = select(Tenant).where(Tenant.id == tenant.id)
# TODO: A workaround to reload the tenant with `expire_on_commit=False`, allowing
# access to it after the session has been closed.
# This prevents `DetachedInstanceError` when accessing the tenant outside
# the session's lifecycle.
# (The `tenant` argument is typically loaded by `db.session` without the
# `expire_on_commit=False` flag, meaning its lifetime is tied to the web
# request's lifecycle.)
tenant_reloaded = session.scalars(tenant_query).one()
if tenant_join:
self.role = TenantAccountRole(tenant_join.role)
self._current_tenant = tenant_reloaded
return
self._current_tenant = None
@ -130,23 +144,19 @@ class Account(UserMixin, Base):
return self._current_tenant.id if self._current_tenant else None
def set_tenant_id(self, tenant_id: str):
tenant_account_join = cast(
tuple[Tenant, TenantAccountJoin],
(
db.session.query(Tenant, TenantAccountJoin)
.where(Tenant.id == tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.account_id == self.id)
.one_or_none()
),
query = (
select(Tenant, TenantAccountJoin)
.where(Tenant.id == tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.account_id == self.id)
)
if not tenant_account_join:
return
tenant, join = tenant_account_join
self.role = TenantAccountRole(join.role)
self._current_tenant = tenant
with Session(db.engine, expire_on_commit=False) as session:
tenant_account_join = session.execute(query).first()
if not tenant_account_join:
return
tenant, join = tenant_account_join
self.role = TenantAccountRole(join.role)
self._current_tenant = tenant
@property
def current_role(self):

View File

@ -522,33 +522,6 @@ class AppModelConfig(Base):
self.file_upload = json.dumps(model_config.get("file_upload")) if model_config.get("file_upload") else None
return self
def copy(self):
new_app_model_config = AppModelConfig(
id=self.id,
app_id=self.app_id,
opening_statement=self.opening_statement,
suggested_questions=self.suggested_questions,
suggested_questions_after_answer=self.suggested_questions_after_answer,
speech_to_text=self.speech_to_text,
text_to_speech=self.text_to_speech,
more_like_this=self.more_like_this,
sensitive_word_avoidance=self.sensitive_word_avoidance,
external_data_tools=self.external_data_tools,
model=self.model,
user_input_form=self.user_input_form,
dataset_query_variable=self.dataset_query_variable,
pre_prompt=self.pre_prompt,
agent_mode=self.agent_mode,
retriever_resource=self.retriever_resource,
prompt_type=self.prompt_type,
chat_prompt_config=self.chat_prompt_config,
completion_prompt_config=self.completion_prompt_config,
dataset_configs=self.dataset_configs,
file_upload=self.file_upload,
)
return new_app_model_config
class RecommendedApp(Base):
__tablename__ = "recommended_apps"
@ -607,6 +580,32 @@ class InstalledApp(Base):
return tenant
class OAuthProviderApp(Base):
"""
Globally shared OAuth provider app information.
Only for Dify Cloud.
"""
__tablename__ = "oauth_provider_apps"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="oauth_provider_app_pkey"),
sa.Index("oauth_provider_app_client_id_idx", "client_id"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
app_icon = mapped_column(String(255), nullable=False)
app_label = mapped_column(sa.JSON, nullable=False, server_default="{}")
client_id = mapped_column(String(255), nullable=False)
client_secret = mapped_column(String(255), nullable=False)
redirect_uris = mapped_column(sa.JSON, nullable=False, server_default="[]")
scope = mapped_column(
String(255),
nullable=False,
server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"),
)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
class Conversation(Base):
__tablename__ = "conversations"
__table_args__ = (

View File

@ -67,7 +67,7 @@ dependencies = [
"pydantic~=2.11.4",
"pydantic-extra-types~=2.10.3",
"pydantic-settings~=2.9.1",
"pyjwt~=2.8.0",
"pyjwt~=2.10.1",
"pypdfium2==4.30.0",
"python-docx~=1.1.0",
"python-dotenv==1.0.1",
@ -179,7 +179,7 @@ storage = [
"google-cloud-storage==2.16.0",
"opendal~=0.45.16",
"oss2==2.18.5",
"supabase~=2.8.1",
"supabase~=2.18.1",
"tos~=2.7.1",
]

View File

@ -20,7 +20,7 @@ def check_upgradable_plugin_task():
strategies = (
db.session.query(TenantPluginAutoUpgradeStrategy)
.filter(
.where(
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day >= now_seconds_of_day,
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day
< now_seconds_of_day + AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL,

View File

@ -93,7 +93,7 @@ def _delete_batch_with_retry(workflow_run_ids: list[str], attempt_count: int) ->
with db.session.begin_nested():
message_data = (
db.session.query(Message.id, Message.conversation_id)
.filter(Message.workflow_run_id.in_(workflow_run_ids))
.where(Message.workflow_run_id.in_(workflow_run_ids))
.all()
)
message_id_list = [msg.id for msg in message_data]

View File

@ -282,7 +282,7 @@ class AppAnnotationService:
annotations_to_delete = (
db.session.query(MessageAnnotation, AppAnnotationSetting)
.outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id)
.filter(MessageAnnotation.id.in_(annotation_ids))
.where(MessageAnnotation.id.in_(annotation_ids))
.all()
)
@ -493,7 +493,7 @@ class AppAnnotationService:
def clear_all_annotations(cls, app_id: str) -> dict:
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)

View File

@ -55,7 +55,7 @@ class AppGenerateService:
cls.system_rate_limiter.increment_rate_limit(app_model.tenant_id)
# app level rate limiter
max_active_request = AppGenerateService._get_max_active_requests(app_model)
max_active_request = cls._get_max_active_requests(app_model)
rate_limit = RateLimit(app_model.id, max_active_request)
request_id = RateLimit.gen_request_key()
try:

View File

@ -62,7 +62,7 @@ class ClearFreePlanTenantExpiredLogs:
# Query records related to expired messages
records = (
session.query(model)
.filter(
.where(
model.message_id.in_(batch_message_ids), # type: ignore
)
.all()
@ -101,7 +101,7 @@ class ClearFreePlanTenantExpiredLogs:
except Exception:
logger.exception("Failed to save %s records", table_name)
session.query(model).filter(
session.query(model).where(
model.id.in_(record_ids), # type: ignore
).delete(synchronize_session=False)
@ -295,7 +295,7 @@ class ClearFreePlanTenantExpiredLogs:
with Session(db.engine).no_autoflush as session:
workflow_app_logs = (
session.query(WorkflowAppLog)
.filter(
.where(
WorkflowAppLog.tenant_id == tenant_id,
WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
@ -321,9 +321,9 @@ class ClearFreePlanTenantExpiredLogs:
workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs]
# delete workflow app logs
session.query(WorkflowAppLog).filter(
WorkflowAppLog.id.in_(workflow_app_log_ids),
).delete(synchronize_session=False)
session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete(
synchronize_session=False
)
session.commit()
click.echo(

View File

@ -1149,7 +1149,7 @@ class DocumentService:
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}
@ -1612,7 +1612,7 @@ class DocumentService:
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
reranking_enable=False,
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
top_k=2,
top_k=4,
score_threshold_enabled=False,
)
# save dataset
@ -2346,7 +2346,7 @@ class SegmentService:
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
segments = (
db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count)
.filter(
.where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,

View File

@ -18,7 +18,7 @@ default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}
@ -66,7 +66,7 @@ class HitTestingService:
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k", 2),
top_k=retrieval_model.get("top_k", 4),
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,

View File

@ -0,0 +1,94 @@
import enum
import uuid
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account
from models.model import OAuthProviderApp
from services.account_service import AccountService
class OAuthGrantType(enum.StrEnum):
AUTHORIZATION_CODE = "authorization_code"
REFRESH_TOKEN = "refresh_token"
OAUTH_AUTHORIZATION_CODE_REDIS_KEY = "oauth_provider:{client_id}:authorization_code:{code}"
OAUTH_ACCESS_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:access_token:{token}"
OAUTH_ACCESS_TOKEN_EXPIRES_IN = 60 * 60 * 12 # 12 hours
OAUTH_REFRESH_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:refresh_token:{token}"
OAUTH_REFRESH_TOKEN_EXPIRES_IN = 60 * 60 * 24 * 30 # 30 days
class OAuthServerService:
@staticmethod
def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
with Session(db.engine) as session:
return session.execute(query).scalar_one_or_none()
@staticmethod
def sign_oauth_authorization_code(client_id: str, user_account_id: str) -> str:
code = str(uuid.uuid4())
redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
redis_client.set(redis_key, user_account_id, ex=60 * 10) # 10 minutes
return code
@staticmethod
def sign_oauth_access_token(
grant_type: OAuthGrantType,
code: str = "",
client_id: str = "",
refresh_token: str = "",
) -> tuple[str, str]:
match grant_type:
case OAuthGrantType.AUTHORIZATION_CODE:
redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
user_account_id = redis_client.get(redis_key)
if not user_account_id:
raise BadRequest("invalid code")
# delete code
redis_client.delete(redis_key)
access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
refresh_token = OAuthServerService._sign_oauth_refresh_token(client_id, user_account_id)
return access_token, refresh_token
case OAuthGrantType.REFRESH_TOKEN:
redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=refresh_token)
user_account_id = redis_client.get(redis_key)
if not user_account_id:
raise BadRequest("invalid refresh token")
access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
return access_token, refresh_token
@staticmethod
def _sign_oauth_access_token(client_id: str, user_account_id: str) -> str:
token = str(uuid.uuid4())
redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
redis_client.set(redis_key, user_account_id, ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN)
return token
@staticmethod
def _sign_oauth_refresh_token(client_id: str, user_account_id: str) -> str:
token = str(uuid.uuid4())
redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
redis_client.set(redis_key, user_account_id, ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN)
return token
@staticmethod
def validate_oauth_access_token(client_id: str, token: str) -> Account | None:
redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
user_account_id = redis_client.get(redis_key)
if not user_account_id:
return None
user_id_str = user_account_id.decode("utf-8")
return AccountService.load_user(user_id_str)

View File

@ -10,7 +10,7 @@ class PluginAutoUpgradeService:
with Session(db.engine) as session:
return (
session.query(TenantPluginAutoUpgradeStrategy)
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
)
@ -26,7 +26,7 @@ class PluginAutoUpgradeService:
with Session(db.engine) as session:
exist_strategy = (
session.query(TenantPluginAutoUpgradeStrategy)
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
)
if not exist_strategy:
@ -54,7 +54,7 @@ class PluginAutoUpgradeService:
with Session(db.engine) as session:
exist_strategy = (
session.query(TenantPluginAutoUpgradeStrategy)
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
)
if not exist_strategy:

View File

@ -2,7 +2,7 @@ import logging
import time
import click
from celery import shared_task # type: ignore
from celery import shared_task
from extensions.ext_database import db
from models import ConversationVariable

View File

@ -1,6 +1,6 @@
import time
import psycopg2 # type: ignore
import psycopg2
from core.rag.datasource.vdb.opengauss.opengauss import OpenGauss, OpenGaussConfig
from tests.integration_tests.vdb.test_vector_store import (

View File

@ -674,7 +674,7 @@ class TestAnnotationService:
history = (
db.session.query(AppAnnotationHitHistory)
.filter(
.where(
AppAnnotationHitHistory.annotation_id == annotation.id, AppAnnotationHitHistory.message_id == message_id
)
.first()

View File

@ -166,7 +166,7 @@ class TestAppDslService:
assert result.imported_dsl_version == ""
# Verify no app was created in database
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count()
assert apps_count == 1 # Only the original test app
def test_import_app_missing_yaml_url(self, db_session_with_containers, mock_external_service_dependencies):
@ -191,7 +191,7 @@ class TestAppDslService:
assert result.imported_dsl_version == ""
# Verify no app was created in database
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count()
assert apps_count == 1 # Only the original test app
def test_import_app_invalid_import_mode(self, db_session_with_containers, mock_external_service_dependencies):
@ -215,7 +215,7 @@ class TestAppDslService:
)
# Verify no app was created in database
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count()
assert apps_count == 1 # Only the original test app
def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies):

View File

@ -0,0 +1,529 @@
from unittest.mock import patch
import pytest
from faker import Faker
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from services.workspace_service import WorkspaceService
class TestWorkspaceService:
"""Integration tests for WorkspaceService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.workspace_service.FeatureService") as mock_feature_service,
patch("services.workspace_service.TenantService") as mock_tenant_service,
patch("services.workspace_service.dify_config") as mock_dify_config,
):
# Setup default mock returns
mock_feature_service.get_features.return_value.can_replace_logo = True
mock_tenant_service.has_roles.return_value = True
mock_dify_config.FILES_URL = "https://example.com/files"
yield {
"feature_service": mock_feature_service,
"tenant_service": mock_tenant_service,
"dify_config": mock_dify_config,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
# Create tenant
tenant = Tenant(
name=fake.company(),
status="normal",
plan="basic",
custom_config='{"replace_webapp_logo": true, "remove_webapp_brand": false}',
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join with owner role
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db.session.add(join)
db.session.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant
def test_get_tenant_info_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful retrieval of tenant information with all features enabled.
This test verifies:
- Proper tenant info retrieval with all required fields
- Correct role assignment from TenantAccountJoin
- Custom config handling when features are enabled
- Logo replacement functionality for privileged users
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Setup mocks for feature service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
# Mock current_user for flask_login
with patch("services.workspace_service.current_user", account):
# Act: Execute the method under test
result = WorkspaceService.get_tenant_info(tenant)
# Assert: Verify the expected outcomes
assert result is not None
assert result["id"] == tenant.id
assert result["name"] == tenant.name
assert result["plan"] == tenant.plan
assert result["status"] == tenant.status
assert result["role"] == TenantAccountRole.OWNER.value
assert result["created_at"] == tenant.created_at
assert result["trial_end_reason"] is None
# Verify custom config is included for privileged users
assert "custom_config" in result
assert result["custom_config"]["remove_webapp_brand"] is False
assert "replace_webapp_logo" in result["custom_config"]
# Verify database state
from extensions.ext_database import db
db.session.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_without_custom_config(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant info retrieval when custom config features are disabled.
This test verifies:
- Tenant info retrieval without custom config when features are disabled
- Proper handling of disabled logo replacement functionality
- Role assignment still works correctly
- Basic tenant information is complete
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Setup mocks to disable custom config features
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = False
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
# Mock current_user for flask_login
with patch("services.workspace_service.current_user", account):
# Act: Execute the method under test
result = WorkspaceService.get_tenant_info(tenant)
# Assert: Verify the expected outcomes
assert result is not None
assert result["id"] == tenant.id
assert result["name"] == tenant.name
assert result["plan"] == tenant.plan
assert result["status"] == tenant.status
assert result["role"] == TenantAccountRole.OWNER.value
assert result["created_at"] == tenant.created_at
assert result["trial_end_reason"] is None
# Verify custom config is not included when features are disabled
assert "custom_config" not in result
# Verify database state
from extensions.ext_database import db
db.session.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_normal_user_role(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant info retrieval for normal user role without privileged features.
This test verifies:
- Tenant info retrieval for non-privileged users
- Role assignment for normal users
- Custom config is not accessible for normal users
- Proper handling of different user roles
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Update the join to have normal role
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
join.role = TenantAccountRole.NORMAL.value
db.session.commit()
# Setup mocks for feature service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
# Mock current_user for flask_login
with patch("services.workspace_service.current_user", account):
# Act: Execute the method under test
result = WorkspaceService.get_tenant_info(tenant)
# Assert: Verify the expected outcomes
assert result is not None
assert result["id"] == tenant.id
assert result["name"] == tenant.name
assert result["plan"] == tenant.plan
assert result["status"] == tenant.status
assert result["role"] == TenantAccountRole.NORMAL.value
assert result["created_at"] == tenant.created_at
assert result["trial_end_reason"] is None
# Verify custom config is not included for normal users
assert "custom_config" not in result
# Verify database state
db.session.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_admin_role_and_logo_replacement(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant info retrieval for admin role with logo replacement enabled.
This test verifies:
- Admin role can access custom config features
- Logo replacement functionality works for admin users
- Proper URL construction for logo replacement
- Custom config handling for admin role
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Update the join to have admin role
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
join.role = TenantAccountRole.ADMIN.value
db.session.commit()
# Setup mocks for feature service and tenant service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
mock_external_service_dependencies["dify_config"].FILES_URL = "https://cdn.example.com"
# Mock current_user for flask_login
with patch("services.workspace_service.current_user", account):
# Act: Execute the method under test
result = WorkspaceService.get_tenant_info(tenant)
# Assert: Verify the expected outcomes
assert result is not None
assert result["role"] == TenantAccountRole.ADMIN.value
# Verify custom config is included for admin users
assert "custom_config" in result
assert result["custom_config"]["remove_webapp_brand"] is False
assert "replace_webapp_logo" in result["custom_config"]
# Verify database state
db.session.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_tenant_none(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test tenant info retrieval when tenant parameter is None.
This test verifies:
- Proper handling of None tenant parameter
- Method returns None for invalid input
- No exceptions are raised for None input
- Graceful degradation for invalid data
"""
# Arrange: No test data needed for this test
# Act: Execute the method under test with None tenant
result = WorkspaceService.get_tenant_info(None)
# Assert: Verify the expected outcomes
assert result is None
def test_get_tenant_info_with_custom_config_variations(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant info retrieval with various custom config configurations.
This test verifies:
- Different custom config combinations work correctly
- Logo replacement URL construction with various configs
- Brand removal functionality
- Edge cases in custom config handling
"""
# Arrange: Create test data with different custom configs
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Test different custom config combinations
test_configs = [
# Case 1: Both logo and brand removal enabled
{"replace_webapp_logo": True, "remove_webapp_brand": True},
# Case 2: Only logo replacement enabled
{"replace_webapp_logo": True, "remove_webapp_brand": False},
# Case 3: Only brand removal enabled
{"replace_webapp_logo": False, "remove_webapp_brand": True},
# Case 4: Neither enabled
{"replace_webapp_logo": False, "remove_webapp_brand": False},
]
for config in test_configs:
# Update tenant custom config
import json
from extensions.ext_database import db
tenant.custom_config = json.dumps(config)
db.session.commit()
# Setup mocks
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
mock_external_service_dependencies["dify_config"].FILES_URL = "https://files.example.com"
# Mock current_user for flask_login
with patch("services.workspace_service.current_user", account):
# Act: Execute the method under test
result = WorkspaceService.get_tenant_info(tenant)
# Assert: Verify the expected outcomes
assert result is not None
assert "custom_config" in result
if config["replace_webapp_logo"]:
assert "replace_webapp_logo" in result["custom_config"]
if config["replace_webapp_logo"]:
expected_url = f"https://files.example.com/files/workspaces/{tenant.id}/webapp-logo"
assert result["custom_config"]["replace_webapp_logo"] == expected_url
else:
assert result["custom_config"]["replace_webapp_logo"] is None
assert result["custom_config"]["remove_webapp_brand"] == config["remove_webapp_brand"]
# Verify database state
db.session.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_editor_role_and_limited_permissions(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant info retrieval for editor role with limited permissions.
This test verifies:
- Editor role has limited access to custom config features
- Proper role-based permission checking
- Custom config handling for different role levels
- Role hierarchy and permission boundaries
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Update the join to have editor role
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
join.role = TenantAccountRole.EDITOR.value
db.session.commit()
# Setup mocks for feature service and tenant service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
# Editor role should not have admin/owner permissions
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
mock_external_service_dependencies["dify_config"].FILES_URL = "https://cdn.example.com"
# Mock current_user for flask_login
with patch("services.workspace_service.current_user", account):
# Act: Execute the method under test
result = WorkspaceService.get_tenant_info(tenant)
# Assert: Verify the expected outcomes
assert result is not None
assert result["role"] == TenantAccountRole.EDITOR.value
# Verify custom config is not included for editor users without admin privileges
assert "custom_config" not in result
# Verify database state
db.session.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_dataset_operator_role(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant info retrieval for dataset operator role.
This test verifies:
- Dataset operator role handling
- Role assignment for specialized roles
- Permission boundaries for dataset operators
- Custom config access for dataset operators
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Update the join to have dataset operator role
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
join.role = TenantAccountRole.DATASET_OPERATOR.value
db.session.commit()
# Setup mocks for feature service and tenant service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
# Dataset operator should not have admin/owner permissions
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
mock_external_service_dependencies["dify_config"].FILES_URL = "https://cdn.example.com"
# Mock current_user for flask_login
with patch("services.workspace_service.current_user", account):
# Act: Execute the method under test
result = WorkspaceService.get_tenant_info(tenant)
# Assert: Verify the expected outcomes
assert result is not None
assert result["role"] == TenantAccountRole.DATASET_OPERATOR.value
# Verify custom config is not included for dataset operators without admin privileges
assert "custom_config" not in result
# Verify database state
db.session.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_complex_custom_config_scenarios(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant info retrieval with complex custom config scenarios.
This test verifies:
- Complex custom config combinations
- Edge cases in custom config handling
- URL construction with various configs
- Error handling for malformed configs
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Test complex custom config scenarios
test_configs = [
# Case 1: Empty custom config
{},
# Case 2: Custom config with only logo replacement
{"replace_webapp_logo": True},
# Case 3: Custom config with only brand removal
{"remove_webapp_brand": True},
# Case 4: Custom config with additional fields
{
"replace_webapp_logo": True,
"remove_webapp_brand": False,
"custom_field": "custom_value",
"nested_config": {"key": "value"},
},
# Case 5: Custom config with null values
{"replace_webapp_logo": None, "remove_webapp_brand": None},
]
for config in test_configs:
# Update tenant custom config
import json
from extensions.ext_database import db
tenant.custom_config = json.dumps(config)
db.session.commit()
# Setup mocks
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
mock_external_service_dependencies["dify_config"].FILES_URL = "https://files.example.com"
# Mock current_user for flask_login
with patch("services.workspace_service.current_user", account):
# Act: Execute the method under test
result = WorkspaceService.get_tenant_info(tenant)
# Assert: Verify the expected outcomes
assert result is not None
assert "custom_config" in result
# Verify logo replacement handling
if config.get("replace_webapp_logo"):
assert "replace_webapp_logo" in result["custom_config"]
expected_url = f"https://files.example.com/files/workspaces/{tenant.id}/webapp-logo"
assert result["custom_config"]["replace_webapp_logo"] == expected_url
else:
assert result["custom_config"]["replace_webapp_logo"] is None
# Verify brand removal handling
if "remove_webapp_brand" in config:
assert result["custom_config"]["remove_webapp_brand"] == config["remove_webapp_brand"]
else:
assert result["custom_config"]["remove_webapp_brand"] is False
# Verify database state
db.session.refresh(tenant)
assert tenant.id is not None

View File

@ -0,0 +1,550 @@
from unittest.mock import patch
import pytest
from faker import Faker
from models.account import Account, Tenant
from models.tools import ApiToolProvider
from services.tools.api_tools_manage_service import ApiToolManageService
class TestApiToolManageService:
"""Integration tests for ApiToolManageService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.tools.api_tools_manage_service.ToolLabelManager") as mock_tool_label_manager,
patch("services.tools.api_tools_manage_service.create_tool_provider_encrypter") as mock_encrypter,
patch("services.tools.api_tools_manage_service.ApiToolProviderController") as mock_provider_controller,
):
# Setup default mock returns
mock_tool_label_manager.update_tool_labels.return_value = None
mock_encrypter.return_value = (mock_encrypter, None)
mock_encrypter.encrypt.return_value = {"encrypted": "credentials"}
mock_provider_controller.from_db.return_value = mock_provider_controller
mock_provider_controller.load_bundled_tools.return_value = None
yield {
"tool_label_manager": mock_tool_label_manager,
"encrypter": mock_encrypter,
"provider_controller": mock_provider_controller,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
from models.account import TenantAccountJoin, TenantAccountRole
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db.session.add(join)
db.session.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant
def _create_test_openapi_schema(self):
"""Helper method to create a test OpenAPI schema."""
return """
{
"openapi": "3.0.0",
"info": {
"title": "Test API",
"version": "1.0.0",
"description": "Test API for testing purposes"
},
"servers": [
{
"url": "https://api.example.com",
"description": "Production server"
}
],
"paths": {
"/test": {
"get": {
"operationId": "testOperation",
"summary": "Test operation",
"responses": {
"200": {
"description": "Success"
}
}
}
}
}
}
"""
def test_parser_api_schema_success(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful parsing of API schema.
This test verifies:
- Proper schema parsing with valid OpenAPI schema
- Correct credentials schema generation
- Proper warning handling
- Return value structure
"""
# Arrange: Create test schema
schema = self._create_test_openapi_schema()
# Act: Parse the schema
result = ApiToolManageService.parser_api_schema(schema)
# Assert: Verify the result structure
assert result is not None
assert "schema_type" in result
assert "parameters_schema" in result
assert "credentials_schema" in result
assert "warning" in result
# Verify credentials schema structure
credentials_schema = result["credentials_schema"]
assert len(credentials_schema) == 3
# Check auth_type field
auth_type_field = next(field for field in credentials_schema if field["name"] == "auth_type")
assert auth_type_field["required"] is True
assert auth_type_field["default"] == "none"
assert len(auth_type_field["options"]) == 2
# Check api_key_header field
api_key_header_field = next(field for field in credentials_schema if field["name"] == "api_key_header")
assert api_key_header_field["required"] is False
assert api_key_header_field["default"] == "api_key"
# Check api_key_value field
api_key_value_field = next(field for field in credentials_schema if field["name"] == "api_key_value")
assert api_key_value_field["required"] is False
assert api_key_value_field["default"] == ""
def test_parser_api_schema_invalid_schema(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test parsing of invalid API schema.
This test verifies:
- Proper error handling for invalid schemas
- Correct exception type and message
- Error propagation from underlying parser
"""
# Arrange: Create invalid schema
invalid_schema = "invalid json schema"
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
ApiToolManageService.parser_api_schema(invalid_schema)
assert "invalid schema" in str(exc_info.value)
def test_parser_api_schema_malformed_json(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test parsing of malformed JSON schema.
This test verifies:
- Proper error handling for malformed JSON
- Correct exception type and message
- Error propagation from JSON parsing
"""
# Arrange: Create malformed JSON schema
malformed_schema = '{"openapi": "3.0.0", "info": {"title": "Test", "version": "1.0.0"}, "paths": {}}'
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
ApiToolManageService.parser_api_schema(malformed_schema)
assert "invalid schema" in str(exc_info.value)
def test_convert_schema_to_tool_bundles_success(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful conversion of schema to tool bundles.
This test verifies:
- Proper schema conversion with valid OpenAPI schema
- Correct tool bundles generation
- Proper schema type detection
- Return value structure
"""
# Arrange: Create test schema
schema = self._create_test_openapi_schema()
# Act: Convert schema to tool bundles
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema)
# Assert: Verify the result structure
assert tool_bundles is not None
assert isinstance(tool_bundles, list)
assert len(tool_bundles) > 0
assert schema_type is not None
assert isinstance(schema_type, str)
# Verify tool bundle structure
tool_bundle = tool_bundles[0]
assert hasattr(tool_bundle, "operation_id")
assert tool_bundle.operation_id == "testOperation"
def test_convert_schema_to_tool_bundles_with_extra_info(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful conversion of schema to tool bundles with extra info.
This test verifies:
- Proper schema conversion with extra info parameter
- Correct tool bundles generation
- Extra info handling
- Return value structure
"""
# Arrange: Create test schema and extra info
schema = self._create_test_openapi_schema()
extra_info = {"description": "Custom description", "version": "2.0.0"}
# Act: Convert schema to tool bundles with extra info
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
# Assert: Verify the result structure
assert tool_bundles is not None
assert isinstance(tool_bundles, list)
assert len(tool_bundles) > 0
assert schema_type is not None
assert isinstance(schema_type, str)
def test_convert_schema_to_tool_bundles_invalid_schema(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test conversion of invalid schema to tool bundles.
This test verifies:
- Proper error handling for invalid schemas
- Correct exception type and message
- Error propagation from underlying parser
"""
# Arrange: Create invalid schema
invalid_schema = "invalid schema content"
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
ApiToolManageService.convert_schema_to_tool_bundles(invalid_schema)
assert "invalid schema" in str(exc_info.value)
def test_create_api_tool_provider_success(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful creation of API tool provider.
This test verifies:
- Proper provider creation with valid parameters
- Correct database state after creation
- Proper relationship establishment
- External service integration
- Return value correctness
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""}
schema_type = "openapi"
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
labels = ["test", "api"]
# Act: Create API tool provider
result = ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon=icon,
credentials=credentials,
schema_type=schema_type,
schema=schema,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
labels=labels,
)
# Assert: Verify the result
assert result == {"result": "success"}
# Verify database state
from extensions.ext_database import db
provider = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first()
)
assert provider is not None
assert provider.name == provider_name
assert provider.tenant_id == tenant.id
assert provider.user_id == account.id
assert provider.schema_type_str == schema_type
assert provider.privacy_policy == privacy_policy
assert provider.custom_disclaimer == custom_disclaimer
# Verify mock interactions
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once()
mock_external_service_dependencies["encrypter"].assert_called_once()
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
mock_external_service_dependencies["provider_controller"].load_bundled_tools.assert_called_once()
def test_create_api_tool_provider_duplicate_name(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test creation of API tool provider with duplicate name.
This test verifies:
- Proper error handling for duplicate provider names
- Correct exception type and message
- Database constraint enforcement
"""
# Arrange: Create test data and existing provider
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {"auth_type": "none"}
schema_type = "openapi"
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
labels = ["test"]
# Create first provider
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon=icon,
credentials=credentials,
schema_type=schema_type,
schema=schema,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
labels=labels,
)
# Act & Assert: Try to create duplicate provider
with pytest.raises(ValueError) as exc_info:
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon=icon,
credentials=credentials,
schema_type=schema_type,
schema=schema,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
labels=labels,
)
assert f"provider {provider_name} already exists" in str(exc_info.value)
def test_create_api_tool_provider_invalid_schema_type(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test creation of API tool provider with invalid schema type.
This test verifies:
- Proper error handling for invalid schema types
- Correct exception type and message
- Schema type validation
"""
# Arrange: Create test data with invalid schema type
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {"auth_type": "none"}
schema_type = "invalid_type"
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
labels = ["test"]
# Act & Assert: Try to create provider with invalid schema type
with pytest.raises(ValueError) as exc_info:
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon=icon,
credentials=credentials,
schema_type=schema_type,
schema=schema,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
labels=labels,
)
assert "invalid schema type" in str(exc_info.value)
def test_create_api_tool_provider_missing_auth_type(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test creation of API tool provider with missing auth type.
This test verifies:
- Proper error handling for missing auth type
- Correct exception type and message
- Credentials validation
"""
# Arrange: Create test data with missing auth type
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {} # Missing auth_type
schema_type = "openapi"
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
labels = ["test"]
# Act & Assert: Try to create provider with missing auth type
with pytest.raises(ValueError) as exc_info:
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon=icon,
credentials=credentials,
schema_type=schema_type,
schema=schema,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
labels=labels,
)
assert "auth_type is required" in str(exc_info.value)
def test_create_api_tool_provider_with_api_key_auth(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful creation of API tool provider with API key authentication.
This test verifies:
- Proper provider creation with API key auth
- Correct credentials handling
- Proper authentication type processing
"""
# Arrange: Create test data with API key auth
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔑"}
credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()}
schema_type = "openapi"
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
labels = ["api_key", "secure"]
# Act: Create API tool provider
result = ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon=icon,
credentials=credentials,
schema_type=schema_type,
schema=schema,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
labels=labels,
)
# Assert: Verify the result
assert result == {"result": "success"}
# Verify database state
from extensions.ext_database import db
provider = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first()
)
assert provider is not None
assert provider.name == provider_name
assert provider.tenant_id == tenant.id
assert provider.user_id == account.id
assert provider.schema_type_str == schema_type
# Verify mock interactions
mock_external_service_dependencies["encrypter"].assert_called_once()
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()

Some files were not shown because too many files have changed in this diff Show More