Compare commits

..

1 Commits

Author SHA1 Message Date
yyh
cf9e649e11 feat(dev-proxy): reload env file changes 2026-05-19 16:08:21 +08:00
66 changed files with 1021 additions and 1050 deletions

View File

@ -1,6 +1,5 @@
[run]
omit =
api/conftest.py
api/tests/*
api/migrations/*
api/core/rag/datasource/vdb/*

View File

@ -48,23 +48,10 @@ jobs:
run: uv sync --project api --dev
- name: Run dify config tests
run: uv run --project api pytest api/tests/unit_tests/configs/test_env_consistency.py
run: uv run --project api dev/pytest/pytest_config_tests.py
- name: Run Unit Tests
run: |
uv run --project api pytest \
-p no:benchmark \
--timeout "${PYTEST_TIMEOUT:-20}" \
-n auto \
api/tests/unit_tests \
api/providers/vdb/*/tests/unit_tests \
api/providers/trace/*/tests/unit_tests \
--ignore=api/tests/unit_tests/controllers
# Controller tests register Flask routes at import time, so keep them out of xdist.
uv run --project api pytest \
--timeout "${PYTEST_TIMEOUT:-20}" \
--cov-append \
api/tests/unit_tests/controllers
run: uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Upload unit coverage data
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
@ -109,11 +96,32 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env
cp docker/envs/middleware.env.example docker/middleware.env
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Sandbox
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
db_postgres
redis
sandbox
ssrf_proxy
- name: setup test config
run: |
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
- name: Run Integration Tests
run: |
uv run --project api pytest \
-p no:benchmark \
--start-middleware \
-n auto \
--timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/workflow \

17
.github/workflows/expose_service_ports.sh vendored Executable file
View File

@ -0,0 +1,17 @@
#!/bin/bash
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml
yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.yaml
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"

View File

@ -55,6 +55,7 @@ jobs:
api:
- 'api/**'
- '.github/workflows/api-tests.yml'
- '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example'
- 'docker/envs/middleware.env.example'
- 'docker/docker-compose.middleware.yaml'
@ -89,13 +90,11 @@ jobs:
vdb:
- 'api/core/rag/datasource/**'
- 'api/tests/integration_tests/vdb/**'
- 'api/conftest.py'
- 'api/tests/pytest_dify.py'
- 'api/providers/vdb/*/tests/**'
- '.github/workflows/vdb-tests.yml'
- '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example'
- 'docker/envs/middleware.env.example'
- 'docker/docker-compose.pytest.ports.yaml'
- 'docker/docker-compose.yaml'
- 'docker/docker-compose-template.yaml'
- 'docker/generate_docker_compose'
@ -115,6 +114,7 @@ jobs:
- 'api/migrations/**'
- 'api/.env.example'
- '.github/workflows/db-migration-test.yml'
- '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example'
- 'docker/envs/middleware.env.example'
- 'docker/docker-compose.middleware.yaml'

View File

@ -48,6 +48,14 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env
cp docker/envs/middleware.env.example docker/middleware.env
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh
# - name: Set up Vector Store (TiDB)
# uses: hoverkraft-tech/compose-action@v2.0.2
# with:
@ -56,13 +64,32 @@ jobs:
# tidb
# tiflash
- name: Set up Full Vector Store Matrix
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.yaml
services: |
weaviate
qdrant
couchbase-server
etcd
minio
milvus-standalone
pgvecto-rs
pgvector
chroma
elasticsearch
oceanbase
- name: setup test config
run: |
echo $(pwd)
ls -lah .
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
- name: Test Vector Stores
run: |
uv run --project api pytest \
--start-vdb \
--vdb-services "weaviate,qdrant,couchbase-server,etcd,minio,milvus-standalone,pgvecto-rs,pgvector,chroma,elasticsearch,oceanbase" \
--timeout "${PYTEST_TIMEOUT:-180}" \
api/providers/vdb/*/tests/integration_tests
run: uv run --project api bash dev/pytest/pytest_vdb.sh

View File

@ -45,6 +45,14 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env
cp docker/envs/middleware.env.example docker/middleware.env
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh
# - name: Set up Vector Store (TiDB)
# uses: hoverkraft-tech/compose-action@v2.0.2
# with:
@ -53,14 +61,31 @@ jobs:
# tidb
# tiflash
- name: Set up Vector Stores for Smoke Coverage
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.yaml
services: |
db_postgres
redis
weaviate
qdrant
pgvector
chroma
- name: setup test config
run: |
echo $(pwd)
ls -lah .
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
- name: Test Vector Stores
run: |
uv run --project api pytest \
--start-vdb \
--timeout "${PYTEST_TIMEOUT:-180}" \
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
api/providers/vdb/vdb-chroma/tests/integration_tests \
api/providers/vdb/vdb-pgvector/tests/integration_tests \
api/providers/vdb/vdb-qdrant/tests/integration_tests \

View File

@ -85,13 +85,13 @@ lint:
type-check:
@echo "📝 Running type checks (pyrefly + mypy)..."
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
@uv --directory api run mypy --exclude-gitignore --exclude '(^|/)conftest\.py$$' --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
@echo "✅ Type checks complete"
type-check-core:
@echo "📝 Running core type checks (pyrefly + mypy)..."
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
@uv --directory api run mypy --exclude-gitignore --exclude '(^|/)conftest\.py$$' --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
@echo "✅ Core type checks complete"
test:
@ -100,46 +100,7 @@ test:
echo "Target: $(TARGET_TESTS)"; \
uv run --project api --dev pytest $(TARGET_TESTS); \
else \
echo "Running backend unit tests"; \
uv run --project api --dev pytest -p no:benchmark --timeout "$${PYTEST_TIMEOUT:-20}" -n auto \
api/tests/unit_tests \
api/providers/vdb/*/tests/unit_tests \
api/providers/trace/*/tests/unit_tests \
--ignore=api/tests/unit_tests/controllers; \
uv run --project api --dev pytest --timeout "$${PYTEST_TIMEOUT:-20}" --cov-append \
api/tests/unit_tests/controllers; \
fi
@echo "✅ Unit tests complete"
test-all:
@echo "🧪 Running full backend test suite..."
@if [ -n "$(TARGET_TESTS)" ]; then \
echo "Target: $(TARGET_TESTS)"; \
uv run --project api --dev pytest $(TARGET_TESTS); \
else \
echo "Running backend unit tests"; \
uv run --project api --dev pytest -p no:benchmark --timeout "$${PYTEST_TIMEOUT:-20}" -n auto \
api/tests/unit_tests \
api/providers/vdb/*/tests/unit_tests \
api/providers/trace/*/tests/unit_tests \
--ignore=api/tests/unit_tests/controllers; \
uv run --project api --dev pytest --timeout "$${PYTEST_TIMEOUT:-20}" --cov-append \
api/tests/unit_tests/controllers; \
echo "Running backend integration tests"; \
uv run --project api --dev pytest -p no:benchmark --start-middleware -n auto \
--timeout "$${PYTEST_TIMEOUT:-180}" \
--cov-append \
api/tests/integration_tests/workflow \
api/tests/integration_tests/tools \
api/tests/test_containers_integration_tests; \
echo "Running VDB smoke tests"; \
uv run --project api --dev pytest --start-vdb \
--timeout "$${PYTEST_TIMEOUT:-180}" \
--cov-append \
api/providers/vdb/vdb-chroma/tests/integration_tests \
api/providers/vdb/vdb-pgvector/tests/integration_tests \
api/providers/vdb/vdb-qdrant/tests/integration_tests \
api/providers/vdb/vdb-weaviate/tests/integration_tests; \
PYTEST_XDIST_ARGS="-n auto" uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
fi
@echo "✅ Tests complete"
@ -194,7 +155,6 @@ help:
@echo " make type-check - Run type checks (pyrefly, mypy)"
@echo " make type-check-core - Run core type checks (pyrefly, mypy)"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
@echo " make test-all - Run full backend tests, including Docker-backed suites"
@echo ""
@echo "Docker Build Targets:"
@echo " make build-web - Build web Docker image"
@ -204,4 +164,4 @@ help:
@echo " make build-push-all - Build and push all Docker images"
# Phony targets
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test test-all
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test

View File

@ -180,8 +180,6 @@ Quick checks while iterating:
- Format: `make format`
- Lint (includes auto-fix): `make lint`
- Type check: `make type-check`
- Unit tests: `make test`
- Full backend tests, including Docker-backed suites: `make test-all`
- Targeted tests: `make test TARGET_TESTS=./api/tests/<target_tests>`
Before opening a PR / submitting:

View File

@ -1,91 +0,0 @@
"""Global pytest hooks for Dify backend tests.
This root conftest is loaded before package-specific conftests, which lets tests opt
into Docker-backed middleware before application modules read environment config.
It intentionally lives at the API root because pytest applies conftest.py files to
tests below their directory, and this setup is shared by api/tests and api/providers.
"""
from __future__ import annotations
from pathlib import Path
import pytest
from tests.pytest_dify import (
DEFAULT_MIDDLEWARE_SERVICES,
DEFAULT_VDB_SERVICES,
DockerComposeStack,
build_middleware_stack,
build_vdb_stack,
ensure_backend_test_environment,
ensure_compose_env_files,
parse_services,
)
_REPO_ROOT = Path(__file__).resolve().parent.parent
_DIFY_COMPOSE_STACKS_KEY = pytest.StashKey[list[DockerComposeStack]]()
# This must run at import time because package-specific conftests can import the
# Flask app before pytest_configure hooks from this file are called.
ensure_backend_test_environment(_REPO_ROOT)
def pytest_addoption(parser: pytest.Parser) -> None:
group = parser.getgroup("dify")
group.addoption(
"--start-middleware",
action="store_true",
default=False,
help="Start the Docker middleware services needed by API integration tests.",
)
group.addoption(
"--middleware-services",
default=",".join(DEFAULT_MIDDLEWARE_SERVICES),
help="Comma-separated services from docker/docker-compose.middleware.yaml to start.",
)
group.addoption(
"--start-vdb",
action="store_true",
default=False,
help="Start vector-store Docker services for VDB integration tests.",
)
group.addoption(
"--vdb-services",
default=",".join(DEFAULT_VDB_SERVICES),
help="Comma-separated services from docker/docker-compose.yaml to start for VDB tests.",
)
def pytest_configure(config: pytest.Config) -> None:
config.stash[_DIFY_COMPOSE_STACKS_KEY] = []
def pytest_sessionstart(session: pytest.Session) -> None:
config = session.config
if hasattr(config, "workerinput"):
return
stacks: list[DockerComposeStack] = []
if config.getoption("start_middleware"):
ensure_compose_env_files(_REPO_ROOT)
stack = build_middleware_stack(_REPO_ROOT, parse_services(config.getoption("middleware_services")))
stack.up()
stacks.append(stack)
if config.getoption("start_vdb"):
ensure_compose_env_files(_REPO_ROOT)
stack = build_vdb_stack(_REPO_ROOT, parse_services(config.getoption("vdb_services")))
stack.up()
stacks.append(stack)
config.stash[_DIFY_COMPOSE_STACKS_KEY] = stacks
def pytest_unconfigure(config: pytest.Config) -> None:
if hasattr(config, "workerinput"):
return
stacks = config.stash.get(_DIFY_COMPOSE_STACKS_KEY, [])
for stack in reversed(stacks):
stack.down()

View File

@ -1,67 +0,0 @@
"""Dify event package.
The package name intentionally stays as ``events`` for existing Dify imports. Some
third-party clients also import ``Events`` from a top-level ``events`` package, so
we expose a small compatible implementation to avoid import shadowing failures.
"""
from collections.abc import Callable, Iterator
from typing import Any
class EventsError(Exception):
"""Raised for invalid event slot operations."""
EventsException = EventsError
class _EventSlot:
"""A dynamically-created event slot supporting ``+=`` and call dispatch."""
targets: list[Callable[..., Any]]
__name__: str
def __init__(self, name: str) -> None:
self.targets = []
self.__name__ = name
def __call__(self, *args: Any, **kwargs: Any) -> None:
for target in tuple(self.targets):
target(*args, **kwargs)
def __iadd__(self, target: Callable[..., Any]) -> "_EventSlot":
self.targets.append(target)
return self
def __isub__(self, target: Callable[..., Any]) -> "_EventSlot":
while target in self.targets:
self.targets.remove(target)
return self
def __iter__(self) -> Iterator[Callable[..., Any]]:
return iter(self.targets)
def __len__(self) -> int:
return len(self.targets)
class Events:
"""A minimal C#-style event container compatible with the external Events package."""
_slots: dict[str, _EventSlot]
def __init__(self, *event_names: str) -> None:
self._slots = {}
for event_name in event_names:
self._slots[event_name] = _EventSlot(event_name)
def __getattr__(self, name: str) -> _EventSlot:
if name.startswith("_"):
raise AttributeError(name)
slot = _EventSlot(name)
self._slots[name] = slot
return slot
__all__ = ["Events", "EventsError", "EventsException"]

View File

@ -13,7 +13,7 @@ class ChromaVectorTest(AbstractVectorTest):
self.vector = ChromaVector(
collection_name=self.collection_name,
config=ChromaConfig(
host="127.0.0.1",
host="localhost",
port=8000,
tenant=chromadb.DEFAULT_TENANT,
database=chromadb.DEFAULT_DATABASE,

View File

@ -16,7 +16,7 @@ class QdrantVectorTest(AbstractVectorTest):
collection_name=self.collection_name,
group_id=self.dataset_id,
config=QdrantConfig(
endpoint="http://127.0.0.1:6333",
endpoint="http://localhost:6333",
api_key="difyai123456",
),
)

View File

@ -26,24 +26,20 @@ _logger = logging.getLogger(__name__)
# Loading the .env file if it exists
def _load_env():
current_file_path = pathlib.Path(__file__).absolute()
# Items later in the list have higher precedence.
env_file_paths = [
pathlib.Path(os.getenv("DIFY_TEST_ENV_FILE", str(current_file_path.parent / _DEFUALT_TEST_ENV))),
os.getenv("DIFY_TEST_ENV_FILE", str(current_file_path.parent / _DEFUALT_TEST_ENV)),
os.getenv("DIFY_VDB_TEST_ENV_FILE", str(current_file_path.parent / _DEFAULT_VDB_TEST_ENV)),
]
vdb_env_path = pathlib.Path(
os.getenv("DIFY_VDB_TEST_ENV_FILE", str(current_file_path.parent / _DEFAULT_VDB_TEST_ENV))
)
if vdb_env_path.exists() or "DIFY_VDB_TEST_ENV_FILE" in os.environ:
env_file_paths.append(vdb_env_path)
for env_path in env_file_paths:
if not env_path.exists():
_logger.warning("specified configuration file %s not exist", env_path)
continue
for env_path_str in env_file_paths:
if not pathlib.Path(env_path_str).exists():
_logger.warning("specified configuration file %s not exist", env_path_str)
from dotenv import load_dotenv
# Set `override=True` to ensure values from `vdb.env` take priority over values from `.env`
load_dotenv(str(env_path), override=True)
load_dotenv(str(env_path_str), override=True)
_load_env()

View File

@ -1,198 +0,0 @@
"""Pytest support helpers for Dify backend test environment setup.
The helpers in this module keep Docker and environment preparation behind explicit
pytest options so ordinary unit-test runs do not start external services.
"""
from __future__ import annotations
import os
import shutil
import subprocess
import time
import urllib.error
import urllib.request
from dataclasses import dataclass
from pathlib import Path
DEFAULT_LOG_FORMAT = "%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s"
DEFAULT_MIDDLEWARE_SERVICES = ("db_postgres", "redis", "sandbox", "ssrf_proxy")
DEFAULT_VDB_SERVICES = ("db_postgres", "redis", "weaviate", "qdrant", "pgvector", "chroma")
VDB_SERVICE_PROFILES = {
"db_postgres": "postgresql",
"weaviate": "weaviate",
"qdrant": "qdrant",
"couchbase-server": "couchbase",
"etcd": "milvus",
"minio": "milvus",
"milvus-standalone": "milvus",
"pgvecto-rs": "pgvecto-rs",
"pgvector": "pgvector",
"chroma": "chroma",
"elasticsearch": "elasticsearch",
"oceanbase": "oceanbase",
}
def parse_services(value: str) -> list[str]:
"""Parse a comma-separated service list from a pytest option."""
return [service.strip() for service in value.split(",") if service.strip()]
def ensure_backend_test_environment(repo_root: Path) -> None:
"""Set deterministic defaults needed before test conftests import application config."""
integration_tests_dir = repo_root / "api" / "tests" / "integration_tests"
test_env_file = integration_tests_dir / ".env"
test_env_example_file = integration_tests_dir / ".env.example"
vdb_env_file = integration_tests_dir / "vdb.env"
if "DIFY_TEST_ENV_FILE" not in os.environ:
os.environ["DIFY_TEST_ENV_FILE"] = str(test_env_file if test_env_file.exists() else test_env_example_file)
if "DIFY_VDB_TEST_ENV_FILE" not in os.environ and vdb_env_file.exists():
os.environ["DIFY_VDB_TEST_ENV_FILE"] = str(vdb_env_file)
os.environ["LOG_OUTPUT_FORMAT"] = "text"
os.environ["LOG_FORMAT"] = DEFAULT_LOG_FORMAT
os.environ.setdefault("STORAGE_TYPE", "opendal")
os.environ.setdefault("OPENDAL_SCHEME", "fs")
os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage")
Path(os.environ["OPENDAL_FS_ROOT"]).mkdir(parents=True, exist_ok=True)
def ensure_compose_env_files(repo_root: Path) -> None:
"""Create ignored Docker env files from examples when Docker-backed tests request compose."""
docker_dir = repo_root / "docker"
env_file = docker_dir / ".env"
env_example_file = docker_dir / ".env.example"
middleware_env_file = docker_dir / "middleware.env"
middleware_env_example_file = docker_dir / "envs" / "middleware.env.example"
if not env_file.exists():
shutil.copyfile(env_example_file, env_file)
if not middleware_env_file.exists():
shutil.copyfile(middleware_env_example_file, middleware_env_file)
@dataclass(frozen=True)
class DockerComposeStack:
"""A docker compose project that pytest can start before collection and stop at shutdown."""
name: str
project_name: str
repo_root: Path
compose_files: tuple[Path, ...]
env_file: Path
services: tuple[str, ...]
profiles: tuple[str, ...] = ()
ready_delay_seconds: float = 0.0
warmup_urls: tuple[str, ...] = ()
def _compose_command(self) -> list[str]:
command = [
"docker",
"compose",
"--project-name",
self.project_name,
"--env-file",
str(self.env_file),
]
for profile in self.profiles:
command.extend(("--profile", profile))
for compose_file in self.compose_files:
command.extend(("-f", str(compose_file)))
return command
def up(self) -> None:
"""Start the configured services and wait for compose healthchecks when supported."""
wait_command = self._compose_command() + [
"up",
"-d",
"--wait",
"--wait-timeout",
"180",
*self.services,
]
completed = subprocess.run(wait_command, cwd=self.repo_root, text=True, capture_output=True)
if completed.returncode == 0:
if self.ready_delay_seconds > 0:
time.sleep(self.ready_delay_seconds)
self._warm_up()
return
combined_output = f"{completed.stdout}\n{completed.stderr}"
if "unknown flag: --wait" in combined_output or "unknown flag: wait-timeout" in combined_output:
subprocess.run(self._compose_command() + ["up", "-d", *self.services], cwd=self.repo_root, check=True)
time.sleep(5)
self._warm_up()
return
raise subprocess.CalledProcessError(
returncode=completed.returncode,
cmd=wait_command,
output=completed.stdout,
stderr=completed.stderr,
)
def down(self) -> None:
"""Stop services started for this pytest run."""
subprocess.run(self._compose_command() + ["down"], cwd=self.repo_root, check=True)
def _warm_up(self) -> None:
for url in self.warmup_urls:
deadline = time.monotonic() + 30.0
last_error: Exception | None = None
while time.monotonic() < deadline:
try:
with urllib.request.urlopen(url, timeout=5) as response:
if 200 <= response.status < 300:
break
except urllib.error.HTTPError as error:
if error.code < 500:
break
last_error = error
except (OSError, urllib.error.URLError) as error:
last_error = error
time.sleep(1)
else:
raise RuntimeError(f"Timed out waiting for {self.name} warmup URL {url}") from last_error
def build_middleware_stack(repo_root: Path, services: list[str]) -> DockerComposeStack:
"""Build the middleware compose stack used by API integration tests."""
return DockerComposeStack(
name="middleware",
project_name="dify-pytest-middleware",
repo_root=repo_root,
compose_files=(repo_root / "docker" / "docker-compose.middleware.yaml",),
env_file=repo_root / "docker" / "middleware.env",
services=tuple(services),
ready_delay_seconds=5.0,
warmup_urls=("http://127.0.0.1:8194/health",),
)
def build_vdb_stack(repo_root: Path, services: list[str]) -> DockerComposeStack:
"""Build the vector-store compose stack used by VDB integration tests."""
profiles = tuple(
dict.fromkeys(profile for service in services if (profile := VDB_SERVICE_PROFILES.get(service)) is not None)
)
service_names = set(services)
warmup_urls = []
if "qdrant" in service_names:
warmup_urls.append("http://127.0.0.1:6333/collections")
if "chroma" in service_names:
warmup_urls.append("http://127.0.0.1:8000/api/v2/auth/identity")
return DockerComposeStack(
name="vdb",
project_name="dify-pytest-vdb",
repo_root=repo_root,
compose_files=(
repo_root / "docker" / "docker-compose.yaml",
repo_root / "docker" / "docker-compose.pytest.ports.yaml",
),
env_file=repo_root / "docker" / ".env",
services=tuple(services),
profiles=profiles,
warmup_urls=tuple(warmup_urls),
)

View File

@ -252,9 +252,6 @@ class TestRedisBroadcastChannelIntegration:
def consumer_thread() -> set[bytes]:
received_msgs: set[bytes] = set()
with subscription:
# Prime the subscription before producers publish. Redis Pub/Sub does not
# replay messages sent before the subscribe command is active.
subscription.receive(timeout=0.1)
consumer_ready.set()
while True:
try:
@ -281,10 +278,8 @@ class TestRedisBroadcastChannelIntegration:
for future in as_completed(producer_futures, timeout=30.0):
sent_msgs.update(future.result())
try:
consumer_received_msgs = consumer_future.result(timeout=30.0)
finally:
subscription.close()
subscription.close()
consumer_received_msgs = consumer_future.result(timeout=30.0)
# Verify message content
assert sent_msgs == consumer_received_msgs

View File

@ -37,9 +37,9 @@ class TestAppGenerateService:
"services.app_generate_service.MessageBasedAppGenerator", autospec=True
) as mock_message_based_generator,
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
patch("services.app_generate_service.dify_config") as mock_dify_config,
patch("services.quota_service.dify_config") as mock_quota_dify_config,
patch("configs.dify_config") as mock_global_dify_config,
patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config,
patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config,
patch("configs.dify_config", autospec=True) as mock_global_dify_config,
):
# Setup default mock returns for billing service
mock_billing_service.quota_reserve.return_value = {

View File

@ -84,7 +84,7 @@ def _mock_factory_for_apps(
class TestRecommendedAppServiceGetApps:
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_success_with_apps(self, mock_config, mock_factory_class):
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
expected = _apps_response()
@ -103,7 +103,7 @@ class TestRecommendedAppServiceGetApps:
mock_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class):
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
empty_response = {"recommended_apps": [], "categories": []}
@ -126,7 +126,7 @@ class TestRecommendedAppServiceGetApps:
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class):
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db"
none_response = {"recommended_apps": None, "categories": ["test"]}
@ -146,7 +146,7 @@ class TestRecommendedAppServiceGetApps:
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_different_languages(self, mock_config, mock_factory_class):
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
@ -164,7 +164,7 @@ class TestRecommendedAppServiceGetApps:
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_uses_correct_factory_mode(self, mock_config, mock_factory_class):
for mode in ["remote", "builtin", "db"]:
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
@ -183,7 +183,7 @@ class TestRecommendedAppServiceGetApps:
class TestRecommendedAppServiceGetDetail:
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_success(self, mock_config, mock_factory_class):
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
expected = _app_detail(app_id="app-123", name="Productivity App", description="A great app")
@ -199,7 +199,7 @@ class TestRecommendedAppServiceGetDetail:
mock_instance.get_recommend_app_detail.assert_called_once_with("app-123")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_different_modes(self, mock_config, mock_factory_class):
for mode in ["remote", "builtin", "db"]:
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
@ -214,7 +214,7 @@ class TestRecommendedAppServiceGetDetail:
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_returns_none_when_not_found(self, mock_config, mock_factory_class):
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
mock_instance = MagicMock()
@ -227,7 +227,7 @@ class TestRecommendedAppServiceGetDetail:
mock_instance.get_recommend_app_detail.assert_called_once_with("nonexistent")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_returns_empty_dict(self, mock_config, mock_factory_class):
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
mock_instance = MagicMock()
@ -239,7 +239,7 @@ class TestRecommendedAppServiceGetDetail:
assert result == {}
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_complex_model_config(self, mock_config, mock_factory_class):
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
complex_config = {

View File

@ -60,7 +60,7 @@ class TestMailInviteMemberTask:
with (
patch("tasks.mail_invite_member_task.mail", autospec=True) as mock_mail,
patch("tasks.mail_invite_member_task.get_email_i18n_service", autospec=True) as mock_email_service,
patch("tasks.mail_invite_member_task.dify_config") as mock_config,
patch("tasks.mail_invite_member_task.dify_config", autospec=True) as mock_config,
):
# Setup mail service mock
mock_mail.is_inited.return_value = True

View File

@ -90,7 +90,7 @@ class TestMailRegisterTask:
to_email = fake.email()
account_name = fake.name()
with patch("tasks.mail_register_task.dify_config") as mock_config:
with patch("tasks.mail_register_task.dify_config", autospec=True) as mock_config:
mock_config.CONSOLE_WEB_URL = "https://console.dify.ai"
send_email_register_mail_task_when_account_exist(language=language, to=to_email, account_name=account_name)

View File

@ -15,11 +15,6 @@ from models.model import App, AppMode
from tests.unit_tests.conftest import setup_mock_tenant_owner_execute_result
def _configure_current_app_mock(mock_current_app):
mock_current_app.login_manager = Mock()
mock_current_app._get_current_object = Mock(return_value=Mock())
class TestAppParameterApi:
"""Test suite for AppParameterApi"""
@ -50,7 +45,7 @@ class TestAppParameterApi:
):
"""Test retrieving parameters for a chat app."""
# Arrange
_configure_current_app_mock(mock_current_app)
mock_current_app.login_manager = Mock()
mock_config = Mock()
mock_config.id = str(uuid.uuid4())
@ -100,7 +95,7 @@ class TestAppParameterApi:
):
"""Test retrieving parameters for a workflow app."""
# Arrange
_configure_current_app_mock(mock_current_app)
mock_current_app.login_manager = Mock()
mock_app_model.mode = AppMode.WORKFLOW
mock_workflow = Mock()
@ -145,7 +140,7 @@ class TestAppParameterApi:
):
"""Test that AppUnavailableError is raised when chat app has no config."""
# Arrange
_configure_current_app_mock(mock_current_app)
mock_current_app.login_manager = Mock()
mock_app_model.app_model_config = None
mock_app_model.workflow = None
@ -183,7 +178,7 @@ class TestAppParameterApi:
):
"""Test that AppUnavailableError is raised when workflow app has no workflow."""
# Arrange
_configure_current_app_mock(mock_current_app)
mock_current_app.login_manager = Mock()
mock_app_model.mode = AppMode.WORKFLOW
mock_app_model.workflow = None
@ -250,7 +245,7 @@ class TestAppMetaApi:
):
"""Test retrieving app metadata via AppService."""
# Arrange
_configure_current_app_mock(mock_current_app)
mock_current_app.login_manager = Mock()
mock_service_instance = Mock()
mock_service_instance.get_app_meta.return_value = {
@ -325,7 +320,7 @@ class TestAppInfoApi:
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
):
"""Test retrieving basic app information."""
_configure_current_app_mock(mock_current_app)
mock_current_app.login_manager = Mock()
# Mock authentication
mock_api_token = Mock()
@ -366,7 +361,7 @@ class TestAppInfoApi:
):
"""Test retrieving app info with multiple tags."""
# Arrange
_configure_current_app_mock(mock_current_app)
mock_current_app.login_manager = Mock()
mock_app = Mock(spec=App)
mock_app.id = str(uuid.uuid4())
@ -419,7 +414,7 @@ class TestAppInfoApi:
):
"""Test retrieving app info when app has no tags."""
# Arrange
_configure_current_app_mock(mock_current_app)
mock_current_app.login_manager = Mock()
mock_app = Mock(spec=App)
mock_app.id = str(uuid.uuid4())
@ -471,7 +466,7 @@ class TestAppInfoApi:
):
"""Test that all app modes are correctly returned."""
# Arrange
_configure_current_app_mock(mock_current_app)
mock_current_app.login_manager = Mock()
mock_app = Mock(spec=App)
mock_app.id = str(uuid.uuid4())

View File

@ -2,29 +2,31 @@
Unit tests for Service API Index endpoint
"""
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.service_api import index as index_module
from controllers.service_api.index import IndexApi
def _get_index_response(app: Flask, version: str) -> dict[str, str]:
with patch.object(index_module.dify_config.project, "version", version):
with app.test_request_context("/", method="GET"):
index_api = IndexApi()
return index_api.get()
class TestIndexApi:
"""Test suite for IndexApi resource."""
def test_get_returns_api_info(self, app: Flask):
@patch("controllers.service_api.index.dify_config", autospec=True)
def test_get_returns_api_info(self, mock_config, app: Flask):
"""Test that GET returns API metadata with correct structure."""
# Arrange
mock_config.project.version = "1.0.0-test"
# Act
response = _get_index_response(app, "1.0.0-test")
with app.test_request_context("/", method="GET"):
index_api = IndexApi()
response = index_api.get()
with patch("controllers.service_api.index.dify_config", mock_config):
with app.test_request_context("/", method="GET"):
index_api = IndexApi()
response = index_api.get()
# Assert
assert response["welcome"] == "Dify OpenAPI"
@ -33,8 +35,15 @@ class TestIndexApi:
def test_get_response_has_required_fields(self, app: Flask):
"""Test that response contains all required fields."""
# Arrange
mock_config = MagicMock()
mock_config.project.version = "1.11.4"
# Act
response = _get_index_response(app, "1.11.4")
with patch("controllers.service_api.index.dify_config", mock_config):
with app.test_request_context("/", method="GET"):
index_api = IndexApi()
response = index_api.get()
# Assert
assert "welcome" in response
@ -47,8 +56,15 @@ class TestIndexApi:
@pytest.mark.parametrize("version", ["0.0.1", "1.0.0", "2.0.0-beta", "1.11.4"])
def test_get_returns_correct_version(self, app: Flask, version):
"""Test that server_version matches config version."""
# Arrange
mock_config = MagicMock()
mock_config.project.version = version
# Act
response = _get_index_response(app, version)
with patch("controllers.service_api.index.dify_config", mock_config):
with app.test_request_context("/", method="GET"):
index_api = IndexApi()
response = index_api.get()
# Assert
assert response["server_version"] == version

View File

@ -29,11 +29,6 @@ from tests.unit_tests.conftest import (
)
def _configure_current_app_mock(mock_current_app):
mock_current_app.login_manager = Mock()
mock_current_app._get_current_object = Mock(return_value=Mock())
class TestValidateAndGetApiToken:
"""Test suite for validate_and_get_api_token function"""
@ -125,7 +120,8 @@ class TestValidateAppToken:
):
"""Test that valid app token allows access to decorated view."""
# Arrange
_configure_current_app_mock(mock_current_app)
# Use standard Mock for login_manager to avoid AsyncMockMixin warnings
mock_current_app.login_manager = Mock()
mock_api_token = Mock()
mock_api_token.app_id = str(uuid.uuid4())
@ -452,7 +448,8 @@ class TestValidateDatasetToken:
def test_valid_dataset_token(self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app: Flask):
"""Test that valid dataset token allows access."""
# Arrange
_configure_current_app_mock(mock_current_app)
# Use standard Mock for login_manager
mock_current_app.login_manager = Mock()
tenant_id = str(uuid.uuid4())
mock_api_token = Mock()

View File

@ -32,7 +32,7 @@ class TestConstants:
class TestCreateSSRFProxyMCPHTTPClient:
"""Test create_ssrf_proxy_mcp_http_client function."""
@patch("core.mcp.utils.dify_config")
@patch("core.mcp.utils.dify_config", autospec=True)
def test_create_client_with_all_url_proxy(self, mock_config):
"""Test client creation with SSRF_PROXY_ALL_URL configured."""
mock_config.SSRF_PROXY_ALL_URL = "http://proxy.example.com:8080"
@ -50,7 +50,7 @@ class TestCreateSSRFProxyMCPHTTPClient:
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
@patch("core.mcp.utils.dify_config", autospec=True)
def test_create_client_with_http_https_proxies(self, mock_config):
"""Test client creation with separate HTTP/HTTPS proxies."""
mock_config.SSRF_PROXY_ALL_URL = None
@ -66,7 +66,7 @@ class TestCreateSSRFProxyMCPHTTPClient:
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
@patch("core.mcp.utils.dify_config", autospec=True)
def test_create_client_without_proxy(self, mock_config):
"""Test client creation without proxy configuration."""
mock_config.SSRF_PROXY_ALL_URL = None
@ -88,7 +88,7 @@ class TestCreateSSRFProxyMCPHTTPClient:
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
@patch("core.mcp.utils.dify_config", autospec=True)
def test_create_client_default_params(self, mock_config):
"""Test client creation with default parameters."""
mock_config.SSRF_PROXY_ALL_URL = None
@ -140,7 +140,7 @@ class TestSSRFProxySSEConnect:
@patch("core.mcp.utils.connect_sse", autospec=True)
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client", autospec=True)
@patch("core.mcp.utils.dify_config")
@patch("core.mcp.utils.dify_config", autospec=True)
def test_sse_connect_without_client(self, mock_config, mock_create_client, mock_connect_sse):
"""Test SSE connection without pre-configured client."""
# Setup config

View File

@ -88,7 +88,7 @@ class TestOutputModeration:
def test_start_thread(self, output_moderation):
mock_app = MagicMock(spec=Flask)
with patch("core.moderation.output_moderation.current_app") as mock_current_app:
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
mock_current_app._get_current_object.return_value = mock_app
with patch("threading.Thread") as mock_thread_class:
mock_thread_instance = MagicMock()
mock_thread_class.return_value = mock_thread_instance

View File

@ -106,7 +106,7 @@ class TestQAIndexProcessor:
patch.object(processor, "_format_qa_document", side_effect=_append_document) as mock_format,
patch("core.rag.index_processor.processor.qa_index_processor.current_app") as mock_current_app,
):
mock_current_app._get_current_object = Mock(return_value=fake_flask_app)
mock_current_app._get_current_object.return_value = fake_flask_app
result = processor.transform(
[document],
process_rule=process_rule,
@ -155,7 +155,7 @@ class TestQAIndexProcessor:
"core.rag.index_processor.processor.qa_index_processor.threading.Thread", side_effect=_ImmediateThread
),
):
mock_current_app._get_current_object = Mock(return_value=fake_flask_app)
mock_current_app._get_current_object.return_value = fake_flask_app
result = processor.transform(documents, process_rule=process_rule, preview=False, tenant_id="tenant-1")
assert len(result) == 2

View File

@ -594,7 +594,6 @@ class TestIndexingRunnerLoad:
patch("core.indexing_runner.threading.Thread") as mock_thread,
patch("core.indexing_runner.concurrent.futures.ThreadPoolExecutor") as mock_executor,
):
mock_app._get_current_object = Mock(return_value=Mock())
yield {
"db": mock_db,
"model_manager": mock_model_manager,

View File

@ -51,7 +51,7 @@ class TestRepositoryFactory:
import_string("invalidpath")
assert "doesn't look like a module path" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
@patch("core.repositories.factory.dify_config", autospec=True)
def test_create_workflow_execution_repository_success(self, mock_config):
"""Test successful WorkflowExecutionRepository creation."""
# Setup mock configuration
@ -86,7 +86,7 @@ class TestRepositoryFactory:
)
assert result is mock_repository_instance
@patch("core.repositories.factory.dify_config")
@patch("core.repositories.factory.dify_config", autospec=True)
def test_create_workflow_execution_repository_import_error(self, mock_config):
"""Test WorkflowExecutionRepository creation with import error."""
# Setup mock configuration with invalid class path
@ -104,7 +104,7 @@ class TestRepositoryFactory:
)
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
@patch("core.repositories.factory.dify_config", autospec=True)
def test_create_workflow_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowExecutionRepository creation with instantiation error."""
# Setup mock configuration
@ -128,7 +128,7 @@ class TestRepositoryFactory:
)
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
@patch("core.repositories.factory.dify_config", autospec=True)
def test_create_workflow_node_execution_repository_success(self, mock_config):
"""Test successful WorkflowNodeExecutionRepository creation."""
# Setup mock configuration
@ -163,7 +163,7 @@ class TestRepositoryFactory:
)
assert result is mock_repository_instance
@patch("core.repositories.factory.dify_config")
@patch("core.repositories.factory.dify_config", autospec=True)
def test_create_workflow_node_execution_repository_import_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with import error."""
# Setup mock configuration with invalid class path
@ -181,7 +181,7 @@ class TestRepositoryFactory:
)
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
@patch("core.repositories.factory.dify_config", autospec=True)
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
# Setup mock configuration
@ -211,7 +211,7 @@ class TestRepositoryFactory:
error = RepositoryImportError(error_message)
assert str(error) == error_message
@patch("core.repositories.factory.dify_config")
@patch("core.repositories.factory.dify_config", autospec=True)
def test_create_with_engine_instead_of_sessionmaker(self, mock_config):
"""Test repository creation with Engine instead of sessionmaker."""
# Setup mock configuration

View File

@ -1,4 +1,3 @@
import copy
import time
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import MagicMock, patch
@ -473,7 +472,7 @@ class TestSchemaResolverClass:
assert resolved[2]["title"] == "Q&A Structure"
def test_cache_performance(self):
"""Test that repeated references share cached schema lookups."""
"""Test that caching improves performance"""
SchemaResolver.clear_cache()
# Create a schema with many references to the same schema
@ -485,16 +484,36 @@ class TestSchemaResolverClass:
},
}
registry = SchemaRegistry.default_registry()
file_schema = registry.get_schema("https://dify.ai/schemas/v1/file.json")
assert file_schema is not None
# First run (no cache) - run multiple times to warm up
results1 = []
for _ in range(3):
SchemaResolver.clear_cache()
start = time.perf_counter()
result1 = resolve_dify_schema_refs(schema)
time_no_cache = time.perf_counter() - start
results1.append(time_no_cache)
with patch.object(registry, "get_schema", wraps=registry.get_schema) as mock_get:
result1 = resolve_dify_schema_refs(copy.deepcopy(schema), registry=registry)
result2 = resolve_dify_schema_refs(copy.deepcopy(schema), registry=registry)
avg_time_no_cache = sum(results1) / len(results1)
# Second run (with cache) - run multiple times
# Warm up cache first
resolve_dify_schema_refs(schema)
results2 = []
for _ in range(3):
start = time.perf_counter()
result2 = resolve_dify_schema_refs(schema)
time_with_cache = time.perf_counter() - start
results2.append(time_with_cache)
avg_time_with_cache = sum(results2) / len(results2)
# Cache should make it faster (more lenient check)
assert result1 == result2
mock_get.assert_called_once_with("https://dify.ai/schemas/v1/file.json")
# Cache should provide some performance benefit (allow for measurement variance)
# We expect cache to be faster, but allow for small timing variations
performance_ratio = avg_time_with_cache / avg_time_no_cache if avg_time_no_cache > 0 else 1.0
assert performance_ratio <= 2.0, f"Cache performance degraded too much: {performance_ratio}"
def test_fast_path_performance_no_refs(self):
"""Test that schemas without $refs use fast path and avoid deep copying"""

View File

@ -1,42 +0,0 @@
import pytest
from events import Events, EventsError, EventsException
def test_events_package_exposes_opensearchpy_compatible_events_class():
calls: list[str] = []
events = Events()
events.request_start += lambda: calls.append("start")
events.request_end += lambda: calls.append("end")
events.request_start()
events.request_end()
assert calls == ["start", "end"]
def test_events_package_supports_named_slots_iteration_removal_and_private_attrs():
calls: list[str] = []
def handler() -> None:
calls.append("handled")
events = Events("request_start")
events.request_start += handler
events.request_start += handler
assert len(events.request_start) == 2
assert list(events.request_start) == [handler, handler]
events.request_start -= handler
assert len(events.request_start) == 0
events.request_start()
assert calls == []
with pytest.raises(AttributeError):
_ = events._private # type: ignore[attr-defined]
assert EventsException is EventsError

View File

@ -11,7 +11,7 @@ class TestSupabaseStorage:
def test_init_success_with_all_config(self):
"""Test successful initialization when all required config is provided."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
@ -31,7 +31,7 @@ class TestSupabaseStorage:
def test_init_raises_error_when_url_missing(self):
"""Test initialization raises ValueError when SUPABASE_URL is None."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = None
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
@ -41,7 +41,7 @@ class TestSupabaseStorage:
def test_init_raises_error_when_api_key_missing(self):
"""Test initialization raises ValueError when SUPABASE_API_KEY is None."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = None
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
@ -51,7 +51,7 @@ class TestSupabaseStorage:
def test_init_raises_error_when_bucket_name_missing(self):
"""Test initialization raises ValueError when SUPABASE_BUCKET_NAME is None."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = None
@ -61,7 +61,7 @@ class TestSupabaseStorage:
def test_create_bucket_when_not_exists(self):
"""Test create_bucket creates bucket when it doesn't exist."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
@ -77,7 +77,7 @@ class TestSupabaseStorage:
def test_create_bucket_when_exists(self):
"""Test create_bucket does not create bucket when it already exists."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
@ -94,7 +94,7 @@ class TestSupabaseStorage:
@pytest.fixture
def storage_with_mock_client(self):
"""Fixture providing SupabaseStorage with mocked client."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
@ -209,7 +209,7 @@ class TestSupabaseStorage:
def test_bucket_exists_returns_true_when_bucket_found(self):
"""Test bucket_exists returns True when bucket is found in list."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
@ -229,7 +229,7 @@ class TestSupabaseStorage:
def test_bucket_exists_returns_false_when_bucket_not_found(self):
"""Test bucket_exists returns False when bucket is not found in list."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
@ -252,7 +252,7 @@ class TestSupabaseStorage:
def test_bucket_exists_returns_false_when_no_buckets(self):
"""Test bucket_exists returns False when no buckets exist."""
with patch("extensions.storage.supabase_storage.dify_config") as mock_config:
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
mock_config.SUPABASE_URL = "https://test.supabase.co"
mock_config.SUPABASE_API_KEY = "test-api-key"
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"

View File

@ -15,7 +15,7 @@ from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
class TestWorkflowRunArchiver:
"""Tests for the WorkflowRunArchiver class."""
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config")
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config", autospec=True)
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage", autospec=True)
def test_archiver_initialization(self, mock_get_storage, mock_config):
"""Test archiver can be initialized with various options."""

View File

@ -403,7 +403,7 @@ class TestBillingDisabledPolicyFilterMessageIds:
class TestCreateMessageCleanPolicy:
"""Unit tests for create_message_clean_policy factory function."""
@patch("services.retention.conversation.messages_clean_policy.dify_config")
@patch("services.retention.conversation.messages_clean_policy.dify_config", autospec=True)
def test_billing_disabled_returns_billing_disabled_policy(self, mock_config):
"""Test that BILLING_ENABLED=False returns BillingDisabledPolicy."""
# Arrange
@ -416,7 +416,7 @@ class TestCreateMessageCleanPolicy:
assert isinstance(policy, BillingDisabledPolicy)
@patch("services.retention.conversation.messages_clean_policy.BillingService", autospec=True)
@patch("services.retention.conversation.messages_clean_policy.dify_config")
@patch("services.retention.conversation.messages_clean_policy.dify_config", autospec=True)
def test_billing_enabled_policy_has_correct_internals(self, mock_config, mock_billing_service):
"""Test that BillingSandboxPolicy is created with correct internal values."""
# Arrange

View File

@ -13,7 +13,7 @@ from services.recommended_app_service import RecommendedAppService
class TestGetRecommendAppDetailNullCheck:
@patch("services.recommended_app_service.FeatureService", autospec=True)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_returns_none_when_retrieval_returns_none_and_trial_disabled(
self, mock_config, mock_factory_class, mock_feature_service
):
@ -29,7 +29,7 @@ class TestGetRecommendAppDetailNullCheck:
@patch("services.recommended_app_service.FeatureService", autospec=True)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_returns_none_when_retrieval_returns_none_and_trial_enabled(
self, mock_config, mock_factory_class, mock_feature_service
):

View File

@ -1,41 +0,0 @@
import subprocess
from pathlib import Path
def test_default_make_test_runs_backend_unit_suites():
repo_root = Path(__file__).resolve().parents[3]
completed = subprocess.run(["make", "-n", "test"], cwd=repo_root, check=True, capture_output=True, text=True)
dry_run_output = completed.stdout
assert "api/tests/unit_tests" in dry_run_output
assert "api/providers/vdb/*/tests/unit_tests" in dry_run_output
assert "api/providers/trace/*/tests/unit_tests" in dry_run_output
assert "-p no:benchmark" in dry_run_output
assert "api/tests/unit_tests/controllers" in dry_run_output
assert "--start-middleware" not in dry_run_output
assert "api/tests/integration_tests/workflow" not in dry_run_output
assert "api/tests/test_containers_integration_tests" not in dry_run_output
assert "--start-vdb" not in dry_run_output
assert "api/providers/vdb/vdb-chroma/tests/integration_tests" not in dry_run_output
def test_make_test_all_runs_backend_pytest_suites():
repo_root = Path(__file__).resolve().parents[3]
completed = subprocess.run(["make", "-n", "test-all"], cwd=repo_root, check=True, capture_output=True, text=True)
dry_run_output = completed.stdout
assert "api/tests/unit_tests" in dry_run_output
assert "api/providers/vdb/*/tests/unit_tests" in dry_run_output
assert "api/providers/trace/*/tests/unit_tests" in dry_run_output
assert "-p no:benchmark" in dry_run_output
assert "--start-middleware" in dry_run_output
assert "api/tests/integration_tests/workflow" in dry_run_output
assert "api/tests/integration_tests/tools" in dry_run_output
assert "api/tests/test_containers_integration_tests" in dry_run_output
assert "--start-vdb" in dry_run_output
assert "api/providers/vdb/vdb-chroma/tests/integration_tests" in dry_run_output
assert "api/providers/vdb/vdb-pgvector/tests/integration_tests" in dry_run_output
assert "api/providers/vdb/vdb-qdrant/tests/integration_tests" in dry_run_output
assert "api/providers/vdb/vdb-weaviate/tests/integration_tests" in dry_run_output

View File

@ -1,115 +0,0 @@
import os
import subprocess
from pathlib import Path
from tests.pytest_dify import (
DEFAULT_LOG_FORMAT,
DockerComposeStack,
build_middleware_stack,
build_vdb_stack,
ensure_backend_test_environment,
ensure_compose_env_files,
parse_services,
)
def test_ensure_backend_test_environment_uses_example_env_and_stable_logging(
tmp_path: Path,
monkeypatch,
):
repo_root = tmp_path
integration_tests_dir = repo_root / "api" / "tests" / "integration_tests"
integration_tests_dir.mkdir(parents=True)
env_example = integration_tests_dir / ".env.example"
env_example.write_text("LOG_LEVEL=INFO\n")
storage_root = repo_root / "storage"
monkeypatch.setenv("LOG_FORMAT", "json")
monkeypatch.delenv("LOG_OUTPUT_FORMAT", raising=False)
monkeypatch.delenv("DIFY_TEST_ENV_FILE", raising=False)
monkeypatch.delenv("DIFY_VDB_TEST_ENV_FILE", raising=False)
monkeypatch.setenv("OPENDAL_FS_ROOT", str(storage_root))
ensure_backend_test_environment(repo_root)
assert os.environ["DIFY_TEST_ENV_FILE"] == str(env_example)
assert "DIFY_VDB_TEST_ENV_FILE" not in os.environ
assert os.environ["LOG_OUTPUT_FORMAT"] == "text"
assert os.environ["LOG_FORMAT"] == DEFAULT_LOG_FORMAT
assert os.environ["STORAGE_TYPE"] == "opendal"
assert os.environ["OPENDAL_SCHEME"] == "fs"
assert storage_root.is_dir()
def test_ensure_compose_env_files_copies_missing_env_files(tmp_path: Path):
docker_dir = tmp_path / "docker"
envs_dir = docker_dir / "envs"
envs_dir.mkdir(parents=True)
(docker_dir / ".env.example").write_text("APP_WEB_URL=http://localhost\n")
(envs_dir / "middleware.env.example").write_text("DB_PASSWORD=difyai123456\n")
ensure_compose_env_files(tmp_path)
assert (docker_dir / ".env").read_text() == "APP_WEB_URL=http://localhost\n"
assert (docker_dir / "middleware.env").read_text() == "DB_PASSWORD=difyai123456\n"
def test_parse_services_discards_empty_items():
assert parse_services(" db_postgres, redis,, sandbox ") == ["db_postgres", "redis", "sandbox"]
def test_stack_up_uses_waiting_compose_command(monkeypatch, tmp_path: Path):
calls: list[list[str]] = []
def fake_run(args, **kwargs):
calls.append(args)
return subprocess.CompletedProcess(args=args, returncode=0)
monkeypatch.setattr(subprocess, "run", fake_run)
monkeypatch.setattr("time.sleep", lambda _: None)
stack = DockerComposeStack(
name="middleware",
project_name="dify-pytest-middleware",
repo_root=tmp_path,
compose_files=(tmp_path / "docker-compose.yaml",),
env_file=tmp_path / "middleware.env",
services=("db_postgres", "redis"),
)
stack.up()
assert calls == [
[
"docker",
"compose",
"--project-name",
"dify-pytest-middleware",
"--env-file",
str(tmp_path / "middleware.env"),
"-f",
str(tmp_path / "docker-compose.yaml"),
"up",
"-d",
"--wait",
"--wait-timeout",
"180",
"db_postgres",
"redis",
]
]
def test_builders_use_expected_compose_files(tmp_path: Path):
middleware = build_middleware_stack(tmp_path, ["db_postgres"])
vdb = build_vdb_stack(tmp_path, ["weaviate", "qdrant"])
assert middleware.compose_files == (tmp_path / "docker" / "docker-compose.middleware.yaml",)
assert middleware.env_file == tmp_path / "docker" / "middleware.env"
assert middleware.ready_delay_seconds == 5.0
assert vdb.compose_files == (
tmp_path / "docker" / "docker-compose.yaml",
tmp_path / "docker" / "docker-compose.pytest.ports.yaml",
)
assert vdb.env_file == tmp_path / "docker" / ".env"
assert vdb.profiles == ("weaviate", "qdrant")

View File

@ -1,5 +1,6 @@
from pathlib import Path
import yaml # type: ignore
from dotenv import dotenv_values
BASE_API_AND_DOCKER_CONFIG_SET_DIFF: frozenset[str] = frozenset(
@ -90,29 +91,34 @@ BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF: frozenset[str] = frozenset(
)
)
REPO_ROOT = Path(__file__).resolve().parents[4]
API_CONFIG_SET = set(dotenv_values(Path("api") / Path(".env.example")).keys())
DOCKER_CONFIG_SET = set(dotenv_values(Path("docker") / Path(".env.example")).keys())
DOCKER_COMPOSE_CONFIG_SET = set(DOCKER_CONFIG_SET)
# Read environment variables from the split env files used by docker-compose
# Walk through all .env.example files in subdirectories (per-module structure)
envs_dir = Path("docker") / Path("envs")
if envs_dir.exists():
for env_file_path in envs_dir.rglob("*.env.example"):
env_keys = set(dotenv_values(env_file_path).keys())
DOCKER_CONFIG_SET.update(env_keys)
DOCKER_COMPOSE_CONFIG_SET.update(env_keys)
def _api_config_set() -> set[str]:
return set(dotenv_values(REPO_ROOT / "api" / ".env.example").keys())
def test_yaml_config():
# python set == operator is used to compare two sets
DIFF_API_WITH_DOCKER = API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF
if DIFF_API_WITH_DOCKER:
print(f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}")
raise Exception("API and Docker config sets are different")
DIFF_API_WITH_DOCKER_COMPOSE = (
API_CONFIG_SET - DOCKER_COMPOSE_CONFIG_SET - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF
)
if DIFF_API_WITH_DOCKER_COMPOSE:
print(f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}")
raise Exception("API and Docker Compose config sets are different")
print("All tests passed!")
def _docker_config_set() -> set[str]:
docker_config_set = set(dotenv_values(REPO_ROOT / "docker" / ".env.example").keys())
envs_dir = REPO_ROOT / "docker" / "envs"
if envs_dir.exists():
for env_file_path in envs_dir.rglob("*.env.example"):
docker_config_set.update(dotenv_values(env_file_path).keys())
return docker_config_set
def test_api_env_keys_exist_in_docker_env_examples():
diff = _api_config_set() - _docker_config_set() - BASE_API_AND_DOCKER_CONFIG_SET_DIFF
assert not diff, f"API and Docker config sets are different with keys: {sorted(diff)}"
def test_api_env_keys_exist_in_docker_compose_env_examples():
diff = _api_config_set() - _docker_config_set() - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF
assert not diff, f"API and Docker Compose config sets are different with keys: {sorted(diff)}"
if __name__ == "__main__":
test_yaml_config()

58
dev/pytest/pytest_full.sh Executable file
View File

@ -0,0 +1,58 @@
#!/bin/bash
set -euo pipefail
set -ex
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/../.."
PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-180}"
# Ensure OpenDAL local storage works even if .env isn't loaded
export STORAGE_TYPE=${STORAGE_TYPE:-opendal}
export OPENDAL_SCHEME=${OPENDAL_SCHEME:-fs}
export OPENDAL_FS_ROOT=${OPENDAL_FS_ROOT:-/tmp/dify-storage}
mkdir -p "${OPENDAL_FS_ROOT}"
# Prepare env files like CI
cp -n docker/.env.example docker/.env || true
cp -n docker/envs/middleware.env.example docker/middleware.env || true
cp -n api/tests/integration_tests/.env.example api/tests/integration_tests/.env || true
# Expose service ports (same as CI) without leaving the repo dirty
EXPOSE_BACKUPS=()
for f in docker/docker-compose.yaml docker/tidb/docker-compose.yaml; do
if [[ -f "$f" ]]; then
cp "$f" "$f.ci.bak"
EXPOSE_BACKUPS+=("$f")
fi
done
if command -v yq >/dev/null 2>&1; then
sh .github/workflows/expose_service_ports.sh || true
else
echo "skip expose_service_ports (yq not installed)" >&2
fi
# Optionally start middleware stack (db, redis, sandbox, ssrf proxy) to mirror CI
STARTED_MIDDLEWARE=0
if [[ "${SKIP_MIDDLEWARE:-0}" != "1" ]]; then
docker compose -f docker/docker-compose.middleware.yaml --env-file docker/middleware.env up -d db_postgres redis sandbox ssrf_proxy
STARTED_MIDDLEWARE=1
# Give services a moment to come up
sleep 5
fi
cleanup() {
if [[ $STARTED_MIDDLEWARE -eq 1 ]]; then
docker compose -f docker/docker-compose.middleware.yaml --env-file docker/middleware.env down
fi
for f in "${EXPOSE_BACKUPS[@]}"; do
mv "$f.ci.bak" "$f"
done
}
trap cleanup EXIT
pytest --timeout "${PYTEST_TIMEOUT}" \
api/tests/integration_tests/workflow \
api/tests/integration_tests/tools \
api/tests/test_containers_integration_tests \
api/tests/unit_tests

21
dev/pytest/pytest_unit_tests.sh Executable file
View File

@ -0,0 +1,21 @@
#!/bin/bash
set -euxo pipefail
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/../.."
PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-20}"
PYTEST_XDIST_ARGS="${PYTEST_XDIST_ARGS:--n auto}"
# Run most tests in parallel (excluding controllers which have import conflicts with xdist)
# Controller tests have module-level side effects (Flask route registration) that cause
# race conditions when imported concurrently by multiple pytest-xdist workers.
pytest --timeout "${PYTEST_TIMEOUT}" ${PYTEST_XDIST_ARGS} \
api/tests/unit_tests \
api/providers/vdb/*/tests/unit_tests \
api/providers/trace/*/tests/unit_tests \
--ignore=api/tests/unit_tests/controllers
# Run controller tests sequentially to avoid import race conditions
pytest --timeout "${PYTEST_TIMEOUT}" --cov-append api/tests/unit_tests/controllers

12
dev/pytest/pytest_vdb.sh Executable file
View File

@ -0,0 +1,12 @@
#!/bin/bash
set -x
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/../.."
PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-180}"
uv sync --project api --group dev
uv run --project api pytest --timeout "${PYTEST_TIMEOUT}" \
api/providers/vdb/*/tests/integration_tests \

View File

@ -1,30 +0,0 @@
services:
weaviate:
ports:
- "${EXPOSE_WEAVIATE_PORT:-8080}:8080"
- "${EXPOSE_WEAVIATE_GRPC_PORT:-50051}:50051"
qdrant:
ports:
- "${EXPOSE_QDRANT_PORT:-6333}:6333"
pgvector:
ports:
- "${EXPOSE_PGVECTOR_PORT:-5433}:5432"
pgvecto-rs:
ports:
- "${EXPOSE_PGVECTO_RS_PORT:-5431}:5432"
chroma:
ports:
- "${EXPOSE_CHROMA_PORT:-8000}:8000"
couchbase-server:
ports:
- "${EXPOSE_COUCHBASE_WEB_PORT_RANGE:-8091-8096}:8091-8096"
- "${EXPOSE_COUCHBASE_DATA_PORT:-11210}:11210"
opensearch:
ports:
- "${EXPOSE_OPENSEARCH_PORT:-9200}:9200"

View File

@ -1986,6 +1986,9 @@
},
"react-refresh/only-export-components": {
"count": 1
},
"react/use-memo": {
"count": 1
}
},
"web/app/components/datasets/documents/detail/completed/segment-list.tsx": {

View File

@ -13,7 +13,7 @@ Add a script in your frontend project:
```json
{
"scripts": {
"dev:proxy": "dev-proxy --config ./dev-proxy.config.ts --env-file ./.env"
"dev:proxy": "dev-proxy --config ./dev-proxy.config.ts --env-file ./.env.local"
}
}
```
@ -36,10 +36,14 @@ Supported options:
- `--env-file`: load environment variables before evaluating the config file.
- `--host`: override `server.host` from config.
- `--port`: override `server.port` from config.
- `--watch`: reload config and env file changes. Enabled by default.
- `--no-watch`: disable config and env file reloads.
- `--help`, `-h`: print help.
`--target` is not supported. Put targets in the config file so routes and upstreams stay explicit.
The CLI watches the config file and the explicit `--env-file` by default. Route, CORS, target, and cookie rewrite changes are applied in the running process. If the resolved host or port changes, the proxy closes the old server and starts a new one.
## Config Shape
```ts
@ -108,9 +112,11 @@ DEV_PROXY_PORT=5001
Command:
```bash
dev-proxy --config ./dev-proxy.config.ts --env-file ./.env
dev-proxy --config ./dev-proxy.config.ts --env-file ./.env.local
```
Edits to `./.env.local` reload the proxy automatically.
## Scenario 2: Proxy Two Route Groups To Two Local Backends
Use this when one frontend needs to talk to two different local services. For example:

View File

@ -30,6 +30,7 @@
"dependencies": {
"@hono/node-server": "catalog:",
"c12": "catalog:",
"chokidar": "catalog:",
"hono": "catalog:"
},
"devDependencies": {

View File

@ -2,10 +2,12 @@
* @vitest-environment node
*/
import type { ChildProcessByStdio } from 'node:child_process'
import type { Server } from 'node:http'
import type { Readable } from 'node:stream'
import { spawn } from 'node:child_process'
import { once } from 'node:events'
import fs from 'node:fs/promises'
import http from 'node:http'
import net from 'node:net'
import os from 'node:os'
import path from 'node:path'
@ -16,6 +18,7 @@ const tempDirs: string[] = []
type DevProxyCliProcess = ChildProcessByStdio<null, Readable, Readable>
const childProcesses: DevProxyCliProcess[] = []
const httpServers: Server[] = []
const binPath = fileURLToPath(new URL('../bin/dev-proxy.js', import.meta.url))
const createTempDir = async () => {
@ -86,6 +89,23 @@ const waitForOutput = (
onData()
})
const fetchTextWithRetry = async (url: string) => {
let lastError: unknown
for (let attempt = 0; attempt < 10; attempt += 1) {
try {
const response = await fetch(url)
return response.text()
}
catch (error) {
lastError = error
await new Promise(resolve => setTimeout(resolve, 50))
}
}
throw lastError
}
const spawnCli = (args: readonly string[], cwd: string) => {
const child = spawn(process.execPath, [binPath, ...args], {
cwd,
@ -107,9 +127,45 @@ const stopChildProcess = async (child: DevProxyCliProcess) => {
await once(child, 'exit')
}
const closeHttpServer = async (server: Server) => {
if (!server.listening)
return
await new Promise<void>((resolve, reject) => {
server.close((error) => {
if (error)
reject(error)
else
resolve()
})
})
}
const startTextServer = async (body: string) => {
const server = http.createServer((_, response) => {
response.writeHead(200, { 'content-type': 'text/plain' })
response.end(body)
})
await new Promise<void>((resolve, reject) => {
server.once('error', reject)
server.listen(0, '127.0.0.1', resolve)
})
const address = server.address()
if (!address || typeof address === 'string')
throw new Error('Failed to start test server.')
httpServers.push(server)
return {
port: address.port,
}
}
describe('dev proxy CLI', () => {
afterEach(async () => {
await Promise.all(childProcesses.splice(0).map(stopChildProcess))
await Promise.all(httpServers.splice(0).map(closeHttpServer))
await Promise.all(tempDirs.splice(0).map(tempDir => fs.rm(tempDir, {
force: true,
recursive: true,
@ -155,4 +211,49 @@ describe('dev proxy CLI', () => {
expect(child.signalCode).toBeNull()
expect(response.status).toBe(404)
})
// Scenario: editing the configured env file should reload route targets without restarting the CLI process.
it('should reload proxy config when the env file changes', async () => {
// Arrange
const tempDir = await createTempDir()
const port = await getFreePort()
const firstTarget = await startTextServer('first target')
const secondTarget = await startTextServer('second target')
await fs.writeFile(path.join(tempDir, '.env.proxy'), `DEV_PROXY_TEST_TARGET=http://127.0.0.1:${firstTarget.port}\n`)
await fs.writeFile(path.join(tempDir, 'dev-proxy.config.ts'), `
export default {
routes: [{ paths: '/api', target: process.env.DEV_PROXY_TEST_TARGET }],
}
`)
let output = ''
const child = spawnCli([
'--config',
'./dev-proxy.config.ts',
'--env-file',
'./.env.proxy',
'--host',
'127.0.0.1',
'--port',
String(port),
], tempDir)
child.stdout.on('data', chunk => output += chunk.toString())
child.stderr.on('data', chunk => output += chunk.toString())
const proxyUrl = `http://127.0.0.1:${port}/api/ping`
// Act
await waitForOutput(child, () => output, `[dev-proxy] listening on http://127.0.0.1:${port}`)
const firstResponse = await fetchTextWithRetry(proxyUrl)
await fs.writeFile(path.join(tempDir, '.env.proxy'), `DEV_PROXY_TEST_TARGET=http://127.0.0.1:${secondTarget.port}\n`)
await waitForOutput(child, () => output, '[dev-proxy] reloaded env file changes')
const secondResponse = await fetchTextWithRetry(proxyUrl)
// Assert
expect(firstResponse).toBe('first target')
expect(secondResponse).toBe('second target')
expect(child.exitCode).toBeNull()
expect(child.signalCode).toBeNull()
})
})

View File

@ -1,6 +1,9 @@
import type { ServerType } from '@hono/node-server'
import type { DevProxyCliOptions, DevProxyConfig } from './types'
import process from 'node:process'
import { serve } from '@hono/node-server'
import { loadDevProxyConfig, parseDevProxyCliArgs, resolveDevProxyServerOptions } from './config'
import { watch } from 'chokidar'
import { assertDevProxyConfig, loadDevProxyConfig, parseDevProxyCliArgs, resolveDevProxyServerOptions, watchDevProxyConfig } from './config'
import { createDevProxyApp } from './server'
function printUsage() {
@ -12,6 +15,8 @@ Options:
--env-file <path> Load environment variables before evaluating the config file.
--host <host> Override the configured host.
--port <port> Override the configured port.
--watch Reload config and env file changes. Enabled by default.
--no-watch Disable config and env file reloads.
--help, -h Show this help message.`)
}
@ -22,6 +27,78 @@ async function flushStandardStreams() {
])
}
const closeServer = (server: ServerType) => new Promise<void>((resolve, reject) => {
server.close((error) => {
if (error)
reject(error)
else
resolve()
})
})
const startDevProxyServer = (config: DevProxyConfig, cliOptions: DevProxyCliOptions) => {
let app = createDevProxyApp(config)
const { host, port } = resolveDevProxyServerOptions(config.server, cliOptions)
const server = serve({
fetch: (request, env) => app.fetch(request, env),
hostname: host,
port,
})
return {
host,
port,
server,
updateConfig(nextConfig: DevProxyConfig) {
app = createDevProxyApp(nextConfig)
},
}
}
const createDevProxyRuntime = (initialConfig: DevProxyConfig, cliOptions: DevProxyCliOptions) => {
let runtime = startDevProxyServer(initialConfig, cliOptions)
let reloadTask = Promise.resolve()
console.log(`[dev-proxy] listening on http://${runtime.host}:${runtime.port}`)
const reload = async (nextConfig: unknown, reason: string) => {
assertDevProxyConfig(nextConfig)
const nextServerOptions = resolveDevProxyServerOptions(nextConfig.server, cliOptions)
if (runtime.host === nextServerOptions.host && runtime.port === nextServerOptions.port) {
runtime.updateConfig(nextConfig)
console.log(`[dev-proxy] reloaded ${reason}`)
return
}
await closeServer(runtime.server)
runtime = startDevProxyServer(nextConfig, cliOptions)
console.log(`[dev-proxy] restarted on http://${runtime.host}:${runtime.port} after ${reason}`)
}
const enqueueReload = (loadConfig: () => Promise<unknown> | unknown, reason: string) => {
reloadTask = reloadTask.then(async () => {
try {
await reload(await loadConfig(), reason)
}
catch (error) {
console.error(`[dev-proxy] failed to reload ${reason}`)
console.error(error instanceof Error ? error.message : error)
}
})
return reloadTask
}
return {
enqueueReload,
close: async () => {
await reloadTask
await closeServer(runtime.server)
},
}
}
async function main() {
const cliOptions = parseDevProxyCliArgs(process.argv.slice(2))
@ -33,16 +110,44 @@ async function main() {
const config = await loadDevProxyConfig(cliOptions.config, process.cwd(), {
envFile: cliOptions.envFile,
})
const { host, port } = resolveDevProxyServerOptions(config.server, cliOptions)
const app = createDevProxyApp(config)
const runtime = createDevProxyRuntime(config, cliOptions)
serve({
fetch: app.fetch,
hostname: host,
port,
if (cliOptions.watch === false)
return
const configWatcher = await watchDevProxyConfig(cliOptions.config, process.cwd(), {
envFile: cliOptions.envFile,
onUpdate: ({ newConfig }) => runtime.enqueueReload(() => newConfig.config, 'config changes'),
})
console.log(`[dev-proxy] listening on http://${host}:${port}`)
const envWatcher = cliOptions.envFile
? watch(cliOptions.envFile, {
cwd: process.cwd(),
ignoreInitial: true,
})
: undefined
envWatcher?.on('all', () => {
void runtime.enqueueReload(
() => loadDevProxyConfig(cliOptions.config, process.cwd(), {
envFile: cliOptions.envFile,
}),
'env file changes',
)
})
const cleanup = async () => {
await envWatcher?.close()
await configWatcher.unwatch()
await runtime.close()
}
process.once('SIGINT', () => {
void cleanup().finally(() => process.exit(0))
})
process.once('SIGTERM', () => {
void cleanup().finally(() => process.exit(0))
})
}
try {

View File

@ -37,6 +37,7 @@ describe('dev proxy config', () => {
'0.0.0.0',
'--port',
'8083',
'--no-watch',
])
// Assert
@ -45,6 +46,7 @@ describe('dev proxy config', () => {
envFile: './.env.proxy',
host: '0.0.0.0',
port: '8083',
watch: false,
})
})

View File

@ -1,7 +1,7 @@
import type { DotenvOptions } from 'c12'
import type { DotenvOptions, LoadConfigOptions, WatchConfigOptions } from 'c12'
import type { DevProxyCliOptions, DevProxyConfig, DevProxyConfigLoadOptions, DevProxyServerConfig, ResolvedDevProxyServerOptions } from './types'
import path from 'node:path'
import { loadConfig } from 'c12'
import { loadConfig, watchConfig } from 'c12'
const DEFAULT_CONFIG_FILE = 'dev-proxy.config.ts'
const DEFAULT_PROXY_HOST = '127.0.0.1'
@ -40,6 +40,16 @@ export const parseDevProxyCliArgs = (argv: readonly string[]): DevProxyCliOption
continue
}
if (arg === '--watch') {
options.watch = true
continue
}
if (arg === '--no-watch') {
options.watch = false
continue
}
const [rawName, inlineValue] = arg.split('=', 2)
const name = rawName ?? ''
@ -105,14 +115,15 @@ const resolveDotenvOptions = (
}
}
export const loadDevProxyConfig = async (
const createC12ConfigOptions = (
configPath = DEFAULT_CONFIG_FILE,
cwd = process.cwd(),
options: DevProxyConfigLoadOptions = {},
): Promise<DevProxyConfig> => {
): LoadConfigOptions<DevProxyConfig> => {
const resolvedConfigPath = path.resolve(cwd, configPath)
const parsedPath = path.parse(resolvedConfigPath)
const { config: loadedConfig } = await loadConfig({
return {
configFile: parsedPath.name,
cwd: parsedPath.dir,
dotenv: resolveDotenvOptions(options.envFile, cwd),
@ -120,10 +131,34 @@ export const loadDevProxyConfig = async (
globalRc: false,
packageJson: false,
rcFile: false,
}
}
export const loadDevProxyConfig = async (
configPath = DEFAULT_CONFIG_FILE,
cwd = process.cwd(),
options: DevProxyConfigLoadOptions = {},
): Promise<DevProxyConfig> => {
const { config: loadedConfig } = await loadConfig({
...createC12ConfigOptions(configPath, cwd, options),
})
assertDevProxyConfig(loadedConfig)
return loadedConfig
}
export const watchDevProxyConfig = async (
configPath = DEFAULT_CONFIG_FILE,
cwd = process.cwd(),
options: DevProxyConfigLoadOptions & Pick<WatchConfigOptions<DevProxyConfig>, 'onUpdate'> = {},
) => {
const watcher = await watchConfig<DevProxyConfig>({
...createC12ConfigOptions(configPath, cwd, options),
onUpdate: options.onUpdate,
})
assertDevProxyConfig(watcher.config)
return watcher
}
export const defineDevProxyConfig = (config: DevProxyConfig) => config

View File

@ -39,6 +39,7 @@ export type DevProxyCliOptions = {
envFile?: string
host?: string
port?: string
watch?: boolean
help?: boolean
}

7
pnpm-lock.yaml generated
View File

@ -255,6 +255,9 @@ catalogs:
c12:
specifier: 4.0.0-beta.5
version: 4.0.0-beta.5
chokidar:
specifier: 5.0.0
version: 5.0.0
class-variance-authority:
specifier: 0.7.1
version: 0.7.1
@ -699,6 +702,9 @@ importers:
c12:
specifier: 'catalog:'
version: 4.0.0-beta.5(chokidar@5.0.0)(dotenv@17.4.2)(giget@3.2.0)(jiti@2.7.0)(magicast@0.5.2)
chokidar:
specifier: 'catalog:'
version: 5.0.0
hono:
specifier: 'catalog:'
version: 4.12.18
@ -16265,6 +16271,7 @@ time:
agentation@3.0.2: '2026-03-25T16:24:19.682Z'
ahooks@3.9.7: '2026-03-23T15:49:13.605Z'
c12@4.0.0-beta.5: '2026-05-06T17:28:34.367Z'
chokidar@5.0.0: '2025-11-25T23:28:06.854Z'
class-variance-authority@0.7.1: '2024-11-26T08:20:34.604Z'
client-only@0.0.1: '2022-09-03T01:07:11.981Z'
clsx@2.1.1: '2024-04-23T05:26:04.645Z'

View File

@ -142,6 +142,7 @@ catalog:
agentation: 3.0.2
ahooks: 3.9.7
c12: 4.0.0-beta.5
chokidar: 5.0.0
class-variance-authority: 0.7.1
client-only: 0.0.1
clsx: 2.1.1

View File

@ -111,27 +111,77 @@ describe('Segment CRUD Flow', () => {
})
describe('Segment Selection → Batch Operations', () => {
const segments = [
createSegment('seg-1'),
createSegment('seg-2'),
createSegment('seg-3'),
]
it('should manage individual segment selection', () => {
const { result } = renderHook(() => useSegmentSelection())
const { result } = renderHook(() => useSegmentSelection(segments))
act(() => {
result.current.onSelectedSegmentIdsChange(['seg-1'])
result.current.onSelected('seg-1')
})
expect(result.current.selectedSegmentIds).toContain('seg-1')
act(() => {
result.current.onSelectedSegmentIdsChange(['seg-1', 'seg-2'])
result.current.onSelected('seg-2')
})
expect(result.current.selectedSegmentIds).toContain('seg-1')
expect(result.current.selectedSegmentIds).toContain('seg-2')
expect(result.current.selectedSegmentIds).toHaveLength(2)
})
it('should clear selection via onCancelBatchOperation', () => {
const { result } = renderHook(() => useSegmentSelection())
it('should toggle selection on repeated click', () => {
const { result } = renderHook(() => useSegmentSelection(segments))
act(() => {
result.current.onSelectedSegmentIdsChange(['seg-1', 'seg-2'])
result.current.onSelected('seg-1')
})
expect(result.current.selectedSegmentIds).toContain('seg-1')
act(() => {
result.current.onSelected('seg-1')
})
expect(result.current.selectedSegmentIds).not.toContain('seg-1')
})
it('should support select all toggle', () => {
const { result } = renderHook(() => useSegmentSelection(segments))
act(() => {
result.current.onSelectedAll()
})
expect(result.current.selectedSegmentIds).toHaveLength(3)
expect(result.current.isAllSelected).toBe(true)
act(() => {
result.current.onSelectedAll()
})
expect(result.current.selectedSegmentIds).toHaveLength(0)
expect(result.current.isAllSelected).toBe(false)
})
it('should detect partial selection via isSomeSelected', () => {
const { result } = renderHook(() => useSegmentSelection(segments))
act(() => {
result.current.onSelected('seg-1')
})
// After selecting one of three, isSomeSelected should be true
expect(result.current.selectedSegmentIds).toEqual(['seg-1'])
expect(result.current.isSomeSelected).toBe(true)
expect(result.current.isAllSelected).toBe(false)
})
it('should clear selection via onCancelBatchOperation', () => {
const { result } = renderHook(() => useSegmentSelection(segments))
act(() => {
result.current.onSelected('seg-1')
result.current.onSelected('seg-2')
})
expect(result.current.selectedSegmentIds).toHaveLength(2)
@ -221,7 +271,7 @@ describe('Segment CRUD Flow', () => {
useSearchFilter({ onPageChange: vi.fn() }),
)
const { result: selectionResult } = renderHook(() =>
useSegmentSelection(),
useSegmentSelection(segments),
)
const { result: modalResult } = renderHook(() =>
useModalState({ onNewSegmentModalChange: vi.fn() }),
@ -234,7 +284,7 @@ describe('Segment CRUD Flow', () => {
// Select a segment
act(() => {
selectionResult.current.onSelectedSegmentIdsChange(['seg-1'])
selectionResult.current.onSelected('seg-1')
})
// Open detail modal

View File

@ -137,40 +137,36 @@ vi.mock('../hooks/use-child-segment-data', () => ({
},
}))
vi.mock('../components/menu-bar', async () => {
const { Checkbox } = await import('@langgenius/dify-ui/checkbox')
return {
default: ({ hasSelectableSegments, totalText, onInputChange, inputValue, isLoading, onChangeStatus }: {
hasSelectableSegments: boolean
totalText: string
onInputChange: (value: string) => void
inputValue: string
isLoading: boolean
onChangeStatus?: (item: { value: string | number, name: string }) => void
}) => (
<div data-testid="menu-bar">
<span data-testid="total-text">{totalText}</span>
<input
data-testid="search-input"
value={inputValue}
onChange={e => onInputChange(e.target.value)}
disabled={isLoading}
/>
{hasSelectableSegments
? <Checkbox parent data-testid="select-all-button" aria-label="Select All" disabled={isLoading} />
: <span data-testid="select-all-spacer" aria-hidden />}
{onChangeStatus && (
<>
<button data-testid="status-enabled" onClick={() => onChangeStatus({ value: 1, name: 'Enabled' })}>Enabled</button>
<button data-testid="status-disabled" onClick={() => onChangeStatus({ value: 0, name: 'Disabled' })}>Disabled</button>
<button data-testid="status-all" onClick={() => onChangeStatus({ value: 'all', name: 'All' })}>All</button>
</>
)}
</div>
),
}
})
vi.mock('../components/menu-bar', () => ({
default: ({ totalText, onInputChange, inputValue, isLoading, onSelectedAll, onChangeStatus }: {
totalText: string
onInputChange: (value: string) => void
inputValue: string
isLoading: boolean
onSelectedAll?: () => void
onChangeStatus?: (item: { value: string | number, name: string }) => void
}) => (
<div data-testid="menu-bar">
<span data-testid="total-text">{totalText}</span>
<input
data-testid="search-input"
value={inputValue}
onChange={e => onInputChange(e.target.value)}
disabled={isLoading}
/>
{onSelectedAll && (
<button data-testid="select-all-button" onClick={onSelectedAll}>Select All</button>
)}
{onChangeStatus && (
<>
<button data-testid="status-enabled" onClick={() => onChangeStatus({ value: 1, name: 'Enabled' })}>Enabled</button>
<button data-testid="status-disabled" onClick={() => onChangeStatus({ value: 0, name: 'Disabled' })}>Disabled</button>
<button data-testid="status-all" onClick={() => onChangeStatus({ value: 'all', name: 'All' })}>All</button>
</>
)}
</div>
),
}))
vi.mock('../components/drawer-group', () => ({
DrawerGroup: () => <div data-testid="drawer-group" />,
@ -755,17 +751,6 @@ describe('Batch Action Callbacks', () => {
})
})
it('should not render select all when there are no current page segments', () => {
mockSegmentListData.data = []
mockSegmentListData.total = 0
render(<Completed {...defaultProps} />, { wrapper: createWrapper() })
expect(screen.queryByTestId('select-all-button')).not.toBeInTheDocument()
expect(screen.getByTestId('select-all-spacer')).toBeInTheDocument()
expect(screen.queryByTestId('batch-action')).not.toBeInTheDocument()
})
it('should call onChangeSwitch with true when batch enable is clicked', async () => {
render(<Completed {...defaultProps} />, { wrapper: createWrapper() })

View File

@ -123,6 +123,8 @@ describe('SegmentList', () => {
ref: null,
isLoading: false,
items: [createMockSegment('seg-1', 'Segment 1 content')],
selectedSegmentIds: [],
onSelected: vi.fn(),
onClick: vi.fn(),
onChangeSwitch: vi.fn(),
onDelete: vi.fn(),
@ -287,10 +289,18 @@ describe('SegmentList', () => {
expect(screen.getAllByRole('checkbox')).toHaveLength(defaultProps.items.length)
})
it('should label each segment checkbox', () => {
render(<SegmentList {...defaultProps} />)
it('should pass selectedSegmentIds to check state', () => {
const { container } = render(<SegmentList {...defaultProps} selectedSegmentIds={['seg-1']} />)
expect(screen.getByRole('checkbox', { name: 'datasetDocuments.segment.chunk 1' })).toBeInTheDocument()
// Assert - component should render with selected state
expect(container.firstChild).toBeInTheDocument()
})
it('should handle empty selectedSegmentIds', () => {
const { container } = render(<SegmentList {...defaultProps} selectedSegmentIds={[]} />)
// Assert - component should render
expect(container.firstChild).toBeInTheDocument()
})
})

View File

@ -1,4 +1,3 @@
import { CheckboxGroup } from '@langgenius/dify-ui/checkbox-group'
import { fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import MenuBar from '../menu-bar'
@ -17,7 +16,9 @@ vi.mock('../../status-item', () => ({
describe('MenuBar', () => {
const defaultProps = {
hasSelectableSegments: true,
isAllSelected: false,
isSomeSelected: false,
onSelectedAll: vi.fn(),
isLoading: false,
totalText: '10 Chunks',
statusList: [
@ -37,63 +38,49 @@ describe('MenuBar', () => {
vi.clearAllMocks()
})
const renderMenuBar = (props: Partial<typeof defaultProps> = {}) => {
return render(
<CheckboxGroup value={[]} onValueChange={vi.fn()} allValues={['seg-1']}>
<MenuBar {...defaultProps} {...props} />
</CheckboxGroup>,
)
}
it('should render total text', () => {
renderMenuBar()
render(<MenuBar {...defaultProps} />)
expect(screen.getByText('10 Chunks')).toBeInTheDocument()
})
it('should render checkbox', () => {
renderMenuBar()
render(<MenuBar {...defaultProps} />)
expect(screen.getByRole('checkbox', { name: 'common.operation.selectAll' })).toBeInTheDocument()
})
it('should not render select all checkbox when there are no selectable segments', () => {
renderMenuBar({ hasSelectableSegments: false })
expect(screen.queryByRole('checkbox', { name: 'common.operation.selectAll' })).not.toBeInTheDocument()
})
it('should call onInputChange when input changes', () => {
renderMenuBar()
render(<MenuBar {...defaultProps} />)
const input = screen.getByRole('textbox')
fireEvent.change(input, { target: { value: 'test search' } })
expect(defaultProps.onInputChange).toHaveBeenCalledWith('test search')
})
it('should render display toggle', () => {
renderMenuBar()
render(<MenuBar {...defaultProps} />)
expect(screen.getByTestId('display-toggle')).toBeInTheDocument()
})
it('should call toggleCollapsed when display toggle clicked', () => {
renderMenuBar()
render(<MenuBar {...defaultProps} />)
fireEvent.click(screen.getByTestId('display-toggle'))
expect(defaultProps.toggleCollapsed).toHaveBeenCalled()
})
it('should call onInputChange with empty string when input is cleared', () => {
renderMenuBar({ inputValue: 'some text' })
render(<MenuBar {...defaultProps} inputValue="some text" />)
const clearButton = screen.getByRole('button', { name: 'common.operation.clear' })
fireEvent.click(clearButton)
expect(defaultProps.onInputChange).toHaveBeenCalledWith('')
})
it('should render select with status items via renderOption', () => {
renderMenuBar()
render(<MenuBar {...defaultProps} />)
expect(screen.getByText('All')).toBeInTheDocument()
})
it('should call renderOption for each item when dropdown is opened', async () => {
renderMenuBar()
render(<MenuBar {...defaultProps} />)
const selectButton = screen.getByRole('combobox')
fireEvent.click(selectButton)

View File

@ -84,6 +84,8 @@ describe('GeneralModeContent', () => {
embeddingAvailable: true,
isLoadingSegmentList: false,
segments: [{ id: 'seg-1' }, { id: 'seg-2' }] as SegmentDetailModel[],
selectedSegmentIds: [],
onSelected: vi.fn(),
onChangeSwitch: vi.fn(),
onDelete: vi.fn(),
onClickCard: vi.fn(),

View File

@ -1,4 +1,5 @@
'use client'
import type { FC } from 'react'
import { Checkbox } from '@langgenius/dify-ui/checkbox'
import { cn } from '@langgenius/dify-ui/cn'
import { Select, SelectContent, SelectItem, SelectItemIndicator, SelectItemText, SelectTrigger } from '@langgenius/dify-ui/select'
@ -15,7 +16,9 @@ type Item = {
} & Record<string, unknown>
type MenuBarProps = {
hasSelectableSegments: boolean
isAllSelected: boolean
isSomeSelected: boolean
onSelectedAll: () => void
isLoading: boolean
totalText: string
statusList: Item[]
@ -27,8 +30,10 @@ type MenuBarProps = {
toggleCollapsed: () => void
}
function MenuBar({
hasSelectableSegments,
const MenuBar: FC<MenuBarProps> = ({
isAllSelected,
isSomeSelected,
onSelectedAll,
isLoading,
totalText,
statusList,
@ -38,24 +43,20 @@ function MenuBar({
onInputChange,
isCollapsed,
toggleCollapsed,
}: MenuBarProps) {
}) => {
const { t } = useTranslation()
const selectedStatus = statusList.find(item => item.value === selectDefaultValue) ?? null
return (
<div className={s.docSearchWrapper}>
{hasSelectableSegments
? (
<Checkbox
className="shrink-0"
parent
aria-label={t('operation.selectAll', { ns: 'common' })}
disabled={isLoading}
/>
)
: (
<span className="size-4 shrink-0" aria-hidden />
)}
<Checkbox
className="shrink-0"
checked={isAllSelected}
indeterminate={!isAllSelected && isSomeSelected}
aria-label={t('operation.selectAll', { ns: 'common' })}
onCheckedChange={() => onSelectedAll()}
disabled={isLoading}
/>
<div className="flex-1 pl-5 system-sm-semibold-uppercase text-text-secondary">{totalText}</div>
<Select
value={selectedStatus ? String(selectedStatus.value) : null}

View File

@ -78,6 +78,8 @@ type GeneralModeContentProps = {
embeddingAvailable: boolean
isLoadingSegmentList: boolean
segments: SegmentDetailModel[]
selectedSegmentIds: string[]
onSelected: (segId: string) => void
onChangeSwitch: (enable: boolean, segId?: string) => Promise<void>
onDelete: (segId?: string) => Promise<void>
onClickCard: (detail: SegmentDetailModel, isEditMode?: boolean) => void
@ -93,6 +95,8 @@ export const GeneralModeContent: FC<GeneralModeContentProps> = ({
embeddingAvailable,
isLoadingSegmentList,
segments,
selectedSegmentIds,
onSelected,
onChangeSwitch,
onDelete,
onClickCard,
@ -108,6 +112,8 @@ export const GeneralModeContent: FC<GeneralModeContentProps> = ({
embeddingAvailable={embeddingAvailable}
isLoading={isLoadingSegmentList}
items={segments}
selectedSegmentIds={selectedSegmentIds}
onSelected={onSelected}
onChangeSwitch={onChangeSwitch}
onDelete={onDelete}
onClick={onClickCard}

View File

@ -1,33 +1,83 @@
import type { SegmentDetailModel } from '@/models/datasets'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { mergeCurrentPageSelectedSegmentIds, useSegmentSelection } from '../use-segment-selection'
import { useSegmentSelection } from '../use-segment-selection'
describe('useSegmentSelection', () => {
const segments = [
{ id: 'seg-1', content: 'A' },
{ id: 'seg-2', content: 'B' },
{ id: 'seg-3', content: 'C' },
] as unknown as SegmentDetailModel[]
beforeEach(() => {
vi.clearAllMocks()
})
it('should initialize with empty selection', () => {
const { result } = renderHook(() => useSegmentSelection())
const { result } = renderHook(() => useSegmentSelection(segments))
expect(result.current.selectedSegmentIds).toEqual([])
expect(result.current.isAllSelected).toBe(false)
expect(result.current.isSomeSelected).toBe(false)
})
it('should select a segment', () => {
const { result } = renderHook(() => useSegmentSelection(segments))
act(() => {
result.current.onSelected('seg-1')
})
expect(result.current.selectedSegmentIds).toEqual(['seg-1'])
expect(result.current.isSomeSelected).toBe(true)
expect(result.current.isAllSelected).toBe(false)
})
it('should deselect a selected segment', () => {
const { result } = renderHook(() => useSegmentSelection(segments))
act(() => {
result.current.onSelected('seg-1')
})
act(() => {
result.current.onSelected('seg-1')
})
expect(result.current.selectedSegmentIds).toEqual([])
})
it('should update selected segment ids', () => {
const { result } = renderHook(() => useSegmentSelection())
it('should select all segments', () => {
const { result } = renderHook(() => useSegmentSelection(segments))
act(() => {
result.current.onSelectedSegmentIdsChange(['seg-1', 'seg-2'])
result.current.onSelectedAll()
})
expect(result.current.selectedSegmentIds).toEqual(['seg-1', 'seg-2'])
expect(result.current.selectedSegmentIds).toEqual(['seg-1', 'seg-2', 'seg-3'])
expect(result.current.isAllSelected).toBe(true)
})
it('should deselect all when all are selected', () => {
const { result } = renderHook(() => useSegmentSelection(segments))
act(() => {
result.current.onSelectedAll()
})
act(() => {
result.current.onSelectedAll()
})
expect(result.current.selectedSegmentIds).toEqual([])
expect(result.current.isAllSelected).toBe(false)
})
it('should cancel batch operation', () => {
const { result } = renderHook(() => useSegmentSelection())
const { result } = renderHook(() => useSegmentSelection(segments))
act(() => {
result.current.onSelectedSegmentIdsChange(['seg-1'])
result.current.onSelected('seg-1')
result.current.onSelected('seg-2')
})
act(() => {
result.current.onCancelBatchOperation()
@ -37,10 +87,10 @@ describe('useSegmentSelection', () => {
})
it('should clear selection', () => {
const { result } = renderHook(() => useSegmentSelection())
const { result } = renderHook(() => useSegmentSelection(segments))
act(() => {
result.current.onSelectedSegmentIdsChange(['seg-1'])
result.current.onSelected('seg-1')
})
act(() => {
result.current.clearSelection()
@ -49,19 +99,61 @@ describe('useSegmentSelection', () => {
expect(result.current.selectedSegmentIds).toEqual([])
})
it('should merge current page selection without dropping selected ids from other pages', () => {
expect(mergeCurrentPageSelectedSegmentIds({
selectedSegmentIds: ['page-1-a', 'page-1-b'],
currentPageSegmentIds: ['page-2-a', 'page-2-b'],
nextCurrentPageSelectedSegmentIds: ['page-2-a'],
})).toEqual(['page-1-a', 'page-1-b', 'page-2-a'])
it('should handle empty segments array', () => {
const { result } = renderHook(() => useSegmentSelection([]))
expect(result.current.isAllSelected).toBe(false)
expect(result.current.isSomeSelected).toBe(false)
})
it('should replace only current page selected ids when current page selection changes', () => {
expect(mergeCurrentPageSelectedSegmentIds({
selectedSegmentIds: ['page-1-a', 'page-2-a', 'page-2-b'],
currentPageSegmentIds: ['page-2-a', 'page-2-b'],
nextCurrentPageSelectedSegmentIds: ['page-2-b'],
})).toEqual(['page-1-a', 'page-2-b'])
it('should allow multiple selections', () => {
const { result } = renderHook(() => useSegmentSelection(segments))
act(() => {
result.current.onSelected('seg-1')
})
act(() => {
result.current.onSelected('seg-2')
})
expect(result.current.selectedSegmentIds).toEqual(['seg-1', 'seg-2'])
expect(result.current.isSomeSelected).toBe(true)
expect(result.current.isAllSelected).toBe(false)
})
it('should preserve selection of segments not in current list', () => {
const { result, rerender } = renderHook(
({ segs }) => useSegmentSelection(segs),
{ initialProps: { segs: segments } },
)
act(() => {
result.current.onSelected('seg-1')
})
// Rerender with different segment list (simulating page change)
const newSegments = [
{ id: 'seg-4', content: 'D' },
{ id: 'seg-5', content: 'E' },
] as unknown as SegmentDetailModel[]
rerender({ segs: newSegments })
// Previously selected segment should still be in selectedSegmentIds
expect(result.current.selectedSegmentIds).toContain('seg-1')
})
it('should select remaining unselected segments when onSelectedAll is called with partial selection', () => {
const { result } = renderHook(() => useSegmentSelection(segments))
act(() => {
result.current.onSelected('seg-1')
})
act(() => {
result.current.onSelectedAll()
})
expect(result.current.selectedSegmentIds).toEqual(expect.arrayContaining(['seg-1', 'seg-2', 'seg-3']))
expect(result.current.isAllSelected).toBe(true)
})
})

View File

@ -6,4 +6,4 @@ export { useSearchFilter } from './use-search-filter'
export { useSegmentListData } from './use-segment-list-data'
export { mergeCurrentPageSelectedSegmentIds, useSegmentSelection } from './use-segment-selection'
export { useSegmentSelection } from './use-segment-selection'

View File

@ -1,39 +1,43 @@
import { useCallback, useState } from 'react'
import type { SegmentDetailModel } from '@/models/datasets'
import { useCallback, useMemo, useState } from 'react'
type UseSegmentSelectionReturn = {
selectedSegmentIds: string[]
onSelectedSegmentIdsChange: (segmentIds: string[]) => void
isAllSelected: boolean
isSomeSelected: boolean
onSelected: (segId: string) => void
onSelectedAll: () => void
onCancelBatchOperation: () => void
clearSelection: () => void
}
type MergeCurrentPageSelectedSegmentIdsOptions = {
selectedSegmentIds: string[]
currentPageSegmentIds: string[]
nextCurrentPageSelectedSegmentIds: string[]
}
export const mergeCurrentPageSelectedSegmentIds = ({
selectedSegmentIds,
currentPageSegmentIds,
nextCurrentPageSelectedSegmentIds,
}: MergeCurrentPageSelectedSegmentIdsOptions) => {
const currentPageSegmentIdSet = new Set(currentPageSegmentIds)
const selectedSegmentIdsOutsideCurrentPage = selectedSegmentIds.filter(segmentId => !currentPageSegmentIdSet.has(segmentId))
return [
...selectedSegmentIdsOutsideCurrentPage,
...nextCurrentPageSelectedSegmentIds,
]
}
export const useSegmentSelection = (): UseSegmentSelectionReturn => {
export const useSegmentSelection = (segments: SegmentDetailModel[]): UseSegmentSelectionReturn => {
const [selectedSegmentIds, setSelectedSegmentIds] = useState<string[]>([])
const onSelectedSegmentIdsChange = useCallback((segmentIds: string[]) => {
setSelectedSegmentIds(segmentIds)
const onSelected = useCallback((segId: string) => {
setSelectedSegmentIds(prev =>
prev.includes(segId)
? prev.filter(id => id !== segId)
: [...prev, segId],
)
}, [])
const isAllSelected = useMemo(() => {
return segments.length > 0 && segments.every(seg => selectedSegmentIds.includes(seg.id))
}, [segments, selectedSegmentIds])
const isSomeSelected = useMemo(() => {
return segments.some(seg => selectedSegmentIds.includes(seg.id))
}, [segments, selectedSegmentIds])
const onSelectedAll = useCallback(() => {
setSelectedSegmentIds((prev) => {
const currentAllSegIds = segments.map(seg => seg.id)
const prevSelectedIds = prev.filter(item => !currentAllSegIds.includes(item))
return [...prevSelectedIds, ...(isAllSelected ? [] : currentAllSegIds)]
})
}, [segments, isAllSelected])
const onCancelBatchOperation = useCallback(() => {
setSelectedSegmentIds([])
}, [])
@ -44,7 +48,10 @@ export const useSegmentSelection = (): UseSegmentSelectionReturn => {
return {
selectedSegmentIds,
onSelectedSegmentIdsChange,
isAllSelected,
isSomeSelected,
onSelected,
onSelectedAll,
onCancelBatchOperation,
clearSelection,
}

View File

@ -2,9 +2,7 @@
import type { FC } from 'react'
import type { SegmentListContextValue } from './segment-list-context'
import type { SegmentImportStatus } from '@/types/dataset'
import { CheckboxGroup } from '@langgenius/dify-ui/checkbox-group'
import { useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Divider from '@/app/components/base/divider'
import Pagination from '@/app/components/base/pagination'
import {
@ -19,7 +17,6 @@ import { DrawerGroup } from './components/drawer-group'
import MenuBar from './components/menu-bar'
import { FullDocModeContent, GeneralModeContent } from './components/segment-list-content'
import {
mergeCurrentPageSelectedSegmentIds,
useChildSegmentData,
useModalState,
useSearchFilter,
@ -52,7 +49,6 @@ const Completed: FC<ICompletedProps> = ({
importStatus,
archived,
}) => {
const { t } = useTranslation()
const docForm = useDocumentContext(s => s.docForm)
// Pagination state
@ -69,8 +65,8 @@ const Completed: FC<ICompletedProps> = ({
onNewSegmentModalChange,
})
// Selection state
const selectionState = useSegmentSelection()
// Selection state (need segments first, so we use a placeholder initially)
const [segmentsForSelection, setSegmentsForSelection] = useState<string[]>([])
// Invalidation hooks for child segment data
const invalidChunkListAll = useInvalid(useChunkListAllKey)
@ -99,29 +95,21 @@ const Completed: FC<ICompletedProps> = ({
const segmentListDataHook = useSegmentListData({
searchValue: searchFilter.searchValue,
selectedStatus: searchFilter.selectedStatus,
selectedSegmentIds: selectionState.selectedSegmentIds,
selectedSegmentIds: segmentsForSelection,
importStatus,
currentPage,
limit,
onCloseSegmentDetail: modalState.onCloseSegmentDetail,
clearSelection: selectionState.clearSelection,
clearSelection: () => setSegmentsForSelection([]),
})
const segmentIds = useMemo(
() => segmentListDataHook.segments.map(segment => segment.id),
[segmentListDataHook.segments],
)
const currentPageSegmentIdSet = useMemo(() => new Set(segmentIds), [segmentIds])
const currentPageSelectedSegmentIds = useMemo(() => {
return selectionState.selectedSegmentIds.filter(segmentId => currentPageSegmentIdSet.has(segmentId))
}, [currentPageSegmentIdSet, selectionState.selectedSegmentIds])
const handleCurrentPageSelectedSegmentIdsChange = useCallback((nextCurrentPageSelectedSegmentIds: string[]) => {
selectionState.onSelectedSegmentIdsChange(mergeCurrentPageSelectedSegmentIds({
selectedSegmentIds: selectionState.selectedSegmentIds,
currentPageSegmentIds: segmentIds,
nextCurrentPageSelectedSegmentIds,
}))
}, [segmentIds, selectionState.selectedSegmentIds, selectionState.onSelectedSegmentIdsChange])
// Selection state (with actual segments)
const selectionState = useSegmentSelection(segmentListDataHook.segments)
// Sync selection state for segment list data hook
useMemo(() => {
setSegmentsForSelection(selectionState.selectedSegmentIds)
}, [selectionState.selectedSegmentIds])
// Child segment data
const childSegmentDataHook = useChildSegmentData({
@ -165,6 +153,24 @@ const Completed: FC<ICompletedProps> = ({
return (
<SegmentListContext.Provider value={contextValue}>
{/* Menu Bar */}
{!segmentListDataHook.isFullDocMode && (
<MenuBar
isAllSelected={selectionState.isAllSelected}
isSomeSelected={selectionState.isSomeSelected}
onSelectedAll={selectionState.onSelectedAll}
isLoading={segmentListDataHook.isLoadingSegmentList}
totalText={segmentListDataHook.totalText}
statusList={searchFilter.statusList}
selectDefaultValue={searchFilter.selectDefaultValue}
onChangeStatus={searchFilter.onChangeStatus}
inputValue={searchFilter.inputValue}
onInputChange={searchFilter.handleInputChange}
isCollapsed={modalState.isCollapsed}
toggleCollapsed={modalState.toggleCollapsed}
/>
)}
{/* Segment list */}
{segmentListDataHook.isFullDocMode
? (
@ -186,40 +192,22 @@ const Completed: FC<ICompletedProps> = ({
/>
)
: (
<CheckboxGroup
aria-label={t('segment.chunk', { ns: 'datasetDocuments' })}
value={currentPageSelectedSegmentIds}
onValueChange={nextSegmentIds => handleCurrentPageSelectedSegmentIdsChange(nextSegmentIds)}
allValues={segmentIds}
className="flex min-h-0 grow flex-col"
>
<MenuBar
hasSelectableSegments={segmentIds.length > 0}
isLoading={segmentListDataHook.isLoadingSegmentList}
totalText={segmentListDataHook.totalText}
statusList={searchFilter.statusList}
selectDefaultValue={searchFilter.selectDefaultValue}
onChangeStatus={searchFilter.onChangeStatus}
inputValue={searchFilter.inputValue}
onInputChange={searchFilter.handleInputChange}
isCollapsed={modalState.isCollapsed}
toggleCollapsed={modalState.toggleCollapsed}
/>
<GeneralModeContent
segmentListRef={segmentListDataHook.segmentListRef}
embeddingAvailable={embeddingAvailable}
isLoadingSegmentList={segmentListDataHook.isLoadingSegmentList}
segments={segmentListDataHook.segments}
onChangeSwitch={segmentListDataHook.onChangeSwitch}
onDelete={segmentListDataHook.onDelete}
onClickCard={modalState.onClickCard}
archived={archived}
onDeleteChildChunk={childSegmentDataHook.onDeleteChildChunk}
handleAddNewChildChunk={modalState.handleAddNewChildChunk}
onClickSlice={modalState.onClickSlice}
onClearFilter={searchFilter.onClearFilter}
/>
</CheckboxGroup>
<GeneralModeContent
segmentListRef={segmentListDataHook.segmentListRef}
embeddingAvailable={embeddingAvailable}
isLoadingSegmentList={segmentListDataHook.isLoadingSegmentList}
segments={segmentListDataHook.segments}
selectedSegmentIds={selectionState.selectedSegmentIds}
onSelected={selectionState.onSelected}
onChangeSwitch={segmentListDataHook.onChangeSwitch}
onDelete={segmentListDataHook.onDelete}
onClickCard={modalState.onClickCard}
archived={archived}
onDeleteChildChunk={childSegmentDataHook.onDeleteChildChunk}
handleAddNewChildChunk={modalState.handleAddNewChildChunk}
onClickSlice={modalState.onClickSlice}
onClearFilter={searchFilter.onClearFilter}
/>
)}
{/* Pagination */}

View File

@ -15,6 +15,8 @@ import ParagraphListSkeleton from './skeleton/paragraph-list-skeleton'
type ISegmentListProps = {
isLoading: boolean
items: SegmentDetailModel[]
selectedSegmentIds: string[]
onSelected: (segId: string) => void
onClick: (detail: SegmentDetailModel, isEditMode?: boolean) => void
onChangeSwitch: (enabled: boolean, segId?: string) => Promise<void>
onDelete: (segId: string) => Promise<void>
@ -31,6 +33,8 @@ const SegmentList = (
ref,
isLoading,
items,
selectedSegmentIds,
onSelected,
onClick: onClickCard,
onChangeSwitch,
onDelete,
@ -80,8 +84,9 @@ const SegmentList = (
<Checkbox
key={`${segItem.id}-checkbox`}
className="mt-3.5 shrink-0"
value={segItem.id}
checked={selectedSegmentIds.includes(segItem.id)}
aria-label={`${t('segment.chunk', { ns: 'datasetDocuments' })} ${segItem.position}`}
onCheckedChange={() => onSelected(segItem.id)}
/>
<div className="min-w-0 grow">
<SegmentCard

View File

@ -1,7 +1,6 @@
'use client'
import { Checkbox } from '@langgenius/dify-ui/checkbox'
import { CheckboxGroup } from '@langgenius/dify-ui/checkbox-group'
import {
Popover,
PopoverContent,
@ -29,6 +28,12 @@ const TagsFilter = ({
const [searchText, setSearchText] = useState('')
const { tags: options, tagsMap } = useTags()
const filteredOptions = options.filter(option => option.label.toLowerCase().includes(searchText.toLowerCase()))
const handleCheck = (id: string) => {
if (tags.includes(id))
onTagsChange(tags.filter((tag: string) => tag !== id))
else
onTagsChange([...tags, id])
}
const selectedTagsLength = tags.length
return (
@ -80,12 +85,7 @@ const TagsFilter = ({
placeholder={t('searchTags', { ns: 'pluginTags' }) || ''}
/>
</div>
<CheckboxGroup
aria-label={t('allTags', { ns: 'pluginTags' })}
value={tags}
onValueChange={nextTags => onTagsChange(nextTags)}
className="max-h-[448px] overflow-y-auto p-1"
>
<div className="max-h-[448px] overflow-y-auto p-1">
{
filteredOptions.map(option => (
<label
@ -94,7 +94,8 @@ const TagsFilter = ({
>
<Checkbox
className="mr-1"
value={option.name}
checked={tags.includes(option.name)}
onCheckedChange={() => handleCheck(option.name)}
/>
<div className="px-1 system-sm-medium text-text-secondary">
{option.label}
@ -102,7 +103,7 @@ const TagsFilter = ({
</label>
))
}
</CheckboxGroup>
</div>
</div>
</PopoverContent>
</Popover>

View File

@ -1,7 +1,6 @@
'use client'
import { Checkbox } from '@langgenius/dify-ui/checkbox'
import { CheckboxGroup } from '@langgenius/dify-ui/checkbox-group'
import { cn } from '@langgenius/dify-ui/cn'
import {
Popover,
@ -30,6 +29,12 @@ const CategoriesFilter = ({
const [searchText, setSearchText] = useState('')
const { categories: options, categoriesMap } = useCategories()
const filteredOptions = options.filter(option => option.name.toLowerCase().includes(searchText.toLowerCase()))
const handleCheck = (id: string) => {
if (value.includes(id))
onChange(value.filter(tag => tag !== id))
else
onChange([...value, id])
}
const selectedTagsLength = value.length
return (
@ -100,12 +105,7 @@ const CategoriesFilter = ({
placeholder={t('searchCategories', { ns: 'plugin' })}
/>
</div>
<CheckboxGroup
aria-label={t('allCategories', { ns: 'plugin' })}
value={value}
onValueChange={nextValue => onChange(nextValue)}
className="max-h-[448px] overflow-y-auto p-1"
>
<div className="max-h-[448px] overflow-y-auto p-1">
{
filteredOptions.map(option => (
<label
@ -114,7 +114,8 @@ const CategoriesFilter = ({
>
<Checkbox
className="mr-1"
value={option.name}
checked={value.includes(option.name)}
onCheckedChange={() => handleCheck(option.name)}
/>
<div className="px-1 system-sm-medium text-text-secondary">
{option.label}
@ -122,7 +123,7 @@ const CategoriesFilter = ({
</label>
))
}
</CheckboxGroup>
</div>
</div>
</PopoverContent>
</Popover>

View File

@ -1,7 +1,6 @@
'use client'
import { Checkbox } from '@langgenius/dify-ui/checkbox'
import { CheckboxGroup } from '@langgenius/dify-ui/checkbox-group'
import { cn } from '@langgenius/dify-ui/cn'
import {
Popover,
@ -30,6 +29,12 @@ const TagsFilter = ({
const [searchText, setSearchText] = useState('')
const { tags: options, getTagLabel } = useTags()
const filteredOptions = options.filter(option => option.name.toLowerCase().includes(searchText.toLowerCase()))
const handleCheck = (id: string) => {
if (value.includes(id))
onChange(value.filter(tag => tag !== id))
else
onChange([...value, id])
}
const selectedTagsLength = value.length
return (
@ -98,12 +103,7 @@ const TagsFilter = ({
placeholder={t('searchTags', { ns: 'pluginTags' })}
/>
</div>
<CheckboxGroup
aria-label={t('allTags', { ns: 'pluginTags' })}
value={value}
onValueChange={nextValue => onChange(nextValue)}
className="max-h-[448px] overflow-y-auto p-1"
>
<div className="max-h-[448px] overflow-y-auto p-1">
{
filteredOptions.map(option => (
<label
@ -112,7 +112,8 @@ const TagsFilter = ({
>
<Checkbox
className="mr-1"
value={option.name}
checked={value.includes(option.name)}
onCheckedChange={() => handleCheck(option.name)}
/>
<div className="px-1 system-sm-medium text-text-secondary">
{option.label}
@ -120,7 +121,7 @@ const TagsFilter = ({
</label>
))
}
</CheckboxGroup>
</div>
</div>
</PopoverContent>
</Popover>

View File

@ -1,5 +1,6 @@
import type { FC } from 'react'
import type { Label } from '@/app/components/tools/labels/constant'
import { Checkbox } from '@langgenius/dify-ui/checkbox'
import { CheckboxGroup } from '@langgenius/dify-ui/checkbox-group'
import { cn } from '@langgenius/dify-ui/cn'
import {
Popover,
@ -7,7 +8,7 @@ import {
PopoverTrigger,
} from '@langgenius/dify-ui/popover'
import { useDebounceFn } from 'ahooks'
import { useState } from 'react'
import { useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { Tag03 } from '@/app/components/base/icons/src/vender/line/financeAndECommerce'
import Input from '@/app/components/base/input'
@ -18,10 +19,10 @@ type LabelSelectorProps = {
onChange: (v: string[]) => void
}
function LabelSelector({
const LabelSelector: FC<LabelSelectorProps> = ({
value,
onChange,
}: LabelSelectorProps) {
}) => {
const { t } = useTranslation()
const [open, setOpen] = useState(false)
@ -38,8 +39,20 @@ function LabelSelector({
handleSearch()
}
const filteredLabelList = labelList.filter(label => label.name.includes(searchKeywords))
const selectedLabels = value.map(v => labelList.find(l => l.name === v)?.label).join(', ')
const filteredLabelList = useMemo(() => {
return labelList.filter(label => label.name.includes(searchKeywords))
}, [labelList, searchKeywords])
const selectedLabels = useMemo(() => {
return value.map(v => labelList.find(l => l.name === v)?.label).join(', ')
}, [value, labelList])
const selectLabel = (label: Label) => {
if (value.includes(label.name))
onChange(value.filter(v => v !== label.name))
else
onChange([...value, label.name])
}
return (
<Popover open={open} onOpenChange={setOpen}>
@ -73,12 +86,7 @@ function LabelSelector({
onClear={() => handleKeywordsChange('')}
/>
</div>
<CheckboxGroup
aria-label={t('createTool.toolInput.labelPlaceholder', { ns: 'tools' })}
value={value}
onValueChange={nextValue => onChange(nextValue)}
className="max-h-[264px] overflow-y-auto p-1"
>
<div className="max-h-[264px] overflow-y-auto p-1">
{filteredLabelList.map(label => (
<label
key={label.name}
@ -86,7 +94,8 @@ function LabelSelector({
>
<Checkbox
className="shrink-0"
value={label.name}
checked={value.includes(label.name)}
onCheckedChange={() => selectLabel(label)}
/>
<div title={label.label} className="grow truncate text-sm leading-5 text-text-secondary">{label.label}</div>
</label>
@ -97,7 +106,7 @@ function LabelSelector({
<div className="text-xs leading-[14px] text-text-tertiary">{t('tag.noTag', { ns: 'common' })}</div>
</div>
)}
</CheckboxGroup>
</div>
</div>
</PopoverContent>
</div>