mirror of
https://github.com/langgenius/dify.git
synced 2026-05-19 08:17:14 +08:00
Compare commits
9 Commits
copilot/ch
...
laipz8200/
| Author | SHA1 | Date | |
|---|---|---|---|
| 6b406af821 | |||
| 4ef85749fd | |||
| 2d5186fb28 | |||
| 06f076e0ff | |||
| 5b79f7e99d | |||
| 1cee1a25b6 | |||
| c0f237bf35 | |||
| 75d7fc0526 | |||
| c057b5c5ff |
@ -1,5 +1,6 @@
|
||||
[run]
|
||||
omit =
|
||||
api/conftest.py
|
||||
api/tests/*
|
||||
api/migrations/*
|
||||
api/core/rag/datasource/vdb/*
|
||||
|
||||
73
.github/scripts/check-hotfix-cherry-picks.sh
vendored
Normal file
73
.github/scripts/check-hotfix-cherry-picks.sh
vendored
Normal file
@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
BASE_SHA=${BASE_SHA:-}
|
||||
HEAD_SHA=${HEAD_SHA:-}
|
||||
MAIN_REF=${MAIN_REF:-origin/main}
|
||||
REMEDIATION_HINT="Changes should be made from the main branch using git cherry-pick -x."
|
||||
|
||||
error() {
|
||||
printf 'ERROR: %s\n' "$1" >&2
|
||||
}
|
||||
|
||||
if [[ -z "$BASE_SHA" || -z "$HEAD_SHA" ]]; then
|
||||
error "BASE_SHA and HEAD_SHA are required. $REMEDIATION_HINT"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
if ! git rev-parse --verify "$BASE_SHA^{commit}" > /dev/null 2>&1; then
|
||||
error "Base commit '$BASE_SHA' is not available in the local git checkout."
|
||||
exit 2
|
||||
fi
|
||||
|
||||
if ! git rev-parse --verify "$HEAD_SHA^{commit}" > /dev/null 2>&1; then
|
||||
error "Head commit '$HEAD_SHA' is not available in the local git checkout."
|
||||
exit 2
|
||||
fi
|
||||
|
||||
if ! git rev-parse --verify "$MAIN_REF^{commit}" > /dev/null 2>&1; then
|
||||
error "Main ref '$MAIN_REF' is not available in the local git checkout. $REMEDIATION_HINT"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
failed=0
|
||||
checked=0
|
||||
|
||||
while IFS= read -r commit_sha; do
|
||||
[[ -n "$commit_sha" ]] || continue
|
||||
|
||||
checked=$((checked + 1))
|
||||
subject=$(git log -1 --format=%s "$commit_sha")
|
||||
source_sha=$(
|
||||
git log -1 --format=%B "$commit_sha" \
|
||||
| sed -nE 's/^\(cherry picked from commit ([0-9a-fA-F]{7,64})\)$/\1/p' \
|
||||
| tail -n 1
|
||||
)
|
||||
|
||||
if [[ -z "$source_sha" ]]; then
|
||||
error "Commit $commit_sha ($subject) is missing cherry-pick provenance. $REMEDIATION_HINT"
|
||||
failed=1
|
||||
continue
|
||||
fi
|
||||
|
||||
if ! git cat-file -e "$source_sha^{commit}" 2> /dev/null; then
|
||||
error "Commit $commit_sha ($subject) references source $source_sha, but that commit is not available locally. $REMEDIATION_HINT"
|
||||
failed=1
|
||||
continue
|
||||
fi
|
||||
|
||||
if ! git merge-base --is-ancestor "$source_sha" "$MAIN_REF"; then
|
||||
error "Commit $commit_sha ($subject) references source $source_sha, but that source is not reachable from main ($MAIN_REF). $REMEDIATION_HINT"
|
||||
failed=1
|
||||
fi
|
||||
done < <(git rev-list --reverse "$BASE_SHA..$HEAD_SHA")
|
||||
|
||||
if [[ "$failed" -ne 0 ]]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$checked" -eq 0 ]]; then
|
||||
echo "No PR commits to check."
|
||||
else
|
||||
echo "Verified $checked PR commit(s) include cherry-pick provenance from main."
|
||||
fi
|
||||
42
.github/workflows/api-tests.yml
vendored
42
.github/workflows/api-tests.yml
vendored
@ -48,10 +48,23 @@ jobs:
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run dify config tests
|
||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||
run: uv run --project api pytest api/tests/unit_tests/configs/test_env_consistency.py
|
||||
|
||||
- name: Run Unit Tests
|
||||
run: uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
run: |
|
||||
uv run --project api pytest \
|
||||
-p no:benchmark \
|
||||
--timeout "${PYTEST_TIMEOUT:-20}" \
|
||||
-n auto \
|
||||
api/tests/unit_tests \
|
||||
api/providers/vdb/*/tests/unit_tests \
|
||||
api/providers/trace/*/tests/unit_tests \
|
||||
--ignore=api/tests/unit_tests/controllers
|
||||
# Controller tests register Flask routes at import time, so keep them out of xdist.
|
||||
uv run --project api pytest \
|
||||
--timeout "${PYTEST_TIMEOUT:-20}" \
|
||||
--cov-append \
|
||||
api/tests/unit_tests/controllers
|
||||
|
||||
- name: Upload unit coverage data
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
@ -96,32 +109,11 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/envs/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Sandbox
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db_postgres
|
||||
redis
|
||||
sandbox
|
||||
ssrf_proxy
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
- name: Run Integration Tests
|
||||
run: |
|
||||
uv run --project api pytest \
|
||||
-p no:benchmark \
|
||||
--start-middleware \
|
||||
-n auto \
|
||||
--timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/tests/integration_tests/workflow \
|
||||
|
||||
17
.github/workflows/expose_service_ports.sh
vendored
17
.github/workflows/expose_service_ports.sh
vendored
@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.yaml
|
||||
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
|
||||
|
||||
echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
|
||||
49
.github/workflows/hotfix-cherry-pick.yml
vendored
Normal file
49
.github/workflows/hotfix-cherry-pick.yml
vendored
Normal file
@ -0,0 +1,49 @@
|
||||
name: Hotfix Cherry-Pick Provenance
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- 'hotfix/**'
|
||||
- 'lts/**'
|
||||
types:
|
||||
- opened
|
||||
- edited
|
||||
- reopened
|
||||
- ready_for_review
|
||||
- synchronize
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: hotfix-cherry-pick-${{ github.event.pull_request.number || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
check-cherry-pick-provenance:
|
||||
name: Require cherry-pick provenance
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Fetch PR base, PR head, and main
|
||||
env:
|
||||
BASE_REF: ${{ github.base_ref }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
run: |
|
||||
git fetch --no-tags --prune origin \
|
||||
"+refs/heads/main:refs/remotes/origin/main" \
|
||||
"+refs/heads/${BASE_REF}:refs/remotes/origin/${BASE_REF}" \
|
||||
"+refs/pull/${PR_NUMBER}/head:refs/remotes/pull/${PR_NUMBER}/head"
|
||||
|
||||
- name: Load checker from main
|
||||
run: git show origin/main:.github/scripts/check-hotfix-cherry-picks.sh > "$RUNNER_TEMP/check-hotfix-cherry-picks.sh"
|
||||
|
||||
- name: Check PR commits
|
||||
env:
|
||||
BASE_SHA: ${{ github.event.pull_request.base.sha }}
|
||||
HEAD_SHA: ${{ github.event.pull_request.head.sha }}
|
||||
MAIN_REF: origin/main
|
||||
run: bash "$RUNNER_TEMP/check-hotfix-cherry-picks.sh"
|
||||
6
.github/workflows/main-ci.yml
vendored
6
.github/workflows/main-ci.yml
vendored
@ -55,7 +55,6 @@ jobs:
|
||||
api:
|
||||
- 'api/**'
|
||||
- '.github/workflows/api-tests.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.example'
|
||||
- 'docker/envs/middleware.env.example'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
@ -90,11 +89,13 @@ jobs:
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'api/tests/integration_tests/vdb/**'
|
||||
- 'api/conftest.py'
|
||||
- 'api/tests/pytest_dify.py'
|
||||
- 'api/providers/vdb/*/tests/**'
|
||||
- '.github/workflows/vdb-tests.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.example'
|
||||
- 'docker/envs/middleware.env.example'
|
||||
- 'docker/docker-compose.pytest.ports.yaml'
|
||||
- 'docker/docker-compose.yaml'
|
||||
- 'docker/docker-compose-template.yaml'
|
||||
- 'docker/generate_docker_compose'
|
||||
@ -114,7 +115,6 @@ jobs:
|
||||
- 'api/migrations/**'
|
||||
- 'api/.env.example'
|
||||
- '.github/workflows/db-migration-test.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.example'
|
||||
- 'docker/envs/middleware.env.example'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
|
||||
39
.github/workflows/vdb-tests-full.yml
vendored
39
.github/workflows/vdb-tests-full.yml
vendored
@ -48,14 +48,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/envs/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
# - name: Set up Vector Store (TiDB)
|
||||
# uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
# with:
|
||||
@ -64,32 +56,13 @@ jobs:
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Full Vector Store Matrix
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
services: |
|
||||
weaviate
|
||||
qdrant
|
||||
couchbase-server
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
oceanbase
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
echo $(pwd)
|
||||
ls -lah .
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
run: |
|
||||
uv run --project api pytest \
|
||||
--start-vdb \
|
||||
--vdb-services "weaviate,qdrant,couchbase-server,etcd,minio,milvus-standalone,pgvecto-rs,pgvector,chroma,elasticsearch,oceanbase" \
|
||||
--timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/providers/vdb/*/tests/integration_tests
|
||||
|
||||
31
.github/workflows/vdb-tests.yml
vendored
31
.github/workflows/vdb-tests.yml
vendored
@ -45,14 +45,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/envs/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
# - name: Set up Vector Store (TiDB)
|
||||
# uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
# with:
|
||||
@ -61,31 +53,14 @@ jobs:
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Vector Stores for Smoke Coverage
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
services: |
|
||||
db_postgres
|
||||
redis
|
||||
weaviate
|
||||
qdrant
|
||||
pgvector
|
||||
chroma
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
echo $(pwd)
|
||||
ls -lah .
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: |
|
||||
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
uv run --project api pytest \
|
||||
--start-vdb \
|
||||
--timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/providers/vdb/vdb-chroma/tests/integration_tests \
|
||||
api/providers/vdb/vdb-pgvector/tests/integration_tests \
|
||||
api/providers/vdb/vdb-qdrant/tests/integration_tests \
|
||||
|
||||
30
Makefile
30
Makefile
@ -85,22 +85,44 @@ lint:
|
||||
type-check:
|
||||
@echo "📝 Running type checks (pyrefly + mypy)..."
|
||||
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude '(^|/)conftest\.py$$' --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Type checks complete"
|
||||
|
||||
type-check-core:
|
||||
@echo "📝 Running core type checks (pyrefly + mypy)..."
|
||||
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude '(^|/)conftest\.py$$' --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Core type checks complete"
|
||||
|
||||
test:
|
||||
@echo "🧪 Running backend unit tests..."
|
||||
@echo "🧪 Running backend tests..."
|
||||
@if [ -n "$(TARGET_TESTS)" ]; then \
|
||||
echo "Target: $(TARGET_TESTS)"; \
|
||||
uv run --project api --dev pytest $(TARGET_TESTS); \
|
||||
else \
|
||||
PYTEST_XDIST_ARGS="-n auto" uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
|
||||
echo "Running backend unit tests"; \
|
||||
uv run --project api --dev pytest -p no:benchmark --timeout "$${PYTEST_TIMEOUT:-20}" -n auto \
|
||||
api/tests/unit_tests \
|
||||
api/providers/vdb/*/tests/unit_tests \
|
||||
api/providers/trace/*/tests/unit_tests \
|
||||
--ignore=api/tests/unit_tests/controllers; \
|
||||
uv run --project api --dev pytest --timeout "$${PYTEST_TIMEOUT:-20}" --cov-append \
|
||||
api/tests/unit_tests/controllers; \
|
||||
echo "Running backend integration tests"; \
|
||||
uv run --project api --dev pytest -p no:benchmark --start-middleware -n auto \
|
||||
--timeout "$${PYTEST_TIMEOUT:-180}" \
|
||||
--cov-append \
|
||||
api/tests/integration_tests/workflow \
|
||||
api/tests/integration_tests/tools \
|
||||
api/tests/test_containers_integration_tests; \
|
||||
echo "Running VDB smoke tests"; \
|
||||
uv run --project api --dev pytest --start-vdb \
|
||||
--timeout "$${PYTEST_TIMEOUT:-180}" \
|
||||
--cov-append \
|
||||
api/providers/vdb/vdb-chroma/tests/integration_tests \
|
||||
api/providers/vdb/vdb-pgvector/tests/integration_tests \
|
||||
api/providers/vdb/vdb-qdrant/tests/integration_tests \
|
||||
api/providers/vdb/vdb-weaviate/tests/integration_tests; \
|
||||
fi
|
||||
@echo "✅ Tests complete"
|
||||
|
||||
|
||||
112
api/conftest.py
Normal file
112
api/conftest.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""Global pytest hooks for Dify backend tests.
|
||||
|
||||
This root conftest is loaded before package-specific conftests, which lets tests opt
|
||||
into Docker-backed middleware before application modules read environment config.
|
||||
It intentionally lives at the API root because pytest applies conftest.py files to
|
||||
tests below their directory, and this setup is shared by api/tests and api/providers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.pytest_dify import (
|
||||
DEFAULT_MIDDLEWARE_SERVICES,
|
||||
DEFAULT_VDB_SERVICES,
|
||||
DockerComposeStack,
|
||||
build_middleware_stack,
|
||||
build_vdb_stack,
|
||||
ensure_backend_test_environment,
|
||||
ensure_compose_env_files,
|
||||
parse_services,
|
||||
)
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
||||
_DIFY_COMPOSE_STACKS_KEY = pytest.StashKey[list[DockerComposeStack]]()
|
||||
|
||||
# This must run at import time because package-specific conftests can import the
|
||||
# Flask app before pytest_configure hooks from this file are called.
|
||||
ensure_backend_test_environment(_REPO_ROOT)
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
group = parser.getgroup("dify")
|
||||
group.addoption(
|
||||
"--start-middleware",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Start the Docker middleware services needed by API integration tests.",
|
||||
)
|
||||
group.addoption(
|
||||
"--middleware-services",
|
||||
default=",".join(DEFAULT_MIDDLEWARE_SERVICES),
|
||||
help="Comma-separated services from docker/docker-compose.middleware.yaml to start.",
|
||||
)
|
||||
group.addoption(
|
||||
"--start-vdb",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Start vector-store Docker services for VDB integration tests.",
|
||||
)
|
||||
group.addoption(
|
||||
"--vdb-services",
|
||||
default=",".join(DEFAULT_VDB_SERVICES),
|
||||
help="Comma-separated services from docker/docker-compose.yaml to start for VDB tests.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
config.stash[_DIFY_COMPOSE_STACKS_KEY] = []
|
||||
|
||||
|
||||
def _stop_stacks(stacks: list[DockerComposeStack]) -> None:
|
||||
errors: list[BaseException] = []
|
||||
for stack in reversed(stacks):
|
||||
try:
|
||||
stack.down()
|
||||
except BaseException as error:
|
||||
errors.append(error)
|
||||
if errors:
|
||||
raise BaseExceptionGroup("Failed to stop one or more Docker compose stacks", errors)
|
||||
|
||||
|
||||
def pytest_sessionstart(session: pytest.Session) -> None:
|
||||
config = session.config
|
||||
if hasattr(config, "workerinput"):
|
||||
return
|
||||
|
||||
stacks: list[DockerComposeStack] = []
|
||||
config.stash[_DIFY_COMPOSE_STACKS_KEY] = stacks
|
||||
try:
|
||||
if config.getoption("start_middleware"):
|
||||
ensure_compose_env_files(_REPO_ROOT)
|
||||
stack = build_middleware_stack(_REPO_ROOT, parse_services(config.getoption("middleware_services")))
|
||||
stacks.append(stack)
|
||||
stack.up()
|
||||
|
||||
if config.getoption("start_vdb"):
|
||||
ensure_compose_env_files(_REPO_ROOT)
|
||||
stack = build_vdb_stack(_REPO_ROOT, parse_services(config.getoption("vdb_services")))
|
||||
stacks.append(stack)
|
||||
stack.up()
|
||||
except BaseException as start_error:
|
||||
try:
|
||||
_stop_stacks(stacks)
|
||||
except BaseException as cleanup_error:
|
||||
config.stash[_DIFY_COMPOSE_STACKS_KEY] = []
|
||||
raise BaseExceptionGroup(
|
||||
"Failed to start Docker compose stacks and clean up",
|
||||
[start_error, cleanup_error],
|
||||
) from start_error
|
||||
config.stash[_DIFY_COMPOSE_STACKS_KEY] = []
|
||||
raise
|
||||
|
||||
|
||||
def pytest_unconfigure(config: pytest.Config) -> None:
|
||||
if hasattr(config, "workerinput"):
|
||||
return
|
||||
|
||||
stacks = config.stash.get(_DIFY_COMPOSE_STACKS_KEY, [])
|
||||
_stop_stacks(stacks)
|
||||
@ -874,6 +874,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@console_ns.expect(console_ns.models[BuiltinProviderDefaultCredentialPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -0,0 +1,67 @@
|
||||
"""Dify event package.
|
||||
|
||||
The package name intentionally stays as ``events`` for existing Dify imports. Some
|
||||
third-party clients also import ``Events`` from a top-level ``events`` package, so
|
||||
we expose a small compatible implementation to avoid import shadowing failures.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Iterator
|
||||
from typing import Any
|
||||
|
||||
|
||||
class EventsError(Exception):
|
||||
"""Raised for invalid event slot operations."""
|
||||
|
||||
|
||||
EventsException = EventsError
|
||||
|
||||
|
||||
class _EventSlot:
|
||||
"""A dynamically-created event slot supporting ``+=`` and call dispatch."""
|
||||
|
||||
targets: list[Callable[..., Any]]
|
||||
__name__: str
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.targets = []
|
||||
self.__name__ = name
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> None:
|
||||
for target in tuple(self.targets):
|
||||
target(*args, **kwargs)
|
||||
|
||||
def __iadd__(self, target: Callable[..., Any]) -> "_EventSlot":
|
||||
self.targets.append(target)
|
||||
return self
|
||||
|
||||
def __isub__(self, target: Callable[..., Any]) -> "_EventSlot":
|
||||
while target in self.targets:
|
||||
self.targets.remove(target)
|
||||
return self
|
||||
|
||||
def __iter__(self) -> Iterator[Callable[..., Any]]:
|
||||
return iter(self.targets)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.targets)
|
||||
|
||||
|
||||
class Events:
|
||||
"""A minimal C#-style event container compatible with the external Events package."""
|
||||
|
||||
_slots: dict[str, _EventSlot]
|
||||
|
||||
def __init__(self, *event_names: str) -> None:
|
||||
self._slots = {}
|
||||
for event_name in event_names:
|
||||
self._slots[event_name] = _EventSlot(event_name)
|
||||
|
||||
def __getattr__(self, name: str) -> _EventSlot:
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(name)
|
||||
slot = _EventSlot(name)
|
||||
self._slots[name] = slot
|
||||
return slot
|
||||
|
||||
|
||||
__all__ = ["Events", "EventsError", "EventsException"]
|
||||
|
||||
@ -13,7 +13,7 @@ class ChromaVectorTest(AbstractVectorTest):
|
||||
self.vector = ChromaVector(
|
||||
collection_name=self.collection_name,
|
||||
config=ChromaConfig(
|
||||
host="localhost",
|
||||
host="127.0.0.1",
|
||||
port=8000,
|
||||
tenant=chromadb.DEFAULT_TENANT,
|
||||
database=chromadb.DEFAULT_DATABASE,
|
||||
|
||||
@ -16,7 +16,7 @@ class QdrantVectorTest(AbstractVectorTest):
|
||||
collection_name=self.collection_name,
|
||||
group_id=self.dataset_id,
|
||||
config=QdrantConfig(
|
||||
endpoint="http://localhost:6333",
|
||||
endpoint="http://127.0.0.1:6333",
|
||||
api_key="difyai123456",
|
||||
),
|
||||
)
|
||||
|
||||
@ -16,6 +16,7 @@ from pydantic import TypeAdapter
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.helper import marketplace
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
@ -310,6 +311,8 @@ class PluginMigration:
|
||||
"""
|
||||
Fetch plugin unique identifier using plugin id.
|
||||
"""
|
||||
if not dify_config.MARKETPLACE_ENABLED:
|
||||
return None
|
||||
plugin_manifest = marketplace.batch_fetch_plugin_manifests([plugin_id])
|
||||
if not plugin_manifest:
|
||||
return None
|
||||
@ -542,6 +545,11 @@ class PluginMigration:
|
||||
"""
|
||||
Install plugins for a tenant.
|
||||
"""
|
||||
if plugin_identifiers_map and not dify_config.MARKETPLACE_ENABLED:
|
||||
raise ValueError(
|
||||
"Marketplace disabled in offline mode; cannot bulk-install plugins. "
|
||||
"Pre-upload plugin packages via Console first."
|
||||
)
|
||||
manager = PluginInstaller()
|
||||
|
||||
# download all the plugins and upload
|
||||
|
||||
@ -73,35 +73,43 @@ class PluginService:
|
||||
cache_not_exists.append(plugin_id)
|
||||
|
||||
if cache_not_exists:
|
||||
manifests = {
|
||||
manifest.plugin_id: manifest
|
||||
for manifest in marketplace.batch_fetch_plugin_manifests(cache_not_exists)
|
||||
}
|
||||
|
||||
for plugin_id, manifest in manifests.items():
|
||||
latest_plugin = PluginService.LatestPluginCache(
|
||||
plugin_id=plugin_id,
|
||||
version=manifest.latest_version,
|
||||
unique_identifier=manifest.latest_package_identifier,
|
||||
status=manifest.status,
|
||||
deprecated_reason=manifest.deprecated_reason,
|
||||
alternative_plugin_id=manifest.alternative_plugin_id,
|
||||
if not dify_config.MARKETPLACE_ENABLED:
|
||||
logger.info(
|
||||
"Marketplace disabled; skipping latest-plugins metadata fetch for %d ids",
|
||||
len(cache_not_exists),
|
||||
)
|
||||
for plugin_id in cache_not_exists:
|
||||
result[plugin_id] = None
|
||||
else:
|
||||
manifests = {
|
||||
manifest.plugin_id: manifest
|
||||
for manifest in marketplace.batch_fetch_plugin_manifests(cache_not_exists)
|
||||
}
|
||||
|
||||
# Store in Redis
|
||||
redis_client.setex(
|
||||
f"{PluginService.REDIS_KEY_PREFIX}{plugin_id}",
|
||||
PluginService.REDIS_TTL,
|
||||
latest_plugin.model_dump_json(),
|
||||
)
|
||||
for plugin_id, manifest in manifests.items():
|
||||
latest_plugin = PluginService.LatestPluginCache(
|
||||
plugin_id=plugin_id,
|
||||
version=manifest.latest_version,
|
||||
unique_identifier=manifest.latest_package_identifier,
|
||||
status=manifest.status,
|
||||
deprecated_reason=manifest.deprecated_reason,
|
||||
alternative_plugin_id=manifest.alternative_plugin_id,
|
||||
)
|
||||
|
||||
result[plugin_id] = latest_plugin
|
||||
# Store in Redis
|
||||
redis_client.setex(
|
||||
f"{PluginService.REDIS_KEY_PREFIX}{plugin_id}",
|
||||
PluginService.REDIS_TTL,
|
||||
latest_plugin.model_dump_json(),
|
||||
)
|
||||
|
||||
# pop plugin_id from cache_not_exists
|
||||
cache_not_exists.remove(plugin_id)
|
||||
result[plugin_id] = latest_plugin
|
||||
|
||||
for plugin_id in cache_not_exists:
|
||||
result[plugin_id] = None
|
||||
# pop plugin_id from cache_not_exists
|
||||
cache_not_exists.remove(plugin_id)
|
||||
|
||||
for plugin_id in cache_not_exists:
|
||||
result[plugin_id] = None
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
|
||||
@ -1350,6 +1350,12 @@ class RagPipelineService:
|
||||
)
|
||||
return workflow_node_execution_db_model
|
||||
|
||||
def _fetch_recommended_plugin_manifests(self, plugin_ids: list[str]) -> list[Any]:
|
||||
if not dify_config.MARKETPLACE_ENABLED:
|
||||
logger.info("Marketplace disabled; recommended-plugins list empty")
|
||||
return []
|
||||
return marketplace.batch_fetch_plugin_by_ids(plugin_ids)
|
||||
|
||||
def get_recommended_plugins(self, type: str) -> dict[str, Any]:
|
||||
# Query active recommended plugins
|
||||
stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
|
||||
@ -1372,7 +1378,7 @@ class RagPipelineService:
|
||||
)
|
||||
providers_map = {provider.plugin_id: provider.to_dict() for provider in providers}
|
||||
|
||||
plugin_manifests = marketplace.batch_fetch_plugin_by_ids(plugin_ids)
|
||||
plugin_manifests = self._fetch_recommended_plugin_manifests(plugin_ids)
|
||||
plugin_manifests_map = {manifest["plugin_id"]: manifest for manifest in plugin_manifests}
|
||||
|
||||
installed_plugin_list = []
|
||||
|
||||
@ -9,6 +9,7 @@ import yaml
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
@ -273,6 +274,13 @@ class RagPipelineTransformService:
|
||||
plugin_unique_identifier = dependency.get("value", {}).get("plugin_unique_identifier")
|
||||
plugin_id = plugin_unique_identifier.split(":")[0]
|
||||
if plugin_id not in installed_plugins_ids:
|
||||
if not dify_config.MARKETPLACE_ENABLED:
|
||||
logger.warning(
|
||||
"Marketplace disabled; skipping auto-install of %s. "
|
||||
"Pre-install via Console if pipeline requires it.",
|
||||
plugin_id,
|
||||
)
|
||||
continue
|
||||
plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(plugin_id) # type: ignore
|
||||
if plugin_unique_identifier:
|
||||
need_install_plugin_unique_identifiers.append(plugin_unique_identifier)
|
||||
|
||||
@ -26,20 +26,24 @@ _logger = logging.getLogger(__name__)
|
||||
# Loading the .env file if it exists
|
||||
def _load_env():
|
||||
current_file_path = pathlib.Path(__file__).absolute()
|
||||
# Items later in the list have higher precedence.
|
||||
env_file_paths = [
|
||||
os.getenv("DIFY_TEST_ENV_FILE", str(current_file_path.parent / _DEFUALT_TEST_ENV)),
|
||||
os.getenv("DIFY_VDB_TEST_ENV_FILE", str(current_file_path.parent / _DEFAULT_VDB_TEST_ENV)),
|
||||
pathlib.Path(os.getenv("DIFY_TEST_ENV_FILE", str(current_file_path.parent / _DEFUALT_TEST_ENV))),
|
||||
]
|
||||
vdb_env_path = pathlib.Path(
|
||||
os.getenv("DIFY_VDB_TEST_ENV_FILE", str(current_file_path.parent / _DEFAULT_VDB_TEST_ENV))
|
||||
)
|
||||
if vdb_env_path.exists() or "DIFY_VDB_TEST_ENV_FILE" in os.environ:
|
||||
env_file_paths.append(vdb_env_path)
|
||||
|
||||
for env_path_str in env_file_paths:
|
||||
if not pathlib.Path(env_path_str).exists():
|
||||
_logger.warning("specified configuration file %s not exist", env_path_str)
|
||||
for env_path in env_file_paths:
|
||||
if not env_path.exists():
|
||||
_logger.warning("specified configuration file %s not exist", env_path)
|
||||
continue
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Set `override=True` to ensure values from `vdb.env` take priority over values from `.env`
|
||||
load_dotenv(str(env_path_str), override=True)
|
||||
load_dotenv(str(env_path), override=True)
|
||||
|
||||
|
||||
_load_env()
|
||||
|
||||
209
api/tests/pytest_dify.py
Normal file
209
api/tests/pytest_dify.py
Normal file
@ -0,0 +1,209 @@
|
||||
"""Pytest support helpers for Dify backend test environment setup.
|
||||
|
||||
The helpers in this module keep Docker and environment preparation behind explicit
|
||||
pytest options so ordinary unit-test runs do not start external services.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
DEFAULT_LOG_FORMAT = "%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s"
|
||||
DEFAULT_MIDDLEWARE_SERVICES = ("db_postgres", "redis", "sandbox", "ssrf_proxy")
|
||||
DEFAULT_VDB_SERVICES = ("db_postgres", "redis", "weaviate", "qdrant", "pgvector", "chroma")
|
||||
COMPOSE_WAIT_FLAGS = ("--wait", "--wait-timeout")
|
||||
COMPOSE_WAIT_TIMEOUT_SECONDS = "180"
|
||||
COMPOSE_NO_WAIT_READY_DELAY_SECONDS = 5.0
|
||||
VDB_SERVICE_PROFILES = {
|
||||
"db_postgres": "postgresql",
|
||||
"weaviate": "weaviate",
|
||||
"qdrant": "qdrant",
|
||||
"couchbase-server": "couchbase",
|
||||
"etcd": "milvus",
|
||||
"minio": "milvus",
|
||||
"milvus-standalone": "milvus",
|
||||
"pgvecto-rs": "pgvecto-rs",
|
||||
"pgvector": "pgvector",
|
||||
"chroma": "chroma",
|
||||
"elasticsearch": "elasticsearch",
|
||||
"oceanbase": "oceanbase",
|
||||
}
|
||||
|
||||
|
||||
def parse_services(value: str) -> list[str]:
|
||||
"""Parse a comma-separated service list from a pytest option."""
|
||||
return [service.strip() for service in value.split(",") if service.strip()]
|
||||
|
||||
|
||||
def compose_up_supports_wait(repo_root: Path) -> bool:
|
||||
"""Return whether the installed Docker Compose supports healthcheck wait flags."""
|
||||
completed = subprocess.run(
|
||||
["docker", "compose", "up", "--help"],
|
||||
cwd=repo_root,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
)
|
||||
if completed.returncode != 0:
|
||||
return True
|
||||
help_output = f"{completed.stdout}\n{completed.stderr}"
|
||||
return all(flag in help_output for flag in COMPOSE_WAIT_FLAGS)
|
||||
|
||||
|
||||
def ensure_backend_test_environment(repo_root: Path) -> None:
|
||||
"""Set deterministic defaults needed before test conftests import application config."""
|
||||
integration_tests_dir = repo_root / "api" / "tests" / "integration_tests"
|
||||
test_env_file = integration_tests_dir / ".env"
|
||||
test_env_example_file = integration_tests_dir / ".env.example"
|
||||
vdb_env_file = integration_tests_dir / "vdb.env"
|
||||
|
||||
if "DIFY_TEST_ENV_FILE" not in os.environ:
|
||||
os.environ["DIFY_TEST_ENV_FILE"] = str(test_env_file if test_env_file.exists() else test_env_example_file)
|
||||
|
||||
if "DIFY_VDB_TEST_ENV_FILE" not in os.environ and vdb_env_file.exists():
|
||||
os.environ["DIFY_VDB_TEST_ENV_FILE"] = str(vdb_env_file)
|
||||
|
||||
os.environ["LOG_OUTPUT_FORMAT"] = "text"
|
||||
os.environ["LOG_FORMAT"] = DEFAULT_LOG_FORMAT
|
||||
os.environ.setdefault("STORAGE_TYPE", "opendal")
|
||||
os.environ.setdefault("OPENDAL_SCHEME", "fs")
|
||||
os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage")
|
||||
Path(os.environ["OPENDAL_FS_ROOT"]).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def ensure_compose_env_files(repo_root: Path) -> None:
|
||||
"""Create ignored Docker env files from examples when Docker-backed tests request compose."""
|
||||
docker_dir = repo_root / "docker"
|
||||
env_file = docker_dir / ".env"
|
||||
env_example_file = docker_dir / ".env.example"
|
||||
middleware_env_file = docker_dir / "middleware.env"
|
||||
middleware_env_example_file = docker_dir / "envs" / "middleware.env.example"
|
||||
|
||||
if not env_file.exists():
|
||||
shutil.copyfile(env_example_file, env_file)
|
||||
if not middleware_env_file.exists():
|
||||
shutil.copyfile(middleware_env_example_file, middleware_env_file)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DockerComposeStack:
|
||||
"""A docker compose project that pytest can start before collection and stop at shutdown."""
|
||||
|
||||
name: str
|
||||
project_name: str
|
||||
repo_root: Path
|
||||
compose_files: tuple[Path, ...]
|
||||
env_file: Path
|
||||
services: tuple[str, ...]
|
||||
profiles: tuple[str, ...] = ()
|
||||
ready_delay_seconds: float = 0.0
|
||||
warmup_urls: tuple[str, ...] = ()
|
||||
|
||||
def _compose_command(self) -> list[str]:
|
||||
command = [
|
||||
"docker",
|
||||
"compose",
|
||||
"--project-name",
|
||||
self.project_name,
|
||||
"--env-file",
|
||||
str(self.env_file),
|
||||
]
|
||||
for profile in self.profiles:
|
||||
command.extend(("--profile", profile))
|
||||
for compose_file in self.compose_files:
|
||||
command.extend(("-f", str(compose_file)))
|
||||
return command
|
||||
|
||||
def up(self) -> None:
|
||||
"""Start the configured services and wait for compose healthchecks when supported."""
|
||||
supports_wait = compose_up_supports_wait(self.repo_root)
|
||||
up_command = self._compose_command() + [
|
||||
"up",
|
||||
"-d",
|
||||
]
|
||||
if supports_wait:
|
||||
up_command.extend((*COMPOSE_WAIT_FLAGS, COMPOSE_WAIT_TIMEOUT_SECONDS))
|
||||
up_command.extend(self.services)
|
||||
|
||||
completed = subprocess.run(up_command, cwd=self.repo_root, text=True, capture_output=True)
|
||||
if completed.returncode != 0:
|
||||
raise subprocess.CalledProcessError(
|
||||
returncode=completed.returncode,
|
||||
cmd=up_command,
|
||||
output=completed.stdout,
|
||||
stderr=completed.stderr,
|
||||
)
|
||||
if self.ready_delay_seconds > 0:
|
||||
time.sleep(self.ready_delay_seconds)
|
||||
elif not supports_wait:
|
||||
time.sleep(COMPOSE_NO_WAIT_READY_DELAY_SECONDS)
|
||||
self._warm_up()
|
||||
|
||||
def down(self) -> None:
|
||||
"""Stop services started for this pytest run."""
|
||||
subprocess.run(self._compose_command() + ["down"], cwd=self.repo_root, check=True)
|
||||
|
||||
def _warm_up(self) -> None:
|
||||
for url in self.warmup_urls:
|
||||
deadline = time.monotonic() + 30.0
|
||||
last_error: Exception | None = None
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=5) as response:
|
||||
if 200 <= response.status < 300:
|
||||
break
|
||||
except urllib.error.HTTPError as error:
|
||||
if error.code < 500:
|
||||
break
|
||||
last_error = error
|
||||
except (OSError, urllib.error.URLError) as error:
|
||||
last_error = error
|
||||
time.sleep(1)
|
||||
else:
|
||||
raise RuntimeError(f"Timed out waiting for {self.name} warmup URL {url}") from last_error
|
||||
|
||||
|
||||
def build_middleware_stack(repo_root: Path, services: list[str]) -> DockerComposeStack:
|
||||
"""Build the middleware compose stack used by API integration tests."""
|
||||
return DockerComposeStack(
|
||||
name="middleware",
|
||||
project_name="dify-pytest-middleware",
|
||||
repo_root=repo_root,
|
||||
compose_files=(repo_root / "docker" / "docker-compose.middleware.yaml",),
|
||||
env_file=repo_root / "docker" / "middleware.env",
|
||||
services=tuple(services),
|
||||
ready_delay_seconds=5.0,
|
||||
warmup_urls=("http://127.0.0.1:8194/health",),
|
||||
)
|
||||
|
||||
|
||||
def build_vdb_stack(repo_root: Path, services: list[str]) -> DockerComposeStack:
|
||||
"""Build the vector-store compose stack used by VDB integration tests."""
|
||||
profiles = tuple(
|
||||
dict.fromkeys(profile for service in services if (profile := VDB_SERVICE_PROFILES.get(service)) is not None)
|
||||
)
|
||||
service_names = set(services)
|
||||
warmup_urls = []
|
||||
if "qdrant" in service_names:
|
||||
warmup_urls.append("http://127.0.0.1:6333/collections")
|
||||
if "chroma" in service_names:
|
||||
warmup_urls.append("http://127.0.0.1:8000/api/v2/auth/identity")
|
||||
return DockerComposeStack(
|
||||
name="vdb",
|
||||
project_name="dify-pytest-vdb",
|
||||
repo_root=repo_root,
|
||||
compose_files=(
|
||||
repo_root / "docker" / "docker-compose.yaml",
|
||||
repo_root / "docker" / "docker-compose.pytest.ports.yaml",
|
||||
),
|
||||
env_file=repo_root / "docker" / ".env",
|
||||
services=tuple(services),
|
||||
profiles=profiles,
|
||||
warmup_urls=tuple(warmup_urls),
|
||||
)
|
||||
@ -252,6 +252,9 @@ class TestRedisBroadcastChannelIntegration:
|
||||
def consumer_thread() -> set[bytes]:
|
||||
received_msgs: set[bytes] = set()
|
||||
with subscription:
|
||||
# Prime the subscription before producers publish. Redis Pub/Sub does not
|
||||
# replay messages sent before the subscribe command is active.
|
||||
subscription.receive(timeout=0.1)
|
||||
consumer_ready.set()
|
||||
while True:
|
||||
try:
|
||||
@ -278,8 +281,10 @@ class TestRedisBroadcastChannelIntegration:
|
||||
for future in as_completed(producer_futures, timeout=30.0):
|
||||
sent_msgs.update(future.result())
|
||||
|
||||
subscription.close()
|
||||
consumer_received_msgs = consumer_future.result(timeout=30.0)
|
||||
try:
|
||||
consumer_received_msgs = consumer_future.result(timeout=30.0)
|
||||
finally:
|
||||
subscription.close()
|
||||
|
||||
# Verify message content
|
||||
assert sent_msgs == consumer_received_msgs
|
||||
|
||||
@ -37,9 +37,9 @@ class TestAppGenerateService:
|
||||
"services.app_generate_service.MessageBasedAppGenerator", autospec=True
|
||||
) as mock_message_based_generator,
|
||||
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
|
||||
patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config,
|
||||
patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config,
|
||||
patch("configs.dify_config", autospec=True) as mock_global_dify_config,
|
||||
patch("services.app_generate_service.dify_config") as mock_dify_config,
|
||||
patch("services.quota_service.dify_config") as mock_quota_dify_config,
|
||||
patch("configs.dify_config") as mock_global_dify_config,
|
||||
):
|
||||
# Setup default mock returns for billing service
|
||||
mock_billing_service.quota_reserve.return_value = {
|
||||
|
||||
@ -84,7 +84,7 @@ def _mock_factory_for_apps(
|
||||
|
||||
class TestRecommendedAppServiceGetApps:
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_success_with_apps(self, mock_config, mock_factory_class):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
expected = _apps_response()
|
||||
@ -103,7 +103,7 @@ class TestRecommendedAppServiceGetApps:
|
||||
mock_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
empty_response = {"recommended_apps": [], "categories": []}
|
||||
@ -126,7 +126,7 @@ class TestRecommendedAppServiceGetApps:
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db"
|
||||
none_response = {"recommended_apps": None, "categories": ["test"]}
|
||||
@ -146,7 +146,7 @@ class TestRecommendedAppServiceGetApps:
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_different_languages(self, mock_config, mock_factory_class):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
|
||||
|
||||
@ -164,7 +164,7 @@ class TestRecommendedAppServiceGetApps:
|
||||
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_uses_correct_factory_mode(self, mock_config, mock_factory_class):
|
||||
for mode in ["remote", "builtin", "db"]:
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
|
||||
@ -183,7 +183,7 @@ class TestRecommendedAppServiceGetApps:
|
||||
|
||||
class TestRecommendedAppServiceGetDetail:
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_success(self, mock_config, mock_factory_class):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
expected = _app_detail(app_id="app-123", name="Productivity App", description="A great app")
|
||||
@ -199,7 +199,7 @@ class TestRecommendedAppServiceGetDetail:
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with("app-123")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_different_modes(self, mock_config, mock_factory_class):
|
||||
for mode in ["remote", "builtin", "db"]:
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
|
||||
@ -214,7 +214,7 @@ class TestRecommendedAppServiceGetDetail:
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_returns_none_when_not_found(self, mock_config, mock_factory_class):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
mock_instance = MagicMock()
|
||||
@ -227,7 +227,7 @@ class TestRecommendedAppServiceGetDetail:
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with("nonexistent")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_returns_empty_dict(self, mock_config, mock_factory_class):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
|
||||
mock_instance = MagicMock()
|
||||
@ -239,7 +239,7 @@ class TestRecommendedAppServiceGetDetail:
|
||||
assert result == {}
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_complex_model_config(self, mock_config, mock_factory_class):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
complex_config = {
|
||||
|
||||
@ -60,7 +60,7 @@ class TestMailInviteMemberTask:
|
||||
with (
|
||||
patch("tasks.mail_invite_member_task.mail", autospec=True) as mock_mail,
|
||||
patch("tasks.mail_invite_member_task.get_email_i18n_service", autospec=True) as mock_email_service,
|
||||
patch("tasks.mail_invite_member_task.dify_config", autospec=True) as mock_config,
|
||||
patch("tasks.mail_invite_member_task.dify_config") as mock_config,
|
||||
):
|
||||
# Setup mail service mock
|
||||
mock_mail.is_inited.return_value = True
|
||||
|
||||
@ -90,7 +90,7 @@ class TestMailRegisterTask:
|
||||
to_email = fake.email()
|
||||
account_name = fake.name()
|
||||
|
||||
with patch("tasks.mail_register_task.dify_config", autospec=True) as mock_config:
|
||||
with patch("tasks.mail_register_task.dify_config") as mock_config:
|
||||
mock_config.CONSOLE_WEB_URL = "https://console.dify.ai"
|
||||
|
||||
send_email_register_mail_task_when_account_exist(language=language, to=to_email, account_name=account_name)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from pathlib import Path
|
||||
|
||||
import yaml # type: ignore
|
||||
from dotenv import dotenv_values
|
||||
|
||||
BASE_API_AND_DOCKER_CONFIG_SET_DIFF: frozenset[str] = frozenset(
|
||||
@ -91,34 +90,29 @@ BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF: frozenset[str] = frozenset(
|
||||
)
|
||||
)
|
||||
|
||||
API_CONFIG_SET = set(dotenv_values(Path("api") / Path(".env.example")).keys())
|
||||
DOCKER_CONFIG_SET = set(dotenv_values(Path("docker") / Path(".env.example")).keys())
|
||||
DOCKER_COMPOSE_CONFIG_SET = set(DOCKER_CONFIG_SET)
|
||||
|
||||
# Read environment variables from the split env files used by docker-compose
|
||||
# Walk through all .env.example files in subdirectories (per-module structure)
|
||||
envs_dir = Path("docker") / Path("envs")
|
||||
if envs_dir.exists():
|
||||
for env_file_path in envs_dir.rglob("*.env.example"):
|
||||
env_keys = set(dotenv_values(env_file_path).keys())
|
||||
DOCKER_CONFIG_SET.update(env_keys)
|
||||
DOCKER_COMPOSE_CONFIG_SET.update(env_keys)
|
||||
REPO_ROOT = Path(__file__).resolve().parents[4]
|
||||
|
||||
|
||||
def test_yaml_config():
|
||||
# python set == operator is used to compare two sets
|
||||
DIFF_API_WITH_DOCKER = API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF
|
||||
if DIFF_API_WITH_DOCKER:
|
||||
print(f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}")
|
||||
raise Exception("API and Docker config sets are different")
|
||||
DIFF_API_WITH_DOCKER_COMPOSE = (
|
||||
API_CONFIG_SET - DOCKER_COMPOSE_CONFIG_SET - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF
|
||||
)
|
||||
if DIFF_API_WITH_DOCKER_COMPOSE:
|
||||
print(f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}")
|
||||
raise Exception("API and Docker Compose config sets are different")
|
||||
print("All tests passed!")
|
||||
def _api_config_set() -> set[str]:
|
||||
return set(dotenv_values(REPO_ROOT / "api" / ".env.example").keys())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_yaml_config()
|
||||
def _docker_config_set() -> set[str]:
|
||||
docker_config_set = set(dotenv_values(REPO_ROOT / "docker" / ".env.example").keys())
|
||||
envs_dir = REPO_ROOT / "docker" / "envs"
|
||||
if envs_dir.exists():
|
||||
for env_file_path in envs_dir.rglob("*.env.example"):
|
||||
docker_config_set.update(dotenv_values(env_file_path).keys())
|
||||
return docker_config_set
|
||||
|
||||
|
||||
def test_api_env_keys_exist_in_docker_env_examples():
|
||||
diff = _api_config_set() - _docker_config_set() - BASE_API_AND_DOCKER_CONFIG_SET_DIFF
|
||||
|
||||
assert not diff, f"API and Docker config sets are different with keys: {sorted(diff)}"
|
||||
|
||||
|
||||
def test_api_env_keys_exist_in_docker_compose_env_examples():
|
||||
diff = _api_config_set() - _docker_config_set() - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF
|
||||
|
||||
assert not diff, f"API and Docker Compose config sets are different with keys: {sorted(diff)}"
|
||||
@ -15,6 +15,11 @@ from models.model import App, AppMode
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_owner_execute_result
|
||||
|
||||
|
||||
def _configure_current_app_mock(mock_current_app):
|
||||
mock_current_app.login_manager = Mock()
|
||||
mock_current_app._get_current_object = Mock(return_value=Mock())
|
||||
|
||||
|
||||
class TestAppParameterApi:
|
||||
"""Test suite for AppParameterApi"""
|
||||
|
||||
@ -45,7 +50,7 @@ class TestAppParameterApi:
|
||||
):
|
||||
"""Test retrieving parameters for a chat app."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
_configure_current_app_mock(mock_current_app)
|
||||
|
||||
mock_config = Mock()
|
||||
mock_config.id = str(uuid.uuid4())
|
||||
@ -95,7 +100,7 @@ class TestAppParameterApi:
|
||||
):
|
||||
"""Test retrieving parameters for a workflow app."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
_configure_current_app_mock(mock_current_app)
|
||||
|
||||
mock_app_model.mode = AppMode.WORKFLOW
|
||||
mock_workflow = Mock()
|
||||
@ -140,7 +145,7 @@ class TestAppParameterApi:
|
||||
):
|
||||
"""Test that AppUnavailableError is raised when chat app has no config."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
_configure_current_app_mock(mock_current_app)
|
||||
|
||||
mock_app_model.app_model_config = None
|
||||
mock_app_model.workflow = None
|
||||
@ -178,7 +183,7 @@ class TestAppParameterApi:
|
||||
):
|
||||
"""Test that AppUnavailableError is raised when workflow app has no workflow."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
_configure_current_app_mock(mock_current_app)
|
||||
|
||||
mock_app_model.mode = AppMode.WORKFLOW
|
||||
mock_app_model.workflow = None
|
||||
@ -245,7 +250,7 @@ class TestAppMetaApi:
|
||||
):
|
||||
"""Test retrieving app metadata via AppService."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
_configure_current_app_mock(mock_current_app)
|
||||
|
||||
mock_service_instance = Mock()
|
||||
mock_service_instance.get_app_meta.return_value = {
|
||||
@ -320,7 +325,7 @@ class TestAppInfoApi:
|
||||
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
|
||||
):
|
||||
"""Test retrieving basic app information."""
|
||||
mock_current_app.login_manager = Mock()
|
||||
_configure_current_app_mock(mock_current_app)
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
@ -361,7 +366,7 @@ class TestAppInfoApi:
|
||||
):
|
||||
"""Test retrieving app info with multiple tags."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
_configure_current_app_mock(mock_current_app)
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = str(uuid.uuid4())
|
||||
@ -414,7 +419,7 @@ class TestAppInfoApi:
|
||||
):
|
||||
"""Test retrieving app info when app has no tags."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
_configure_current_app_mock(mock_current_app)
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = str(uuid.uuid4())
|
||||
@ -466,7 +471,7 @@ class TestAppInfoApi:
|
||||
):
|
||||
"""Test that all app modes are correctly returned."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
_configure_current_app_mock(mock_current_app)
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = str(uuid.uuid4())
|
||||
|
||||
@ -2,31 +2,29 @@
|
||||
Unit tests for Service API Index endpoint
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.service_api import index as index_module
|
||||
from controllers.service_api.index import IndexApi
|
||||
|
||||
|
||||
def _get_index_response(app: Flask, version: str) -> dict[str, str]:
|
||||
with patch.object(index_module.dify_config.project, "version", version):
|
||||
with app.test_request_context("/", method="GET"):
|
||||
index_api = IndexApi()
|
||||
return index_api.get()
|
||||
|
||||
|
||||
class TestIndexApi:
|
||||
"""Test suite for IndexApi resource."""
|
||||
|
||||
@patch("controllers.service_api.index.dify_config", autospec=True)
|
||||
def test_get_returns_api_info(self, mock_config, app: Flask):
|
||||
def test_get_returns_api_info(self, app: Flask):
|
||||
"""Test that GET returns API metadata with correct structure."""
|
||||
# Arrange
|
||||
mock_config.project.version = "1.0.0-test"
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/", method="GET"):
|
||||
index_api = IndexApi()
|
||||
response = index_api.get()
|
||||
with patch("controllers.service_api.index.dify_config", mock_config):
|
||||
with app.test_request_context("/", method="GET"):
|
||||
index_api = IndexApi()
|
||||
response = index_api.get()
|
||||
response = _get_index_response(app, "1.0.0-test")
|
||||
|
||||
# Assert
|
||||
assert response["welcome"] == "Dify OpenAPI"
|
||||
@ -35,15 +33,8 @@ class TestIndexApi:
|
||||
|
||||
def test_get_response_has_required_fields(self, app: Flask):
|
||||
"""Test that response contains all required fields."""
|
||||
# Arrange
|
||||
mock_config = MagicMock()
|
||||
mock_config.project.version = "1.11.4"
|
||||
|
||||
# Act
|
||||
with patch("controllers.service_api.index.dify_config", mock_config):
|
||||
with app.test_request_context("/", method="GET"):
|
||||
index_api = IndexApi()
|
||||
response = index_api.get()
|
||||
response = _get_index_response(app, "1.11.4")
|
||||
|
||||
# Assert
|
||||
assert "welcome" in response
|
||||
@ -56,15 +47,8 @@ class TestIndexApi:
|
||||
@pytest.mark.parametrize("version", ["0.0.1", "1.0.0", "2.0.0-beta", "1.11.4"])
|
||||
def test_get_returns_correct_version(self, app: Flask, version):
|
||||
"""Test that server_version matches config version."""
|
||||
# Arrange
|
||||
mock_config = MagicMock()
|
||||
mock_config.project.version = version
|
||||
|
||||
# Act
|
||||
with patch("controllers.service_api.index.dify_config", mock_config):
|
||||
with app.test_request_context("/", method="GET"):
|
||||
index_api = IndexApi()
|
||||
response = index_api.get()
|
||||
response = _get_index_response(app, version)
|
||||
|
||||
# Assert
|
||||
assert response["server_version"] == version
|
||||
|
||||
@ -29,6 +29,11 @@ from tests.unit_tests.conftest import (
|
||||
)
|
||||
|
||||
|
||||
def _configure_current_app_mock(mock_current_app):
|
||||
mock_current_app.login_manager = Mock()
|
||||
mock_current_app._get_current_object = Mock(return_value=Mock())
|
||||
|
||||
|
||||
class TestValidateAndGetApiToken:
|
||||
"""Test suite for validate_and_get_api_token function"""
|
||||
|
||||
@ -120,8 +125,7 @@ class TestValidateAppToken:
|
||||
):
|
||||
"""Test that valid app token allows access to decorated view."""
|
||||
# Arrange
|
||||
# Use standard Mock for login_manager to avoid AsyncMockMixin warnings
|
||||
mock_current_app.login_manager = Mock()
|
||||
_configure_current_app_mock(mock_current_app)
|
||||
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = str(uuid.uuid4())
|
||||
@ -448,8 +452,7 @@ class TestValidateDatasetToken:
|
||||
def test_valid_dataset_token(self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app: Flask):
|
||||
"""Test that valid dataset token allows access."""
|
||||
# Arrange
|
||||
# Use standard Mock for login_manager
|
||||
mock_current_app.login_manager = Mock()
|
||||
_configure_current_app_mock(mock_current_app)
|
||||
|
||||
tenant_id = str(uuid.uuid4())
|
||||
mock_api_token = Mock()
|
||||
|
||||
@ -32,7 +32,7 @@ class TestConstants:
|
||||
class TestCreateSSRFProxyMCPHTTPClient:
|
||||
"""Test create_ssrf_proxy_mcp_http_client function."""
|
||||
|
||||
@patch("core.mcp.utils.dify_config", autospec=True)
|
||||
@patch("core.mcp.utils.dify_config")
|
||||
def test_create_client_with_all_url_proxy(self, mock_config):
|
||||
"""Test client creation with SSRF_PROXY_ALL_URL configured."""
|
||||
mock_config.SSRF_PROXY_ALL_URL = "http://proxy.example.com:8080"
|
||||
@ -50,7 +50,7 @@ class TestCreateSSRFProxyMCPHTTPClient:
|
||||
# Clean up
|
||||
client.close()
|
||||
|
||||
@patch("core.mcp.utils.dify_config", autospec=True)
|
||||
@patch("core.mcp.utils.dify_config")
|
||||
def test_create_client_with_http_https_proxies(self, mock_config):
|
||||
"""Test client creation with separate HTTP/HTTPS proxies."""
|
||||
mock_config.SSRF_PROXY_ALL_URL = None
|
||||
@ -66,7 +66,7 @@ class TestCreateSSRFProxyMCPHTTPClient:
|
||||
# Clean up
|
||||
client.close()
|
||||
|
||||
@patch("core.mcp.utils.dify_config", autospec=True)
|
||||
@patch("core.mcp.utils.dify_config")
|
||||
def test_create_client_without_proxy(self, mock_config):
|
||||
"""Test client creation without proxy configuration."""
|
||||
mock_config.SSRF_PROXY_ALL_URL = None
|
||||
@ -88,7 +88,7 @@ class TestCreateSSRFProxyMCPHTTPClient:
|
||||
# Clean up
|
||||
client.close()
|
||||
|
||||
@patch("core.mcp.utils.dify_config", autospec=True)
|
||||
@patch("core.mcp.utils.dify_config")
|
||||
def test_create_client_default_params(self, mock_config):
|
||||
"""Test client creation with default parameters."""
|
||||
mock_config.SSRF_PROXY_ALL_URL = None
|
||||
@ -140,7 +140,7 @@ class TestSSRFProxySSEConnect:
|
||||
|
||||
@patch("core.mcp.utils.connect_sse", autospec=True)
|
||||
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client", autospec=True)
|
||||
@patch("core.mcp.utils.dify_config", autospec=True)
|
||||
@patch("core.mcp.utils.dify_config")
|
||||
def test_sse_connect_without_client(self, mock_config, mock_create_client, mock_connect_sse):
|
||||
"""Test SSE connection without pre-configured client."""
|
||||
# Setup config
|
||||
|
||||
@ -88,7 +88,7 @@ class TestOutputModeration:
|
||||
def test_start_thread(self, output_moderation):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
with patch("core.moderation.output_moderation.current_app") as mock_current_app:
|
||||
mock_current_app._get_current_object.return_value = mock_app
|
||||
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
with patch("threading.Thread") as mock_thread_class:
|
||||
mock_thread_instance = MagicMock()
|
||||
mock_thread_class.return_value = mock_thread_instance
|
||||
|
||||
@ -106,7 +106,7 @@ class TestQAIndexProcessor:
|
||||
patch.object(processor, "_format_qa_document", side_effect=_append_document) as mock_format,
|
||||
patch("core.rag.index_processor.processor.qa_index_processor.current_app") as mock_current_app,
|
||||
):
|
||||
mock_current_app._get_current_object.return_value = fake_flask_app
|
||||
mock_current_app._get_current_object = Mock(return_value=fake_flask_app)
|
||||
result = processor.transform(
|
||||
[document],
|
||||
process_rule=process_rule,
|
||||
@ -155,7 +155,7 @@ class TestQAIndexProcessor:
|
||||
"core.rag.index_processor.processor.qa_index_processor.threading.Thread", side_effect=_ImmediateThread
|
||||
),
|
||||
):
|
||||
mock_current_app._get_current_object.return_value = fake_flask_app
|
||||
mock_current_app._get_current_object = Mock(return_value=fake_flask_app)
|
||||
result = processor.transform(documents, process_rule=process_rule, preview=False, tenant_id="tenant-1")
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
@ -594,6 +594,7 @@ class TestIndexingRunnerLoad:
|
||||
patch("core.indexing_runner.threading.Thread") as mock_thread,
|
||||
patch("core.indexing_runner.concurrent.futures.ThreadPoolExecutor") as mock_executor,
|
||||
):
|
||||
mock_app._get_current_object = Mock(return_value=Mock())
|
||||
yield {
|
||||
"db": mock_db,
|
||||
"model_manager": mock_model_manager,
|
||||
|
||||
@ -51,7 +51,7 @@ class TestRepositoryFactory:
|
||||
import_string("invalidpath")
|
||||
assert "doesn't look like a module path" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config", autospec=True)
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_success(self, mock_config):
|
||||
"""Test successful WorkflowExecutionRepository creation."""
|
||||
# Setup mock configuration
|
||||
@ -86,7 +86,7 @@ class TestRepositoryFactory:
|
||||
)
|
||||
assert result is mock_repository_instance
|
||||
|
||||
@patch("core.repositories.factory.dify_config", autospec=True)
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_import_error(self, mock_config):
|
||||
"""Test WorkflowExecutionRepository creation with import error."""
|
||||
# Setup mock configuration with invalid class path
|
||||
@ -104,7 +104,7 @@ class TestRepositoryFactory:
|
||||
)
|
||||
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config", autospec=True)
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_instantiation_error(self, mock_config):
|
||||
"""Test WorkflowExecutionRepository creation with instantiation error."""
|
||||
# Setup mock configuration
|
||||
@ -128,7 +128,7 @@ class TestRepositoryFactory:
|
||||
)
|
||||
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config", autospec=True)
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_success(self, mock_config):
|
||||
"""Test successful WorkflowNodeExecutionRepository creation."""
|
||||
# Setup mock configuration
|
||||
@ -163,7 +163,7 @@ class TestRepositoryFactory:
|
||||
)
|
||||
assert result is mock_repository_instance
|
||||
|
||||
@patch("core.repositories.factory.dify_config", autospec=True)
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_import_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with import error."""
|
||||
# Setup mock configuration with invalid class path
|
||||
@ -181,7 +181,7 @@ class TestRepositoryFactory:
|
||||
)
|
||||
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config", autospec=True)
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
|
||||
# Setup mock configuration
|
||||
@ -211,7 +211,7 @@ class TestRepositoryFactory:
|
||||
error = RepositoryImportError(error_message)
|
||||
assert str(error) == error_message
|
||||
|
||||
@patch("core.repositories.factory.dify_config", autospec=True)
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_with_engine_instead_of_sessionmaker(self, mock_config):
|
||||
"""Test repository creation with Engine instead of sessionmaker."""
|
||||
# Setup mock configuration
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import MagicMock, patch
|
||||
@ -472,7 +473,7 @@ class TestSchemaResolverClass:
|
||||
assert resolved[2]["title"] == "Q&A Structure"
|
||||
|
||||
def test_cache_performance(self):
|
||||
"""Test that caching improves performance"""
|
||||
"""Test that repeated references share cached schema lookups."""
|
||||
SchemaResolver.clear_cache()
|
||||
|
||||
# Create a schema with many references to the same schema
|
||||
@ -484,36 +485,16 @@ class TestSchemaResolverClass:
|
||||
},
|
||||
}
|
||||
|
||||
# First run (no cache) - run multiple times to warm up
|
||||
results1 = []
|
||||
for _ in range(3):
|
||||
SchemaResolver.clear_cache()
|
||||
start = time.perf_counter()
|
||||
result1 = resolve_dify_schema_refs(schema)
|
||||
time_no_cache = time.perf_counter() - start
|
||||
results1.append(time_no_cache)
|
||||
registry = SchemaRegistry.default_registry()
|
||||
file_schema = registry.get_schema("https://dify.ai/schemas/v1/file.json")
|
||||
assert file_schema is not None
|
||||
|
||||
avg_time_no_cache = sum(results1) / len(results1)
|
||||
with patch.object(registry, "get_schema", wraps=registry.get_schema) as mock_get:
|
||||
result1 = resolve_dify_schema_refs(copy.deepcopy(schema), registry=registry)
|
||||
result2 = resolve_dify_schema_refs(copy.deepcopy(schema), registry=registry)
|
||||
|
||||
# Second run (with cache) - run multiple times
|
||||
# Warm up cache first
|
||||
resolve_dify_schema_refs(schema)
|
||||
|
||||
results2 = []
|
||||
for _ in range(3):
|
||||
start = time.perf_counter()
|
||||
result2 = resolve_dify_schema_refs(schema)
|
||||
time_with_cache = time.perf_counter() - start
|
||||
results2.append(time_with_cache)
|
||||
|
||||
avg_time_with_cache = sum(results2) / len(results2)
|
||||
|
||||
# Cache should make it faster (more lenient check)
|
||||
assert result1 == result2
|
||||
# Cache should provide some performance benefit (allow for measurement variance)
|
||||
# We expect cache to be faster, but allow for small timing variations
|
||||
performance_ratio = avg_time_with_cache / avg_time_no_cache if avg_time_no_cache > 0 else 1.0
|
||||
assert performance_ratio <= 2.0, f"Cache performance degraded too much: {performance_ratio}"
|
||||
mock_get.assert_called_once_with("https://dify.ai/schemas/v1/file.json")
|
||||
|
||||
def test_fast_path_performance_no_refs(self):
|
||||
"""Test that schemas without $refs use fast path and avoid deep copying"""
|
||||
|
||||
42
api/tests/unit_tests/events/test_events_package_compat.py
Normal file
42
api/tests/unit_tests/events/test_events_package_compat.py
Normal file
@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
|
||||
from events import Events, EventsError, EventsException
|
||||
|
||||
|
||||
def test_events_package_exposes_opensearchpy_compatible_events_class():
|
||||
calls: list[str] = []
|
||||
events = Events()
|
||||
|
||||
events.request_start += lambda: calls.append("start")
|
||||
events.request_end += lambda: calls.append("end")
|
||||
|
||||
events.request_start()
|
||||
events.request_end()
|
||||
|
||||
assert calls == ["start", "end"]
|
||||
|
||||
|
||||
def test_events_package_supports_named_slots_iteration_removal_and_private_attrs():
|
||||
calls: list[str] = []
|
||||
|
||||
def handler() -> None:
|
||||
calls.append("handled")
|
||||
|
||||
events = Events("request_start")
|
||||
|
||||
events.request_start += handler
|
||||
events.request_start += handler
|
||||
|
||||
assert len(events.request_start) == 2
|
||||
assert list(events.request_start) == [handler, handler]
|
||||
|
||||
events.request_start -= handler
|
||||
|
||||
assert len(events.request_start) == 0
|
||||
events.request_start()
|
||||
assert calls == []
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
_ = events._private # type: ignore[attr-defined]
|
||||
|
||||
assert EventsException is EventsError
|
||||
@ -11,7 +11,7 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_init_success_with_all_config(self):
|
||||
"""Test successful initialization when all required config is provided."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
@ -31,7 +31,7 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_init_raises_error_when_url_missing(self):
|
||||
"""Test initialization raises ValueError when SUPABASE_URL is None."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
mock_config.SUPABASE_URL = None
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
@ -41,7 +41,7 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_init_raises_error_when_api_key_missing(self):
|
||||
"""Test initialization raises ValueError when SUPABASE_API_KEY is None."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = None
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
@ -51,7 +51,7 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_init_raises_error_when_bucket_name_missing(self):
|
||||
"""Test initialization raises ValueError when SUPABASE_BUCKET_NAME is None."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = None
|
||||
@ -61,7 +61,7 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_create_bucket_when_not_exists(self):
|
||||
"""Test create_bucket creates bucket when it doesn't exist."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
@ -77,7 +77,7 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_create_bucket_when_exists(self):
|
||||
"""Test create_bucket does not create bucket when it already exists."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
@ -94,7 +94,7 @@ class TestSupabaseStorage:
|
||||
@pytest.fixture
|
||||
def storage_with_mock_client(self):
|
||||
"""Fixture providing SupabaseStorage with mocked client."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
@ -209,7 +209,7 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_bucket_exists_returns_true_when_bucket_found(self):
|
||||
"""Test bucket_exists returns True when bucket is found in list."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
@ -229,7 +229,7 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_bucket_exists_returns_false_when_bucket_not_found(self):
|
||||
"""Test bucket_exists returns False when bucket is not found in list."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
@ -252,7 +252,7 @@ class TestSupabaseStorage:
|
||||
|
||||
def test_bucket_exists_returns_false_when_no_buckets(self):
|
||||
"""Test bucket_exists returns False when no buckets exist."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
@ -0,0 +1,76 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
|
||||
MIGRATION_MODULE = "services.plugin.plugin_migration"
|
||||
|
||||
|
||||
def test_fetch_plugin_unique_identifier_returns_none_when_disabled(mocker: MockerFixture) -> None:
|
||||
mocker.patch("services.plugin.plugin_migration.dify_config.MARKETPLACE_ENABLED", False)
|
||||
batch_fetch = mocker.patch("services.plugin.plugin_migration.marketplace.batch_fetch_plugin_manifests")
|
||||
|
||||
result = PluginMigration._fetch_plugin_unique_identifier("langgenius/openai")
|
||||
|
||||
assert result is None
|
||||
batch_fetch.assert_not_called()
|
||||
|
||||
|
||||
def test_fetch_plugin_unique_identifier_calls_marketplace_when_enabled(mocker: MockerFixture) -> None:
|
||||
mocker.patch("services.plugin.plugin_migration.dify_config.MARKETPLACE_ENABLED", True)
|
||||
manifest = mocker.MagicMock()
|
||||
manifest.latest_package_identifier = "langgenius/openai:1.0.0@abc"
|
||||
mocker.patch(
|
||||
"services.plugin.plugin_migration.marketplace.batch_fetch_plugin_manifests",
|
||||
return_value=[manifest],
|
||||
)
|
||||
|
||||
result = PluginMigration._fetch_plugin_unique_identifier("langgenius/openai")
|
||||
|
||||
assert result == "langgenius/openai:1.0.0@abc"
|
||||
|
||||
|
||||
class TestHandlePluginInstanceInstall:
|
||||
def test_raises_when_disabled_and_map_nonempty(self) -> None:
|
||||
with patch(f"{MIGRATION_MODULE}.dify_config") as mock_cfg:
|
||||
mock_cfg.MARKETPLACE_ENABLED = False
|
||||
|
||||
with pytest.raises(ValueError, match="Marketplace disabled"):
|
||||
PluginMigration.handle_plugin_instance_install(
|
||||
"tenant1", {"langgenius/openai": "langgenius/openai:1.0.0@abc"}
|
||||
)
|
||||
|
||||
def test_no_raise_when_disabled_and_map_empty(self) -> None:
|
||||
with (
|
||||
patch(f"{MIGRATION_MODULE}.dify_config") as mock_cfg,
|
||||
patch(f"{MIGRATION_MODULE}.PluginInstaller") as mock_installer_cls,
|
||||
):
|
||||
mock_cfg.MARKETPLACE_ENABLED = False
|
||||
mock_installer = MagicMock()
|
||||
mock_installer_cls.return_value = mock_installer
|
||||
mock_installer.install_from_identifiers.return_value = MagicMock(all_installed=True)
|
||||
|
||||
result = PluginMigration.handle_plugin_instance_install("tenant1", {})
|
||||
|
||||
assert isinstance(result, dict)
|
||||
|
||||
def test_proceeds_when_enabled(self) -> None:
|
||||
with (
|
||||
patch(f"{MIGRATION_MODULE}.dify_config") as mock_cfg,
|
||||
patch(f"{MIGRATION_MODULE}.marketplace") as mock_marketplace,
|
||||
patch(f"{MIGRATION_MODULE}.PluginInstaller") as mock_installer_cls,
|
||||
):
|
||||
mock_cfg.MARKETPLACE_ENABLED = True
|
||||
mock_marketplace.download_plugin_pkg.return_value = b"pkg_data"
|
||||
mock_installer = MagicMock()
|
||||
mock_installer_cls.return_value = mock_installer
|
||||
mock_installer.install_from_identifiers.return_value = MagicMock(all_installed=True)
|
||||
|
||||
result = PluginMigration.handle_plugin_instance_install(
|
||||
"tenant1", {"langgenius/openai": "langgenius/openai:1.0.0@abc"}
|
||||
)
|
||||
|
||||
mock_marketplace.download_plugin_pkg.assert_called_once()
|
||||
assert "success" in result or "failed" in result
|
||||
50
api/tests/unit_tests/services/plugin/test_plugin_service.py
Normal file
50
api/tests/unit_tests/services/plugin/test_plugin_service.py
Normal file
@ -0,0 +1,50 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
MODULE = "services.plugin.plugin_service"
|
||||
|
||||
|
||||
class TestFetchLatestPluginVersion:
|
||||
def test_skips_marketplace_fetch_when_disabled(self) -> None:
|
||||
"""Cache misses stay None; marketplace is never called when disabled."""
|
||||
with (
|
||||
patch(f"{MODULE}.dify_config") as mock_cfg,
|
||||
patch(f"{MODULE}.redis_client") as mock_redis,
|
||||
patch(f"{MODULE}.marketplace") as mock_marketplace,
|
||||
):
|
||||
mock_cfg.MARKETPLACE_ENABLED = False
|
||||
mock_redis.get.return_value = None # all cache misses
|
||||
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.fetch_latest_plugin_version(["langgenius/openai", "langgenius/anthropic"])
|
||||
|
||||
mock_marketplace.batch_fetch_plugin_manifests.assert_not_called()
|
||||
assert result == {"langgenius/openai": None, "langgenius/anthropic": None}
|
||||
|
||||
def test_calls_marketplace_fetch_when_enabled(self) -> None:
|
||||
"""Cache misses trigger marketplace fetch when enabled."""
|
||||
manifest = MagicMock()
|
||||
manifest.plugin_id = "langgenius/openai"
|
||||
manifest.latest_version = "1.0.0"
|
||||
manifest.latest_package_identifier = "langgenius/openai:1.0.0@abc"
|
||||
manifest.status = "active"
|
||||
manifest.deprecated_reason = ""
|
||||
manifest.alternative_plugin_id = ""
|
||||
|
||||
with (
|
||||
patch(f"{MODULE}.dify_config") as mock_cfg,
|
||||
patch(f"{MODULE}.redis_client") as mock_redis,
|
||||
patch(f"{MODULE}.marketplace") as mock_marketplace,
|
||||
):
|
||||
mock_cfg.MARKETPLACE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_marketplace.batch_fetch_plugin_manifests.return_value = [manifest]
|
||||
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.fetch_latest_plugin_version(["langgenius/openai"])
|
||||
|
||||
# The list arg is mutated by remove() after the call, so check call count + result.
|
||||
mock_marketplace.batch_fetch_plugin_manifests.assert_called_once()
|
||||
assert result["langgenius/openai"] is not None
|
||||
assert result["langgenius/openai"].version == "1.0.0"
|
||||
@ -0,0 +1,36 @@
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
def _make_service() -> RagPipelineService:
|
||||
return RagPipelineService.__new__(RagPipelineService)
|
||||
|
||||
|
||||
def test_fetch_recommended_plugin_manifests_returns_empty_when_disabled(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.dify_config.MARKETPLACE_ENABLED", False)
|
||||
batch_fetch = mocker.patch("services.rag_pipeline.rag_pipeline.marketplace.batch_fetch_plugin_by_ids")
|
||||
|
||||
service = _make_service()
|
||||
result = service._fetch_recommended_plugin_manifests(["langgenius/openai"])
|
||||
|
||||
assert result == []
|
||||
batch_fetch.assert_not_called()
|
||||
|
||||
|
||||
def test_fetch_recommended_plugin_manifests_returns_data_when_enabled(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.dify_config.MARKETPLACE_ENABLED", True)
|
||||
expected = [{"plugin_id": "langgenius/openai", "name": "OpenAI"}]
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline.marketplace.batch_fetch_plugin_by_ids",
|
||||
return_value=expected,
|
||||
)
|
||||
|
||||
service = _make_service()
|
||||
result = service._fetch_recommended_plugin_manifests(["langgenius/openai"])
|
||||
|
||||
assert result == expected
|
||||
@ -1,8 +1,10 @@
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from models.dataset import Dataset
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration
|
||||
@ -514,3 +516,64 @@ def test_deal_document_data_upload_file_with_existing_file(mocker) -> None:
|
||||
assert document.data_source_type == "local_file"
|
||||
assert "real_file_id" in document.data_source_info
|
||||
assert add_mock.call_count >= 2
|
||||
|
||||
|
||||
def _make_service():
|
||||
return RagPipelineTransformService.__new__(RagPipelineTransformService)
|
||||
|
||||
|
||||
def test_deal_dependencies_skips_marketplace_when_disabled(mocker: MockerFixture, caplog) -> None:
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.dify_config.MARKETPLACE_ENABLED",
|
||||
False,
|
||||
)
|
||||
installer = mocker.patch("services.rag_pipeline.rag_pipeline_transform_service.PluginInstaller").return_value
|
||||
installer.list_plugins.return_value = []
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_transform_service.PluginMigration")
|
||||
install_call = mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.PluginService.install_from_marketplace_pkg"
|
||||
)
|
||||
|
||||
pipeline_yaml = {
|
||||
"dependencies": [
|
||||
{
|
||||
"type": "marketplace",
|
||||
"value": {"plugin_unique_identifier": "langgenius/openai:1.0.0@abc"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
service = _make_service()
|
||||
with caplog.at_level(logging.WARNING):
|
||||
service._deal_dependencies(pipeline_yaml, "tenant-1")
|
||||
|
||||
install_call.assert_not_called()
|
||||
assert any("Marketplace disabled" in rec.message for rec in caplog.records)
|
||||
|
||||
|
||||
def test_deal_dependencies_installs_when_enabled(mocker: MockerFixture) -> None:
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.dify_config.MARKETPLACE_ENABLED",
|
||||
True,
|
||||
)
|
||||
installer = mocker.patch("services.rag_pipeline.rag_pipeline_transform_service.PluginInstaller").return_value
|
||||
installer.list_plugins.return_value = []
|
||||
migration = mocker.patch("services.rag_pipeline.rag_pipeline_transform_service.PluginMigration").return_value
|
||||
migration._fetch_plugin_unique_identifier.return_value = "langgenius/openai:1.0.0@abc"
|
||||
install_call = mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.PluginService.install_from_marketplace_pkg"
|
||||
)
|
||||
|
||||
pipeline_yaml = {
|
||||
"dependencies": [
|
||||
{
|
||||
"type": "marketplace",
|
||||
"value": {"plugin_unique_identifier": "langgenius/openai:1.0.0@abc"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
service = _make_service()
|
||||
service._deal_dependencies(pipeline_yaml, "tenant-1")
|
||||
|
||||
install_call.assert_called_once_with("tenant-1", ["langgenius/openai:1.0.0@abc"])
|
||||
|
||||
@ -15,7 +15,7 @@ from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
class TestWorkflowRunArchiver:
|
||||
"""Tests for the WorkflowRunArchiver class."""
|
||||
|
||||
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config", autospec=True)
|
||||
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config")
|
||||
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage", autospec=True)
|
||||
def test_archiver_initialization(self, mock_get_storage, mock_config):
|
||||
"""Test archiver can be initialized with various options."""
|
||||
|
||||
@ -403,7 +403,7 @@ class TestBillingDisabledPolicyFilterMessageIds:
|
||||
class TestCreateMessageCleanPolicy:
|
||||
"""Unit tests for create_message_clean_policy factory function."""
|
||||
|
||||
@patch("services.retention.conversation.messages_clean_policy.dify_config", autospec=True)
|
||||
@patch("services.retention.conversation.messages_clean_policy.dify_config")
|
||||
def test_billing_disabled_returns_billing_disabled_policy(self, mock_config):
|
||||
"""Test that BILLING_ENABLED=False returns BillingDisabledPolicy."""
|
||||
# Arrange
|
||||
@ -416,7 +416,7 @@ class TestCreateMessageCleanPolicy:
|
||||
assert isinstance(policy, BillingDisabledPolicy)
|
||||
|
||||
@patch("services.retention.conversation.messages_clean_policy.BillingService", autospec=True)
|
||||
@patch("services.retention.conversation.messages_clean_policy.dify_config", autospec=True)
|
||||
@patch("services.retention.conversation.messages_clean_policy.dify_config")
|
||||
def test_billing_enabled_policy_has_correct_internals(self, mock_config, mock_billing_service):
|
||||
"""Test that BillingSandboxPolicy is created with correct internal values."""
|
||||
# Arrange
|
||||
|
||||
@ -13,7 +13,7 @@ from services.recommended_app_service import RecommendedAppService
|
||||
class TestGetRecommendAppDetailNullCheck:
|
||||
@patch("services.recommended_app_service.FeatureService", autospec=True)
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_returns_none_when_retrieval_returns_none_and_trial_disabled(
|
||||
self, mock_config, mock_factory_class, mock_feature_service
|
||||
):
|
||||
@ -29,7 +29,7 @@ class TestGetRecommendAppDetailNullCheck:
|
||||
|
||||
@patch("services.recommended_app_service.FeatureService", autospec=True)
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_returns_none_when_retrieval_returns_none_and_trial_enabled(
|
||||
self, mock_config, mock_factory_class, mock_feature_service
|
||||
):
|
||||
|
||||
23
api/tests/unit_tests/test_makefile_backend_tests.py
Normal file
23
api/tests/unit_tests/test_makefile_backend_tests.py
Normal file
@ -0,0 +1,23 @@
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_default_make_test_runs_backend_pytest_suites():
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
|
||||
completed = subprocess.run(["make", "-n", "test"], cwd=repo_root, check=True, capture_output=True, text=True)
|
||||
dry_run_output = completed.stdout
|
||||
|
||||
assert "api/tests/unit_tests" in dry_run_output
|
||||
assert "api/providers/vdb/*/tests/unit_tests" in dry_run_output
|
||||
assert "api/providers/trace/*/tests/unit_tests" in dry_run_output
|
||||
assert "-p no:benchmark" in dry_run_output
|
||||
assert "--start-middleware" in dry_run_output
|
||||
assert "api/tests/integration_tests/workflow" in dry_run_output
|
||||
assert "api/tests/integration_tests/tools" in dry_run_output
|
||||
assert "api/tests/test_containers_integration_tests" in dry_run_output
|
||||
assert "--start-vdb" in dry_run_output
|
||||
assert "api/providers/vdb/vdb-chroma/tests/integration_tests" in dry_run_output
|
||||
assert "api/providers/vdb/vdb-pgvector/tests/integration_tests" in dry_run_output
|
||||
assert "api/providers/vdb/vdb-qdrant/tests/integration_tests" in dry_run_output
|
||||
assert "api/providers/vdb/vdb-weaviate/tests/integration_tests" in dry_run_output
|
||||
239
api/tests/unit_tests/test_pytest_dify.py
Normal file
239
api/tests/unit_tests/test_pytest_dify.py
Normal file
@ -0,0 +1,239 @@
|
||||
import importlib.util
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.pytest_dify import (
|
||||
DEFAULT_LOG_FORMAT,
|
||||
DockerComposeStack,
|
||||
build_middleware_stack,
|
||||
build_vdb_stack,
|
||||
ensure_backend_test_environment,
|
||||
ensure_compose_env_files,
|
||||
parse_services,
|
||||
)
|
||||
|
||||
|
||||
def _load_api_conftest():
|
||||
api_conftest_path = Path(__file__).resolve().parents[2] / "conftest.py"
|
||||
spec = importlib.util.spec_from_file_location("api_root_conftest_for_tests", api_conftest_path)
|
||||
assert spec is not None
|
||||
assert spec.loader is not None
|
||||
api_conftest = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = api_conftest
|
||||
spec.loader.exec_module(api_conftest)
|
||||
return api_conftest
|
||||
|
||||
|
||||
def test_ensure_backend_test_environment_uses_example_env_and_stable_logging(
|
||||
tmp_path: Path,
|
||||
monkeypatch,
|
||||
):
|
||||
repo_root = tmp_path
|
||||
integration_tests_dir = repo_root / "api" / "tests" / "integration_tests"
|
||||
integration_tests_dir.mkdir(parents=True)
|
||||
env_example = integration_tests_dir / ".env.example"
|
||||
env_example.write_text("LOG_LEVEL=INFO\n")
|
||||
storage_root = repo_root / "storage"
|
||||
|
||||
monkeypatch.setenv("LOG_FORMAT", "json")
|
||||
monkeypatch.delenv("LOG_OUTPUT_FORMAT", raising=False)
|
||||
monkeypatch.delenv("DIFY_TEST_ENV_FILE", raising=False)
|
||||
monkeypatch.delenv("DIFY_VDB_TEST_ENV_FILE", raising=False)
|
||||
monkeypatch.setenv("OPENDAL_FS_ROOT", str(storage_root))
|
||||
|
||||
ensure_backend_test_environment(repo_root)
|
||||
|
||||
assert os.environ["DIFY_TEST_ENV_FILE"] == str(env_example)
|
||||
assert "DIFY_VDB_TEST_ENV_FILE" not in os.environ
|
||||
assert os.environ["LOG_OUTPUT_FORMAT"] == "text"
|
||||
assert os.environ["LOG_FORMAT"] == DEFAULT_LOG_FORMAT
|
||||
assert os.environ["STORAGE_TYPE"] == "opendal"
|
||||
assert os.environ["OPENDAL_SCHEME"] == "fs"
|
||||
assert storage_root.is_dir()
|
||||
|
||||
|
||||
def test_ensure_compose_env_files_copies_missing_env_files(tmp_path: Path):
|
||||
docker_dir = tmp_path / "docker"
|
||||
envs_dir = docker_dir / "envs"
|
||||
envs_dir.mkdir(parents=True)
|
||||
(docker_dir / ".env.example").write_text("APP_WEB_URL=http://localhost\n")
|
||||
(envs_dir / "middleware.env.example").write_text("DB_PASSWORD=difyai123456\n")
|
||||
|
||||
ensure_compose_env_files(tmp_path)
|
||||
|
||||
assert (docker_dir / ".env").read_text() == "APP_WEB_URL=http://localhost\n"
|
||||
assert (docker_dir / "middleware.env").read_text() == "DB_PASSWORD=difyai123456\n"
|
||||
|
||||
|
||||
def test_parse_services_discards_empty_items():
|
||||
assert parse_services(" db_postgres, redis,, sandbox ") == ["db_postgres", "redis", "sandbox"]
|
||||
|
||||
|
||||
def test_stack_up_uses_waiting_compose_command(monkeypatch, tmp_path: Path):
|
||||
calls: list[list[str]] = []
|
||||
|
||||
def fake_run(args, **kwargs):
|
||||
calls.append(args)
|
||||
if args == ["docker", "compose", "up", "--help"]:
|
||||
return subprocess.CompletedProcess(args=args, returncode=0, stdout="--wait\n--wait-timeout\n")
|
||||
return subprocess.CompletedProcess(args=args, returncode=0)
|
||||
|
||||
monkeypatch.setattr(subprocess, "run", fake_run)
|
||||
monkeypatch.setattr("time.sleep", lambda _: None)
|
||||
|
||||
stack = DockerComposeStack(
|
||||
name="middleware",
|
||||
project_name="dify-pytest-middleware",
|
||||
repo_root=tmp_path,
|
||||
compose_files=(tmp_path / "docker-compose.yaml",),
|
||||
env_file=tmp_path / "middleware.env",
|
||||
services=("db_postgres", "redis"),
|
||||
)
|
||||
|
||||
stack.up()
|
||||
|
||||
assert calls == [
|
||||
["docker", "compose", "up", "--help"],
|
||||
[
|
||||
"docker",
|
||||
"compose",
|
||||
"--project-name",
|
||||
"dify-pytest-middleware",
|
||||
"--env-file",
|
||||
str(tmp_path / "middleware.env"),
|
||||
"-f",
|
||||
str(tmp_path / "docker-compose.yaml"),
|
||||
"up",
|
||||
"-d",
|
||||
"--wait",
|
||||
"--wait-timeout",
|
||||
"180",
|
||||
"db_postgres",
|
||||
"redis",
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
def test_stack_up_skips_wait_flags_when_compose_help_omits_them(monkeypatch, tmp_path: Path):
|
||||
calls: list[list[str]] = []
|
||||
|
||||
def fake_run(args, **kwargs):
|
||||
calls.append(args)
|
||||
if args == ["docker", "compose", "up", "--help"]:
|
||||
return subprocess.CompletedProcess(args=args, returncode=0, stdout="Usage: docker compose up\n")
|
||||
return subprocess.CompletedProcess(args=args, returncode=0)
|
||||
|
||||
monkeypatch.setattr(subprocess, "run", fake_run)
|
||||
monkeypatch.setattr("time.sleep", lambda _: None)
|
||||
|
||||
stack = DockerComposeStack(
|
||||
name="middleware",
|
||||
project_name="dify-pytest-middleware",
|
||||
repo_root=tmp_path,
|
||||
compose_files=(tmp_path / "docker-compose.yaml",),
|
||||
env_file=tmp_path / "middleware.env",
|
||||
services=("db_postgres",),
|
||||
)
|
||||
|
||||
stack.up()
|
||||
|
||||
assert calls[1][-3:] == ["up", "-d", "db_postgres"]
|
||||
assert "--wait" not in calls[1]
|
||||
assert "--wait-timeout" not in calls[1]
|
||||
|
||||
|
||||
def test_builders_use_expected_compose_files(tmp_path: Path):
|
||||
middleware = build_middleware_stack(tmp_path, ["db_postgres"])
|
||||
vdb = build_vdb_stack(tmp_path, ["weaviate", "qdrant"])
|
||||
|
||||
assert middleware.compose_files == (tmp_path / "docker" / "docker-compose.middleware.yaml",)
|
||||
assert middleware.env_file == tmp_path / "docker" / "middleware.env"
|
||||
assert middleware.ready_delay_seconds == 5.0
|
||||
assert vdb.compose_files == (
|
||||
tmp_path / "docker" / "docker-compose.yaml",
|
||||
tmp_path / "docker" / "docker-compose.pytest.ports.yaml",
|
||||
)
|
||||
assert vdb.env_file == tmp_path / "docker" / ".env"
|
||||
assert vdb.profiles == ("weaviate", "qdrant")
|
||||
|
||||
|
||||
def test_pytest_sessionstart_cleans_started_stacks_when_later_stack_fails(monkeypatch):
|
||||
api_conftest = _load_api_conftest()
|
||||
|
||||
events: list[str] = []
|
||||
|
||||
class FakeStack:
|
||||
def __init__(self, name: str, fail: bool = False) -> None:
|
||||
self.name = name
|
||||
self.fail = fail
|
||||
|
||||
def up(self) -> None:
|
||||
events.append(f"{self.name}:up")
|
||||
if self.fail:
|
||||
raise RuntimeError(f"{self.name} failed")
|
||||
|
||||
def down(self) -> None:
|
||||
events.append(f"{self.name}:down")
|
||||
|
||||
class FakeConfig:
|
||||
stash: dict[object, list[FakeStack]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.stash = {}
|
||||
|
||||
def getoption(self, name: str) -> bool | str:
|
||||
options: dict[str, bool | str] = {
|
||||
"start_middleware": True,
|
||||
"middleware_services": "db_postgres",
|
||||
"start_vdb": True,
|
||||
"vdb_services": "qdrant",
|
||||
}
|
||||
return options[name]
|
||||
|
||||
middleware_stack = FakeStack("middleware")
|
||||
vdb_stack = FakeStack("vdb", fail=True)
|
||||
monkeypatch.setattr(api_conftest, "ensure_compose_env_files", lambda _repo_root: None)
|
||||
monkeypatch.setattr(api_conftest, "build_middleware_stack", lambda *_args: middleware_stack)
|
||||
monkeypatch.setattr(api_conftest, "build_vdb_stack", lambda *_args: vdb_stack)
|
||||
|
||||
config = FakeConfig()
|
||||
session = SimpleNamespace(config=config)
|
||||
|
||||
with pytest.raises(RuntimeError, match="vdb failed"):
|
||||
api_conftest.pytest_sessionstart(session)
|
||||
|
||||
assert events == ["middleware:up", "vdb:up", "vdb:down", "middleware:down"]
|
||||
assert config.stash[api_conftest._DIFY_COMPOSE_STACKS_KEY] == []
|
||||
|
||||
|
||||
def test_stop_stacks_attempts_all_stacks_before_reporting_errors():
|
||||
api_conftest = _load_api_conftest()
|
||||
events: list[str] = []
|
||||
|
||||
class FakeStack:
|
||||
def __init__(self, name: str, fail: bool = False) -> None:
|
||||
self.name = name
|
||||
self.fail = fail
|
||||
|
||||
def down(self) -> None:
|
||||
events.append(f"{self.name}:down")
|
||||
if self.fail:
|
||||
raise RuntimeError(f"{self.name} failed")
|
||||
|
||||
with pytest.raises(BaseExceptionGroup) as exc_info:
|
||||
api_conftest._stop_stacks(
|
||||
[
|
||||
FakeStack("middleware"),
|
||||
FakeStack("vdb", fail=True),
|
||||
FakeStack("extra"),
|
||||
]
|
||||
)
|
||||
|
||||
assert events == ["extra:down", "vdb:down", "middleware:down"]
|
||||
assert len(exc_info.value.exceptions) == 1
|
||||
assert "vdb failed" in str(exc_info.value.exceptions[0])
|
||||
@ -1,58 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
set -ex
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
cd "$SCRIPT_DIR/../.."
|
||||
|
||||
PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-180}"
|
||||
|
||||
# Ensure OpenDAL local storage works even if .env isn't loaded
|
||||
export STORAGE_TYPE=${STORAGE_TYPE:-opendal}
|
||||
export OPENDAL_SCHEME=${OPENDAL_SCHEME:-fs}
|
||||
export OPENDAL_FS_ROOT=${OPENDAL_FS_ROOT:-/tmp/dify-storage}
|
||||
mkdir -p "${OPENDAL_FS_ROOT}"
|
||||
|
||||
# Prepare env files like CI
|
||||
cp -n docker/.env.example docker/.env || true
|
||||
cp -n docker/envs/middleware.env.example docker/middleware.env || true
|
||||
cp -n api/tests/integration_tests/.env.example api/tests/integration_tests/.env || true
|
||||
|
||||
# Expose service ports (same as CI) without leaving the repo dirty
|
||||
EXPOSE_BACKUPS=()
|
||||
for f in docker/docker-compose.yaml docker/tidb/docker-compose.yaml; do
|
||||
if [[ -f "$f" ]]; then
|
||||
cp "$f" "$f.ci.bak"
|
||||
EXPOSE_BACKUPS+=("$f")
|
||||
fi
|
||||
done
|
||||
if command -v yq >/dev/null 2>&1; then
|
||||
sh .github/workflows/expose_service_ports.sh || true
|
||||
else
|
||||
echo "skip expose_service_ports (yq not installed)" >&2
|
||||
fi
|
||||
|
||||
# Optionally start middleware stack (db, redis, sandbox, ssrf proxy) to mirror CI
|
||||
STARTED_MIDDLEWARE=0
|
||||
if [[ "${SKIP_MIDDLEWARE:-0}" != "1" ]]; then
|
||||
docker compose -f docker/docker-compose.middleware.yaml --env-file docker/middleware.env up -d db_postgres redis sandbox ssrf_proxy
|
||||
STARTED_MIDDLEWARE=1
|
||||
# Give services a moment to come up
|
||||
sleep 5
|
||||
fi
|
||||
|
||||
cleanup() {
|
||||
if [[ $STARTED_MIDDLEWARE -eq 1 ]]; then
|
||||
docker compose -f docker/docker-compose.middleware.yaml --env-file docker/middleware.env down
|
||||
fi
|
||||
for f in "${EXPOSE_BACKUPS[@]}"; do
|
||||
mv "$f.ci.bak" "$f"
|
||||
done
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
pytest --timeout "${PYTEST_TIMEOUT}" \
|
||||
api/tests/integration_tests/workflow \
|
||||
api/tests/integration_tests/tools \
|
||||
api/tests/test_containers_integration_tests \
|
||||
api/tests/unit_tests
|
||||
@ -1,21 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -euxo pipefail
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
cd "$SCRIPT_DIR/../.."
|
||||
|
||||
PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-20}"
|
||||
PYTEST_XDIST_ARGS="${PYTEST_XDIST_ARGS:--n auto}"
|
||||
|
||||
# Run most tests in parallel (excluding controllers which have import conflicts with xdist)
|
||||
# Controller tests have module-level side effects (Flask route registration) that cause
|
||||
# race conditions when imported concurrently by multiple pytest-xdist workers.
|
||||
pytest --timeout "${PYTEST_TIMEOUT}" ${PYTEST_XDIST_ARGS} \
|
||||
api/tests/unit_tests \
|
||||
api/providers/vdb/*/tests/unit_tests \
|
||||
api/providers/trace/*/tests/unit_tests \
|
||||
--ignore=api/tests/unit_tests/controllers
|
||||
|
||||
# Run controller tests sequentially to avoid import race conditions
|
||||
pytest --timeout "${PYTEST_TIMEOUT}" --cov-append api/tests/unit_tests/controllers
|
||||
|
||||
@ -1,12 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
cd "$SCRIPT_DIR/../.."
|
||||
|
||||
PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-180}"
|
||||
|
||||
uv sync --project api --group dev
|
||||
|
||||
uv run --project api pytest --timeout "${PYTEST_TIMEOUT}" \
|
||||
api/providers/vdb/*/tests/integration_tests \
|
||||
@ -5,7 +5,7 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T
|
||||
### What's Updated
|
||||
|
||||
- **Certbot Container**: `docker-compose.yaml` now contains `certbot` for managing SSL certificates. This container automatically renews certificates and ensures secure HTTPS connections.\
|
||||
For more information, refer `docker/certbot/README.md`.
|
||||
For more information, refer to `docker/certbot/README.md`.
|
||||
|
||||
- **Persistent Environment Variables**: Essential startup defaults are provided in `.env.example`, while local values are stored in `.env`, ensuring that your configurations persist across deployments.
|
||||
|
||||
@ -17,26 +17,26 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T
|
||||
### How to Deploy Dify with `docker-compose.yaml`
|
||||
|
||||
1. **Prerequisites**: Ensure Docker and Docker Compose are installed on your system.
|
||||
1. **Environment Setup**:
|
||||
2. **Environment Setup**:
|
||||
- Navigate to the `docker` directory.
|
||||
- Copy `.env.example` to `.env`.
|
||||
- Customize `.env` when you need to change essential startup defaults. Copy optional files from `envs/` without the `.example` suffix when you need advanced settings.
|
||||
- **Optional (for advanced deployments)**:
|
||||
If you maintain a full `.env` file copied from `.env.example`, you may use the environment synchronization tool to keep it aligned with the latest `.env.example` updates while preserving your custom settings.
|
||||
See the [Environment Variables Synchronization](#environment-variables-synchronization) section below.
|
||||
1. **Running the Services**:
|
||||
3. **Running the Services**:
|
||||
- Execute `docker compose up -d` from the `docker` directory to start the services.
|
||||
- To specify a vector database, set the `VECTOR_STORE` variable in your `.env` file to your desired vector database service, such as `milvus`, `weaviate`, or `opensearch`.
|
||||
- To specify a vector database, set the `VECTOR_STORE` variable in your `.env` file to your desired vector database service, such as `milvus`, `weaviate`, or `opensearch`. See `envs/vectorstores/` for the full list of supported options.
|
||||
```bash
|
||||
cp .env.example .env
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
1. **SSL Certificate Setup**:
|
||||
- Refer `docker/certbot/README.md` to set up SSL certificates using Certbot.
|
||||
1. **OpenTelemetry Collector Setup**:
|
||||
- Change `ENABLE_OTEL` to `true` in `.env`.
|
||||
- Configure `OTLP_BASE_ENDPOINT` properly.
|
||||
4. **SSL Certificate Setup**:
|
||||
- Refer to `docker/certbot/README.md` to set up SSL certificates using Certbot.
|
||||
5. **OpenTelemetry Collector Setup**:
|
||||
- Copy `envs/core-services/shared.env.example` to `envs/core-services/shared.env`.
|
||||
- Set `ENABLE_OTEL=true` and configure `OTLP_BASE_ENDPOINT`. Tune the other `OTEL_*` knobs in the same file if needed.
|
||||
|
||||
### How to Deploy Middleware for Developing Dify
|
||||
|
||||
@ -44,7 +44,7 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T
|
||||
- Use the `docker-compose.middleware.yaml` for setting up essential middleware services like databases and caches.
|
||||
- Navigate to the `docker` directory.
|
||||
- Ensure the `middleware.env` file is created by running `cp envs/middleware.env.example middleware.env` (refer to the `envs/middleware.env.example` file).
|
||||
1. **Running Middleware Services**:
|
||||
2. **Running Middleware Services**:
|
||||
- Navigate to the `docker` directory.
|
||||
- Execute `docker compose --env-file middleware.env -f docker-compose.middleware.yaml -p dify up -d` to start PostgreSQL/MySQL (per `DB_TYPE`) plus the bundled Weaviate instance.
|
||||
|
||||
@ -55,9 +55,9 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T
|
||||
For users migrating from the `docker-legacy` setup:
|
||||
|
||||
1. **Review Changes**: Familiarize yourself with the new `.env` configuration and Docker Compose setup.
|
||||
1. **Transfer Customizations**:
|
||||
2. **Transfer Customizations**:
|
||||
- If you have customized configurations such as `docker-compose.yaml`, `ssrf_proxy/squid.conf`, or `nginx/conf.d/default.conf`, you will need to reflect these changes in the `.env` file you create.
|
||||
1. **Data Migration**:
|
||||
3. **Data Migration**:
|
||||
- Ensure that data from services like databases and caches is backed up and migrated appropriately to the new structure if necessary.
|
||||
|
||||
### Overview of `.env`, `.env.example`, and `envs/`
|
||||
@ -80,49 +80,51 @@ The root `.env.example` file contains the essential startup settings. Optional a
|
||||
|
||||
1. **Common Variables**:
|
||||
|
||||
- `CONSOLE_API_URL`, `SERVICE_API_URL`: URLs for different API services.
|
||||
- `APP_WEB_URL`: Frontend application URL.
|
||||
- `FILES_URL`: Base URL for file downloads and previews.
|
||||
- `CONSOLE_API_URL`, `CONSOLE_WEB_URL`, `SERVICE_API_URL`, `APP_API_URL`, `APP_WEB_URL`: URLs for the API and frontend services.
|
||||
- `FILES_URL`, `INTERNAL_FILES_URL`: Public and internal base URLs for file downloads and previews.
|
||||
- `ENDPOINT_URL_TEMPLATE`, `NEXT_PUBLIC_SOCKET_URL`, `TRIGGER_URL`: Additional service URLs.
|
||||
|
||||
See `.env.example` for the full list.
|
||||
|
||||
1. **Server Configuration**:
|
||||
2. **Server Configuration**:
|
||||
|
||||
- `LOG_LEVEL`, `DEBUG`, `FLASK_DEBUG`: Logging and debug settings.
|
||||
- `SECRET_KEY`: A key for signing sessions, JWTs, and file URLs. Leave it empty to let Dify generate a persistent key in the storage directory, or set a unique value yourself.
|
||||
|
||||
1. **Database Configuration**:
|
||||
3. **Database Configuration**:
|
||||
|
||||
- `DB_USERNAME`, `DB_PASSWORD`, `DB_HOST`, `DB_PORT`, `DB_DATABASE`: PostgreSQL database credentials and connection details.
|
||||
|
||||
1. **Redis Configuration**:
|
||||
4. **Redis Configuration**:
|
||||
|
||||
- `REDIS_HOST`, `REDIS_PORT`, `REDIS_PASSWORD`: Redis server connection settings.
|
||||
- `REDIS_KEY_PREFIX`: Optional global namespace prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
|
||||
|
||||
1. **Celery Configuration**:
|
||||
5. **Celery Configuration**:
|
||||
|
||||
- `CELERY_BROKER_URL`: Configuration for Celery message broker.
|
||||
|
||||
1. **Storage Configuration**:
|
||||
6. **Storage Configuration**:
|
||||
|
||||
- `STORAGE_TYPE`, `OPENDAL_SCHEME`, `OPENDAL_FS_ROOT`: Default local file storage settings. Optional storage backends are configured from the files under `envs/`.
|
||||
|
||||
1. **Vector Database Configuration**:
|
||||
7. **Vector Database Configuration**:
|
||||
|
||||
- `VECTOR_STORE`: Type of vector database (e.g., `weaviate`, `milvus`).
|
||||
- `VECTOR_STORE`: Type of vector database (e.g., `weaviate`, `milvus`). See `envs/vectorstores/` for the full list of supported options.
|
||||
- Specific settings for each vector store like `WEAVIATE_ENDPOINT`, `MILVUS_URI`.
|
||||
|
||||
1. **CORS Configuration**:
|
||||
8. **CORS Configuration**:
|
||||
|
||||
- `WEB_API_CORS_ALLOW_ORIGINS`, `CONSOLE_CORS_ALLOW_ORIGINS`: Settings for cross-origin resource sharing.
|
||||
|
||||
1. **OpenTelemetry Configuration**:
|
||||
9. **OpenTelemetry Configuration**:
|
||||
|
||||
- `ENABLE_OTEL`: Enable OpenTelemetry collector in api.
|
||||
- `OTLP_BASE_ENDPOINT`: Endpoint for your OTLP exporter.
|
||||
|
||||
1. **Other Service-Specific Environment Variables**:
|
||||
10. **Other Service-Specific Environment Variables**:
|
||||
|
||||
- Each service like `nginx`, `redis`, `db`, and vector databases have specific environment variables that are directly referenced in the `docker-compose.yaml`.
|
||||
- Each service like `nginx`, `redis`, `db`, and vector databases have specific environment variables that are directly referenced in the `docker-compose.yaml`.
|
||||
|
||||
### Environment Variables Synchronization
|
||||
|
||||
|
||||
30
docker/docker-compose.pytest.ports.yaml
Normal file
30
docker/docker-compose.pytest.ports.yaml
Normal file
@ -0,0 +1,30 @@
|
||||
services:
|
||||
weaviate:
|
||||
ports:
|
||||
- "${EXPOSE_WEAVIATE_PORT:-8080}:8080"
|
||||
- "${EXPOSE_WEAVIATE_GRPC_PORT:-50051}:50051"
|
||||
|
||||
qdrant:
|
||||
ports:
|
||||
- "${EXPOSE_QDRANT_PORT:-6333}:6333"
|
||||
|
||||
pgvector:
|
||||
ports:
|
||||
- "${EXPOSE_PGVECTOR_PORT:-5433}:5432"
|
||||
|
||||
pgvecto-rs:
|
||||
ports:
|
||||
- "${EXPOSE_PGVECTO_RS_PORT:-5431}:5432"
|
||||
|
||||
chroma:
|
||||
ports:
|
||||
- "${EXPOSE_CHROMA_PORT:-8000}:8000"
|
||||
|
||||
couchbase-server:
|
||||
ports:
|
||||
- "${EXPOSE_COUCHBASE_WEB_PORT_RANGE:-8091-8096}:8091-8096"
|
||||
- "${EXPOSE_COUCHBASE_DATA_PORT:-11210}:11210"
|
||||
|
||||
opensearch:
|
||||
ports:
|
||||
- "${EXPOSE_OPENSEARCH_PORT:-9200}:9200"
|
||||
@ -40,7 +40,7 @@ vi.mock('../score-slider', () => ({
|
||||
<input
|
||||
role="slider"
|
||||
type="range"
|
||||
min={80}
|
||||
min={0}
|
||||
max={100}
|
||||
value={value}
|
||||
onChange={e => onChange(Number((e.target as HTMLInputElement).value))}
|
||||
@ -272,7 +272,7 @@ describe('ConfigParamModal', () => {
|
||||
)
|
||||
|
||||
const slider = screen.getByRole('slider')
|
||||
expect(slider).toHaveAttribute('min', '80')
|
||||
expect(slider).toHaveAttribute('min', '0')
|
||||
expect(slider).toHaveAttribute('max', '100')
|
||||
expect(slider).toHaveValue('90')
|
||||
})
|
||||
@ -375,7 +375,7 @@ describe('ConfigParamModal', () => {
|
||||
it('should use ANNOTATION_DEFAULT score_threshold when config has no score_threshold', () => {
|
||||
const configWithoutThreshold = {
|
||||
...defaultAnnotationConfig,
|
||||
score_threshold: 0,
|
||||
score_threshold: undefined as unknown as number,
|
||||
}
|
||||
render(
|
||||
<ConfigParamModal
|
||||
@ -390,6 +390,35 @@ describe('ConfigParamModal', () => {
|
||||
expect(screen.getByRole('slider')).toHaveValue('90')
|
||||
})
|
||||
|
||||
it('should preserve zero score threshold instead of falling back to default', async () => {
|
||||
const onSave = vi.fn().mockResolvedValue(undefined)
|
||||
render(
|
||||
<ConfigParamModal
|
||||
appId="test-app"
|
||||
isShow={true}
|
||||
onHide={vi.fn()}
|
||||
onSave={onSave}
|
||||
annotationConfig={{
|
||||
...defaultAnnotationConfig,
|
||||
score_threshold: 0,
|
||||
}}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('slider')).toHaveValue('0')
|
||||
|
||||
const buttons = screen.getAllByRole('button')
|
||||
const saveBtn = buttons.find(b => b.textContent?.includes('initSetup'))
|
||||
fireEvent.click(saveBtn!)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onSave).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ embedding_provider_name: 'openai' }),
|
||||
0,
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('should set loading state while saving', async () => {
|
||||
let resolveOnSave: () => void
|
||||
const onSave = vi.fn().mockImplementation(() => new Promise<void>((resolve) => {
|
||||
|
||||
@ -175,6 +175,22 @@ describe('AnnotationReply', () => {
|
||||
expect(screen.getByText('text-embedding-ada-002')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show zero score threshold when enabled', () => {
|
||||
renderWithProvider({}, {
|
||||
annotationReply: {
|
||||
enabled: true,
|
||||
score_threshold: 0,
|
||||
embedding_model: {
|
||||
embedding_provider_name: 'openai',
|
||||
embedding_model_name: 'text-embedding-ada-002',
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expect(screen.getByText('0')).toBeInTheDocument()
|
||||
expect(screen.getByText('text-embedding-ada-002')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show dash when score threshold is not set', () => {
|
||||
renderWithProvider({}, {
|
||||
annotationReply: {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import type { AnnotationReplyConfig } from '@/models/debug'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { queryAnnotationJobStatus } from '@/service/annotation'
|
||||
import { queryAnnotationJobStatus, updateAnnotationStatus } from '@/service/annotation'
|
||||
import { sleep } from '@/utils'
|
||||
import useAnnotationConfig from '../use-annotation-config'
|
||||
|
||||
@ -162,6 +162,35 @@ describe('useAnnotationConfig', () => {
|
||||
expect(updatedConfig.score_threshold).toBe(0.85)
|
||||
})
|
||||
|
||||
it('should preserve zero score threshold when enabling annotation', async () => {
|
||||
const zeroScoreConfig = { ...defaultConfig, score_threshold: 0 }
|
||||
const setAnnotationConfig = vi.fn()
|
||||
const { result } = renderHook(() => useAnnotationConfig({
|
||||
appId: 'test-app',
|
||||
annotationConfig: zeroScoreConfig,
|
||||
setAnnotationConfig,
|
||||
}))
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleEnableAnnotation({
|
||||
embedding_provider_name: 'openai',
|
||||
embedding_model_name: 'text-embedding-3-small',
|
||||
}, 0)
|
||||
})
|
||||
|
||||
expect(updateAnnotationStatus).toHaveBeenCalledWith(
|
||||
'test-app',
|
||||
'enable',
|
||||
{
|
||||
embedding_provider_name: 'openai',
|
||||
embedding_model_name: 'text-embedding-3-small',
|
||||
},
|
||||
0,
|
||||
)
|
||||
const updatedConfig = setAnnotationConfig.mock.calls[0]![0]
|
||||
expect(updatedConfig.score_threshold).toBe(0)
|
||||
})
|
||||
|
||||
it('should set score and embedding model together', () => {
|
||||
const setAnnotationConfig = vi.fn()
|
||||
const { result } = renderHook(() => useAnnotationConfig({
|
||||
|
||||
@ -75,7 +75,7 @@ const ConfigParamModal: FC<Props> = ({ isShow, onHide: doHide, onSave, isInit, a
|
||||
<Item title={t('feature.annotation.scoreThreshold.title', { ns: 'appDebug' })} tooltip={t('feature.annotation.scoreThreshold.description', { ns: 'appDebug' })}>
|
||||
<ScoreSlider
|
||||
className="mt-1"
|
||||
value={(annotationConfig.score_threshold || ANNOTATION_DEFAULT.score_threshold) * 100}
|
||||
value={(annotationConfig.score_threshold ?? ANNOTATION_DEFAULT.score_threshold) * 100}
|
||||
onChange={(val) => {
|
||||
setAnnotationConfig({
|
||||
...annotationConfig,
|
||||
|
||||
@ -100,7 +100,7 @@ const AnnotationReply = ({
|
||||
<div className="flex items-center gap-4 pt-0.5">
|
||||
<div className="">
|
||||
<div className="mb-0.5 system-2xs-medium-uppercase text-text-tertiary">{t('feature.annotation.scoreThreshold.title', { ns: 'appDebug' })}</div>
|
||||
<div className="system-xs-regular text-text-secondary">{annotationReply.score_threshold || '-'}</div>
|
||||
<div className="system-xs-regular text-text-secondary">{annotationReply.score_threshold ?? '-'}</div>
|
||||
</div>
|
||||
<div className="h-[27px] w-px rotate-12 bg-divider-subtle"></div>
|
||||
<div className="">
|
||||
|
||||
@ -17,7 +17,7 @@ describe('ScoreSlider', () => {
|
||||
it('should display easy match and accurate match labels', () => {
|
||||
render(<ScoreSlider value={90} onChange={vi.fn()} />)
|
||||
|
||||
expect(screen.getByText('0.8')).toBeInTheDocument()
|
||||
expect(screen.getByText('0.0')).toBeInTheDocument()
|
||||
expect(screen.getByText('1.0')).toBeInTheDocument()
|
||||
expect(screen.getByText(/feature\.annotation\.scoreThreshold\.easyMatch/)).toBeInTheDocument()
|
||||
expect(screen.getByText(/feature\.annotation\.scoreThreshold\.accurateMatch/)).toBeInTheDocument()
|
||||
@ -36,4 +36,11 @@ describe('ScoreSlider', () => {
|
||||
expect(getSliderInput()).toHaveValue('95')
|
||||
expect(screen.getByText('0.95')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should allow zero as the minimum score threshold', () => {
|
||||
render(<ScoreSlider value={0} onChange={vi.fn()} />)
|
||||
|
||||
expect(getSliderInput()).toHaveValue('0')
|
||||
expect(screen.getByText('0.00')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@ -17,13 +17,16 @@ const clamp = (value: number, min: number, max: number) => {
|
||||
return Math.min(Math.max(value, min), max)
|
||||
}
|
||||
|
||||
const SCORE_MIN = 0
|
||||
const SCORE_MAX = 100
|
||||
|
||||
const ScoreSlider: FC<Props> = ({
|
||||
className,
|
||||
value,
|
||||
onChange,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const safeValue = clamp(value, 80, 100)
|
||||
const safeValue = clamp(value, SCORE_MIN, SCORE_MAX)
|
||||
|
||||
return (
|
||||
<div className={className}>
|
||||
@ -31,8 +34,8 @@ const ScoreSlider: FC<Props> = ({
|
||||
<Slider
|
||||
className="w-full"
|
||||
value={safeValue}
|
||||
min={80}
|
||||
max={100}
|
||||
min={SCORE_MIN}
|
||||
max={SCORE_MAX}
|
||||
step={1}
|
||||
onValueChange={onChange}
|
||||
aria-label={t('feature.annotation.scoreThreshold.title', { ns: 'appDebug' })}
|
||||
@ -40,7 +43,7 @@ const ScoreSlider: FC<Props> = ({
|
||||
<div
|
||||
className="pointer-events-none absolute top-[-16px] system-sm-semibold text-text-primary"
|
||||
style={{
|
||||
left: `calc(4px + ${(safeValue - 80) / 20} * (100% - 8px))`,
|
||||
left: `calc(4px + ${safeValue / SCORE_MAX} * (100% - 8px))`,
|
||||
transform: 'translateX(-50%)',
|
||||
}}
|
||||
>
|
||||
@ -49,7 +52,7 @@ const ScoreSlider: FC<Props> = ({
|
||||
</div>
|
||||
<div className="mt-[10px] flex items-center justify-between system-xs-semibold-uppercase">
|
||||
<div className="flex space-x-1 text-util-colors-cyan-cyan-500">
|
||||
<div>0.8</div>
|
||||
<div>0.0</div>
|
||||
<div>·</div>
|
||||
<div>{t('feature.annotation.scoreThreshold.easyMatch', { ns: 'appDebug' })}</div>
|
||||
</div>
|
||||
|
||||
@ -53,7 +53,7 @@ const useAnnotationConfig = ({
|
||||
setAnnotationConfig(produce(annotationConfig, (draft: AnnotationReplyConfig) => {
|
||||
draft.enabled = true
|
||||
draft.embedding_model = embeddingModel
|
||||
if (!draft.score_threshold)
|
||||
if (draft.score_threshold === undefined || draft.score_threshold === null)
|
||||
draft.score_threshold = ANNOTATION_DEFAULT.score_threshold
|
||||
}))
|
||||
}
|
||||
|
||||
@ -13,6 +13,7 @@ let parameterRules: Array<Record<string, unknown>> | undefined = [
|
||||
},
|
||||
]
|
||||
let isRulesLoading = false
|
||||
let isRulesPending = false
|
||||
let currentProvider: Record<string, unknown> | undefined = { provider: 'openai', label: { en_US: 'OpenAI' } }
|
||||
let currentModel: Record<string, unknown> | undefined = {
|
||||
model: 'gpt-3.5-turbo',
|
||||
@ -49,7 +50,7 @@ vi.mock('@/service/use-common', () => ({
|
||||
data: parameterRules,
|
||||
},
|
||||
isLoading: isRulesLoading,
|
||||
isPending: isRulesLoading,
|
||||
isPending: isRulesPending,
|
||||
}),
|
||||
}))
|
||||
|
||||
@ -92,9 +93,21 @@ vi.mock('../../model-selector', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('../presets-parameter', () => ({
|
||||
default: ({ onSelect }: { onSelect: (id: number) => void }) => (
|
||||
<button onClick={() => onSelect(1)}>Preset 1</button>
|
||||
),
|
||||
default: ({ onSelect, supportedParameterNames }: { onSelect: (id: number) => void, supportedParameterNames?: string[] }) => {
|
||||
if (supportedParameterNames && !supportedParameterNames.includes('temperature'))
|
||||
return null
|
||||
|
||||
return <button onClick={() => onSelect(1)}>Preset 1</button>
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('../presets-parameter-utils', () => ({
|
||||
getSupportedPresetConfig: (_toneId: number, supportedParameterNames?: string[]) => {
|
||||
if (supportedParameterNames && !supportedParameterNames.includes('temperature'))
|
||||
return {}
|
||||
|
||||
return { temperature: 0.8 }
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('../trigger', () => ({
|
||||
@ -126,6 +139,7 @@ describe('ModelParameterModal', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
isRulesLoading = false
|
||||
isRulesPending = false
|
||||
parameterRules = [
|
||||
{
|
||||
name: 'temperature',
|
||||
@ -194,7 +208,28 @@ describe('ModelParameterModal', () => {
|
||||
render(<ModelParameterModal {...defaultProps} />)
|
||||
fireEvent.click(screen.getByText('Open Settings'))
|
||||
fireEvent.click(screen.getByText('Preset 1'))
|
||||
expect(defaultProps.onCompletionParamsChange).toHaveBeenCalled()
|
||||
expect(defaultProps.onCompletionParamsChange).toHaveBeenCalledWith({
|
||||
...defaultProps.completionParams,
|
||||
temperature: 0.8,
|
||||
})
|
||||
})
|
||||
|
||||
it('should not render preset control when visible parameters do not support preset keys', () => {
|
||||
parameterRules = [
|
||||
{
|
||||
name: 'max_tokens',
|
||||
label: { en_US: 'Max Tokens' },
|
||||
type: 'int',
|
||||
default: 256,
|
||||
min: 1,
|
||||
max: 4096,
|
||||
},
|
||||
]
|
||||
|
||||
render(<ModelParameterModal {...defaultProps} />)
|
||||
fireEvent.click(screen.getByText('Open Settings'))
|
||||
|
||||
expect(screen.queryByText('Preset 1')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call setModel when model selector picks another model', () => {
|
||||
@ -219,11 +254,29 @@ describe('ModelParameterModal', () => {
|
||||
|
||||
it('should render loading state when parameter rules are loading', () => {
|
||||
isRulesLoading = true
|
||||
isRulesPending = true
|
||||
render(<ModelParameterModal {...defaultProps} />)
|
||||
fireEvent.click(screen.getByText('Open Settings'))
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render parameter loading when model is not configured and parameter rules query is pending but disabled', () => {
|
||||
isRulesPending = true
|
||||
parameterRules = []
|
||||
|
||||
render(
|
||||
<ModelParameterModal
|
||||
{...defaultProps}
|
||||
provider=""
|
||||
modelId=""
|
||||
/>,
|
||||
)
|
||||
fireEvent.click(screen.getByText('Open Settings'))
|
||||
|
||||
expect(screen.queryByRole('status')).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('model-selector')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not open content when readonly is true', () => {
|
||||
render(<ModelParameterModal {...defaultProps} readonly />)
|
||||
fireEvent.click(screen.getByText('Open Settings'))
|
||||
@ -299,6 +352,7 @@ describe('ModelParameterModal', () => {
|
||||
it('should render the empty loading fallback when rules resolve to an empty list', () => {
|
||||
parameterRules = []
|
||||
isRulesLoading = true
|
||||
isRulesPending = true
|
||||
|
||||
render(<ModelParameterModal {...defaultProps} />)
|
||||
fireEvent.click(screen.getByText('Open Settings'))
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { vi } from 'vitest'
|
||||
import PresetsParameter from '../presets-parameter'
|
||||
import { getSupportedPresetConfig } from '../presets-parameter-utils'
|
||||
|
||||
describe('PresetsParameter', () => {
|
||||
beforeEach(() => {
|
||||
@ -47,4 +48,22 @@ describe('PresetsParameter', () => {
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith(3)
|
||||
})
|
||||
|
||||
it('should render presets when at least one preset parameter is supported', () => {
|
||||
render(<PresetsParameter onSelect={vi.fn()} supportedParameterNames={['temperature']} />)
|
||||
|
||||
expect(screen.getByRole('button', { name: /common\.modelProvider\.loadPresets/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render presets when no preset parameters are supported', () => {
|
||||
render(<PresetsParameter onSelect={vi.fn()} supportedParameterNames={['max_tokens']} />)
|
||||
|
||||
expect(screen.queryByRole('button', { name: /common\.modelProvider\.loadPresets/i })).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should return only supported preset config keys', () => {
|
||||
expect(getSupportedPresetConfig(1, ['temperature'])).toEqual({
|
||||
temperature: 0.8,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -24,7 +24,7 @@ import { useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { PROVIDER_WITH_PRESET_TONE, STOP_PARAMETER_RULE, TONE_LIST } from '@/config'
|
||||
import { PROVIDER_WITH_PRESET_TONE, STOP_PARAMETER_RULE } from '@/config'
|
||||
import { useModelParameterRules } from '@/service/use-common'
|
||||
import {
|
||||
useTextGenerationCurrentProviderAndModelAndModelList,
|
||||
@ -32,6 +32,7 @@ import {
|
||||
import ModelSelector from '../model-selector'
|
||||
import ParameterItem from './parameter-item'
|
||||
import PresetsParameter from './presets-parameter'
|
||||
import { getSupportedPresetConfig } from './presets-parameter-utils'
|
||||
import Trigger from './trigger'
|
||||
|
||||
export type ModelParameterModalProps = {
|
||||
@ -75,10 +76,9 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
const settingsIconRef = useRef<HTMLDivElement>(null)
|
||||
const {
|
||||
data: parameterRulesData,
|
||||
isPending,
|
||||
isLoading,
|
||||
} = useModelParameterRules(provider, modelId)
|
||||
const isRulesLoading = isPending || isLoading
|
||||
const isRulesLoading = !!provider && !!modelId && isLoading
|
||||
const {
|
||||
currentProvider,
|
||||
currentModel,
|
||||
@ -90,6 +90,9 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
const parameterRules: ModelParameterRule[] = useMemo(() => {
|
||||
return parameterRulesData?.data || []
|
||||
}, [parameterRulesData])
|
||||
const supportedPresetParameterNames = useMemo(() => {
|
||||
return parameterRules.map(parameterRule => parameterRule.name)
|
||||
}, [parameterRules])
|
||||
|
||||
const handleParamChange = (key: string, value: ParameterValue) => {
|
||||
onCompletionParamsChange({
|
||||
@ -125,13 +128,10 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
}
|
||||
|
||||
const handleSelectPresetParameter = (toneId: number) => {
|
||||
const tone = TONE_LIST.find(tone => tone.id === toneId)
|
||||
if (tone) {
|
||||
onCompletionParamsChange({
|
||||
...completionParams,
|
||||
...tone.config,
|
||||
})
|
||||
}
|
||||
onCompletionParamsChange({
|
||||
...completionParams,
|
||||
...getSupportedPresetConfig(toneId, supportedPresetParameterNames),
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
@ -199,7 +199,10 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
<div className="flex flex-1 items-center system-sm-semibold-uppercase text-text-secondary">{t('modelProvider.parameters', { ns: 'common' })}</div>
|
||||
{
|
||||
PROVIDER_WITH_PRESET_TONE.includes(provider) && (
|
||||
<PresetsParameter onSelect={handleSelectPresetParameter} />
|
||||
<PresetsParameter
|
||||
onSelect={handleSelectPresetParameter}
|
||||
supportedParameterNames={supportedPresetParameterNames}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
|
||||
@ -0,0 +1,19 @@
|
||||
import { TONE_LIST } from '@/config'
|
||||
|
||||
export const getSupportedPresetConfig = (toneId: number, supportedParameterNames?: string[]) => {
|
||||
const tone = TONE_LIST.find(tone => tone.id === toneId)
|
||||
if (!tone?.config)
|
||||
return {}
|
||||
|
||||
if (!supportedParameterNames)
|
||||
return { ...tone.config }
|
||||
|
||||
const supportedParameterNameSet = new Set(supportedParameterNames)
|
||||
|
||||
return Object.entries(tone.config).reduce<Record<string, number>>((acc, [key, value]) => {
|
||||
if (supportedParameterNameSet.has(key))
|
||||
acc[key] = value
|
||||
|
||||
return acc
|
||||
}, {})
|
||||
}
|
||||
@ -12,6 +12,8 @@ import { Scales02 } from '@/app/components/base/icons/src/vender/solid/FinanceAn
|
||||
import { Target04 } from '@/app/components/base/icons/src/vender/solid/general'
|
||||
import { TONE_LIST } from '@/config'
|
||||
|
||||
const PRESET_TONE_LIST = TONE_LIST.slice(0, 3)
|
||||
|
||||
const toneI18nKeyMap = {
|
||||
Creative: 'model.tone.Creative',
|
||||
Balanced: 'model.tone.Balanced',
|
||||
@ -27,10 +29,18 @@ const TONE_ICONS: Record<number, ReactNode> = {
|
||||
|
||||
type PresetsParameterProps = {
|
||||
onSelect: (toneId: number) => void
|
||||
supportedParameterNames?: string[]
|
||||
}
|
||||
|
||||
function PresetsParameter({ onSelect }: PresetsParameterProps) {
|
||||
function PresetsParameter({ onSelect, supportedParameterNames }: PresetsParameterProps) {
|
||||
const { t } = useTranslation()
|
||||
const supportedParameterNameSet = supportedParameterNames ? new Set(supportedParameterNames) : undefined
|
||||
const visiblePresetTones = supportedParameterNameSet
|
||||
? PRESET_TONE_LIST.filter(tone => Object.keys(tone.config ?? {}).some(key => supportedParameterNameSet.has(key)))
|
||||
: PRESET_TONE_LIST
|
||||
|
||||
if (!visiblePresetTones.length)
|
||||
return null
|
||||
|
||||
return (
|
||||
<DropdownMenu>
|
||||
@ -47,7 +57,7 @@ function PresetsParameter({ onSelect }: PresetsParameterProps) {
|
||||
<span className="ml-0.5 i-ri-arrow-down-s-line h-3.5 w-3.5" />
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent>
|
||||
{TONE_LIST.slice(0, 3).map(tone => (
|
||||
{visiblePresetTones.map(tone => (
|
||||
<DropdownMenuItem key={tone.id} onClick={() => onSelect(tone.id)}>
|
||||
{TONE_ICONS[tone.id]}
|
||||
{t(toneI18nKeyMap[tone.name], { ns: 'common' })}
|
||||
|
||||
@ -75,14 +75,58 @@ vi.mock('@/config', () => ({
|
||||
|
||||
// Mock PresetsParameter component
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/model-parameter-modal/presets-parameter', () => ({
|
||||
default: ({ onSelect }: { onSelect: (toneId: number) => void }) => (
|
||||
<div data-testid="presets-parameter">
|
||||
<button data-testid="preset-creative" onClick={() => onSelect(1)}>Creative</button>
|
||||
<button data-testid="preset-balanced" onClick={() => onSelect(2)}>Balanced</button>
|
||||
<button data-testid="preset-precise" onClick={() => onSelect(3)}>Precise</button>
|
||||
<button data-testid="preset-custom" onClick={() => onSelect(4)}>Custom</button>
|
||||
</div>
|
||||
),
|
||||
default: ({ onSelect, supportedParameterNames }: { onSelect: (toneId: number) => void, supportedParameterNames?: string[] }) => {
|
||||
const hasSupportedParameter = !supportedParameterNames || supportedParameterNames.some(name => ['temperature', 'top_p', 'presence_penalty', 'frequency_penalty'].includes(name))
|
||||
if (!hasSupportedParameter)
|
||||
return null
|
||||
|
||||
return (
|
||||
<div data-testid="presets-parameter">
|
||||
<button data-testid="preset-creative" onClick={() => onSelect(1)}>Creative</button>
|
||||
<button data-testid="preset-balanced" onClick={() => onSelect(2)}>Balanced</button>
|
||||
<button data-testid="preset-precise" onClick={() => onSelect(3)}>Precise</button>
|
||||
<button data-testid="preset-custom" onClick={() => onSelect(4)}>Custom</button>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/model-parameter-modal/presets-parameter-utils', () => ({
|
||||
getSupportedPresetConfig: (toneId: number, supportedParameterNames?: string[]) => {
|
||||
const toneConfigMap: Record<number, Record<string, number> | undefined> = {
|
||||
1: {
|
||||
temperature: 0.8,
|
||||
top_p: 0.9,
|
||||
presence_penalty: 0.1,
|
||||
frequency_penalty: 0.1,
|
||||
},
|
||||
2: {
|
||||
temperature: 0.5,
|
||||
top_p: 0.85,
|
||||
presence_penalty: 0.2,
|
||||
frequency_penalty: 0.3,
|
||||
},
|
||||
3: {
|
||||
temperature: 0.2,
|
||||
top_p: 0.75,
|
||||
presence_penalty: 0.5,
|
||||
frequency_penalty: 0.5,
|
||||
},
|
||||
}
|
||||
const toneConfig = toneConfigMap[toneId]
|
||||
if (!toneConfig)
|
||||
return {}
|
||||
|
||||
if (!supportedParameterNames)
|
||||
return toneConfig
|
||||
|
||||
return Object.entries(toneConfig).reduce<Record<string, number>>((acc, [key, value]) => {
|
||||
if (supportedParameterNames.includes(key))
|
||||
acc[key] = value
|
||||
|
||||
return acc
|
||||
}, {})
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock ParameterItem component
|
||||
@ -148,10 +192,12 @@ const createDefaultProps = (overrides: Partial<{
|
||||
const setupModelParameterRulesMock = (config: {
|
||||
data?: ModelParameterRule[]
|
||||
isPending?: boolean
|
||||
isLoading?: boolean
|
||||
} = {}) => {
|
||||
mockUseModelParameterRules.mockReturnValue({
|
||||
data: config.data ? { data: config.data } : undefined,
|
||||
isPending: config.isPending ?? false,
|
||||
isLoading: config.isLoading ?? config.isPending ?? false,
|
||||
})
|
||||
}
|
||||
|
||||
@ -188,6 +234,19 @@ describe('LLMParamsPanel', () => {
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render loading state when model is not configured and parameter rules query is pending but disabled', () => {
|
||||
// Arrange
|
||||
setupModelParameterRulesMock({ isPending: true, isLoading: false })
|
||||
const props = createDefaultProps({ provider: '', modelId: '' })
|
||||
|
||||
// Act
|
||||
render(<LLMParamsPanel {...props} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByRole('status')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('common.modelProvider.parameters')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render parameters header', () => {
|
||||
// Arrange
|
||||
setupModelParameterRulesMock({ data: [], isPending: false })
|
||||
@ -202,7 +261,7 @@ describe('LLMParamsPanel', () => {
|
||||
|
||||
it('should render PresetsParameter for openai provider', () => {
|
||||
// Arrange
|
||||
setupModelParameterRulesMock({ data: [], isPending: false })
|
||||
setupModelParameterRulesMock({ data: [createParameterRule({ name: 'temperature' })], isPending: false })
|
||||
const props = createDefaultProps({ provider: 'langgenius/openai/openai' })
|
||||
|
||||
// Act
|
||||
@ -214,7 +273,7 @@ describe('LLMParamsPanel', () => {
|
||||
|
||||
it('should render PresetsParameter for azure_openai provider', () => {
|
||||
// Arrange
|
||||
setupModelParameterRulesMock({ data: [], isPending: false })
|
||||
setupModelParameterRulesMock({ data: [createParameterRule({ name: 'temperature' })], isPending: false })
|
||||
const props = createDefaultProps({ provider: 'langgenius/azure_openai/azure_openai' })
|
||||
|
||||
// Act
|
||||
@ -224,6 +283,18 @@ describe('LLMParamsPanel', () => {
|
||||
expect(screen.getByTestId('presets-parameter')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render PresetsParameter when no visible parameter supports presets', () => {
|
||||
// Arrange
|
||||
setupModelParameterRulesMock({ data: [createParameterRule({ name: 'max_tokens', type: 'int' })], isPending: false })
|
||||
const props = createDefaultProps({ provider: 'langgenius/openai/openai' })
|
||||
|
||||
// Act
|
||||
render(<LLMParamsPanel {...props} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByTestId('presets-parameter')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render PresetsParameter for non-preset providers', () => {
|
||||
// Arrange
|
||||
setupModelParameterRulesMock({ data: [], isPending: false })
|
||||
@ -360,7 +431,15 @@ describe('LLMParamsPanel', () => {
|
||||
it('should apply Creative preset config', () => {
|
||||
// Arrange
|
||||
const onCompletionParamsChange = vi.fn()
|
||||
setupModelParameterRulesMock({ data: [], isPending: false })
|
||||
setupModelParameterRulesMock({
|
||||
data: [
|
||||
createParameterRule({ name: 'temperature' }),
|
||||
createParameterRule({ name: 'top_p' }),
|
||||
createParameterRule({ name: 'presence_penalty' }),
|
||||
createParameterRule({ name: 'frequency_penalty' }),
|
||||
],
|
||||
isPending: false,
|
||||
})
|
||||
const props = createDefaultProps({
|
||||
provider: 'langgenius/openai/openai',
|
||||
onCompletionParamsChange,
|
||||
@ -384,7 +463,15 @@ describe('LLMParamsPanel', () => {
|
||||
it('should apply Balanced preset config', () => {
|
||||
// Arrange
|
||||
const onCompletionParamsChange = vi.fn()
|
||||
setupModelParameterRulesMock({ data: [], isPending: false })
|
||||
setupModelParameterRulesMock({
|
||||
data: [
|
||||
createParameterRule({ name: 'temperature' }),
|
||||
createParameterRule({ name: 'top_p' }),
|
||||
createParameterRule({ name: 'presence_penalty' }),
|
||||
createParameterRule({ name: 'frequency_penalty' }),
|
||||
],
|
||||
isPending: false,
|
||||
})
|
||||
const props = createDefaultProps({
|
||||
provider: 'langgenius/openai/openai',
|
||||
onCompletionParamsChange,
|
||||
@ -407,7 +494,15 @@ describe('LLMParamsPanel', () => {
|
||||
it('should apply Precise preset config', () => {
|
||||
// Arrange
|
||||
const onCompletionParamsChange = vi.fn()
|
||||
setupModelParameterRulesMock({ data: [], isPending: false })
|
||||
setupModelParameterRulesMock({
|
||||
data: [
|
||||
createParameterRule({ name: 'temperature' }),
|
||||
createParameterRule({ name: 'top_p' }),
|
||||
createParameterRule({ name: 'presence_penalty' }),
|
||||
createParameterRule({ name: 'frequency_penalty' }),
|
||||
],
|
||||
isPending: false,
|
||||
})
|
||||
const props = createDefaultProps({
|
||||
provider: 'langgenius/openai/openai',
|
||||
onCompletionParamsChange,
|
||||
@ -430,7 +525,7 @@ describe('LLMParamsPanel', () => {
|
||||
it('should apply empty config for Custom preset (spreads undefined)', () => {
|
||||
// Arrange
|
||||
const onCompletionParamsChange = vi.fn()
|
||||
setupModelParameterRulesMock({ data: [], isPending: false })
|
||||
setupModelParameterRulesMock({ data: [createParameterRule({ name: 'temperature' })], isPending: false })
|
||||
const props = createDefaultProps({
|
||||
provider: 'langgenius/openai/openai',
|
||||
onCompletionParamsChange,
|
||||
@ -444,6 +539,27 @@ describe('LLMParamsPanel', () => {
|
||||
// Assert - Custom preset has no config, so only existing params are kept
|
||||
expect(onCompletionParamsChange).toHaveBeenCalledWith({ existing: 'value' })
|
||||
})
|
||||
|
||||
it('should apply only preset config keys supported by visible parameters', () => {
|
||||
// Arrange
|
||||
const onCompletionParamsChange = vi.fn()
|
||||
setupModelParameterRulesMock({ data: [createParameterRule({ name: 'temperature' })], isPending: false })
|
||||
const props = createDefaultProps({
|
||||
provider: 'langgenius/openai/openai',
|
||||
onCompletionParamsChange,
|
||||
completionParams: { existing: 'value' },
|
||||
})
|
||||
|
||||
// Act
|
||||
render(<LLMParamsPanel {...props} />)
|
||||
fireEvent.click(screen.getByTestId('preset-creative'))
|
||||
|
||||
// Assert
|
||||
expect(onCompletionParamsChange).toHaveBeenCalledWith({
|
||||
existing: 'value',
|
||||
temperature: 0.8,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleParamChange', () => {
|
||||
|
||||
@ -10,7 +10,8 @@ import { useTranslation } from 'react-i18next'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import ParameterItem from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item'
|
||||
import PresetsParameter from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/presets-parameter'
|
||||
import { PROVIDER_WITH_PRESET_TONE, STOP_PARAMETER_RULE, TONE_LIST } from '@/config'
|
||||
import { getSupportedPresetConfig } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/presets-parameter-utils'
|
||||
import { PROVIDER_WITH_PRESET_TONE, STOP_PARAMETER_RULE } from '@/config'
|
||||
import { useModelParameterRules } from '@/service/use-common'
|
||||
|
||||
type Props = {
|
||||
@ -29,20 +30,21 @@ const LLMParamsPanel = ({
|
||||
onCompletionParamsChange,
|
||||
}: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const { data: parameterRulesData, isPending: isLoading } = useModelParameterRules(provider, modelId)
|
||||
const { data: parameterRulesData, isLoading } = useModelParameterRules(provider, modelId)
|
||||
const isRulesLoading = !!provider && !!modelId && isLoading
|
||||
|
||||
const parameterRules: ModelParameterRule[] = useMemo(() => {
|
||||
return parameterRulesData?.data || []
|
||||
}, [parameterRulesData])
|
||||
const supportedPresetParameterNames = useMemo(() => {
|
||||
return parameterRules.map(parameterRule => parameterRule.name)
|
||||
}, [parameterRules])
|
||||
|
||||
const handleSelectPresetParameter = (toneId: number) => {
|
||||
const tone = TONE_LIST.find(tone => tone.id === toneId)
|
||||
if (tone) {
|
||||
onCompletionParamsChange({
|
||||
...completionParams,
|
||||
...tone.config,
|
||||
})
|
||||
}
|
||||
onCompletionParamsChange({
|
||||
...completionParams,
|
||||
...getSupportedPresetConfig(toneId, supportedPresetParameterNames),
|
||||
})
|
||||
}
|
||||
const handleParamChange = (key: string, value: ParameterValue) => {
|
||||
onCompletionParamsChange({
|
||||
@ -65,7 +67,7 @@ const LLMParamsPanel = ({
|
||||
}
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
if (isRulesLoading) {
|
||||
return (
|
||||
<div className="mt-5"><Loading /></div>
|
||||
)
|
||||
@ -77,7 +79,10 @@ const LLMParamsPanel = ({
|
||||
<div className={cn('flex h-6 items-center system-sm-semibold text-text-secondary')}>{t('modelProvider.parameters', { ns: 'common' })}</div>
|
||||
{
|
||||
PROVIDER_WITH_PRESET_TONE.includes(provider) && (
|
||||
<PresetsParameter onSelect={handleSelectPresetParameter} />
|
||||
<PresetsParameter
|
||||
onSelect={handleSelectPresetParameter}
|
||||
supportedParameterNames={supportedPresetParameterNames}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
|
||||
28
web/service/annotation.spec.ts
Normal file
28
web/service/annotation.spec.ts
Normal file
@ -0,0 +1,28 @@
|
||||
import { AnnotationEnableStatus } from '@/app/components/app/annotation/type'
|
||||
import { updateAnnotationStatus } from './annotation'
|
||||
import { post } from './base'
|
||||
|
||||
vi.mock('./base', () => ({
|
||||
post: vi.fn(),
|
||||
}))
|
||||
|
||||
describe('annotation service', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should preserve zero score threshold when updating annotation status', () => {
|
||||
updateAnnotationStatus('app-1', AnnotationEnableStatus.enable, {
|
||||
embedding_model_name: 'model',
|
||||
embedding_provider_name: 'provider',
|
||||
}, 0)
|
||||
|
||||
expect(post).toHaveBeenCalledWith('apps/app-1/annotation-reply/enable', {
|
||||
body: {
|
||||
embedding_model_name: 'model',
|
||||
embedding_provider_name: 'provider',
|
||||
score_threshold: 0,
|
||||
},
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -7,7 +7,7 @@ export const fetchAnnotationConfig = (appId: string) => {
|
||||
}
|
||||
export const updateAnnotationStatus = (appId: string, action: AnnotationEnableStatus, embeddingModel?: EmbeddingModelConfig, score?: number) => {
|
||||
let body: any = {
|
||||
score_threshold: score || ANNOTATION_DEFAULT.score_threshold,
|
||||
score_threshold: score ?? ANNOTATION_DEFAULT.score_threshold,
|
||||
}
|
||||
if (embeddingModel) {
|
||||
body = {
|
||||
|
||||
Reference in New Issue
Block a user