Compare commits

..

1 Commits

Author SHA1 Message Date
d70acb217a add sample_vector_space_usage 2026-05-09 16:49:49 +08:00
1158 changed files with 12835 additions and 79955 deletions

View File

@ -1,709 +0,0 @@
#!/usr/bin/env bash
set -Eeuo pipefail
SCRIPT_NAME="$(basename "$0")"
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd -P)"
DEFAULT_REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd -P)"
REPO_ROOT="${DIFY_REPO_ROOT:-$DEFAULT_REPO_ROOT}"
YES=false
DRY_RUN=true
SKIP_SMOKE=false
SKIP_MIGRATION=false
TIMEOUT_SECONDS="${DIFY_RESET_TIMEOUT_SECONDS:-300}"
SMOKE_URL="${DIFY_RESET_SMOKE_URL:-}"
LOCK_DIR=""
CURRENT_PHASE="init"
DELETED_PATHS=()
SKIPPED_PATHS=()
DELETED_NAMED_VOLUMES=()
SKIPPED_NAMED_VOLUMES=()
PRESERVED_PATHS=()
HEALTH_RESULTS=()
SMOKE_RESULT="not-run"
START_TIME="$(date +%s)"
RUNTIME_PATHS=(
"volumes/db/data"
"volumes/mysql/data"
"volumes/redis/data"
"volumes/app/storage"
"volumes/plugin_daemon"
"volumes/weaviate"
"volumes/qdrant"
"volumes/pgvector"
"volumes/pgvecto_rs"
"volumes/chroma"
"volumes/milvus"
"volumes/opensearch/data"
)
NAMED_VOLUMES=(
"dify_es01_data"
)
PRESERVE_PATHS=(
".env"
"middleware.env"
"docker-compose.yaml"
"docker-compose.middleware.yaml"
"nginx"
"ssrf_proxy"
"volumes/certbot"
"volumes/opensearch/opensearch_dashboards.yml"
"nginx/ssl"
)
usage() {
cat <<EOF
Usage: $SCRIPT_NAME [options]
Safely reset a Dify test environment in place. The command defaults to dry-run.
Options:
--yes Perform destructive reset. Required to delete data.
--dry-run Print planned actions without changing services or data.
--repo-root PATH Repository root. Defaults to auto-detected Dify root.
--smoke-url URL Public URL to verify after restart.
--skip-smoke Skip public-domain smoke verification.
--skip-migration Skip explicit migration gate.
--timeout SECONDS Health-check timeout. Default: $TIMEOUT_SECONDS.
-h, --help Show this help.
Required for destructive reset:
ALLOW_DIFY_TEST_RESET=true
DIFY_ENV_NAME=test
Optional:
DIFY_RESET_SMOKE_URL=https://test.example.com
RESET_TARGET_DOMAIN=test.example.com
EOF
}
log() {
printf '[%s] %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*"
}
fail() {
local message="$1"
print_report "failure"
printf 'ERROR: %s\n' "$message" >&2
exit 1
}
run_cmd() {
printf '+'
printf ' %q' "$@"
printf '\n'
if [ "$DRY_RUN" = false ]; then
set +e
"$@"
local status=$?
set -e
if [ "$status" -ne 0 ]; then
fail "Command failed with exit code $status: $(command_string "$@")"
fi
fi
}
command_string() {
local arg
local result=""
for arg in "$@"; do
result="$result $(printf '%q' "$arg")"
done
printf '%s' "${result# }"
}
read_env_value() {
local key="$1"
local default_value="$2"
local env_file="$DOCKER_DIR/.env"
local value=""
if [ -f "$env_file" ]; then
value="$(awk -F= -v key="$key" '
$0 !~ /^[[:space:]]*#/ && $1 == key {
sub(/^[^=]*=/, "")
print
}
' "$env_file" | tail -n 1)"
fi
if [ -z "$value" ]; then
printf '%s' "$default_value"
return
fi
value="${value%\"}"
value="${value#\"}"
value="${value%\'}"
value="${value#\'}"
printf '%s' "$value"
}
parse_args() {
while [ "$#" -gt 0 ]; do
case "$1" in
--yes)
YES=true
DRY_RUN=false
;;
--dry-run)
DRY_RUN=true
;;
--repo-root)
[ "$#" -ge 2 ] || fail "--repo-root requires a path"
REPO_ROOT="$2"
shift
;;
--smoke-url)
[ "$#" -ge 2 ] || fail "--smoke-url requires a URL"
SMOKE_URL="$2"
shift
;;
--skip-smoke)
SKIP_SMOKE=true
;;
--skip-migration)
SKIP_MIGRATION=true
;;
--timeout)
[ "$#" -ge 2 ] || fail "--timeout requires seconds"
TIMEOUT_SECONDS="$2"
shift
;;
-h|--help)
usage
exit 0
;;
*)
fail "Unknown option: $1"
;;
esac
shift
done
}
validate_number() {
case "$TIMEOUT_SECONDS" in
''|*[!0-9]*)
fail "--timeout must be a positive integer"
;;
esac
}
require_docker() {
command -v docker >/dev/null 2>&1 || fail "docker command not found"
docker compose version >/dev/null 2>&1 || fail "docker compose is not available"
}
validate_environment() {
CURRENT_PHASE="validate"
REPO_ROOT="$(cd "$REPO_ROOT" && pwd -P)"
DOCKER_DIR="$REPO_ROOT/docker"
[ -d "$DOCKER_DIR" ] || fail "Docker directory not found: $DOCKER_DIR"
[ -f "$DOCKER_DIR/docker-compose.yaml" ] || fail "docker-compose.yaml not found in $DOCKER_DIR"
[ -f "$DOCKER_DIR/.env" ] || fail ".env not found in $DOCKER_DIR"
if [ "$DRY_RUN" = false ]; then
[ "$YES" = true ] || fail "Destructive reset requires --yes"
[ "${ALLOW_DIFY_TEST_RESET:-}" = "true" ] || fail "ALLOW_DIFY_TEST_RESET=true is required"
[ "${DIFY_ENV_NAME:-}" = "test" ] || fail "DIFY_ENV_NAME=test is required"
require_docker
fi
}
acquire_lock() {
CURRENT_PHASE="lock"
local env_name="${DIFY_ENV_NAME:-dry-run}"
LOCK_DIR="${TMPDIR:-/tmp}/dify-test-reset-${env_name}.lock"
if ! mkdir "$LOCK_DIR" 2>/dev/null; then
fail "Reset lock is already held: $LOCK_DIR"
fi
printf '%s\n' "$$" > "$LOCK_DIR/pid"
trap cleanup EXIT
}
cleanup() {
if [ -n "$LOCK_DIR" ] && [ -d "$LOCK_DIR" ]; then
rm -rf "$LOCK_DIR"
fi
}
compose() {
local args=(compose --env-file "$DOCKER_DIR/.env" -f "$DOCKER_DIR/docker-compose.yaml")
if [ -n "${DIFY_COMPOSE_PROJECT:-}" ]; then
args+=(-p "$DIFY_COMPOSE_PROJECT")
fi
docker "${args[@]}" "$@"
}
compose_project_name() {
if [ -n "${DIFY_COMPOSE_PROJECT:-}" ]; then
printf '%s' "$DIFY_COMPOSE_PROJECT"
return
fi
if [ -n "${COMPOSE_PROJECT_NAME:-}" ]; then
printf '%s' "$COMPOSE_PROJECT_NAME"
return
fi
local env_project
env_project="$(read_env_value COMPOSE_PROJECT_NAME "")"
if [ -n "$env_project" ]; then
printf '%s' "$env_project"
return
fi
basename "$DOCKER_DIR"
}
active_db_service() {
local db_type
db_type="$(read_env_value DB_TYPE postgresql)"
case "$db_type" in
postgresql|'')
printf '%s\n' "db_postgres"
;;
mysql)
printf '%s\n' "db_mysql"
;;
oceanbase)
printf '%s\n' "oceanbase"
;;
*)
printf '%s\n' "$db_type"
;;
esac
}
active_vector_service() {
local vector_store
vector_store="$(read_env_value VECTOR_STORE weaviate)"
case "$vector_store" in
''|none|external)
return 0
;;
pgvecto-rs|pgvecto_rs)
printf '%s\n' "pgvecto-rs"
;;
milvus)
printf '%s\n' "milvus-standalone"
;;
elasticsearch|opensearch|weaviate|qdrant|pgvector|chroma|oceanbase|seekdb|couchbase-server|iris)
printf '%s\n' "$vector_store"
;;
couchbase)
printf '%s\n' "couchbase-server"
;;
*)
return 0
;;
esac
}
safe_runtime_path() {
local rel_path="$1"
case "$rel_path" in
""|"/"| "." | ".." | *".."* | /*)
return 1
;;
esac
case "$rel_path" in
volumes/*)
return 0
;;
*)
return 1
;;
esac
}
safe_named_volume() {
local volume="$1"
case "$volume" in
""|*"/"*|*" "*|*$'\t'*|*$'\n'*|*$'\r'*)
return 1
;;
*[!a-zA-Z0-9_.-]*)
return 1
;;
*)
return 0
;;
esac
}
volume_exists() {
docker volume inspect "$1" >/dev/null 2>&1
}
append_unique_volume() {
local candidate="$1"
local existing
[ -n "$candidate" ] || return 0
for existing in "${RESOLVED_VOLUME_NAMES[@]}"; do
if [ "$existing" = "$candidate" ]; then
return
fi
done
RESOLVED_VOLUME_NAMES+=("$candidate")
}
resolve_named_volume_names() {
local logical_name="$1"
local project_name
local candidate
local volume_list
local status
RESOLVED_VOLUME_NAMES=()
project_name="$(compose_project_name)"
set +e
volume_list="$(docker volume ls -q \
--filter "label=com.docker.compose.project=$project_name" \
--filter "label=com.docker.compose.volume=$logical_name" 2>/dev/null)"
status=$?
set -e
if [ "$status" -ne 0 ]; then
fail "Failed to list Docker volumes for Compose project $project_name"
fi
while IFS= read -r candidate; do
append_unique_volume "$candidate"
done <<< "$volume_list"
for candidate in "${project_name}_${logical_name}" "$logical_name"; do
if volume_exists "$candidate"; then
append_unique_volume "$candidate"
fi
done
}
collect_preserved_paths() {
PRESERVED_PATHS=()
local rel_path
for rel_path in "${PRESERVE_PATHS[@]}"; do
if [ -e "$DOCKER_DIR/$rel_path" ]; then
PRESERVED_PATHS+=("$rel_path")
fi
done
}
print_plan() {
CURRENT_PHASE="plan"
local db_service
local vector_service
db_service="$(active_db_service)"
vector_service="$(active_vector_service || true)"
collect_preserved_paths
log "Reset mode: $([ "$DRY_RUN" = true ] && printf dry-run || printf destructive)"
log "Repository root: $REPO_ROOT"
log "Docker directory: $DOCKER_DIR"
log "Compose project: $(compose_project_name)"
log "Database service: $db_service"
log "Vector service: ${vector_service:-<external-or-none>}"
log "Timeout: ${TIMEOUT_SECONDS}s"
printf '\nPlanned runtime path deletions:\n'
local rel_path
for rel_path in "${RUNTIME_PATHS[@]}"; do
printf ' - %s\n' "$rel_path"
done
printf '\nPlanned named volume deletions:\n'
local volume
for volume in "${NAMED_VOLUMES[@]}"; do
printf ' - %s (Compose project: %s)\n' "$volume" "$(compose_project_name)"
done
printf '\nPreserved configuration paths found:\n'
for rel_path in "${PRESERVED_PATHS[@]}"; do
printf ' - %s\n' "$rel_path"
done
printf '\nCommands:\n'
printf ' - docker compose down --remove-orphans\n'
printf ' - delete allowlisted runtime paths and named volumes\n'
printf ' - docker compose up -d %s redis%s\n' "$db_service" "${vector_service:+ $vector_service}"
printf ' - docker compose run --rm -e MIGRATION_ENABLED=true -e MODE=migration api\n'
printf ' - docker compose up -d\n'
printf ' - health checks and smoke check\n\n'
}
delete_runtime_paths() {
CURRENT_PHASE="delete-runtime-data"
local rel_path
local abs_path
for rel_path in "${RUNTIME_PATHS[@]}"; do
safe_runtime_path "$rel_path" || fail "Unsafe runtime path in allowlist: $rel_path"
abs_path="$DOCKER_DIR/$rel_path"
if [ ! -e "$abs_path" ]; then
SKIPPED_PATHS+=("$rel_path (absent)")
continue
fi
DELETED_PATHS+=("$rel_path")
run_cmd rm -rf -- "$abs_path"
done
}
delete_named_volumes() {
CURRENT_PHASE="delete-runtime-volumes"
local logical_name
local actual_name
for logical_name in "${NAMED_VOLUMES[@]}"; do
safe_named_volume "$logical_name" || fail "Unsafe named volume in allowlist: $logical_name"
resolve_named_volume_names "$logical_name"
if [ "${#RESOLVED_VOLUME_NAMES[@]}" -eq 0 ]; then
SKIPPED_NAMED_VOLUMES+=("$logical_name (absent)")
continue
fi
for actual_name in "${RESOLVED_VOLUME_NAMES[@]}"; do
DELETED_NAMED_VOLUMES+=("$actual_name")
run_cmd docker volume rm "$actual_name"
done
done
}
stop_stack() {
CURRENT_PHASE="stop-stack"
run_cmd compose down --remove-orphans
}
start_middleware() {
CURRENT_PHASE="start-middleware"
local db_service
local vector_service
local services=()
db_service="$(active_db_service)"
vector_service="$(active_vector_service || true)"
services+=("$db_service" "redis")
if [ -n "$vector_service" ]; then
services+=("$vector_service")
fi
run_cmd compose up -d "${services[@]}"
if [ "$DRY_RUN" = false ]; then
wait_for_services "${services[@]}"
fi
}
run_migration() {
CURRENT_PHASE="migration"
if [ "$SKIP_MIGRATION" = true ]; then
HEALTH_RESULTS+=("migration:skipped")
return
fi
run_cmd compose run --rm -e MIGRATION_ENABLED=true -e MODE=migration api
HEALTH_RESULTS+=("migration:ok")
}
start_full_stack() {
CURRENT_PHASE="start-full-stack"
run_cmd compose up -d
if [ "$DRY_RUN" = false ]; then
wait_for_services api web worker nginx
wait_if_service_exists plugin_daemon
fi
}
container_status() {
local service="$1"
local container_id
container_id="$(compose ps -q "$service" 2>/dev/null || true)"
[ -n "$container_id" ] || return 1
local health
health="$(docker inspect --format '{{if .State.Health}}{{.State.Health.Status}}{{else}}{{.State.Status}}{{end}}' "$container_id" 2>/dev/null || true)"
printf '%s' "$health"
}
wait_if_service_exists() {
local service="$1"
if [ -n "$(compose ps -q "$service" 2>/dev/null || true)" ]; then
wait_for_services "$service"
fi
}
wait_for_services() {
local service
for service in "$@"; do
wait_for_service "$service"
done
}
wait_for_service() {
local service="$1"
local deadline=$(( $(date +%s) + TIMEOUT_SECONDS ))
local status=""
log "Waiting for service: $service"
while [ "$(date +%s)" -le "$deadline" ]; do
status="$(container_status "$service" || true)"
case "$status" in
healthy|running)
HEALTH_RESULTS+=("$service:$status")
return 0
;;
unhealthy|exited|dead)
HEALTH_RESULTS+=("$service:$status")
fail "Service $service reached failure status: $status"
;;
esac
sleep 3
done
HEALTH_RESULTS+=("$service:timeout")
fail "Timed out waiting for service: $service"
}
default_smoke_url() {
if [ -n "$SMOKE_URL" ]; then
printf '%s' "$SMOKE_URL"
return
fi
if [ -n "${RESET_TARGET_DOMAIN:-}" ]; then
local https_enabled
https_enabled="$(read_env_value NGINX_HTTPS_ENABLED false)"
if [ "$https_enabled" = "true" ]; then
printf 'https://%s' "$RESET_TARGET_DOMAIN"
else
printf 'http://%s' "$RESET_TARGET_DOMAIN"
fi
return
fi
local port
port="$(read_env_value EXPOSE_NGINX_PORT 80)"
printf 'http://localhost:%s' "$port"
}
run_smoke_check() {
CURRENT_PHASE="smoke"
if [ "$SKIP_SMOKE" = true ]; then
SMOKE_RESULT="skipped"
return
fi
local url
url="$(default_smoke_url)"
if [ "$DRY_RUN" = true ]; then
SMOKE_RESULT="planned:$url"
printf '+ curl -fsS --max-time 10 %q\n' "$url"
return
fi
curl -fsS --max-time 10 "$url" >/dev/null || fail "Smoke check failed: $url"
SMOKE_RESULT="ok:$url"
}
print_report() {
local status="${1:-success}"
local end_time
end_time="$(date +%s)"
printf '\nReset report\n'
printf '============\n'
printf 'status: %s\n' "$status"
printf 'environment: %s\n' "${DIFY_ENV_NAME:-<unset>}"
printf 'repo_root: %s\n' "${REPO_ROOT:-<unset>}"
printf 'phase: %s\n' "$CURRENT_PHASE"
printf 'duration_seconds: %s\n' "$(( end_time - START_TIME ))"
printf 'mode: %s\n' "$([ "$DRY_RUN" = true ] && printf dry-run || printf destructive)"
printf '\ndeleted_runtime_paths:\n'
if [ "${#DELETED_PATHS[@]}" -eq 0 ]; then
printf ' - <none>\n'
else
printf ' - %s\n' "${DELETED_PATHS[@]}"
fi
printf '\nskipped_runtime_paths:\n'
if [ "${#SKIPPED_PATHS[@]}" -eq 0 ]; then
printf ' - <none>\n'
else
printf ' - %s\n' "${SKIPPED_PATHS[@]}"
fi
printf '\ndeleted_named_volumes:\n'
if [ "${#DELETED_NAMED_VOLUMES[@]}" -eq 0 ]; then
printf ' - <none>\n'
else
printf ' - %s\n' "${DELETED_NAMED_VOLUMES[@]}"
fi
printf '\nskipped_named_volumes:\n'
if [ "${#SKIPPED_NAMED_VOLUMES[@]}" -eq 0 ]; then
printf ' - <none>\n'
else
printf ' - %s\n' "${SKIPPED_NAMED_VOLUMES[@]}"
fi
printf '\npreserved_paths:\n'
if [ "${#PRESERVED_PATHS[@]}" -eq 0 ]; then
printf ' - <none found>\n'
else
printf ' - %s\n' "${PRESERVED_PATHS[@]}"
fi
printf '\nhealth_results:\n'
if [ "${#HEALTH_RESULTS[@]}" -eq 0 ]; then
printf ' - <not run>\n'
else
printf ' - %s\n' "${HEALTH_RESULTS[@]}"
fi
printf '\nsmoke_result: %s\n' "$SMOKE_RESULT"
}
main() {
parse_args "$@"
validate_number
validate_environment
acquire_lock
print_plan
if [ "$DRY_RUN" = true ]; then
run_smoke_check
print_report "dry-run"
return 0
fi
stop_stack
delete_runtime_paths
delete_named_volumes
start_middleware
run_migration
start_full_stack
run_smoke_check
CURRENT_PHASE="complete"
print_report "success"
}
main "$@"

View File

@ -99,7 +99,7 @@ jobs:
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env
cp docker/envs/middleware.env.example docker/middleware.env
cp docker/middleware.env.example docker/middleware.env
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh

View File

@ -116,12 +116,6 @@ jobs:
if: github.event_name != 'merge_group'
uses: ./.github/actions/setup-web
- name: Generate API docs
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
run: |
cd api
uv run dev/generate_swagger_markdown_docs.py --swagger-dir openapi --markdown-dir openapi/markdown
- name: ESLint autofix
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
run: |

View File

@ -37,7 +37,7 @@ jobs:
- name: Prepare middleware env
run: |
cd docker
cp envs/middleware.env.example middleware.env
cp middleware.env.example middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
@ -87,7 +87,7 @@ jobs:
- name: Prepare middleware env for MySQL
run: |
cd docker
cp envs/middleware.env.example middleware.env
cp middleware.env.example middleware.env
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' middleware.env
sed -i 's/DB_HOST=db_postgres/DB_HOST=db_mysql/' middleware.env
sed -i 's/DB_PORT=5432/DB_PORT=3306/' middleware.env

View File

@ -57,7 +57,7 @@ jobs:
- '.github/workflows/api-tests.yml'
- '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example'
- 'docker/envs/middleware.env.example'
- 'docker/middleware.env.example'
- 'docker/docker-compose.middleware.yaml'
- 'docker/docker-compose-template.yaml'
- 'docker/generate_docker_compose'
@ -84,7 +84,7 @@ jobs:
- 'pnpm-workspace.yaml'
- '.nvmrc'
- 'docker/docker-compose.middleware.yaml'
- 'docker/envs/middleware.env.example'
- 'docker/middleware.env.example'
- '.github/workflows/web-e2e.yml'
- '.github/actions/setup-web/**'
vdb:
@ -94,7 +94,7 @@ jobs:
- '.github/workflows/vdb-tests.yml'
- '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example'
- 'docker/envs/middleware.env.example'
- 'docker/middleware.env.example'
- 'docker/docker-compose.yaml'
- 'docker/docker-compose-template.yaml'
- 'docker/generate_docker_compose'
@ -116,7 +116,7 @@ jobs:
- '.github/workflows/db-migration-test.yml'
- '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example'
- 'docker/envs/middleware.env.example'
- 'docker/middleware.env.example'
- 'docker/docker-compose.middleware.yaml'
- 'docker/docker-compose-template.yaml'
- 'docker/generate_docker_compose'

View File

@ -77,6 +77,8 @@ jobs:
with:
files: |
web/**
e2e/**
sdks/nodejs-client/**
packages/**
package.json
pnpm-lock.yaml
@ -94,14 +96,14 @@ jobs:
id: eslint-cache-restore
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: web/.eslintcache
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
path: .eslintcache
key: ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
restore-keys: |
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
working-directory: .
run: vp run lint:ci
- name: Web tsslint
@ -113,7 +115,7 @@ jobs:
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
working-directory: .
run: vp run type-check
- name: Web dead code check
@ -125,7 +127,7 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: web/.eslintcache
path: .eslintcache
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
superlinter:

View File

@ -51,7 +51,7 @@ jobs:
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env
cp docker/envs/middleware.env.example docker/middleware.env
cp docker/middleware.env.example docker/middleware.env
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh

View File

@ -48,7 +48,7 @@ jobs:
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env
cp docker/envs/middleware.env.example docker/middleware.env
cp docker/middleware.env.example docker/middleware.env
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh

3
.gitignore vendored
View File

@ -203,6 +203,7 @@ sdks/python-client/dify_client.egg-info
.vscode/*
!.vscode/launch.json.template
!.vscode/settings.example.json
!.vscode/README.md
api/.vscode
# vscode Code History Extension
@ -249,3 +250,5 @@ scripts/stress-test/reports/
# Code Agent Folder
.qoder/*
.eslintcache

View File

@ -56,44 +56,9 @@ if $api_modified; then
fi
fi
if $web_modified; then
if $skip_web_checks; then
echo "Git operation in progress, skipping web checks"
exit 0
fi
echo "Running ESLint on web module"
if git diff --cached --quiet -- 'web/**/*.ts' 'web/**/*.tsx'; then
web_ts_modified=false
else
ts_diff_status=$?
if [ $ts_diff_status -eq 1 ]; then
web_ts_modified=true
else
echo "Unable to determine staged TypeScript changes (git exit code: $ts_diff_status)."
exit $ts_diff_status
fi
fi
cd ./web || exit 1
pnpm exec vp staged
if $web_ts_modified; then
echo "Running TypeScript type-check:tsgo"
if ! npm run type-check:tsgo; then
echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors."
exit 1
fi
else
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
fi
echo "Running knip"
if ! npm run knip; then
echo "Knip check failed. Please run 'npm run knip' to fix the errors."
exit 1
fi
cd ../
if $skip_web_checks; then
echo "Git operation in progress, skipping web checks"
exit 0
fi
vp staged

View File

@ -71,13 +71,13 @@ type-check:
@echo "📝 Running type checks (basedpyright + pyrefly + mypy)..."
@./dev/basedpyright-check $(PATH_TO_CHECK)
@./dev/pyrefly-check-local
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
@echo "✅ Type checks complete"
type-check-core:
@echo "📝 Running core type checks (basedpyright + mypy)..."
@./dev/basedpyright-check $(PATH_TO_CHECK)
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
@echo "✅ Core type checks complete"
test:

View File

@ -76,10 +76,11 @@ The easiest way to start the Dify server is through [Docker Compose](docker/dock
```bash
cd dify
cd docker
cp .env.example .env
docker compose up -d
./dify-compose up -d
```
On Windows PowerShell, run `.\dify-compose.ps1 up -d` from the `docker` directory.
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process.
#### Seeking help
@ -137,7 +138,7 @@ Star Dify on GitHub and be instantly notified of new releases.
### Custom configurations
If you need to customize the configuration, edit `docker/.env`. The essential startup defaults live in [`docker/.env.example`](docker/.env.example), and optional advanced variables are split under `docker/envs/` by theme. After making any changes, re-run `docker compose up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
If you need to customize the configuration, add only the values you want to override to `docker/.env`. The default values live in [`docker/.env.default`](docker/.env.default), and the full reference remains in [`docker/.env.example`](docker/.env.example). After making any changes, re-run `./dify-compose up -d` or `.\dify-compose.ps1 up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
### Metrics Monitoring with Grafana

View File

@ -98,8 +98,6 @@ DB_DATABASE=dify
SQLALCHEMY_POOL_PRE_PING=true
SQLALCHEMY_POOL_TIMEOUT=30
# Connection pool reset behavior on return
SQLALCHEMY_POOL_RESET_ON_RETURN=rollback
# Storage configuration
# use for store upload files, private keys...
@ -383,7 +381,7 @@ VIKINGDB_ACCESS_KEY=your-ak
VIKINGDB_SECRET_KEY=your-sk
VIKINGDB_REGION=cn-shanghai
VIKINGDB_HOST=api-vikingdb.xxx.volces.com
VIKINGDB_SCHEME=http
VIKINGDB_SCHEMA=http
VIKINGDB_CONNECTION_TIMEOUT=30
VIKINGDB_SOCKET_TIMEOUT=30
@ -434,6 +432,8 @@ UPLOAD_FILE_EXTENSION_BLACKLIST=
# Model configuration
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
# Mail configuration, support: resend, smtp, sendgrid

View File

@ -17,14 +17,14 @@ FROM base AS packages
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
# basic environment
git g++ \
g++ \
# for building gmpy2
libmpfr-dev libmpc-dev
# Install Python dependencies (workspace members under providers/vdb/)
COPY pyproject.toml uv.lock ./
COPY providers ./providers
RUN uv sync --locked --no-dev --group evaluation
RUN uv sync --locked --no-dev
# production stage
FROM base AS production
@ -77,7 +77,6 @@ RUN \
# Install dependencies
&& apt-get install -y --no-install-recommends \
# basic environment
git \
nodejs=${NODE_PACKAGE_VERSION} \
# for gmpy2 \
libgmp-dev libmpfr-dev libmpc-dev \

View File

@ -33,6 +33,7 @@ from .vector import (
old_metadata_migration,
vdb_migrate,
)
from .vector_space import sample_vector_space_usage
__all__ = [
"add_qdrant_index",
@ -62,6 +63,7 @@ __all__ = [
"reset_encrypt_key_pair",
"reset_password",
"restore_workflow_runs",
"sample_vector_space_usage",
"setup_datasource_oauth_client",
"setup_system_tool_oauth_client",
"setup_system_trigger_oauth_client",

View File

@ -0,0 +1,558 @@
import csv
import json
from dataclasses import dataclass
from decimal import Decimal
from pathlib import Path
from typing import Any
import click
import httpx
import sqlalchemy as sa
from sqlalchemy import func, select
from configs import dify_config
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from extensions.ext_database import db
from models.dataset import (
ChildChunk,
Dataset,
DatasetCollectionBinding,
DocumentSegment,
DocumentSegmentSummary,
SegmentAttachmentBinding,
TidbAuthBinding,
)
from models.dataset import Document as DatasetDocument
from models.enums import IndexingStatus, SegmentStatus, SummaryStatus, TidbAuthBindingStatus
from models.model import App, AppAnnotationSetting, MessageAnnotation
COMMON_EMBEDDING_MODEL_DIMS = {
# OpenAI
"text-embedding-ada-002": 1536,
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
# Cohere
"embed-english-v3.0": 1024,
"embed-multilingual-v3.0": 1024,
"embed-english-light-v3.0": 384,
"embed-multilingual-light-v3.0": 384,
# Google
"embedding-001": 768,
"text-embedding-004": 768,
# Voyage
"voyage-2": 1024,
"voyage-3": 1024,
"voyage-3-lite": 512,
"voyage-large-2": 1536,
"voyage-code-2": 1536,
# BAAI BGE
"bge-small-en": 384,
"bge-small-en-v1.5": 384,
"bge-small-zh": 512,
"bge-small-zh-v1.5": 512,
"bge-base-en": 768,
"bge-base-en-v1.5": 768,
"bge-base-zh": 768,
"bge-base-zh-v1.5": 768,
"bge-large-en": 1024,
"bge-large-en-v1.5": 1024,
"bge-large-zh": 1024,
"bge-large-zh-v1.5": 1024,
"bge-m3": 1024,
# E5
"multilingual-e5-small": 384,
"multilingual-e5-base": 768,
"multilingual-e5-large": 1024,
"e5-small-v2": 384,
"e5-base-v2": 768,
"e5-large-v2": 1024,
# M3E
"m3e-small": 512,
"m3e-base": 768,
"m3e-large": 1024,
# Jina
"jina-embeddings-v2-small-en": 512,
"jina-embeddings-v2-base-en": 768,
"jina-embeddings-v2-base-zh": 768,
"jina-embeddings-v3": 1024,
}
@dataclass(frozen=True)
class CollectionPointStats:
collection_name: str
source_type: str
source_id: str
model_provider: str | None
model_name: str | None
segment_points: int = 0
child_chunk_points: int = 0
summary_points: int = 0
attachment_points: int = 0
annotation_points: int = 0
@property
def total_points(self) -> int:
return (
self.segment_points
+ self.child_chunk_points
+ self.summary_points
+ self.attachment_points
+ self.annotation_points
)
def _parse_overheads(value: str) -> list[int]:
overheads = []
for item in value.split(","):
item = item.strip()
if not item:
continue
overheads.append(int(item))
if not overheads:
raise click.BadParameter("At least one overhead is required.")
return overheads
def _normalize_model_name(model_name: str) -> str:
return model_name.strip().split("/")[-1]
def _tidb_storage_usage_bytes(binding: TidbAuthBinding, timeout: float) -> int:
if not binding.qdrant_endpoint:
raise ValueError("qdrant_endpoint is empty")
endpoint = binding.qdrant_endpoint.rstrip("/")
with httpx.Client(timeout=timeout, verify=False) as client:
response = client.get(f"{endpoint}/cluster", headers={"api-key": f"{binding.account}:{binding.password}"})
response.raise_for_status()
data = response.json()
storage = data.get("usage", {}).get("storage", {})
row_based = int(storage.get("row_based") or 0)
columnar = int(storage.get("columnar") or 0)
return row_based + columnar
def _extract_vector_size(collection_payload: dict[str, Any]) -> int | None:
vectors = (
collection_payload.get("result", {})
.get("config", {})
.get("params", {})
.get("vectors")
)
if isinstance(vectors, dict):
size = vectors.get("size")
if isinstance(size, int):
return size
for vector_config in vectors.values():
if isinstance(vector_config, dict) and isinstance(vector_config.get("size"), int):
return vector_config["size"]
return None
def _qdrant_collection_dim(
binding: TidbAuthBinding,
collection_name: str,
timeout: float,
dim_cache: dict[str, int | None],
) -> int | None:
if collection_name in dim_cache:
return dim_cache[collection_name]
if not binding.qdrant_endpoint:
dim_cache[collection_name] = None
return None
endpoint = binding.qdrant_endpoint.rstrip("/")
try:
with httpx.Client(timeout=timeout, verify=False) as client:
response = client.get(
f"{endpoint}/collections/{collection_name}",
headers={"api-key": f"{binding.account}:{binding.password}"},
)
if response.status_code == 404:
dim_cache[collection_name] = None
return None
response.raise_for_status()
dim = _extract_vector_size(response.json())
dim_cache[collection_name] = dim
return dim
except Exception:
dim_cache[collection_name] = None
return None
def _dataset_vector_type(dataset: Dataset) -> str | None:
if dataset.index_struct_dict:
return dataset.index_struct_dict.get("type")
return dify_config.VECTOR_STORE
def _dataset_collection_name(dataset: Dataset) -> str:
if dataset.index_struct_dict:
vector_store = dataset.index_struct_dict.get("vector_store") or {}
collection_name = vector_store.get("class_prefix")
if collection_name:
return collection_name
if dataset.collection_binding_id:
binding = db.session.get(DatasetCollectionBinding, dataset.collection_binding_id)
if binding:
return binding.collection_name
return Dataset.gen_collection_name_by_id(dataset.id)
def _active_tidb_bindings(tenant_ids: tuple[str, ...], limit: int, offset: int) -> list[TidbAuthBinding]:
stmt = (
select(TidbAuthBinding)
.where(
TidbAuthBinding.tenant_id.is_not(None),
TidbAuthBinding.active == True,
TidbAuthBinding.status == TidbAuthBindingStatus.ACTIVE,
)
.order_by(TidbAuthBinding.created_at.desc())
)
if tenant_ids:
stmt = stmt.where(TidbAuthBinding.tenant_id.in_(tenant_ids))
else:
stmt = stmt.limit(limit).offset(offset)
return list(db.session.scalars(stmt).all())
def _completed_document_filter() -> tuple[Any, ...]:
return (
DatasetDocument.indexing_status == IndexingStatus.COMPLETED,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
def _completed_segment_filter() -> tuple[Any, ...]:
return (
DocumentSegment.status == SegmentStatus.COMPLETED,
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.is_not(None),
)
def _count_dataset_points(dataset: Dataset) -> CollectionPointStats:
segment_points = (
db.session.scalar(
select(func.count(DocumentSegment.id))
.join(DatasetDocument, DatasetDocument.id == DocumentSegment.document_id)
.where(
DocumentSegment.tenant_id == dataset.tenant_id,
DocumentSegment.dataset_id == dataset.id,
DatasetDocument.doc_form != IndexStructureType.PARENT_CHILD_INDEX,
*_completed_document_filter(),
*_completed_segment_filter(),
)
)
or 0
)
child_chunk_points = (
db.session.scalar(
select(func.count(ChildChunk.id))
.join(DatasetDocument, DatasetDocument.id == ChildChunk.document_id)
.where(
ChildChunk.tenant_id == dataset.tenant_id,
ChildChunk.dataset_id == dataset.id,
ChildChunk.index_node_id.is_not(None),
*_completed_document_filter(),
)
)
or 0
)
summary_points = (
db.session.scalar(
select(func.count(DocumentSegmentSummary.id))
.join(DatasetDocument, DatasetDocument.id == DocumentSegmentSummary.document_id)
.where(
DocumentSegmentSummary.dataset_id == dataset.id,
DocumentSegmentSummary.enabled == True,
DocumentSegmentSummary.status == SummaryStatus.COMPLETED,
DocumentSegmentSummary.summary_index_node_id.is_not(None),
*_completed_document_filter(),
)
)
or 0
)
attachment_points = 0
if dataset.is_multimodal:
attachment_points = (
db.session.scalar(
select(func.count(sa.distinct(SegmentAttachmentBinding.attachment_id)))
.join(DocumentSegment, DocumentSegment.id == SegmentAttachmentBinding.segment_id)
.join(DatasetDocument, DatasetDocument.id == SegmentAttachmentBinding.document_id)
.where(
SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
SegmentAttachmentBinding.dataset_id == dataset.id,
*_completed_document_filter(),
*_completed_segment_filter(),
)
)
or 0
)
return CollectionPointStats(
collection_name=_dataset_collection_name(dataset),
source_type="dataset",
source_id=dataset.id,
model_provider=dataset.embedding_model_provider,
model_name=dataset.embedding_model,
segment_points=int(segment_points),
child_chunk_points=int(child_chunk_points),
summary_points=int(summary_points),
attachment_points=int(attachment_points),
)
def _dataset_stats_for_tenant(tenant_id: str) -> list[CollectionPointStats]:
datasets = db.session.scalars(
select(Dataset).where(
Dataset.tenant_id == tenant_id,
Dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY,
)
).all()
stats = []
for dataset in datasets:
if _dataset_vector_type(dataset) != VectorType.TIDB_ON_QDRANT:
continue
dataset_stats = _count_dataset_points(dataset)
if dataset_stats.total_points > 0:
stats.append(dataset_stats)
return stats
def _annotation_stats_for_tenant(tenant_id: str) -> list[CollectionPointStats]:
rows = db.session.execute(
select(
App.id,
DatasetCollectionBinding.provider_name,
DatasetCollectionBinding.model_name,
DatasetCollectionBinding.collection_name,
func.count(MessageAnnotation.id),
)
.join(AppAnnotationSetting, AppAnnotationSetting.app_id == App.id)
.join(DatasetCollectionBinding, DatasetCollectionBinding.id == AppAnnotationSetting.collection_binding_id)
.join(MessageAnnotation, MessageAnnotation.app_id == App.id)
.where(App.tenant_id == tenant_id)
.group_by(
App.id,
DatasetCollectionBinding.provider_name,
DatasetCollectionBinding.model_name,
DatasetCollectionBinding.collection_name,
)
).all()
return [
CollectionPointStats(
collection_name=row[3],
source_type="annotation",
source_id=row[0],
model_provider=row[1],
model_name=row[2],
annotation_points=int(row[4] or 0),
)
for row in rows
if int(row[4] or 0) > 0
]
def _resolve_dim(
stat: CollectionPointStats,
binding: TidbAuthBinding,
default_dim: int,
fetch_qdrant_dim: bool,
timeout: float,
dim_cache: dict[str, int | None],
) -> tuple[int, str]:
if stat.model_provider and stat.model_name:
builtin_dim = COMMON_EMBEDDING_MODEL_DIMS.get(_normalize_model_name(stat.model_name))
if builtin_dim:
return builtin_dim, "builtin_model_map"
if fetch_qdrant_dim:
qdrant_dim = _qdrant_collection_dim(binding, stat.collection_name, timeout, dim_cache)
if qdrant_dim:
return qdrant_dim, "qdrant"
return default_dim, "default"
def _mb(value: int | float | Decimal) -> float:
return round(float(value) / 1024 / 1024, 4)
def _log(message: str, quiet: bool) -> None:
if not quiet:
click.echo(message, err=True)
@click.command(
"sample-vector-space-usage",
help="Sample TiDB vector storage usage and compare it with local formula estimates.",
)
@click.option("--tenant-id", multiple=True, help="Tenant ID to sample. Can be repeated.")
@click.option("--limit", default=20, show_default=True, help="Number of active TiDB tenants to sample.")
@click.option("--offset", default=0, show_default=True, help="Offset when sampling active TiDB tenants.")
@click.option("--default-dim", default=3072, show_default=True, help="Fallback embedding dimension.")
@click.option(
"--overheads",
default="3584,5120,8192",
show_default=True,
help="Comma-separated per-point overhead bytes to compare.",
)
@click.option("--fetch-qdrant-dim/--no-fetch-qdrant-dim", default=True, show_default=True)
@click.option("--include-annotations/--exclude-annotations", default=True, show_default=True)
@click.option("--timeout", default=10.0, show_default=True, help="HTTP timeout for TiDB/Qdrant calls.")
@click.option("--output", type=click.Path(dir_okay=False, path_type=Path), help="CSV output path. Defaults to stdout.")
@click.option("--quiet", is_flag=True, help="Suppress progress logs. CSV output is unaffected.")
def sample_vector_space_usage(
tenant_id: tuple[str, ...],
limit: int,
offset: int,
default_dim: int,
overheads: str,
fetch_qdrant_dim: bool,
include_annotations: bool,
timeout: float,
output: Path | None,
quiet: bool,
):
overhead_values = _parse_overheads(overheads)
bindings = _active_tidb_bindings(tenant_id, limit, offset)
sample_scope = f" for tenant_id={','.join(tenant_id)}" if tenant_id else f" with limit={limit}, offset={offset}"
_log(
f"Sampling {len(bindings)} active TiDB binding(s){sample_scope}.",
quiet,
)
if not bindings:
_log("No active TiDB bindings found. Nothing to sample.", quiet)
fieldnames = [
"tenant_id",
"cluster_id",
"tidb_actual_mb",
"total_points",
"segment_points",
"child_chunk_points",
"summary_points",
"attachment_points",
"annotation_points",
"collection_count",
"dim_sources",
"dims",
"errors",
]
for overhead in overhead_values:
fieldnames.extend(
[
f"estimated_mb_o{overhead}",
f"diff_mb_o{overhead}",
f"ratio_o{overhead}",
]
)
output_file = output.open("w", newline="") if output else None
try:
writer = csv.DictWriter(output_file or click.get_text_stream("stdout"), fieldnames=fieldnames)
writer.writeheader()
for index, binding in enumerate(bindings, start=1):
assert binding.tenant_id is not None
tenant = binding.tenant_id
errors = []
dim_cache: dict[str, int | None] = {}
_log(f"[{index}/{len(bindings)}] tenant={tenant} cluster={binding.cluster_id}: fetching TiDB usage", quiet)
try:
actual_bytes = _tidb_storage_usage_bytes(binding, timeout)
_log(
f"[{index}/{len(bindings)}] tenant={tenant}: TiDB actual={_mb(actual_bytes)} MB",
quiet,
)
except Exception as exc:
actual_bytes = 0
errors.append(f"tidb_usage:{exc.__class__.__name__}:{exc}")
_log(
f"[{index}/{len(bindings)}] tenant={tenant}: failed to fetch TiDB usage: "
f"{exc.__class__.__name__}: {exc}",
quiet,
)
_log(f"[{index}/{len(bindings)}] tenant={tenant}: counting local vector points", quiet)
collection_stats = _dataset_stats_for_tenant(tenant)
if include_annotations:
collection_stats.extend(_annotation_stats_for_tenant(tenant))
total_points = 0
segment_points = 0
child_chunk_points = 0
summary_points = 0
attachment_points = 0
annotation_points = 0
dim_sources: dict[str, int] = {}
dims: dict[str, int] = {}
estimated_by_overhead = dict.fromkeys(overhead_values, 0)
for stat in collection_stats:
dim, dim_source = _resolve_dim(
stat,
binding,
default_dim,
fetch_qdrant_dim,
timeout,
dim_cache,
)
dim_sources[dim_source] = dim_sources.get(dim_source, 0) + 1
dims[str(dim)] = dims.get(str(dim), 0) + stat.total_points
total_points += stat.total_points
segment_points += stat.segment_points
child_chunk_points += stat.child_chunk_points
summary_points += stat.summary_points
attachment_points += stat.attachment_points
annotation_points += stat.annotation_points
for overhead in overhead_values:
estimated_by_overhead[overhead] += stat.total_points * (dim * 4 + overhead)
_log(
f"[{index}/{len(bindings)}] tenant={tenant}: points={total_points}, "
f"collections={len(collection_stats)}, dim_sources={json.dumps(dim_sources, sort_keys=True)}",
quiet,
)
row: dict[str, Any] = {
"tenant_id": tenant,
"cluster_id": binding.cluster_id,
"tidb_actual_mb": _mb(actual_bytes),
"total_points": total_points,
"segment_points": segment_points,
"child_chunk_points": child_chunk_points,
"summary_points": summary_points,
"attachment_points": attachment_points,
"annotation_points": annotation_points,
"collection_count": len(collection_stats),
"dim_sources": json.dumps(dim_sources, sort_keys=True),
"dims": json.dumps(dims, sort_keys=True),
"errors": ";".join(errors),
}
for overhead, estimated_bytes in estimated_by_overhead.items():
diff_bytes = estimated_bytes - actual_bytes
ratio = round(estimated_bytes / actual_bytes, 6) if actual_bytes > 0 else ""
row[f"estimated_mb_o{overhead}"] = _mb(estimated_bytes)
row[f"diff_mb_o{overhead}"] = _mb(diff_bytes)
row[f"ratio_o{overhead}"] = ratio
writer.writerow(row)
_log(f"[{index}/{len(bindings)}] tenant={tenant}: row written", quiet)
finally:
if output_file:
output_file.close()

View File

@ -1394,32 +1394,6 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
)
class EvaluationConfig(BaseSettings):
"""
Configuration for evaluation runtime
"""
EVALUATION_FRAMEWORK: str = Field(
description="Evaluation framework to use (ragas/deepeval/none)",
default="none",
)
EVALUATION_MAX_CONCURRENT_RUNS: PositiveInt = Field(
description="Maximum number of concurrent evaluation runs per tenant",
default=3,
)
EVALUATION_MAX_DATASET_ROWS: PositiveInt = Field(
description="Maximum number of rows allowed in an evaluation dataset",
default=500,
)
EVALUATION_TASK_TIMEOUT: PositiveInt = Field(
description="Timeout in seconds for a single evaluation task",
default=3600,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -1433,7 +1407,6 @@ class FeatureConfig(
MarketplaceConfig,
DataSetConfig,
EndpointConfig,
EvaluationConfig,
FileAccessConfig,
FileUploadConfig,
HttpConfig,

View File

@ -114,7 +114,7 @@ class SQLAlchemyEngineOptionsDict(TypedDict):
pool_pre_ping: bool
connect_args: dict[str, str]
pool_use_lifo: bool
pool_reset_on_return: Literal["commit", "rollback", None]
pool_reset_on_return: None
pool_timeout: int
@ -223,11 +223,6 @@ class DatabaseConfig(BaseSettings):
default=30,
)
SQLALCHEMY_POOL_RESET_ON_RETURN: Literal["commit", "rollback", None] = Field(
description="Connection pool reset behavior on return. Options: 'commit', 'rollback', or None",
default="rollback",
)
RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field(
description="Number of processes for the retrieval service, default to CPU cores.",
default=os.cpu_count() or 1,
@ -257,7 +252,7 @@ class DatabaseConfig(BaseSettings):
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
"connect_args": connect_args,
"pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO,
"pool_reset_on_return": self.SQLALCHEMY_POOL_RESET_ON_RETURN,
"pool_reset_on_return": None,
"pool_timeout": self.SQLALCHEMY_POOL_TIMEOUT,
}
return result

View File

@ -1,36 +1,6 @@
from pydantic import BaseModel, Field, JsonValue
HUMAN_INPUT_FORM_INPUT_EXAMPLE = {
"decision": "approve",
"attachment": {
"transfer_method": "local_file",
"upload_file_id": "4e0d1b87-52f2-49f6-b8c6-95cd9c954b3e",
"type": "document",
},
"attachments": [
{
"transfer_method": "local_file",
"upload_file_id": "1a77f0df-c0e6-461c-987c-e72526f341ee",
"type": "document",
},
{
"transfer_method": "remote_url",
"url": "https://example.com/report.pdf",
"type": "document",
},
],
}
from pydantic import BaseModel, JsonValue
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict[str, JsonValue] = Field(
description=(
"Submitted human input values keyed by output variable name. "
"Use a string for paragraph or select input values, a file mapping for file inputs, "
"and a list of file mappings for file-list inputs. Local file mappings use "
"`transfer_method=local_file` with `upload_file_id`; remote file mappings use "
"`transfer_method=remote_url` with `url` or `remote_url`."
),
examples=[HUMAN_INPUT_FORM_INPUT_EXAMPLE],
)
inputs: dict[str, JsonValue]
action: str

View File

@ -1,14 +1,6 @@
"""Helpers for registering Pydantic models with Flask-RESTX namespaces.
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
Flask-RESTX treats `SchemaModel` bodies as opaque JSON schemas; it does not
promote Pydantic's nested `$defs` into top-level Swagger `definitions`.
These helpers keep that translation centralized so models registered through
`register_schema_models` emit resolvable Swagger 2.0 references.
"""
from collections.abc import Mapping
from enum import StrEnum
from typing import Any, NotRequired, TypedDict
from flask_restx import Namespace
from pydantic import BaseModel, TypeAdapter
@ -16,52 +8,10 @@ from pydantic import BaseModel, TypeAdapter
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
QueryParamDoc = TypedDict(
"QueryParamDoc",
{
"in": NotRequired[str],
"type": NotRequired[str],
"items": NotRequired[dict[str, object]],
"required": NotRequired[bool],
"description": NotRequired[str],
"enum": NotRequired[list[object]],
"default": NotRequired[object],
"minimum": NotRequired[int | float],
"maximum": NotRequired[int | float],
"minLength": NotRequired[int],
"maxLength": NotRequired[int],
"minItems": NotRequired[int],
"maxItems": NotRequired[int],
},
)
def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None:
"""Register a JSON schema and promote any nested Pydantic `$defs`."""
nested_definitions = schema.get("$defs")
schema_to_register = dict(schema)
if isinstance(nested_definitions, dict):
schema_to_register.pop("$defs")
namespace.schema_model(name, schema_to_register)
if not isinstance(nested_definitions, dict):
return
for nested_name, nested_schema in nested_definitions.items():
if isinstance(nested_schema, dict):
_register_json_schema(namespace, nested_name, nested_schema)
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
"""Register a BaseModel and its nested schema definitions for Swagger documentation."""
"""Register a single BaseModel with a namespace for Swagger documentation."""
_register_json_schema(
namespace,
model.__name__,
model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
@ -84,111 +34,14 @@ def get_or_create_model(model_name: str, field_def):
def register_enum_models(namespace: Namespace, *models: type[StrEnum]) -> None:
"""Register multiple StrEnum with a namespace."""
for model in models:
_register_json_schema(
namespace,
model.__name__,
TypeAdapter(model).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
namespace.schema_model(
model.__name__, TypeAdapter(model).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
def query_params_from_model(model: type[BaseModel]) -> dict[str, QueryParamDoc]:
"""Build Flask-RESTX query parameter docs from a flat Pydantic model.
`Namespace.expect()` treats Pydantic schema models as request bodies, so GET
endpoints should keep runtime validation on the Pydantic model and feed this
derived mapping to `Namespace.doc(params=...)` for Swagger documentation.
"""
schema = model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
properties = schema.get("properties", {})
if not isinstance(properties, Mapping):
return {}
required = schema.get("required", [])
required_names = set(required) if isinstance(required, list) else set()
params: dict[str, QueryParamDoc] = {}
for name, property_schema in properties.items():
if not isinstance(name, str) or not isinstance(property_schema, Mapping):
continue
params[name] = _query_param_from_property(property_schema, required=name in required_names)
return params
def _query_param_from_property(property_schema: Mapping[str, Any], *, required: bool) -> QueryParamDoc:
param_schema = _nullable_property_schema(property_schema)
param_doc: QueryParamDoc = {"in": "query", "required": required}
description = param_schema.get("description")
if isinstance(description, str):
param_doc["description"] = description
schema_type = param_schema.get("type")
if isinstance(schema_type, str) and schema_type in {"array", "boolean", "integer", "number", "string"}:
param_doc["type"] = schema_type
if schema_type == "array":
items = param_schema.get("items")
if isinstance(items, Mapping):
item_type = items.get("type")
if isinstance(item_type, str):
param_doc["items"] = {"type": item_type}
enum = param_schema.get("enum")
if isinstance(enum, list):
param_doc["enum"] = enum
default = param_schema.get("default")
if default is not None:
param_doc["default"] = default
minimum = param_schema.get("minimum")
if isinstance(minimum, int | float):
param_doc["minimum"] = minimum
maximum = param_schema.get("maximum")
if isinstance(maximum, int | float):
param_doc["maximum"] = maximum
min_length = param_schema.get("minLength")
if isinstance(min_length, int):
param_doc["minLength"] = min_length
max_length = param_schema.get("maxLength")
if isinstance(max_length, int):
param_doc["maxLength"] = max_length
min_items = param_schema.get("minItems")
if isinstance(min_items, int):
param_doc["minItems"] = min_items
max_items = param_schema.get("maxItems")
if isinstance(max_items, int):
param_doc["maxItems"] = max_items
return param_doc
def _nullable_property_schema(property_schema: Mapping[str, Any]) -> Mapping[str, Any]:
any_of = property_schema.get("anyOf")
if not isinstance(any_of, list):
return property_schema
non_null_candidates = [
candidate for candidate in any_of if isinstance(candidate, Mapping) and candidate.get("type") != "null"
]
if len(non_null_candidates) == 1:
return {**property_schema, **non_null_candidates[0]}
return property_schema
__all__ = [
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
"get_or_create_model",
"query_params_from_model",
"register_enum_models",
"register_schema_model",
"register_schema_models",

View File

@ -108,9 +108,6 @@ from .datasets.rag_pipeline import (
rag_pipeline_workflow,
)
# Import evaluation controllers
from .evaluation import evaluation
# Import explore controllers
from .explore import (
banner,
@ -120,12 +117,8 @@ from .explore import (
saved_message,
trial,
)
# Import snippet controllers
from .snippets import snippet_workflow, snippet_workflow_draft_variable
from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport]
# Import snippet controllers
# Import tag controllers
from .tag import tags
@ -139,8 +132,6 @@ from .workspace import (
model_providers,
models,
plugin,
rbac,
snippets,
tool_providers,
trigger_providers,
workspace,
@ -178,7 +169,6 @@ __all__ = [
"datasource_content_preview",
"email_register",
"endpoint",
"evaluation",
"extension",
"external",
"feature",
@ -209,17 +199,10 @@ __all__ = [
"rag_pipeline_draft_variable",
"rag_pipeline_import",
"rag_pipeline_workflow",
"rbac",
"recommended_app",
"saved_message",
"setup",
"site",
"snippet_workflow",
"snippet_workflow",
"snippet_workflow_draft_variable",
"snippet_workflow_draft_variable",
"snippets",
"snippets",
"socketio_workflow",
"spec",
"statistic",

View File

@ -12,7 +12,6 @@ from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from configs import dify_config
from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import only_edition_cloud
from core.db.session_factory import session_factory
@ -302,7 +301,15 @@ class BatchAddNotificationAccountsPayload(BaseModel):
user_email: list[str] = Field(..., description="List of account email addresses")
register_schema_models(console_ns, UpsertNotificationPayload, BatchAddNotificationAccountsPayload)
console_ns.schema_model(
UpsertNotificationPayload.__name__,
UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
BatchAddNotificationAccountsPayload.__name__,
BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/admin/upsert_notification")

View File

@ -25,7 +25,6 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
)
from core.db.session_factory import session_factory
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -37,7 +36,6 @@ from libs.helper import build_icon_url
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
from models.workflow import resolve_workflow_kind
from services.app_dsl_service import AppDslService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
@ -352,8 +350,6 @@ class AppPartial(ResponseModel):
create_user_name: str | None = None
author_name: str | None = None
has_draft_trigger: bool | None = None
workflow_type: str | None = None
workflow_kind: str | None = None
@computed_field(return_type=str | None) # type: ignore
@property
@ -388,8 +384,6 @@ class AppDetail(ResponseModel):
updated_by: str | None = None
updated_at: int | None = None
access_mode: str | None = None
workflow_type: str | None = None
workflow_kind: str | None = None
tags: list[Tag] = Field(default_factory=list)
@field_validator("created_at", "updated_at", mode="before")
@ -532,25 +526,6 @@ class AppListApi(Resource):
for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
workflow_ids = [str(app.workflow_id) for app in app_pagination.items if app.workflow_id]
workflow_info_map: dict[str, tuple[str, str]] = {}
if workflow_ids:
rows = db.session.execute(
select(Workflow.id, Workflow.type, Workflow.kind).where(Workflow.id.in_(workflow_ids))
).all()
workflow_info_map = {
str(row.id): (
row.type.value if hasattr(row.type, "value") else str(row.type),
resolve_workflow_kind(row.kind).value,
)
for row in rows
}
for app in app_pagination.items:
workflow_info = workflow_info_map.get(str(app.workflow_id)) if app.workflow_id else None
app.workflow_type = workflow_info[0] if workflow_info else None
app.workflow_kind = workflow_info[1] if workflow_info else None
pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True)
return pagination_model.model_dump(mode="json"), 200
@ -597,18 +572,6 @@ class AppApi(Resource):
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
app_model.access_mode = app_setting.access_mode
if app_model.workflow_id:
row = db.session.execute(
select(Workflow.type, Workflow.kind).where(Workflow.id == app_model.workflow_id)
).first()
app_model.workflow_type = (
(row.type.value if hasattr(row.type, "value") else str(row.type)) if row else None
)
app_model.workflow_kind = resolve_workflow_kind(row.kind).value if row else None
else:
app_model.workflow_type = None
app_model.workflow_kind = None
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
@ -878,8 +841,7 @@ class AppTraceApi(Resource):
@account_initialization_required
def get(self, app_id):
"""Get app trace"""
with session_factory.create_session() as session:
app_trace_config = OpsTraceManager.get_app_tracing_config(app_id, session)
app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id)
return app_trace_config

View File

@ -2,7 +2,7 @@ from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.common.schema import register_schema_models
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
@ -33,7 +33,6 @@ class AppImportPayload(BaseModel):
app_id: str | None = Field(None)
register_enum_models(console_ns, ImportStatus)
register_schema_models(console_ns, AppImportPayload, Import, CheckDependenciesResult)

View File

@ -3,7 +3,6 @@ from collections.abc import Sequence
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
CompletionRequestError,
@ -20,12 +19,13 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.llm_generator import LLMGenerator
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.errors.invoke import InvokeError
from libs.login import current_account_with_tenant, login_required
from models import App
from services.workflow_service import WorkflowService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InstructionGeneratePayload(BaseModel):
flow_id: str = Field(..., description="Workflow/Flow ID")
@ -41,16 +41,16 @@ class InstructionTemplatePayload(BaseModel):
type: str = Field(..., description="Instruction template type")
register_enum_models(console_ns, LLMMode)
register_schema_models(
console_ns,
RuleGeneratePayload,
RuleCodeGeneratePayload,
RuleStructuredOutputPayload,
InstructionGeneratePayload,
InstructionTemplatePayload,
ModelConfig,
)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(RuleGeneratePayload)
reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
reg(ModelConfig)
@console_ns.route("/rule-generate")

View File

@ -18,12 +18,6 @@ from models.enums import AppMCPServerStatus
from models.model import AppMCPServer
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
@ -36,19 +30,25 @@ class MCPServerUpdatePayload(BaseModel):
status: str | None = Field(default=None, description="Server status")
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class AppMCPServerResponse(ResponseModel):
id: str
name: str
server_code: str
description: str
status: str
status: AppMCPServerStatus
parameters: dict[str, Any] | list[Any] | str
created_at: int | None = None
updated_at: int | None = None
@field_validator("parameters", mode="before")
@classmethod
def _parse_json_string(cls, value: Any) -> Any:
def _normalize_parameters(cls, value: Any) -> Any:
if isinstance(value, str):
try:
return json.loads(value)
@ -70,7 +70,9 @@ class AppMCPServerController(Resource):
@console_ns.doc("get_app_mcp_server")
@console_ns.doc(description="Get MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(
200, "MCP server configuration retrieved successfully", console_ns.models[AppMCPServerResponse.__name__]
)
@login_required
@account_initialization_required
@setup_required
@ -85,7 +87,9 @@ class AppMCPServerController(Resource):
@console_ns.doc(description="Create MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
@console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(
201, "MCP server configuration created successfully", console_ns.models[AppMCPServerResponse.__name__]
)
@console_ns.response(403, "Insufficient permissions")
@account_initialization_required
@get_app_model
@ -111,13 +115,15 @@ class AppMCPServerController(Resource):
)
db.session.add(server)
db.session.commit()
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json"), 201
@console_ns.doc("update_app_mcp_server")
@console_ns.doc(description="Update MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
@console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(
200, "MCP server configuration updated successfully", console_ns.models[AppMCPServerResponse.__name__]
)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@get_app_model
@ -154,7 +160,7 @@ class AppMCPServerRefreshController(Resource):
@console_ns.doc("refresh_app_mcp_server")
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
@console_ns.doc(params={"server_id": "Server ID"})
@console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(200, "MCP server refreshed successfully", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@setup_required

View File

@ -1,16 +1,16 @@
import json
import logging
from collections.abc import Sequence
from typing import Any, Literal
from typing import Any
from flask import abort, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, ValidationError, field_validator
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services
from controllers.common.controller_schemas import DefaultBlockConfigQuery
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
from controllers.console import console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.workflow_run import workflow_run_node_execution_model
@ -48,7 +48,7 @@ from libs.helper import TimestampField, uuid_value
from libs.login import current_account_with_tenant, login_required
from models import App
from models.model import AppMode
from models.workflow import Workflow, WorkflowKind
from models.workflow import Workflow
from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
@ -155,23 +155,6 @@ class ConvertToWorkflowPayload(BaseModel):
icon_background: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
keyword: str | None = Field(default=None, max_length=255)
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class WorkflowTypeConvertQuery(BaseModel):
target_type: Literal["workflow", "evaluation"]
class WorkflowFeaturesPayload(BaseModel):
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
@ -208,7 +191,6 @@ reg(DefaultBlockConfigQuery)
reg(ConvertToWorkflowPayload)
reg(WorkflowListQuery)
reg(WorkflowUpdatePayload)
reg(WorkflowTypeConvertQuery)
reg(WorkflowFeaturesPayload)
reg(WorkflowOnlineUsersPayload)
reg(DraftWorkflowTriggerRunPayload)
@ -883,54 +865,6 @@ class PublishedWorkflowApi(Resource):
}
@console_ns.route("/apps/<uuid:app_id>/workflows/publish/evaluation")
class EvaluationPublishedWorkflowApi(Resource):
@console_ns.doc("publish_evaluation_workflow")
@console_ns.doc(description="Publish draft workflow as evaluation workflow")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
@console_ns.response(200, "Evaluation workflow published successfully")
@console_ns.response(400, "Invalid workflow or unsupported node type")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
Publish draft workflow as evaluation workflow.
Evaluation workflows cannot include trigger or human-input nodes.
"""
current_user, _ = current_account_with_tenant()
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService()
with Session(db.engine, expire_on_commit=False) as session:
workflow = workflow_service.publish_evaluation_workflow(
session=session,
app_model=app_model,
account=current_user,
marked_name=args.marked_name or "",
marked_comment=args.marked_comment or "",
)
# Keep workflow_id aligned with the latest published workflow.
app_model_in_session = session.get(App, app_model.id)
if app_model_in_session:
app_model_in_session.workflow_id = workflow.id
app_model_in_session.updated_by = current_user.id
app_model_in_session.updated_at = naive_utc_now()
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
return {
"result": "success",
"created_at": workflow_created_at,
}
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
class DefaultBlockConfigsApi(Resource):
@console_ns.doc("get_default_block_configs")
@ -1128,52 +1062,6 @@ class DraftWorkflowRestoreApi(Resource):
}
@console_ns.route("/apps/<uuid:app_id>/workflows/convert-type")
class WorkflowTypeConvertApi(Resource):
@console_ns.doc("convert_published_workflow_type")
@console_ns.doc(description="Convert current effective published workflow type in-place")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowTypeConvertQuery.__name__])
@console_ns.response(200, "Workflow type converted successfully")
@console_ns.response(400, "Invalid workflow type or unsupported workflow graph")
@console_ns.response(404, "Workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
args = WorkflowTypeConvertQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
target_type = WorkflowKind.EVALUATION if args.target_type == "evaluation" else WorkflowKind.STANDARD
workflow_service = WorkflowService()
with Session(db.engine, expire_on_commit=False) as session:
try:
workflow = workflow_service.convert_published_workflow_type(
session=session,
app_model=app_model,
target_type=target_type,
account=current_user,
)
except WorkflowNotFoundError as exc:
raise NotFound(str(exc)) from exc
except IsDraftWorkflowError as exc:
raise BadRequest(str(exc)) from exc
except ValueError as exc:
raise BadRequest(str(exc)) from exc
session.commit()
return {
"result": "success",
"workflow_id": workflow.id,
"type": workflow.type.value,
"kind": workflow.kind_or_standard,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
class WorkflowByIdApi(Resource):
@console_ns.doc("update_workflow_by_id")

View File

@ -104,28 +104,10 @@ class WorkflowRunForArchivedLogResponse(ResponseModel):
return str(getattr(value, "value", value))
class WorkflowAppLogEvaluationNodeInfoResponse(ResponseModel):
node_id: str
type: str
title: str
class WorkflowAppLogEvaluationItemResponse(ResponseModel):
name: str
value: Any = None
details: dict[str, Any] | None = None
node_info: WorkflowAppLogEvaluationNodeInfoResponse | None = Field(
default=None,
validation_alias="node_info",
serialization_alias="nodeInfo",
)
class WorkflowAppLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForLogResponse | None = None
details: Any = None
evaluation: list[WorkflowAppLogEvaluationItemResponse] = Field(default_factory=list)
created_from: str | None = None
created_by_role: str | None = None
created_by_account: SimpleAccount | None = None
@ -139,11 +121,6 @@ class WorkflowAppLogPartialResponse(ResponseModel):
return int(value.timestamp())
return value
@field_validator("evaluation", mode="before")
@classmethod
def _normalize_evaluation(cls, value: Any) -> list[dict[str, Any]] | list[WorkflowAppLogEvaluationItemResponse]:
return value or []
class WorkflowArchivedLogPartialResponse(ResponseModel):
id: str
@ -182,8 +159,6 @@ register_schema_models(
WorkflowAppLogQuery,
WorkflowRunForLogResponse,
WorkflowRunForArchivedLogResponse,
WorkflowAppLogEvaluationNodeInfoResponse,
WorkflowAppLogEvaluationItemResponse,
WorkflowAppLogPartialResponse,
WorkflowArchivedLogPartialResponse,
WorkflowAppLogPaginationResponse,

View File

@ -1,6 +1,4 @@
import base64
import json
from datetime import UTC, datetime, timedelta
from typing import Literal
from flask import request
@ -12,7 +10,6 @@ from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@ -80,39 +77,3 @@ class PartnerTenants(Resource):
raise BadRequest("Invalid partner information")
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
_DEBUG_KEY = "billing:debug"
_DEBUG_TTL = timedelta(days=7)
class DebugDataPayload(BaseModel):
type: str = Field(..., min_length=1, description="Data type key")
data: str = Field(..., min_length=1, description="Data value to append")
@console_ns.route("/billing/debug/data")
class DebugData(Resource):
def post(self):
body = DebugDataPayload.model_validate(request.get_json(force=True))
item = json.dumps({
"type": body.type,
"data": body.data,
"createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
})
redis_client.lpush(_DEBUG_KEY, item)
redis_client.expire(_DEBUG_KEY, _DEBUG_TTL)
return {"result": "ok"}, 201
def get(self):
recent = request.args.get("recent", 10, type=int)
items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1)
return {
"data": [
json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items
]
}
def delete(self):
redis_client.delete(_DEBUG_KEY)
return {"result": "ok"}

View File

@ -1,13 +1,10 @@
import json
from typing import Any, cast
from urllib.parse import quote
from flask import Response, request
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
@ -24,7 +21,6 @@ from controllers.console.wraps import (
setup_required,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
from core.indexing_runner import IndexingRunner
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.datasource.vdb.vector_type import VectorType
@ -33,7 +29,6 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from extensions.ext_storage import storage
from fields.app_fields import app_detail_kernel_fields, related_app_list
from fields.dataset_fields import (
content_fields,
@ -56,19 +51,12 @@ from fields.document_fields import document_status_fields
from graphon.model_runtime.entities.model_entities import ModelType
from libs.login import current_account_with_tenant, login_required
from libs.url_utils import normalize_api_base_url
from models import ApiToken, Dataset, Document, DocumentSegment, EvaluationRun, EvaluationTargetType, UploadFile
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import ApiTokenType, SegmentStatus
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.errors.evaluation import (
EvaluationDatasetInvalidError,
EvaluationFrameworkNotConfiguredError,
EvaluationMaxConcurrentRunsError,
EvaluationNotFoundError,
)
from services.evaluation_service import EvaluationService
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
@ -997,432 +985,3 @@ class DatasetAutoDisableLogApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
# ---- Knowledge Base Retrieval Evaluation ----
def _serialize_dataset_evaluation_run(run: EvaluationRun) -> dict[str, Any]:
return {
"id": run.id,
"tenant_id": run.tenant_id,
"target_type": run.target_type,
"target_id": run.target_id,
"evaluation_config_id": run.evaluation_config_id,
"status": run.status,
"dataset_file_id": run.dataset_file_id,
"result_file_id": run.result_file_id,
"total_items": run.total_items,
"completed_items": run.completed_items,
"failed_items": run.failed_items,
"progress": run.progress,
"metrics_summary": json.loads(run.metrics_summary) if run.metrics_summary else {},
"error": run.error,
"created_by": run.created_by,
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
}
def _serialize_dataset_evaluation_run_item(item: Any) -> dict[str, Any]:
return {
"id": item.id,
"item_index": item.item_index,
"inputs": item.inputs_dict,
"expected_output": item.expected_output,
"actual_output": item.actual_output,
"metrics": item.metrics_list,
"judgment": item.judgment_dict,
"metadata": item.metadata_dict,
"error": item.error,
"overall_score": item.overall_score,
}
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/template/download")
class DatasetEvaluationTemplateDownloadApi(Resource):
@console_ns.doc("download_dataset_evaluation_template")
@console_ns.response(200, "Template file streamed as XLSX attachment")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
"""Download evaluation dataset template for knowledge base retrieval."""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
xlsx_content, filename = EvaluationService.generate_retrieval_dataset_template()
encoded_filename = quote(filename)
response = Response(
xlsx_content,
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Length"] = str(len(xlsx_content))
return response
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation")
class DatasetEvaluationDetailApi(Resource):
@console_ns.doc("get_dataset_evaluation_config")
@console_ns.response(200, "Evaluation configuration retrieved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
"""Get evaluation configuration for the knowledge base."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.get_evaluation_config(
session, current_tenant_id, "dataset", dataset_id_str
)
if config is None:
return {
"evaluation_model": None,
"evaluation_model_provider": None,
"default_metrics": None,
"customized_metrics": None,
"judgment_config": None,
}
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"default_metrics": config.default_metrics_list,
"customized_metrics": config.customized_metrics_dict,
"judgment_config": config.judgment_config_dict,
}
@console_ns.doc("save_dataset_evaluation_config")
@console_ns.response(200, "Evaluation configuration saved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def put(self, dataset_id):
"""Save evaluation configuration for the knowledge base."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
body = request.get_json(force=True)
try:
config_data = EvaluationConfigData.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.save_evaluation_config(
session=session,
tenant_id=current_tenant_id,
target_type="dataset",
target_id=dataset_id_str,
account_id=str(current_user.id),
data=config_data,
)
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"default_metrics": config.default_metrics_list,
"customized_metrics": config.customized_metrics_dict,
"judgment_config": config.judgment_config_dict,
}
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/run")
class DatasetEvaluationRunApi(Resource):
@console_ns.doc("start_dataset_evaluation_run")
@console_ns.response(200, "Evaluation run started")
@console_ns.response(400, "Invalid request")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
"""Start an evaluation run for the knowledge base retrieval."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
body = request.get_json(force=True)
if not body:
raise BadRequest("Request body is required.")
try:
run_request = EvaluationRunRequest.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
upload_file = (
db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=current_tenant_id).first()
)
if not upload_file:
raise NotFound("Dataset file not found.")
try:
dataset_content = storage.load_once(upload_file.key)
except Exception:
raise BadRequest("Failed to read dataset file.")
if not dataset_content:
raise BadRequest("Dataset file is empty.")
try:
with Session(db.engine, expire_on_commit=False) as session:
evaluation_run = EvaluationService.start_evaluation_run(
session=session,
tenant_id=current_tenant_id,
target_type=EvaluationTargetType.KNOWLEDGE_BASE,
target_id=dataset_id_str,
account_id=str(current_user.id),
dataset_file_content=dataset_content,
run_request=run_request,
)
return _serialize_dataset_evaluation_run(evaluation_run), 200
except EvaluationFrameworkNotConfiguredError as e:
return {"message": str(e.description)}, 400
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except EvaluationMaxConcurrentRunsError as e:
return {"message": str(e.description)}, 429
except EvaluationDatasetInvalidError as e:
return {"message": str(e.description)}, 400
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/logs")
class DatasetEvaluationLogsApi(Resource):
@console_ns.doc("get_dataset_evaluation_logs")
@console_ns.response(200, "Evaluation logs retrieved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
"""Get evaluation run history for the knowledge base."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
with Session(db.engine, expire_on_commit=False) as session:
runs, total = EvaluationService.get_evaluation_runs(
session=session,
tenant_id=current_tenant_id,
target_type="dataset",
target_id=dataset_id_str,
page=page,
page_size=page_size,
)
return {
"data": [_serialize_dataset_evaluation_run(run) for run in runs],
"total": total,
"page": page,
"page_size": page_size,
}
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>")
class DatasetEvaluationRunDetailApi(Resource):
@console_ns.doc("get_dataset_evaluation_run_detail")
@console_ns.response(200, "Evaluation run detail retrieved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset or run not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, run_id):
"""Get evaluation run detail including per-item results."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
run_id_str = str(run_id)
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 50, type=int)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.get_evaluation_run_detail(
session=session,
tenant_id=current_tenant_id,
run_id=run_id_str,
)
items, total_items = EvaluationService.get_evaluation_run_items(
session=session,
run_id=run_id_str,
page=page,
page_size=page_size,
)
return {
"run": _serialize_dataset_evaluation_run(run),
"items": {
"data": [_serialize_dataset_evaluation_run_item(item) for item in items],
"total": total_items,
"page": page,
"page_size": page_size,
},
}
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>/cancel")
class DatasetEvaluationRunCancelApi(Resource):
@console_ns.doc("cancel_dataset_evaluation_run")
@console_ns.response(200, "Evaluation run cancelled")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset or run not found")
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id, run_id):
"""Cancel a running knowledge base evaluation."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
run_id_str = str(run_id)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.cancel_evaluation_run(
session=session,
tenant_id=current_tenant_id,
run_id=run_id_str,
)
return _serialize_dataset_evaluation_run(run)
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except ValueError as e:
return {"message": str(e)}, 400
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/metrics")
class DatasetEvaluationMetricsApi(Resource):
@console_ns.doc("get_dataset_evaluation_metrics")
@console_ns.response(200, "Available retrieval metrics retrieved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
"""Get available evaluation metrics for knowledge base retrieval."""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
return {
"metrics": EvaluationService.get_supported_metrics(EvaluationCategory.RETRIEVAL)
}
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/files/<uuid:file_id>")
class DatasetEvaluationFileDownloadApi(Resource):
@console_ns.doc("download_dataset_evaluation_file")
@console_ns.response(200, "File download URL generated")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset or file not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, file_id):
"""Download evaluation test file or result file for the knowledge base."""
from core.workflow.file import helpers as file_helpers
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
file_id_str = str(file_id)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(UploadFile).where(
UploadFile.id == file_id_str,
UploadFile.tenant_id == current_tenant_id,
)
upload_file = session.execute(stmt).scalar_one_or_none()
if not upload_file:
raise NotFound("File not found.")
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
"download_url": download_url,
}

View File

@ -3,18 +3,19 @@ import logging
from argparse import ArgumentTypeError
from collections.abc import Sequence
from contextlib import ExitStack
from datetime import datetime
from typing import Any, Literal, cast
import sqlalchemy as sa
from flask import request, send_file
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource, marshal
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import asc, desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from core.errors.error import (
LLMBadRequestError,
@ -29,11 +30,9 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
from fields.dataset_fields import dataset_fields
from fields.base import ResponseModel
from fields.document_fields import (
dataset_and_document_fields,
document_fields,
document_metadata_fields,
document_status_fields,
document_with_segments_fields,
)
@ -72,27 +71,100 @@ from ..wraps import (
logger = logging.getLogger(__name__)
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_model = get_or_create_model("Dataset", dataset_fields)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields)
document_fields_copy = document_fields.copy()
document_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
document_model = get_or_create_model("Document", document_fields_copy)
def _normalize_enum(value: Any) -> Any:
if isinstance(value, str) or value is None:
return value
return getattr(value, "value", value)
document_with_segments_fields_copy = document_with_segments_fields.copy()
document_with_segments_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
dataset_and_document_fields_copy = dataset_and_document_fields.copy()
dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
class DatasetResponse(ResponseModel):
id: str
name: str
description: str | None = None
permission: str | None = None
data_source_type: str | None = None
indexing_technique: str | None = None
created_by: str | None = None
created_at: int | None = None
@field_validator("data_source_type", "indexing_technique", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> Any:
return _normalize_enum(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class DocumentMetadataResponse(ResponseModel):
id: str
name: str
type: str
value: str | None = None
class DocumentResponse(ResponseModel):
id: str
position: int | None = None
data_source_type: str | None = None
data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict")
data_source_detail_dict: Any = None
dataset_process_rule_id: str | None = None
name: str
created_from: str | None = None
created_by: str | None = None
created_at: int | None = None
tokens: int | None = None
indexing_status: str | None = None
error: str | None = None
enabled: bool | None = None
disabled_at: int | None = None
disabled_by: str | None = None
archived: bool | None = None
display_status: str | None = None
word_count: int | None = None
hit_count: int | None = None
doc_form: str | None = None
doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details")
summary_index_status: str | None = None
need_summary: bool | None = None
@field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> Any:
return _normalize_enum(value)
@field_validator("doc_metadata", mode="before")
@classmethod
def _normalize_doc_metadata(cls, value: Any) -> list[Any]:
if value is None:
return []
return value
@field_validator("created_at", "disabled_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class DocumentWithSegmentsResponse(DocumentResponse):
process_rule_dict: Any = None
completed_segments: int | None = None
total_segments: int | None = None
class DatasetAndDocumentResponse(ResponseModel):
dataset: DatasetResponse
documents: list[DocumentResponse]
batch: str
class DocumentRetryPayload(BaseModel):
@ -107,6 +179,11 @@ class GenerateSummaryPayload(BaseModel):
document_list: list[str]
class DocumentMetadataUpdatePayload(BaseModel):
doc_type: str | None = None
doc_metadata: Any = None
class DocumentDatasetListParam(BaseModel):
page: int = Field(1, title="Page", description="Page number.")
limit: int = Field(20, title="Limit", description="Page size.")
@ -124,7 +201,13 @@ register_schema_models(
DocumentRetryPayload,
DocumentRenamePayload,
GenerateSummaryPayload,
DocumentMetadataUpdatePayload,
DocumentBatchDownloadZipPayload,
DatasetResponse,
DocumentMetadataResponse,
DocumentResponse,
DocumentWithSegmentsResponse,
DatasetAndDocumentResponse,
)
@ -357,10 +440,10 @@ class DatasetDocumentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
@ -398,7 +481,9 @@ class DatasetDocumentListApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
return {"dataset": dataset, "documents": documents, "batch": batch}
return DatasetAndDocumentResponse.model_validate(
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
).model_dump(mode="json")
@setup_required
@login_required
@ -426,12 +511,13 @@ class DatasetInitApi(Resource):
@console_ns.doc("init_dataset")
@console_ns.doc(description="Initialize dataset with documents")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
@console_ns.response(
201, "Dataset initialized successfully", console_ns.models[DatasetAndDocumentResponse.__name__]
)
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
@ -479,9 +565,9 @@ class DatasetInitApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
response = {"dataset": dataset, "documents": documents, "batch": batch}
return response
return DatasetAndDocumentResponse.model_validate(
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
).model_dump(mode="json")
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
@ -988,15 +1074,7 @@ class DocumentMetadataApi(DocumentResource):
@console_ns.doc("update_document_metadata")
@console_ns.doc(description="Update document metadata")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.expect(
console_ns.model(
"UpdateDocumentMetadataRequest",
{
"doc_type": fields.String(description="Document type"),
"doc_metadata": fields.Raw(description="Document metadata"),
},
)
)
@console_ns.expect(console_ns.models[DocumentMetadataUpdatePayload.__name__])
@console_ns.response(200, "Document metadata updated successfully")
@console_ns.response(404, "Document not found")
@console_ns.response(403, "Permission denied")
@ -1009,10 +1087,10 @@ class DocumentMetadataApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
req_data = request.get_json()
req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
doc_type = req_data.get("doc_type")
doc_metadata = req_data.get("doc_metadata")
doc_type = req_data.doc_type
doc_metadata = req_data.doc_metadata
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
@ -1194,7 +1272,7 @@ class DocumentRenameApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(document_model)
@console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
@ -1212,7 +1290,7 @@ class DocumentRenameApi(DocumentResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return document
return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json")
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")

View File

@ -1 +0,0 @@
# Evaluation controller module

View File

@ -1,993 +0,0 @@
from __future__ import annotations
import logging
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Union
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource, fields, marshal
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.workflow import WorkflowListQuery
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
from extensions.ext_database import db
from extensions.ext_storage import storage
from fields.member_fields import simple_account_fields
from graphon.file import helpers as file_helpers
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models import App, Dataset
from models.evaluation import EvaluationTargetType
from models.model import UploadFile
from models.snippet import CustomizedSnippet
from services.errors.evaluation import (
EvaluationDatasetInvalidError,
EvaluationFrameworkNotConfiguredError,
EvaluationMaxConcurrentRunsError,
EvaluationNotFoundError,
)
from services.evaluation_service import EvaluationService
from services.workflow_service import WorkflowService
if TYPE_CHECKING:
from models.evaluation import EvaluationRun, EvaluationRunItem
logger = logging.getLogger(__name__)
EVALUATE_TARGET_TYPES = {
EvaluationTargetType.APPS.value,
EvaluationTargetType.SNIPPETS.value,
}
class VersionQuery(BaseModel):
"""Query parameters for version endpoint."""
version: str
register_schema_models(
console_ns,
VersionQuery,
)
# Response field definitions
file_info_fields = {
"id": fields.String,
"name": fields.String,
}
evaluation_log_fields = {
"created_at": TimestampField,
"created_by": fields.String,
"test_file": fields.Nested(
console_ns.model(
"EvaluationTestFile",
file_info_fields,
)
),
"result_file": fields.Nested(
console_ns.model(
"EvaluationResultFile",
file_info_fields,
),
allow_null=True,
),
"version": fields.String,
}
evaluation_log_list_model = console_ns.model(
"EvaluationLogList",
{
"data": fields.List(fields.Nested(console_ns.model("EvaluationLog", evaluation_log_fields))),
},
)
evaluation_default_metric_node_info_fields = {
"node_id": fields.String,
"type": fields.String,
"title": fields.String,
}
evaluation_default_metric_item_fields = {
"metric": fields.String,
"value_type": fields.String,
"node_info_list": fields.List(
fields.Nested(
console_ns.model("EvaluationDefaultMetricNodeInfo", evaluation_default_metric_node_info_fields),
),
),
}
customized_metrics_fields = {
"evaluation_workflow_id": fields.String,
"input_fields": fields.Raw,
"output_fields": fields.Raw,
}
judgment_condition_fields = {
"variable_selector": fields.List(fields.String),
"comparison_operator": fields.String,
"value": fields.String,
}
judgment_config_fields = {
"logical_operator": fields.String,
"conditions": fields.List(fields.Nested(console_ns.model("JudgmentCondition", judgment_condition_fields))),
}
evaluation_detail_fields = {
"evaluation_model": fields.String,
"evaluation_model_provider": fields.String,
"default_metrics": fields.List(
fields.Nested(console_ns.model("EvaluationDefaultMetricItem_Detail", evaluation_default_metric_item_fields)),
allow_null=True,
),
"customized_metrics": fields.Nested(
console_ns.model("EvaluationCustomizedMetrics", customized_metrics_fields),
allow_null=True,
),
"judgment_config": fields.Nested(
console_ns.model("EvaluationJudgmentConfig", judgment_config_fields),
allow_null=True,
),
}
evaluation_detail_model = console_ns.model("EvaluationDetail", evaluation_detail_fields)
available_evaluation_workflow_list_fields = {
"id": fields.String,
"app_id": fields.String,
"app_name": fields.String,
"type": fields.String,
"kind": fields.String,
"version": fields.String,
"marked_name": fields.String,
"marked_comment": fields.String,
"hash": fields.String,
"created_by": fields.Nested(simple_account_fields),
"created_at": TimestampField,
"updated_by": fields.Nested(simple_account_fields, allow_null=True),
"updated_at": TimestampField,
}
available_evaluation_workflow_pagination_fields = {
"items": fields.List(fields.Nested(available_evaluation_workflow_list_fields)),
"page": fields.Integer,
"limit": fields.Integer,
"has_more": fields.Boolean,
}
available_evaluation_workflow_pagination_model = console_ns.model(
"AvailableEvaluationWorkflowPagination",
available_evaluation_workflow_pagination_fields,
)
evaluation_default_metrics_response_model = console_ns.model(
"EvaluationDefaultMetricsResponse",
{
"default_metrics": fields.List(
fields.Nested(console_ns.model("EvaluationDefaultMetricItem", evaluation_default_metric_item_fields)),
),
},
)
evaluation_dataset_columns_response_model = console_ns.model(
"EvaluationDatasetColumnsResponse",
{
"columns": fields.List(
fields.Nested(
console_ns.model(
"EvaluationTemplateColumn",
{
"name": fields.String,
"type": fields.String,
},
)
)
),
},
)
def get_evaluation_target[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
"""
Decorator to resolve polymorphic evaluation target (apps or snippets).
Validates the target_type parameter and fetches the corresponding
model (App or CustomizedSnippet) with tenant isolation.
"""
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
target_type = kwargs.get("evaluate_target_type")
target_id = kwargs.get("evaluate_target_id")
if target_type not in EVALUATE_TARGET_TYPES:
raise NotFound(f"Invalid evaluation target type: {target_type}")
_, current_tenant_id = current_account_with_tenant()
target_id = str(target_id)
# Remove path parameters
del kwargs["evaluate_target_type"]
del kwargs["evaluate_target_id"]
target: Union[App, CustomizedSnippet] | None = None
if target_type == EvaluationTargetType.APPS.value:
target = db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first()
elif target_type == EvaluationTargetType.SNIPPETS.value:
target = (
db.session.query(CustomizedSnippet)
.where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id)
.first()
)
if not target:
raise NotFound(f"{str(target_type)} not found")
kwargs["target"] = target
kwargs["target_type"] = target_type
return view_func(*args, **kwargs)
return decorated_view
def _load_evaluation_run_request_and_dataset(tenant_id: str) -> tuple[EvaluationRunRequest, bytes, str]:
"""Validate the run payload and load the uploaded dataset bytes."""
body = request.get_json(force=True)
if not body:
raise BadRequest("Request body is required.")
try:
run_request = EvaluationRunRequest.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
upload_file = db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=tenant_id).first()
if not upload_file:
raise NotFound("Dataset file not found.")
try:
dataset_content = storage.load_once(upload_file.key)
except Exception:
raise BadRequest("Failed to read dataset file.")
if not dataset_content:
raise BadRequest("Dataset file is empty.")
return run_request, dataset_content, upload_file.name
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/dataset-template/download")
class EvaluationDatasetTemplateDownloadApi(Resource):
@console_ns.doc("download_evaluation_dataset_template")
@console_ns.response(200, "Template file streamed as XLSX attachment")
@console_ns.response(400, "Invalid target type or excluded app mode")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Download evaluation dataset template.
Generates an XLSX template based on the target's input parameters
and streams it directly as a file attachment.
"""
try:
xlsx_content, filename = EvaluationService.generate_dataset_template(
target=target,
target_type=target_type,
)
except ValueError as e:
return {"message": str(e)}, 400
encoded_filename = quote(filename)
response = Response(
xlsx_content,
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Length"] = str(len(xlsx_content))
return response
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation")
class EvaluationDetailApi(Resource):
@console_ns.doc("get_evaluation_detail")
@console_ns.response(200, "Evaluation details retrieved successfully", evaluation_detail_model)
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation configuration for the target.
Returns evaluation configuration including model settings,
metrics config, and judgement conditions.
"""
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.get_evaluation_config(session, current_tenant_id, target_type, str(target.id))
if config is None:
return {
"evaluation_model": None,
"evaluation_model_provider": None,
"default_metrics": None,
"customized_metrics": None,
"judgment_config": None,
}
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"default_metrics": EvaluationService.serialize_console_default_metrics(config.default_metrics_list),
"customized_metrics": config.customized_metrics_dict,
"judgment_config": EvaluationService.serialize_console_judgment_config(config.judgment_config_dict),
}
@console_ns.doc("save_evaluation_detail")
@console_ns.response(200, "Evaluation configuration saved successfully")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def put(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Save evaluation configuration for the target.
"""
current_account, current_tenant_id = current_account_with_tenant()
body = request.get_json(force=True)
try:
config_data = EvaluationConfigData.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.save_evaluation_config(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
account_id=str(current_account.id),
data=config_data,
)
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"default_metrics": EvaluationService.serialize_console_default_metrics(config.default_metrics_list),
"customized_metrics": config.customized_metrics_dict,
"judgment_config": EvaluationService.serialize_console_judgment_config(config.judgment_config_dict),
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/template-columns")
class EvaluationTemplateColumnsApi(Resource):
@console_ns.doc("get_evaluation_template_columns")
@console_ns.response(200, "Evaluation dataset columns resolved", evaluation_dataset_columns_response_model)
@console_ns.response(400, "Invalid request body")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
"""Return the dataset template columns implied by the current evaluation config."""
body = request.get_json(silent=True) or {}
try:
config_data = EvaluationConfigData.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
return {
"columns": EvaluationService.get_dataset_column_names(
target=target,
target_type=target_type,
data=config_data,
)
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/logs")
class EvaluationLogsApi(Resource):
@console_ns.doc("get_evaluation_logs")
@console_ns.response(200, "Evaluation logs retrieved successfully")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation run history for the target.
Returns a paginated list of evaluation runs.
"""
_, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
with Session(db.engine, expire_on_commit=False) as session:
runs, total = EvaluationService.get_evaluation_runs(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
page=page,
page_size=page_size,
)
return {
"data": [_serialize_evaluation_run(run) for run in runs],
"total": total,
"page": page,
"page_size": page_size,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/run1")
class EvaluationRunApi(Resource):
@console_ns.doc("start_evaluation_run")
@console_ns.response(200, "Evaluation run started")
@console_ns.response(400, "Invalid request")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet, Dataset], target_type: str):
"""
Start an evaluation run.
Expects JSON body with:
- file_id: uploaded dataset file ID
- evaluation_model: evaluation model name
- evaluation_model_provider: evaluation model provider
- default_metrics: list of default metric objects
- customized_metrics: customized metrics object (optional)
- judgment_config: judgment conditions config (optional)
"""
current_account, current_tenant_id = current_account_with_tenant()
run_request, dataset_content, dataset_filename = _load_evaluation_run_request_and_dataset(current_tenant_id)
try:
with Session(db.engine, expire_on_commit=False) as session:
if target_type == EvaluationTargetType.APPS.value:
evaluation_run = EvaluationService.start_stub_evaluation_run(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
account_id=str(current_account.id),
dataset_file_content=dataset_content,
dataset_filename=dataset_filename,
run_request=run_request,
)
else:
evaluation_run = EvaluationService.start_evaluation_run(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
account_id=str(current_account.id),
dataset_file_content=dataset_content,
dataset_filename=dataset_filename,
run_request=run_request,
)
return _serialize_evaluation_run(evaluation_run), 200
except EvaluationFrameworkNotConfiguredError as e:
return {"message": str(e.description)}, 400
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except EvaluationMaxConcurrentRunsError as e:
return {"message": str(e.description)}, 429
except EvaluationDatasetInvalidError as e:
return {"message": str(e.description)}, 400
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/run")
class EvaluationRunRealApi(Resource):
@console_ns.doc("start_evaluation_run_real")
@console_ns.response(200, "Evaluation run started")
@console_ns.response(400, "Invalid request")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet, Dataset], target_type: str):
"""Start the real evaluation execution flow on the temporary dev path."""
current_account, current_tenant_id = current_account_with_tenant()
run_request, dataset_content, dataset_filename = _load_evaluation_run_request_and_dataset(current_tenant_id)
try:
with Session(db.engine, expire_on_commit=False) as session:
evaluation_run = EvaluationService.start_evaluation_run(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
account_id=str(current_account.id),
dataset_file_content=dataset_content,
dataset_filename=dataset_filename,
run_request=run_request,
)
return _serialize_evaluation_run(evaluation_run), 200
except EvaluationFrameworkNotConfiguredError as e:
return {"message": str(e.description)}, 400
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except EvaluationMaxConcurrentRunsError as e:
return {"message": str(e.description)}, 429
except EvaluationDatasetInvalidError as e:
return {"message": str(e.description)}, 400
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>")
class EvaluationRunDetailApi(Resource):
@console_ns.doc("get_evaluation_run_detail")
@console_ns.response(200, "Evaluation run detail retrieved")
@console_ns.response(404, "Run not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
"""
Get evaluation run detail including items.
"""
_, current_tenant_id = current_account_with_tenant()
run_id = str(run_id)
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 50, type=int)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.get_evaluation_run_detail(
session=session,
tenant_id=current_tenant_id,
run_id=run_id,
)
items, total_items = EvaluationService.get_evaluation_run_items(
session=session,
run_id=run_id,
page=page,
page_size=page_size,
)
return {
"run": _serialize_evaluation_run(run),
"items": {
"data": [_serialize_evaluation_run_item(item) for item in items],
"total": total_items,
"page": page,
"page_size": page_size,
},
}
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>/cancel")
class EvaluationRunCancelApi(Resource):
@console_ns.doc("cancel_evaluation_run")
@console_ns.response(200, "Evaluation run cancelled")
@console_ns.response(404, "Run not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
"""Cancel a running evaluation."""
_, current_tenant_id = current_account_with_tenant()
run_id = str(run_id)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.cancel_evaluation_run(
session=session,
tenant_id=current_tenant_id,
run_id=run_id,
)
return _serialize_evaluation_run(run)
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except ValueError as e:
return {"message": str(e)}, 400
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/metrics")
class EvaluationMetricsApi(Resource):
@console_ns.doc("get_evaluation_metrics")
@console_ns.response(200, "Available metrics retrieved")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get available evaluation metrics for the current framework.
"""
result = {}
for category in EvaluationCategory:
if category in EvaluationService.CONSOLE_DISABLED_CATEGORIES:
continue
result[category.value] = EvaluationService.get_supported_metrics(category)
return {"metrics": result}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/default-metrics")
class EvaluationDefaultMetricsApi(Resource):
@console_ns.doc(
"get_evaluation_default_metrics_with_nodes",
description=(
"List default metrics supported by the current evaluation framework with matching nodes "
"from the target's published workflow only (draft is ignored)."
),
)
@console_ns.response(
200,
"Default metrics and node candidates for the published workflow",
evaluation_default_metrics_response_model,
)
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
default_metrics = EvaluationService.get_default_metrics_with_nodes_for_published_target(
target=target,
target_type=target_type,
)
return {
"default_metrics": [
m.model_dump() for m in EvaluationService.filter_console_default_metrics(default_metrics)
]
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/node-info")
class EvaluationNodeInfoApi(Resource):
@console_ns.doc("get_evaluation_node_info")
@console_ns.response(200, "Node info grouped by metric")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
"""Return workflow/snippet node info grouped by requested metrics.
Request body (JSON):
- metrics: list[str] | None metric names to query; omit or pass
an empty list to get all nodes under key ``"all"``.
Response:
``{metric_or_all: [{"node_id": ..., "type": ..., "title": ...}, ...]}``
"""
body = request.get_json(silent=True) or {}
metrics: list[str] | None = body.get("metrics") or None
result = EvaluationService.get_nodes_for_metrics(
target=target,
target_type=target_type,
metrics=metrics,
)
if not metrics:
result = {
"all": [
node
for node in result.get("all", [])
if node.get("type") not in EvaluationService.CONSOLE_DISABLED_CATEGORIES
]
}
else:
result = {
metric: nodes
for metric, nodes in result.items()
if metric not in EvaluationService.CONSOLE_DISABLED_METRICS
}
return result
@console_ns.route("/evaluation/available-metrics")
class EvaluationAvailableMetricsApi(Resource):
@console_ns.doc("get_available_evaluation_metrics")
@console_ns.response(200, "Available metrics list")
@setup_required
@login_required
@account_initialization_required
def get(self):
"""Return the centrally-defined list of evaluation metrics."""
return {
"metrics": [
metric
for metric in EvaluationService.get_available_metrics()
if metric not in EvaluationService.CONSOLE_DISABLED_METRICS
]
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/files/<uuid:file_id>")
class EvaluationFileDownloadApi(Resource):
@console_ns.doc("download_evaluation_file")
@console_ns.response(200, "File download URL generated successfully")
@console_ns.response(404, "Target or file not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str, file_id: str):
"""
Download evaluation test file or result file.
Looks up the specified file, verifies it belongs to the same tenant,
and returns file info and download URL.
"""
file_id = str(file_id)
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(UploadFile).where(
UploadFile.id == file_id,
UploadFile.tenant_id == current_tenant_id,
)
upload_file = session.execute(stmt).scalar_one_or_none()
if not upload_file:
raise NotFound("File not found")
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
"download_url": download_url,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/version")
class EvaluationVersionApi(Resource):
@console_ns.doc("get_evaluation_version_detail")
@console_ns.expect(console_ns.models.get(VersionQuery.__name__))
@console_ns.response(200, "Version details retrieved successfully")
@console_ns.response(404, "Target or version not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation target version details.
Returns the workflow graph for the specified version.
"""
version = request.args.get("version")
if not version:
return {"message": "version parameter is required"}, 400
graph = {}
if target_type == EvaluationTargetType.SNIPPETS.value and isinstance(target, CustomizedSnippet):
graph = target.graph_dict
return {
"graph": graph,
}
@console_ns.route("/workspaces/current/available-evaluation-workflows")
class AvailableEvaluationWorkflowsApi(Resource):
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
@console_ns.doc("list_available_evaluation_workflows")
@console_ns.doc(description="List published evaluation workflows in the current workspace (all apps)")
@console_ns.response(
200,
"Available evaluation workflows retrieved",
available_evaluation_workflow_pagination_model,
)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self):
"""List published evaluation-type workflows for the current tenant (cross-app)."""
current_user, current_tenant_id = current_account_with_tenant()
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
page = args.page
limit = args.limit
user_id = args.user_id
named_only = args.named_only
keyword = args.keyword
if user_id and user_id != current_user.id:
raise Forbidden()
workflow_service = WorkflowService()
with Session(db.engine) as session:
workflows, has_more = workflow_service.list_published_evaluation_workflows(
session=session,
tenant_id=current_tenant_id,
page=page,
limit=limit,
user_id=user_id,
named_only=named_only,
keyword=keyword,
)
app_ids = {w.app_id for w in workflows}
if app_ids:
apps = session.scalars(select(App).where(App.id.in_(app_ids))).all()
app_names = {a.id: a.name for a in apps}
else:
app_names = {}
items = []
for wf in workflows:
items.append(
{
"id": wf.id,
"app_id": wf.app_id,
"app_name": app_names.get(wf.app_id, ""),
"type": wf.type.value,
"kind": wf.kind_or_standard,
"version": wf.version,
"marked_name": wf.marked_name,
"marked_comment": wf.marked_comment,
"hash": wf.unique_hash,
"created_by": wf.created_by_account,
"created_at": wf.created_at,
"updated_by": wf.updated_by_account,
"updated_at": wf.updated_at,
}
)
return (
marshal(
{"items": items, "page": page, "limit": limit, "has_more": has_more},
available_evaluation_workflow_pagination_fields,
),
200,
)
@console_ns.route("/workspaces/current/evaluation-workflows/<string:workflow_id>/associated-targets")
class EvaluationWorkflowAssociatedTargetsApi(Resource):
@console_ns.doc("list_evaluation_workflow_associated_targets")
@console_ns.doc(
description="List targets (apps / snippets / knowledge bases) that use the given workflow as customized metrics"
)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, workflow_id: str):
"""Return all evaluation targets that reference this workflow as customized metrics."""
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
configs = EvaluationService.list_targets_by_customized_workflow(
session=session,
tenant_id=current_tenant_id,
customized_workflow_id=workflow_id,
)
target_ids_by_type: dict[str, list[str]] = {}
for cfg in configs:
target_ids_by_type.setdefault(cfg.target_type, []).append(cfg.target_id)
app_names: dict[str, str] = {}
if EvaluationTargetType.APPS.value in target_ids_by_type:
apps = session.scalars(
select(App).where(App.id.in_(target_ids_by_type[EvaluationTargetType.APPS.value]))
).all()
app_names = {a.id: a.name for a in apps}
snippet_names: dict[str, str] = {}
if "snippets" in target_ids_by_type:
snippets = session.scalars(
select(CustomizedSnippet).where(CustomizedSnippet.id.in_(target_ids_by_type["snippets"]))
).all()
snippet_names = {s.id: s.name for s in snippets}
dataset_names: dict[str, str] = {}
if "knowledge_base" in target_ids_by_type:
datasets = session.scalars(
select(Dataset).where(Dataset.id.in_(target_ids_by_type["knowledge_base"]))
).all()
dataset_names = {d.id: d.name for d in datasets}
items = []
for cfg in configs:
name = ""
if cfg.target_type == EvaluationTargetType.APPS.value:
name = app_names.get(cfg.target_id, "")
elif cfg.target_type == EvaluationTargetType.SNIPPETS.value:
name = snippet_names.get(cfg.target_id, "")
elif cfg.target_type == "knowledge_base":
name = dataset_names.get(cfg.target_id, "")
items.append(
{
"target_type": cfg.target_type,
"target_id": cfg.target_id,
"target_name": name,
}
)
return {"items": items}, 200
# ---- Serialization Helpers ----
def _serialize_evaluation_run(run: EvaluationRun) -> dict[str, object]:
return {
"id": run.id,
"tenant_id": run.tenant_id,
"target_type": run.target_type,
"target_id": run.target_id,
"evaluation_config_id": run.evaluation_config_id,
"status": run.status,
"dataset_file_id": run.dataset_file_id,
"result_file_id": run.result_file_id,
"total_items": run.total_items,
"completed_items": run.completed_items,
"failed_items": run.failed_items,
"progress": run.progress,
"metrics_summary": run.metrics_summary_dict,
"error": run.error,
"created_by": run.created_by,
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
}
def _serialize_evaluation_run_item(item: EvaluationRunItem) -> dict[str, object]:
return {
"id": item.id,
"item_index": item.item_index,
"inputs": item.inputs_dict,
"expected_output": item.expected_output,
"actual_output": item.actual_output,
"metrics": item.metrics_list,
"judgment": item.judgment_dict,
"metadata": item.metadata_dict,
"error": item.error,
"overall_score": item.overall_score,
}

View File

@ -5,7 +5,7 @@ from flask_restx import Resource
from pydantic import BaseModel, Field, computed_field, field_validator
from constants.languages import languages
from controllers.common.schema import query_params_from_model, register_schema_models
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required
from fields.base import ResponseModel
@ -15,7 +15,7 @@ from services.recommended_app_service import RecommendedAppService
class RecommendedAppsQuery(BaseModel):
language: str | None = Field(default=None, description="Language code for recommended app localization")
language: str | None = Field(default=None)
class RecommendedAppInfoResponse(ResponseModel):
@ -74,7 +74,7 @@ register_schema_models(
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource):
@console_ns.doc(params=query_params_from_model(RecommendedAppsQuery))
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
@console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__])
@login_required
@account_initialization_required

View File

@ -1,142 +0,0 @@
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
class SnippetListQuery(BaseModel):
"""Query parameters for listing snippets."""
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=20, ge=1, le=100)
keyword: str | None = None
is_published: bool | None = Field(default=None, description="Filter by published status")
creators: list[str] | None = Field(default=None, description="Filter by creator account IDs")
@field_validator("creators", mode="before")
@classmethod
def parse_creators(cls, value: object) -> list[str] | None:
"""Normalize creators filter from query string or list input."""
if value is None:
return None
if isinstance(value, str):
return [creator.strip() for creator in value.split(",") if creator.strip()] or None
if isinstance(value, list):
return [str(creator).strip() for creator in value if str(creator).strip()] or None
return None
class IconInfo(BaseModel):
"""Icon information model."""
icon: str | None = None
icon_type: Literal["emoji", "image"] | None = None
icon_background: str | None = None
icon_url: str | None = None
class InputFieldDefinition(BaseModel):
"""Input field definition for snippet parameters."""
default: str | None = None
hint: bool | None = None
label: str | None = None
max_length: int | None = None
options: list[str] | None = None
placeholder: str | None = None
required: bool | None = None
type: str | None = None # e.g., "text-input"
class CreateSnippetPayload(BaseModel):
"""Payload for creating a new snippet."""
name: str = Field(..., min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=2000)
type: Literal["node", "group"] = "node"
icon_info: IconInfo | None = None
graph: dict[str, Any] | None = None
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
class UpdateSnippetPayload(BaseModel):
"""Payload for updating a snippet."""
name: str | None = Field(default=None, min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=2000)
icon_info: IconInfo | None = None
class SnippetDraftSyncPayload(BaseModel):
"""Payload for syncing snippet draft workflow."""
graph: dict[str, Any]
hash: str | None = None
conversation_variables: list[dict[str, Any]] | None = Field(
default=None,
description="Ignored. Snippet workflows do not persist conversation variables.",
)
input_fields: list[dict[str, Any]] | None = None
class SnippetWorkflowListQuery(BaseModel):
"""Query parameters for listing snippet published workflows."""
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
class WorkflowRunQuery(BaseModel):
"""Query parameters for workflow runs."""
last_id: str | None = None
limit: int = Field(default=20, ge=1, le=100)
class SnippetDraftRunPayload(BaseModel):
"""Payload for running snippet draft workflow."""
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
class SnippetDraftNodeRunPayload(BaseModel):
"""Payload for running a single node in snippet draft workflow."""
inputs: dict[str, Any]
query: str = ""
files: list[dict[str, Any]] | None = None
class SnippetIterationNodeRunPayload(BaseModel):
"""Payload for running an iteration node in snippet draft workflow."""
inputs: dict[str, Any] | None = None
class SnippetLoopNodeRunPayload(BaseModel):
"""Payload for running a loop node in snippet draft workflow."""
inputs: dict[str, Any] | None = None
class PublishWorkflowPayload(BaseModel):
"""Payload for publishing snippet workflow."""
knowledge_base_setting: dict[str, Any] | None = None
class SnippetImportPayload(BaseModel):
"""Payload for importing snippet from DSL."""
mode: str = Field(..., description="Import mode: yaml-content or yaml-url")
yaml_content: str | None = Field(default=None, description="YAML content (required for yaml-content mode)")
yaml_url: str | None = Field(default=None, description="YAML URL (required for yaml-url mode)")
name: str | None = Field(default=None, description="Override snippet name")
description: str | None = Field(default=None, description="Override snippet description")
snippet_id: str | None = Field(default=None, description="Snippet ID to update (optional)")
class IncludeSecretQuery(BaseModel):
"""Query parameter for including secret variables in export."""
include_secret: str = Field(default="false", description="Whether to include secret variables")

View File

@ -1,617 +0,0 @@
import logging
from collections.abc import Callable
from functools import wraps
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.workflow import (
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE,
workflow_model,
workflow_pagination_model,
)
from controllers.console.app.workflow_run import (
workflow_run_detail_model,
workflow_run_node_execution_list_model,
workflow_run_node_execution_model,
workflow_run_pagination_model,
)
from controllers.console.snippets.payloads import (
PublishWorkflowPayload,
SnippetDraftNodeRunPayload,
SnippetDraftRunPayload,
SnippetDraftSyncPayload,
SnippetIterationNodeRunPayload,
SnippetLoopNodeRunPayload,
SnippetWorkflowListQuery,
WorkflowRunQuery,
)
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.graph_engine.manager import GraphEngineManager
from libs import helper
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.snippet import CustomizedSnippet
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
from services.snippet_generate_service import SnippetGenerateService
from services.snippet_service import SnippetService
logger = logging.getLogger(__name__)
# Register Pydantic models with Swagger
register_schema_models(
console_ns,
SnippetDraftSyncPayload,
SnippetDraftNodeRunPayload,
SnippetDraftRunPayload,
SnippetIterationNodeRunPayload,
SnippetLoopNodeRunPayload,
SnippetWorkflowListQuery,
WorkflowRunQuery,
PublishWorkflowPayload,
)
snippet_workflow_model = console_ns.clone("SnippetWorkflow", workflow_model, {
"input_fields": fields.Raw(default=[]),
})
class SnippetNotFoundError(Exception):
"""Snippet not found error."""
pass
def get_snippet[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
"""Decorator to fetch and validate snippet access."""
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
if not kwargs.get("snippet_id"):
raise ValueError("missing snippet_id in path parameters")
_, current_tenant_id = current_account_with_tenant()
snippet_id = str(kwargs.get("snippet_id"))
del kwargs["snippet_id"]
snippet = SnippetService.get_snippet_by_id(
snippet_id=snippet_id,
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
kwargs["snippet"] = snippet
return view_func(*args, **kwargs)
return decorated_view
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
class SnippetDraftWorkflowApi(Resource):
@console_ns.doc("get_snippet_draft_workflow")
@console_ns.response(200, "Draft workflow retrieved successfully", snippet_workflow_model)
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
@marshal_with(snippet_workflow_model)
def get(self, snippet: CustomizedSnippet):
"""Get draft workflow for snippet."""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise DraftWorkflowNotExist()
db.session.expunge(workflow)
workflow.conversation_variables = []
workflow.input_fields = snippet.input_fields_list
return workflow
@console_ns.doc("sync_snippet_draft_workflow")
@console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
@console_ns.response(200, "Draft workflow synced successfully")
@console_ns.response(400, "Hash mismatch")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""Sync draft workflow for snippet."""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
try:
snippet_service = SnippetService()
workflow = snippet_service.sync_draft_workflow(
snippet=snippet,
graph=payload.graph,
unique_hash=payload.hash,
account=current_user,
input_fields=payload.input_fields,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
except ValueError as e:
return {"message": str(e)}, 400
return {
"result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
class SnippetDraftConfigApi(Resource):
@console_ns.doc("get_snippet_draft_config")
@console_ns.response(200, "Draft config retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get snippet draft workflow configuration limits."""
return {
"parallel_depth_limit": 3,
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
class SnippetPublishedWorkflowApi(Resource):
@console_ns.doc("get_snippet_published_workflow")
@console_ns.response(200, "Published workflow retrieved successfully", snippet_workflow_model)
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
@marshal_with(snippet_workflow_model)
def get(self, snippet: CustomizedSnippet):
"""Get published workflow for snippet."""
if not snippet.is_published:
return None
snippet_service = SnippetService()
workflow = snippet_service.get_published_workflow(snippet=snippet)
if workflow:
workflow.input_fields = snippet.input_fields_list
return workflow
@console_ns.doc("publish_snippet_workflow")
@console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
@console_ns.response(200, "Workflow published successfully")
@console_ns.response(400, "No draft workflow found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""Publish snippet workflow."""
current_user, _ = current_account_with_tenant()
snippet_service = SnippetService()
with Session(db.engine) as session:
snippet = session.merge(snippet)
try:
workflow = snippet_service.publish_workflow(
session=session,
snippet=snippet,
account=current_user,
)
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
except ValueError as e:
return {"message": str(e)}, 400
return {
"result": "success",
"created_at": workflow_created_at,
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
class SnippetDefaultBlockConfigsApi(Resource):
@console_ns.doc("get_snippet_default_block_configs")
@console_ns.response(200, "Default block configs retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get default block configurations for snippet workflow."""
snippet_service = SnippetService()
return snippet_service.get_default_block_configs()
@console_ns.route("/snippets/<uuid:snippet_id>/workflows")
class SnippetPublishedAllWorkflowApi(Resource):
@console_ns.expect(console_ns.models[SnippetWorkflowListQuery.__name__])
@console_ns.doc("get_all_snippet_published_workflows")
@console_ns.doc(description="Get all published workflows for a snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID"})
@console_ns.response(200, "Published workflows retrieved successfully", workflow_pagination_model)
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get all published workflow versions for snippet."""
args = SnippetWorkflowListQuery.model_validate(request.args.to_dict(flat=True))
snippet_service = SnippetService()
with Session(db.engine) as session:
workflows, has_more = snippet_service.get_all_published_workflows(
session=session,
snippet=snippet,
page=args.page,
limit=args.limit,
)
serialized_workflows = marshal(workflows, workflow_model)
return {
"items": serialized_workflows,
"page": args.page,
"limit": args.limit,
"has_more": has_more,
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/<string:workflow_id>/restore")
class SnippetDraftWorkflowRestoreApi(Resource):
@console_ns.doc("restore_snippet_workflow_to_draft")
@console_ns.doc(description="Restore a published snippet workflow version into the draft workflow")
@console_ns.doc(params={"snippet_id": "Snippet ID", "workflow_id": "Published workflow ID"})
@console_ns.response(200, "Workflow restored successfully")
@console_ns.response(400, "Source workflow must be published")
@console_ns.response(404, "Workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, workflow_id: str):
"""Restore a published snippet workflow version into the draft workflow."""
current_user, _ = current_account_with_tenant()
snippet_service = SnippetService()
try:
workflow = snippet_service.restore_published_workflow_to_draft(
snippet=snippet,
workflow_id=workflow_id,
account=current_user,
)
except IsDraftWorkflowError as exc:
raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
except WorkflowNotFoundError as exc:
raise NotFound(str(exc)) from exc
except ValueError as exc:
raise BadRequest(str(exc)) from exc
return {
"result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
class SnippetWorkflowRunsApi(Resource):
@console_ns.doc("list_snippet_workflow_runs")
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_pagination_model)
def get(self, snippet: CustomizedSnippet):
"""List workflow runs for snippet."""
query = WorkflowRunQuery.model_validate(
{
"last_id": request.args.get("last_id"),
"limit": request.args.get("limit", type=int, default=20),
}
)
args = {
"last_id": query.last_id,
"limit": query.limit,
}
snippet_service = SnippetService()
result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
return result
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
class SnippetWorkflowRunDetailApi(Resource):
@console_ns.doc("get_snippet_workflow_run_detail")
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_detail_model)
def get(self, snippet: CustomizedSnippet, run_id):
"""Get workflow run detail for snippet."""
run_id = str(run_id)
snippet_service = SnippetService()
workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
if not workflow_run:
raise NotFound("Workflow run not found")
return workflow_run
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
class SnippetWorkflowRunNodeExecutionsApi(Resource):
@console_ns.doc("list_snippet_workflow_run_node_executions")
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_node_execution_list_model)
def get(self, snippet: CustomizedSnippet, run_id):
"""List node executions for a workflow run."""
run_id = str(run_id)
snippet_service = SnippetService()
node_executions = snippet_service.get_snippet_workflow_run_node_executions(
snippet=snippet,
run_id=run_id,
)
return {"data": node_executions}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/run")
class SnippetDraftNodeRunApi(Resource):
@console_ns.doc("run_snippet_draft_node")
@console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__))
@console_ns.response(200, "Node run completed successfully", workflow_run_node_execution_model)
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_node_execution_model)
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a single node in snippet draft workflow.
Executes a specific node with provided inputs for single-step debugging.
Returns the node execution result including status, outputs, and timing.
"""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
user_inputs = payload.inputs
# Get draft workflow for file parsing
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not draft_workflow:
raise NotFound("Draft workflow not found")
files = SnippetGenerateService.parse_files(draft_workflow, payload.files)
workflow_node_execution = SnippetGenerateService.run_draft_node(
snippet=snippet,
node_id=node_id,
user_inputs=user_inputs,
account=current_user,
query=payload.query,
files=files,
)
return workflow_node_execution
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/last-run")
class SnippetDraftNodeLastRunApi(Resource):
@console_ns.doc("get_snippet_draft_node_last_run")
@console_ns.doc(description="Get last run result for a node in snippet draft workflow")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model)
@console_ns.response(404, "Snippet, draft workflow, or node last run not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_node_execution_model)
def get(self, snippet: CustomizedSnippet, node_id: str):
"""
Get the last run result for a specific node in snippet draft workflow.
Returns the most recent execution record for the given node,
including status, inputs, outputs, and timing information.
"""
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not draft_workflow:
raise NotFound("Draft workflow not found")
node_exec = snippet_service.get_snippet_node_last_run(
snippet=snippet,
workflow=draft_workflow,
node_id=node_id,
)
if node_exec is None:
raise NotFound("Node last run not found")
return node_exec
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class SnippetDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_snippet_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node for snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__))
@console_ns.response(200, "Iteration node run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a draft workflow iteration node for snippet.
Iteration nodes execute their internal sub-graph multiple times over an input list.
Returns an SSE event stream with iteration progress and results.
"""
current_user, _ = current_account_with_tenant()
args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
response = SnippetGenerateService.generate_single_iteration(
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class SnippetDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_snippet_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node for snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__))
@console_ns.response(200, "Loop node run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a draft workflow loop node for snippet.
Loop nodes execute their internal sub-graph repeatedly until a condition is met.
Returns an SSE event stream with loop progress and results.
"""
current_user, _ = current_account_with_tenant()
args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
try:
response = SnippetGenerateService.generate_single_loop(
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/run")
class SnippetDraftWorkflowRunApi(Resource):
@console_ns.doc("run_snippet_draft_workflow")
@console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__))
@console_ns.response(200, "Draft workflow run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""
Run draft workflow for snippet.
Executes the snippet's draft workflow with the provided inputs
and returns an SSE event stream with execution progress and results.
"""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try:
response = SnippetGenerateService.generate(
snippet=snippet,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/tasks/<string:task_id>/stop")
class SnippetWorkflowTaskStopApi(Resource):
@console_ns.doc("stop_snippet_workflow_task")
@console_ns.response(200, "Task stopped successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, task_id: str):
"""
Stop a running snippet workflow task.
Uses both the legacy stop flag mechanism and the graph engine
command channel for backward compatibility.
"""
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@ -1,316 +0,0 @@
"""
Snippet draft workflow variable APIs.
Mirrors console app routes under /apps/.../workflows/draft/variables for snippet scope,
using CustomizedSnippet.id as WorkflowDraftVariable.app_id (same invariant as snippet execution).
Snippet workflows do not expose system variables (`node_id == sys`) or conversation variables
(`node_id == conversation`): paginated list queries exclude those rows; single-variable GET/PATCH/DELETE/reset
reject them; `GET .../system-variables` and `GET .../conversation-variables` return empty lists for API parity.
Other routes mirror `workflow_draft_variable` app APIs under `/snippets/...`.
"""
from collections.abc import Callable
from functools import wraps
from typing import Any
from flask import Response, request
from flask_restx import Resource, marshal, marshal_with
from sqlalchemy.orm import Session
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist
from controllers.console.app.workflow_draft_variable import (
WorkflowDraftVariableListQuery,
WorkflowDraftVariableUpdatePayload,
_ensure_variable_access,
_file_access_controller,
validate_node_id,
workflow_draft_variable_list_model,
workflow_draft_variable_list_without_value_model,
workflow_draft_variable_model,
)
from controllers.console.snippets.snippet_workflow import get_snippet
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from graphon.variables.types import SegmentType
from libs.login import current_user, login_required
from models.snippet import CustomizedSnippet
from models.workflow import WorkflowDraftVariable
from services.snippet_service import SnippetService
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS: frozenset[str] = frozenset(
{SYSTEM_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID}
)
def _ensure_snippet_draft_variable_row_allowed(
*,
variable: WorkflowDraftVariable,
variable_id: str,
) -> None:
"""Snippet scope only supports canvas-node draft variables; treat sys/conversation rows as not found."""
if variable.node_id in _SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS:
raise NotFoundError(description=f"variable not found, id={variable_id}")
def _snippet_draft_var_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R]:
"""Setup, auth, snippet resolution, and tenant edit permission (same stack as snippet workflow APIs)."""
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return f(*args, **kwargs)
return wrapper
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables")
class SnippetWorkflowVariableCollectionApi(Resource):
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
@console_ns.doc("get_snippet_workflow_variables")
@console_ns.doc(description="List draft workflow variables without values (paginated, snippet scope)")
@console_ns.response(
200,
"Workflow variables retrieved successfully",
workflow_draft_variable_list_without_value_model,
)
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_list_without_value_model)
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
snippet_service = SnippetService()
if snippet_service.get_draft_workflow(snippet=snippet) is None:
raise DraftWorkflowNotExist()
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(session=session)
workflow_vars = draft_var_srv.list_variables_without_values(
app_id=snippet.id,
page=args.page,
limit=args.limit,
user_id=current_user.id,
exclude_node_ids=_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS,
)
return workflow_vars
@console_ns.doc("delete_snippet_workflow_variables")
@console_ns.doc(description="Delete all draft workflow variables for the current user (snippet scope)")
@console_ns.response(204, "Workflow variables deleted successfully")
@_snippet_draft_var_prerequisite
def delete(self, snippet: CustomizedSnippet) -> Response:
draft_var_srv = WorkflowDraftVariableService(session=db.session())
draft_var_srv.delete_user_workflow_variables(snippet.id, user_id=current_user.id)
db.session.commit()
return Response("", 204)
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/variables")
class SnippetNodeVariableCollectionApi(Resource):
@console_ns.doc("get_snippet_node_variables")
@console_ns.doc(description="Get variables for a specific node (snippet draft workflow)")
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_list_model)
def get(self, snippet: CustomizedSnippet, node_id: str) -> WorkflowDraftVariableList:
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(session=session)
node_vars = draft_var_srv.list_node_variables(snippet.id, node_id, user_id=current_user.id)
return node_vars
@console_ns.doc("delete_snippet_node_variables")
@console_ns.doc(description="Delete all variables for a specific node (snippet draft workflow)")
@console_ns.response(204, "Node variables deleted successfully")
@_snippet_draft_var_prerequisite
def delete(self, snippet: CustomizedSnippet, node_id: str) -> Response:
validate_node_id(node_id)
srv = WorkflowDraftVariableService(db.session())
srv.delete_node_variables(snippet.id, node_id, user_id=current_user.id)
db.session.commit()
return Response("", 204)
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>")
class SnippetVariableApi(Resource):
@console_ns.doc("get_snippet_workflow_variable")
@console_ns.doc(description="Get a specific draft workflow variable (snippet scope)")
@console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model)
@console_ns.response(404, "Variable not found")
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_model)
def get(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
draft_var_srv = WorkflowDraftVariableService(session=db.session())
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=snippet.id,
variable_id=variable_id,
)
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
return variable
@console_ns.doc("update_snippet_workflow_variable")
@console_ns.doc(description="Update a draft workflow variable (snippet scope)")
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
@console_ns.response(404, "Variable not found")
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_model)
def patch(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
draft_var_srv = WorkflowDraftVariableService(session=db.session())
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=snippet.id,
variable_id=variable_id,
)
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
new_name = args_model.name
raw_value = args_model.value
if new_name is None and raw_value is None:
return variable
new_value = None
if raw_value is not None:
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=snippet.tenant_id,
access_controller=_file_access_controller,
)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=snippet.tenant_id,
access_controller=_file_access_controller,
)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()
return variable
@console_ns.doc("delete_snippet_workflow_variable")
@console_ns.doc(description="Delete a draft workflow variable (snippet scope)")
@console_ns.response(204, "Variable deleted successfully")
@console_ns.response(404, "Variable not found")
@_snippet_draft_var_prerequisite
def delete(self, snippet: CustomizedSnippet, variable_id: str) -> Response:
draft_var_srv = WorkflowDraftVariableService(session=db.session())
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=snippet.id,
variable_id=variable_id,
)
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>/reset")
class SnippetVariableResetApi(Resource):
@console_ns.doc("reset_snippet_workflow_variable")
@console_ns.doc(description="Reset a draft workflow variable to its default value (snippet scope)")
@console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model)
@console_ns.response(204, "Variable reset (no content)")
@console_ns.response(404, "Variable not found")
@_snippet_draft_var_prerequisite
def put(self, snippet: CustomizedSnippet, variable_id: str) -> Response | Any:
draft_var_srv = WorkflowDraftVariableService(session=db.session())
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if draft_workflow is None:
raise NotFoundError(
f"Draft workflow not found, snippet_id={snippet.id}",
)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=snippet.id,
variable_id=variable_id,
)
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
db.session.commit()
if resetted is None:
return Response("", 204)
return marshal(resetted, workflow_draft_variable_model)
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/conversation-variables")
class SnippetConversationVariableCollectionApi(Resource):
@console_ns.doc("get_snippet_conversation_variables")
@console_ns.doc(
description="Conversation variables are not used in snippet workflows; returns an empty list for API parity"
)
@console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_list_model)
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
return WorkflowDraftVariableList(variables=[])
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/system-variables")
class SnippetSystemVariableCollectionApi(Resource):
@console_ns.doc("get_snippet_system_variables")
@console_ns.doc(
description="System variables are not used in snippet workflows; returns an empty list for API parity"
)
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
@_snippet_draft_var_prerequisite
@marshal_with(workflow_draft_variable_list_model)
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
return WorkflowDraftVariableList(variables=[])
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/environment-variables")
class SnippetEnvironmentVariableCollectionApi(Resource):
@console_ns.doc("get_snippet_environment_variables")
@console_ns.doc(description="Get environment variables from snippet draft workflow graph")
@console_ns.response(200, "Environment variables retrieved successfully")
@console_ns.response(404, "Draft workflow not found")
@_snippet_draft_var_prerequisite
def get(self, snippet: CustomizedSnippet) -> dict[str, list[dict[str, Any]]]:
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if workflow is None:
raise DraftWorkflowNotExist()
env_vars_list: list[dict[str, Any]] = []
for v in workflow.environment_variables:
env_vars_list.append(
{
"id": v.id,
"type": "env",
"name": v.name,
"description": v.description,
"selector": v.selector,
"value_type": v.value_type.exposed_type().value,
"value": v.value,
"edited": False,
"visible": True,
"editable": True,
}
)
return {"items": env_vars_list}

View File

@ -1,511 +0,0 @@
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import Any
from flask import request
from flask_restx import Resource
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, ValidationError, field_validator
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import console_ns
from libs.login import current_account_with_tenant, login_required
from services.enterprise import rbac_service as svc
def _current_ids() -> tuple[str, str]:
"""Return ``(tenant_id, account_id)`` for the authenticated user, or
raise a 404 when no tenant is associated with the session.
"""
user, tenant_id = current_account_with_tenant()
if not tenant_id:
raise NotFound("Current workspace not found")
return tenant_id, user.id
def _payload(model: type[BaseModel]) -> Any:
"""Validate the JSON body against ``model`` or raise ``ValidationError``.
``ValidationError`` bubbles up as HTTP 400 thanks to
``controllers/common/helpers.py`` error handling.
"""
try:
return model.model_validate(console_ns.payload or {})
except ValidationError as exc:
# Re-raise as-is so the upstream error handler renders a 400.
raise exc
def _dump(model: BaseModel) -> dict[str, Any]:
return model.model_dump(mode="json")
class _PaginationQuery(BaseModel):
model_config = ConfigDict(extra="ignore")
page_number: int | None = Field(default=None, ge=1, validation_alias=AliasChoices("page", "page_number"))
results_per_page: int | None = Field(
default=None, ge=1, le=100, validation_alias=AliasChoices("limit", "results_per_page")
)
reverse: bool | None = None
def to_inner_options(self) -> svc.ListOption:
return svc.ListOption.model_validate(self.model_dump())
def _pagination_options() -> svc.ListOption:
return _PaginationQuery.model_validate(request.args.to_dict(flat=True)).to_inner_options()
# ---------------------------------------------------------------------------
# Permission catalogs.
# ---------------------------------------------------------------------------
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog")
class RBACWorkspaceCatalogApi(Resource):
@login_required
def get(self):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.Catalog.workspace(tenant_id, account_id))
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog/app")
class RBACAppCatalogApi(Resource):
@login_required
def get(self):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.Catalog.app(tenant_id, account_id))
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog/dataset")
class RBACDatasetCatalogApi(Resource):
@login_required
def get(self):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.Catalog.dataset(tenant_id, account_id))
# ---------------------------------------------------------------------------
# Roles.
# ---------------------------------------------------------------------------
class _RoleUpsertRequest(BaseModel):
"""Accepts the payload sent by the Create/Edit Role dialog."""
name: str
description: str = ""
permission_keys: list[str] = []
def to_mutation(self) -> svc.RoleMutation:
return svc.RoleMutation(
name=self.name,
description=self.description,
permission_keys=list(self.permission_keys),
)
@console_ns.route("/workspaces/current/rbac/roles")
class RBACRolesApi(Resource):
@login_required
def get(self):
tenant_id, account_id = _current_ids()
options = _pagination_options()
return _dump(svc.RBACService.Roles.list(tenant_id, account_id, options=options))
@login_required
def post(self):
tenant_id, account_id = _current_ids()
request = _payload(_RoleUpsertRequest)
role = svc.RBACService.Roles.create(tenant_id, account_id, request.to_mutation())
return _dump(role), 201
@console_ns.route("/workspaces/current/rbac/roles/<uuid:role_id>")
class RBACRoleItemApi(Resource):
@login_required
def get(self, role_id):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.Roles.get(tenant_id, account_id, str(role_id)))
@login_required
def put(self, role_id):
tenant_id, account_id = _current_ids()
request = _payload(_RoleUpsertRequest)
role = svc.RBACService.Roles.update(tenant_id, account_id, str(role_id), request.to_mutation())
return _dump(role)
@login_required
def delete(self, role_id):
tenant_id, account_id = _current_ids()
svc.RBACService.Roles.delete(tenant_id, account_id, str(role_id))
return {"result": "success"}
@console_ns.route("/workspaces/current/rbac/roles/<uuid:role_id>/copy")
class RBACRoleCopyApi(Resource):
@login_required
def post(self, role_id):
tenant_id, account_id = _current_ids()
role = svc.RBACService.Roles.copy(tenant_id, account_id, str(role_id))
return _dump(role), 201
# ---------------------------------------------------------------------------
# Access policies (tenant-level permission sets).
# ---------------------------------------------------------------------------
class _AccessPolicyCreateRequest(BaseModel):
name: str
resource_type: svc.RBACResourceType
description: str = ""
permission_keys: list[str] = []
class _AccessPolicyUpdateRequest(BaseModel):
name: str
description: str = ""
permission_keys: list[str] = []
@console_ns.route("/workspaces/current/rbac/access-policies")
class RBACAccessPoliciesApi(Resource):
@login_required
def get(self):
tenant_id, account_id = _current_ids()
# `resource_type` is exposed as a query argument so the UI can show
# only app-scoped or only dataset-scoped permission sets.
resource_type = request.args.get("resource_type") or None
return _dump(
svc.RBACService.AccessPolicies.list(
tenant_id,
account_id,
resource_type=resource_type,
options=_pagination_options(),
)
)
@login_required
def post(self):
tenant_id, account_id = _current_ids()
request = _payload(_AccessPolicyCreateRequest)
policy = svc.RBACService.AccessPolicies.create(
tenant_id,
account_id,
svc.AccessPolicyCreate(
name=request.name,
resource_type=request.resource_type,
description=request.description,
permission_keys=list(request.permission_keys),
),
)
return _dump(policy), 201
@console_ns.route("/workspaces/current/rbac/access-policies/<uuid:policy_id>")
class RBACAccessPolicyItemApi(Resource):
@login_required
def get(self, policy_id):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.AccessPolicies.get(tenant_id, account_id, str(policy_id)))
@login_required
def put(self, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_AccessPolicyUpdateRequest)
policy = svc.RBACService.AccessPolicies.update(
tenant_id,
account_id,
str(policy_id),
svc.AccessPolicyUpdate(
name=request.name,
description=request.description,
permission_keys=list(request.permission_keys),
),
)
return _dump(policy)
@login_required
def delete(self, policy_id):
tenant_id, account_id = _current_ids()
svc.RBACService.AccessPolicies.delete(tenant_id, account_id, str(policy_id))
return {"result": "success"}
@console_ns.route("/workspaces/current/rbac/access-policies/<uuid:policy_id>/copy")
class RBACAccessPolicyCopyApi(Resource):
@login_required
def post(self, policy_id):
tenant_id, account_id = _current_ids()
policy = svc.RBACService.AccessPolicies.copy(tenant_id, account_id, str(policy_id))
return _dump(policy), 201
# ---------------------------------------------------------------------------
# Per-app access (App Access Config).
# ---------------------------------------------------------------------------
class _ReplaceBindingsRequest(BaseModel):
role_ids: list[str] = []
account_ids: list[str] = []
@field_validator("role_ids", "account_ids", mode="before")
@classmethod
def _coerce_bindings(cls, value: Any) -> list[str]:
if value is None:
return []
return value
@console_ns.route("/workspaces/current/rbac/my-permissions")
class RBACMyPermissionsApi(Resource):
@login_required
def get(self):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.MyPermissions.get(
tenant_id,
account_id,
app_id=request.args.get("app_id") or None,
dataset_id=request.args.get("dataset_id") or None,
)
)
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policy")
class RBACAppMatrixApi(Resource):
@login_required
def get(self, app_id):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.AppAccess.matrix(tenant_id, account_id, str(app_id)))
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/role-bindings")
class RBACAppRoleBindingsApi(Resource):
@login_required
def get(self, app_id, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.AppAccess.list_role_bindings(tenant_id, account_id, str(app_id), str(policy_id))
)
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/member-bindings")
class RBACAppMemberBindingsApi(Resource):
@login_required
def get(self, app_id, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.AppAccess.list_member_bindings(tenant_id, account_id, str(app_id), str(policy_id))
)
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/bindings")
class RBACAppBindingsApi(Resource):
@login_required
def put(self, app_id, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceBindingsRequest)
return _dump(
svc.RBACService.AppAccess.replace_bindings(
tenant_id,
account_id,
str(app_id),
str(policy_id),
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
)
)
# ---------------------------------------------------------------------------
# Per-dataset access (Knowledge Base Access Config).
# ---------------------------------------------------------------------------
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policy")
class RBACDatasetMatrixApi(Resource):
@login_required
def get(self, dataset_id):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.DatasetAccess.matrix(tenant_id, account_id, str(dataset_id)))
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policies/<uuid:policy_id>/role-bindings")
class RBACDatasetRoleBindingsApi(Resource):
@login_required
def get(self, dataset_id, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.DatasetAccess.list_role_bindings(
tenant_id, account_id, str(dataset_id), str(policy_id)
)
)
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policies/<uuid:policy_id>/bindings")
class RBACDatasetBindingsApi(Resource):
@login_required
def put(self, dataset_id, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceBindingsRequest)
return _dump(
svc.RBACService.DatasetAccess.replace_bindings(
tenant_id,
account_id,
str(dataset_id),
str(policy_id),
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
)
)
@console_ns.route(
"/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policies/<uuid:policy_id>/member-bindings"
)
class RBACDatasetMemberBindingsApi(Resource):
@login_required
def get(self, dataset_id, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.DatasetAccess.list_member_bindings(
tenant_id, account_id, str(dataset_id), str(policy_id)
)
)
# ---------------------------------------------------------------------------
# Workspace-level access (Settings > Access Rules).
# ---------------------------------------------------------------------------
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policy")
class RBACWorkspaceAppMatrixApi(Resource):
@login_required
def get(self):
tenant_id, account_id = _current_ids()
options = _pagination_options()
return _dump(svc.RBACService.WorkspaceAccess.app_matrix(tenant_id, account_id, options=options))
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/role-bindings")
class RBACWorkspaceAppRoleBindingsApi(Resource):
@login_required
def get(self, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.WorkspaceAccess.list_app_role_bindings(tenant_id, account_id, str(policy_id))
)
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/bindings")
class RBACWorkspaceAppBindingsApi(Resource):
@login_required
def put(self, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceBindingsRequest)
return _dump(
svc.RBACService.WorkspaceAccess.replace_app_bindings(
tenant_id,
account_id,
str(policy_id),
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
)
)
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/member-bindings")
class RBACWorkspaceAppMemberBindingsApi(Resource):
@login_required
def get(self, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.WorkspaceAccess.list_app_member_bindings(tenant_id, account_id, str(policy_id))
)
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policy")
class RBACWorkspaceDatasetMatrixApi(Resource):
@login_required
def get(self):
tenant_id, account_id = _current_ids()
options = _pagination_options()
return _dump(svc.RBACService.WorkspaceAccess.dataset_matrix(tenant_id, account_id, options=options))
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/role-bindings")
class RBACWorkspaceDatasetRoleBindingsApi(Resource):
@login_required
def get(self, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.WorkspaceAccess.list_dataset_role_bindings(tenant_id, account_id, str(policy_id))
)
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/bindings")
class RBACWorkspaceDatasetBindingsApi(Resource):
@login_required
def put(self, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceBindingsRequest)
return _dump(
svc.RBACService.WorkspaceAccess.replace_dataset_bindings(
tenant_id,
account_id,
str(policy_id),
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
)
)
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/member-bindings")
class RBACWorkspaceDatasetMemberBindingsApi(Resource):
@login_required
def get(self, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.WorkspaceAccess.list_dataset_member_bindings(tenant_id, account_id, str(policy_id))
)
# ---------------------------------------------------------------------------
# Member ↔ role bindings (Settings > Members > Assign roles).
# ---------------------------------------------------------------------------
class _ReplaceMemberRolesRequest(BaseModel):
role_ids: list[str] = []
@field_validator("role_ids", mode="before")
@classmethod
def _coerce_role_ids(cls, value: Any) -> list[str]:
if value is None:
return []
return value
@console_ns.route("/workspaces/current/rbac/members/<uuid:member_id>/rbac-roles")
class RBACMemberRolesApi(Resource):
@login_required
def get(self, member_id):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.MemberRoles.get(tenant_id, account_id, str(member_id)))
@login_required
def put(self, member_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceMemberRolesRequest)
return _dump(
svc.RBACService.MemberRoles.replace(
tenant_id,
account_id,
str(member_id),
role_ids=list(request.role_ids),
)
)

View File

@ -1,380 +0,0 @@
import logging
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource, marshal
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.snippets.payloads import (
CreateSnippetPayload,
IncludeSecretQuery,
SnippetImportPayload,
SnippetListQuery,
UpdateSnippetPayload,
)
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from extensions.ext_database import db
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
from libs.login import current_account_with_tenant, login_required
from models.snippet import SnippetType
from services.app_dsl_service import ImportStatus
from services.snippet_dsl_service import SnippetDslService
from services.snippet_service import SnippetService
logger = logging.getLogger(__name__)
# Register Pydantic models with Swagger
register_schema_models(
console_ns,
SnippetListQuery,
CreateSnippetPayload,
UpdateSnippetPayload,
SnippetImportPayload,
IncludeSecretQuery,
)
# Create namespace models for marshaling
snippet_model = console_ns.model("Snippet", snippet_fields)
snippet_list_model = console_ns.model("SnippetList", snippet_list_fields)
snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields)
@console_ns.route("/workspaces/current/customized-snippets")
class CustomizedSnippetsApi(Resource):
@console_ns.doc("list_customized_snippets")
@console_ns.expect(console_ns.models.get(SnippetListQuery.__name__))
@console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model)
@setup_required
@login_required
@account_initialization_required
def get(self):
"""List customized snippets with pagination and search."""
_, current_tenant_id = current_account_with_tenant()
query_params = request.args.to_dict()
query = SnippetListQuery.model_validate(query_params)
snippets, total, has_more = SnippetService.get_snippets(
tenant_id=current_tenant_id,
page=query.page,
limit=query.limit,
keyword=query.keyword,
is_published=query.is_published,
creators=query.creators,
)
return {
"data": marshal(snippets, snippet_list_fields),
"page": query.page,
"limit": query.limit,
"total": total,
"has_more": has_more,
}, 200
@console_ns.doc("create_customized_snippet")
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
@console_ns.response(201, "Snippet created successfully", snippet_model)
@console_ns.response(400, "Invalid request or name already exists")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
"""Create a new customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
try:
snippet_type = SnippetType(payload.type)
except ValueError:
snippet_type = SnippetType.NODE
try:
snippet = SnippetService.create_snippet(
tenant_id=current_tenant_id,
name=payload.name,
description=payload.description,
snippet_type=snippet_type,
icon_info=payload.icon_info.model_dump() if payload.icon_info else None,
input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None,
account=current_user,
)
except ValueError as e:
return {"message": str(e)}, 400
return marshal(snippet, snippet_fields), 201
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>")
class CustomizedSnippetDetailApi(Resource):
@console_ns.doc("get_customized_snippet")
@console_ns.response(200, "Snippet retrieved successfully", snippet_model)
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
def get(self, snippet_id: str):
"""Get customized snippet details."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
return marshal(snippet, snippet_fields), 200
@console_ns.doc("update_customized_snippet")
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
@console_ns.response(200, "Snippet updated successfully", snippet_model)
@console_ns.response(400, "Invalid request or name already exists")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def patch(self, snippet_id: str):
"""Update customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
payload = UpdateSnippetPayload.model_validate(console_ns.payload or {})
update_data = payload.model_dump(exclude_unset=True)
if "icon_info" in update_data and update_data["icon_info"] is not None:
update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None
if not update_data:
return {"message": "No valid fields to update"}, 400
try:
with Session(db.engine, expire_on_commit=False) as session:
snippet = session.merge(snippet)
snippet = SnippetService.update_snippet(
session=session,
snippet=snippet,
account_id=current_user.id,
data=update_data,
)
session.commit()
except ValueError as e:
return {"message": str(e)}, 400
return marshal(snippet, snippet_fields), 200
@console_ns.doc("delete_customized_snippet")
@console_ns.response(204, "Snippet deleted successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, snippet_id: str):
"""Delete customized snippet."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
snippet = session.merge(snippet)
SnippetService.delete_snippet(
session=session,
snippet=snippet,
)
session.commit()
return "", 204
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/export")
class CustomizedSnippetExportApi(Resource):
@console_ns.doc("export_customized_snippet")
@console_ns.doc(description="Export snippet configuration as DSL")
@console_ns.doc(params={"snippet_id": "Snippet ID to export"})
@console_ns.response(200, "Snippet exported successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, snippet_id: str):
"""Export snippet as DSL."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
# Get include_secret parameter
query = IncludeSecretQuery.model_validate(request.args.to_dict())
with Session(db.engine) as session:
export_service = SnippetDslService(session)
result = export_service.export_snippet_dsl(snippet=snippet, include_secret=query.include_secret == "true")
# Set filename with .snippet extension
filename = f"{snippet.name}.snippet"
encoded_filename = quote(filename)
response = Response(
result,
mimetype="application/x-yaml",
)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/x-yaml"
return response
@console_ns.route("/workspaces/current/customized-snippets/imports")
class CustomizedSnippetImportApi(Resource):
@console_ns.doc("import_customized_snippet")
@console_ns.doc(description="Import snippet from DSL")
@console_ns.expect(console_ns.models.get(SnippetImportPayload.__name__))
@console_ns.response(200, "Snippet imported successfully")
@console_ns.response(202, "Import pending confirmation")
@console_ns.response(400, "Import failed")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
"""Import snippet from DSL."""
current_user, _ = current_account_with_tenant()
payload = SnippetImportPayload.model_validate(console_ns.payload or {})
with Session(db.engine) as session:
import_service = SnippetDslService(session)
result = import_service.import_snippet(
account=current_user,
import_mode=payload.mode,
yaml_content=payload.yaml_content,
yaml_url=payload.yaml_url,
snippet_id=payload.snippet_id,
name=payload.name,
description=payload.description,
)
session.commit()
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/customized-snippets/imports/<string:import_id>/confirm")
class CustomizedSnippetImportConfirmApi(Resource):
@console_ns.doc("confirm_snippet_import")
@console_ns.doc(description="Confirm a pending snippet import")
@console_ns.doc(params={"import_id": "Import ID to confirm"})
@console_ns.response(200, "Import confirmed successfully")
@console_ns.response(400, "Import failed")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, import_id: str):
"""Confirm a pending snippet import."""
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
import_service = SnippetDslService(session)
result = import_service.confirm_import(import_id=import_id, account=current_user)
session.commit()
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/check-dependencies")
class CustomizedSnippetCheckDependenciesApi(Resource):
@console_ns.doc("check_snippet_dependencies")
@console_ns.doc(description="Check dependencies for a snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID"})
@console_ns.response(200, "Dependencies checked successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, snippet_id: str):
"""Check dependencies for a snippet."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
import_service = SnippetDslService(session)
result = import_service.check_dependencies(snippet=snippet)
return result.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/use-count/increment")
class CustomizedSnippetUseCountIncrementApi(Resource):
@console_ns.doc("increment_snippet_use_count")
@console_ns.doc(description="Increment snippet use count by 1")
@console_ns.doc(params={"snippet_id": "Snippet ID"})
@console_ns.response(200, "Use count incremented successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, snippet_id: str):
"""Increment snippet use count when it is inserted into a workflow."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
snippet = session.merge(snippet)
SnippetService.increment_use_count(session=session, snippet=snippet)
session.commit()
session.refresh(snippet)
return {"result": "success", "use_count": snippet.use_count}, 200

View File

@ -20,10 +20,13 @@ class TenantUserPayload(BaseModel):
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
"""
Get current user
Get current user.
NOTE: user_id is not trusted, it could be maliciously set to any value.
As a result, it could only be considered as an end user id.
As a result, it could only be considered as an end user id. Even when a
concrete end-user ID is supplied, lookups must stay tenant-scoped so one
tenant cannot bind another tenant's user record into the plugin request
context.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
@ -42,7 +45,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
.limit(1)
)
else:
user_model = session.get(EndUser, user_id)
user_model = session.scalar(
select(EndUser)
.where(
EndUser.id == user_id,
EndUser.tenant_id == tenant_id,
)
.limit(1)
)
if not user_model:
user_model = EndUser(

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Literal
from typing import Literal
from dateutil.parser import isoparse
from flask import request
@ -136,28 +136,10 @@ class WorkflowRunForLogResponse(ResponseModel):
return _to_timestamp(value)
class WorkflowAppLogEvaluationNodeInfoResponse(ResponseModel):
node_id: str
type: str
title: str
class WorkflowAppLogEvaluationItemResponse(ResponseModel):
name: str
value: Any = None
details: dict[str, Any] | None = None
node_info: WorkflowAppLogEvaluationNodeInfoResponse | None = Field(
default=None,
validation_alias="node_info",
serialization_alias="nodeInfo",
)
class WorkflowAppLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForLogResponse | None = None
details: dict | list | str | int | float | bool | None = None
evaluation: list[WorkflowAppLogEvaluationItemResponse] = Field(default_factory=list)
created_from: str | None = None
created_by_role: str | None = None
created_by_account: SimpleAccount | None = None
@ -174,11 +156,6 @@ class WorkflowAppLogPartialResponse(ResponseModel):
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
@field_validator("evaluation", mode="before")
@classmethod
def _normalize_evaluation(cls, value: Any) -> list[dict[str, Any]] | list[WorkflowAppLogEvaluationItemResponse]:
return value or []
class WorkflowAppLogPaginationResponse(ResponseModel):
page: int
@ -192,8 +169,6 @@ register_schema_models(
service_api_ns,
WorkflowRunResponse,
WorkflowRunForLogResponse,
WorkflowAppLogEvaluationNodeInfoResponse,
WorkflowAppLogEvaluationItemResponse,
WorkflowAppLogPartialResponse,
WorkflowAppLogPaginationResponse,
)

View File

@ -23,7 +23,6 @@ from . import (
feature,
files,
forgot_password,
human_input_file_upload,
human_input_form,
login,
message,
@ -47,7 +46,6 @@ __all__ = [
"feature",
"files",
"forgot_password",
"human_input_file_upload",
"human_input_form",
"login",
"message",

View File

@ -1,181 +0,0 @@
import httpx
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, HttpUrl
import services
from controllers.common import helpers
from controllers.common.errors import (
BlockedFileExtensionError,
FileTooLargeError,
NoFileUploadedError,
RemoteFileUploadError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import FileResponse, FileWithSignedUrl
from graphon.file import helpers as file_helpers
from libs.exception import BaseHTTPException
from services.file_service import FileService
from services.human_input_file_upload_service import (
HITL_UPLOAD_TOKEN_PREFIX,
HumanInputFileUploadService,
InvalidUploadTokenError,
)
class InvalidUploadTokenBadRequestError(BaseHTTPException):
error_code = "invalid_upload_token"
description = "Invalid upload token."
code = 400
class InvalidUploadTokenUnauthorizedError(BaseHTTPException):
error_code = "invalid_upload_token"
description = "Upload token is required."
code = 401
class InvalidUploadTokenForbiddenError(BaseHTTPException):
error_code = "invalid_upload_token"
description = "Upload token is invalid or expired."
code = 403
class HumanInputRemoteFileUploadPayload(BaseModel):
url: HttpUrl = Field(description="Remote file URL")
register_schema_models(web_ns, HumanInputRemoteFileUploadPayload, FileResponse, FileWithSignedUrl)
def _extract_hitl_upload_token() -> str:
"""Read HITL upload token from Authorization without invoking other bearer auth chains."""
authorization = request.headers.get("Authorization")
if authorization is None:
raise InvalidUploadTokenUnauthorizedError()
parts = authorization.split()
if len(parts) != 2:
raise InvalidUploadTokenUnauthorizedError()
scheme, token = parts
if scheme.lower() != "bearer":
raise InvalidUploadTokenBadRequestError()
if not token:
raise InvalidUploadTokenUnauthorizedError()
if not token.startswith(HITL_UPLOAD_TOKEN_PREFIX):
raise InvalidUploadTokenBadRequestError()
return token
def _validate_context(service: HumanInputFileUploadService, token: str):
try:
return service.validate_upload_token(token)
except InvalidUploadTokenError as exc:
raise InvalidUploadTokenForbiddenError() from exc
def _parse_local_upload_file():
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
file = request.files["file"]
if not file.filename:
from controllers.common.errors import FilenameNotExistsError
raise FilenameNotExistsError()
return file
@web_ns.route("/form/human_input/files/upload")
class HumanInputFileUploadApi(Resource):
def post(self):
"""Upload one local file for a HITL human input form."""
token = _extract_hitl_upload_token()
upload_service = HumanInputFileUploadService(db.engine)
context = _validate_context(upload_service, token)
file = _parse_local_upload_file()
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename or "",
content=file.read(),
mimetype=file.mimetype,
user=context.owner,
source=None,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
except services.errors.file.BlockedFileExtensionError as exc:
raise BlockedFileExtensionError() from exc
upload_service.record_upload_file(context=context, file_id=upload_file.id)
response = FileResponse.model_validate(upload_file, from_attributes=True)
return response.model_dump(mode="json"), 201
@web_ns.route("/form/human_input/files/remote-upload")
class HumanInputRemoteFileUploadApi(Resource):
def post(self):
"""Upload one remote URL file for a HITL human input form."""
token = _extract_hitl_upload_token()
upload_service = HumanInputFileUploadService(db.engine)
context = _validate_context(upload_service, token)
payload = HumanInputRemoteFileUploadPayload.model_validate(request.get_json(silent=True) or {})
url = str(payload.url)
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as exc:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(exc)}")
file_info = helpers.guess_file_info_from_response(resp)
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError()
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=context.owner,
source_url=url,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
except services.errors.file.BlockedFileExtensionError as exc:
raise BlockedFileExtensionError() from exc
upload_service.record_upload_file(context=context, file_id=upload_file.id)
payload1 = FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
)
return payload1.model_dump(mode="json"), 201

View File

@ -9,13 +9,11 @@ from typing import Any, NotRequired, TypedDict
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
from controllers.web.site import serialize_app_site_payload
@ -23,20 +21,11 @@ from extensions.ext_database import db
from libs.helper import RateLimiter, extract_remote_ip
from models.account import TenantStatus
from models.model import App, Site
from services.human_input_file_upload_service import HumanInputFileUploadService
from services.human_input_service import Form, FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
class HumanInputUploadTokenResponse(BaseModel):
upload_token: str
expires_at: int
register_schema_models(web_ns, HumanInputUploadTokenResponse)
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
prefix="web_form_submit_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
@ -47,11 +36,6 @@ _FORM_ACCESS_RATE_LIMITER = RateLimiter(
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
)
_FORM_UPLOAD_TOKEN_RATE_LIMITER = RateLimiter(
prefix="web_form_upload_token_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
)
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
@ -94,33 +78,6 @@ def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Re
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
@web_ns.route("/form/human_input/<string:form_token>/upload-token")
class HumanInputFormUploadTokenApi(Resource):
"""API for issuing HITL upload tokens for active human input forms."""
def post(self, form_token: str):
"""
Issue an upload token for a human input form.
POST /api/form/human_input/<form_token>/upload-token
"""
ip_address = extract_remote_ip(request)
if _FORM_UPLOAD_TOKEN_RATE_LIMITER.is_rate_limited(ip_address):
raise WebFormRateLimitExceededError()
_FORM_UPLOAD_TOKEN_RATE_LIMITER.increment_rate_limit(ip_address)
try:
token = HumanInputFileUploadService(db.engine).issue_upload_token(form_token)
except FormNotFoundError:
raise NotFoundError("Form not found")
response = HumanInputUploadTokenResponse(
upload_token=token.upload_token,
expires_at=_to_timestamp(token.expires_at),
)
return response.model_dump(mode="json"), 200
@web_ns.route("/form/human_input/<string:form_token>")
class HumanInputFormApi(Resource):
"""API for getting and submitting human input forms via the web app."""

View File

@ -408,19 +408,17 @@ class WorkflowResponseConverter:
self, *, event: QueueHumanInputFormFilledEvent, task_id: str
) -> HumanInputFormFilledResponse:
run_id = self._ensure_workflow_run_id()
data = HumanInputFormFilledResponse.Data(
node_id=event.node_id,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
return HumanInputFormFilledResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputFormFilledResponse.Data(
node_id=event.node_id,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
),
)
if event.submitted_data is not None:
runtime_type_converter = WorkflowRuntimeTypeConverter()
data.submitted_data = runtime_type_converter.value_to_json_encodable_recursive(event.submitted_data)
return HumanInputFormFilledResponse(task_id=task_id, workflow_run_id=run_id, data=data)
def human_input_form_timeout_to_stream_response(
self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str
@ -844,24 +842,24 @@ class WorkflowResponseConverter:
return []
files: list[Mapping[str, Any]] = []
match value:
case FileSegment():
files.append(value.value.to_dict())
case ArrayFileSegment():
files.extend([i.to_dict() for i in value.value])
case File():
files.append(value.to_dict())
case list():
for item in value:
file = cls._get_file_var_from_value(item)
if file:
files.append(file)
case dict():
file = cls._get_file_var_from_value(value)
if isinstance(value, FileSegment):
files.append(value.value.to_dict())
elif isinstance(value, ArrayFileSegment):
files.extend([i.to_dict() for i in value.value])
elif isinstance(value, File):
files.append(value.to_dict())
elif isinstance(value, list):
for item in value:
file = cls._get_file_var_from_value(item)
if file:
files.append(file)
case _:
pass
elif isinstance(
value,
dict,
):
file = cls._get_file_var_from_value(value)
if file:
files.append(file)
return files

View File

@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Literal, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker
import contexts
from configs import dify_config
@ -58,25 +58,6 @@ logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator):
@staticmethod
def _ensure_snippet_start_node_in_worker(*, session: Session, workflow: Workflow) -> Workflow:
"""Re-apply snippet virtual Start injection after worker reloads workflow from DB."""
if workflow.kind_or_standard != "snippet":
return workflow
from models.snippet import CustomizedSnippet
from services.snippet_generate_service import SnippetGenerateService
snippet = session.scalar(
select(CustomizedSnippet).where(
CustomizedSnippet.id == workflow.app_id,
CustomizedSnippet.tenant_id == workflow.tenant_id,
)
)
if snippet is None:
return workflow
return SnippetGenerateService.ensure_start_node_for_worker(workflow, snippet)
@staticmethod
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
@ -580,8 +561,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
if workflow is None:
raise ValueError("Workflow not found")
workflow = self._ensure_snippet_start_node_in_worker(session=session, workflow=workflow)
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,

View File

@ -10,7 +10,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
from core.workflow.node_factory import get_default_root_node_id
from core.workflow.snippet_start import get_compatible_start_aliases
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
from core.workflow.workflow_entry import WorkflowEntry
@ -116,15 +115,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
),
)
root_node_id = self._root_node_id or get_default_root_node_id(self._workflow.graph_dict)
add_node_inputs_to_pool(
variable_pool,
node_id=root_node_id,
inputs=inputs,
aliases=get_compatible_start_aliases(
workflow_kind=getattr(self._workflow, "kind_or_standard", None),
root_node_id=root_node_id,
),
)
add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph = self._init_graph(

View File

@ -432,7 +432,6 @@ class WorkflowBasedAppRunner:
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
submitted_data=event.submitted_data,
)
)
elif isinstance(event, NodeRunHumanInputFormTimeoutEvent):

View File

@ -11,7 +11,6 @@ from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import PauseReason
from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from graphon.variables.segments import Segment
class QueueEvent(StrEnum):
@ -509,10 +508,6 @@ class QueueHumanInputFormFilledEvent(AppQueueEvent):
action_id: str
action_text: str
# Keep the field name aligned with Graphon so the app-layer bridge does not
# need to translate between two equivalent payload names.
submitted_data: Mapping[str, Segment] | None = None
class QueueHumanInputFormTimeoutEvent(AppQueueEvent):
"""

View File

@ -10,7 +10,7 @@ from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from graphon.nodes.human_input.entities import FormInputConfig, UserActionConfig
from graphon.nodes.human_input.entities import FormInput, UserAction
class AnnotationReplyAccount(BaseModel):
@ -284,8 +284,8 @@ class HumanInputRequiredResponse(StreamResponse):
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInputConfig] = Field(default_factory=list)
actions: Sequence[UserActionConfig] = Field(default_factory=list)
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
@ -307,8 +307,8 @@ class HumanInputRequiredPauseReasonPayload(BaseModel):
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInputConfig] = Field(default_factory=list)
actions: Sequence[UserActionConfig] = Field(default_factory=list)
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
@ -342,8 +342,6 @@ class HumanInputFormFilledResponse(StreamResponse):
action_id: str
action_text: str
submitted_data: Mapping[str, Any] | None = None
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_FILLED
workflow_run_id: str
data: Data

View File

@ -3,9 +3,9 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any, TypeAlias
from pydantic import BaseModel, ConfigDict, Field, JsonValue
from pydantic import BaseModel, ConfigDict, Field
from graphon.nodes.human_input.entities import FormInputConfig, UserActionConfig
from graphon.nodes.human_input.entities import FormInput, UserAction
from models.execution_extra_content import ExecutionContentType
@ -16,11 +16,9 @@ class HumanInputFormDefinition(BaseModel):
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInputConfig] = Field(default_factory=list)
actions: Sequence[UserActionConfig] = Field(default_factory=list)
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
# `form_token` is `None` if the corresponding form has been submitted.
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int
@ -31,31 +29,16 @@ class HumanInputFormSubmissionData(BaseModel):
node_id: str
node_title: str
# deprecate: the rendered_content is deprecated and only for historical reasons.
rendered_content: str
# The identifier of action user has chosen.
action_id: str
# The button text of the action user has chosen.
action_text: str
# submitted_data records the submitted form data.
# Keys correspond to `output_variable_name` of HumanInput inputs.
# Values are serialized JSON forms of runtime values, including file dictionaries.
#
# For form submitted before this field is introduced, this field is populated from
# the stored submission data.
submitted_data: Mapping[str, JsonValue] | None = None
class HumanInputContent(BaseModel):
model_config = ConfigDict(frozen=True)
workflow_run_id: str
submitted: bool
# Both the form_defintion and the form_submission_data are present in
# HumanInputContent. For historical records, the
form_definition: HumanInputFormDefinition | None = None
form_submission_data: HumanInputFormSubmissionData | None = None
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)

View File

@ -1,279 +0,0 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any
from core.evaluation.entities.evaluation_entity import (
CustomizedMetrics,
EvaluationCategory,
EvaluationItemInput,
EvaluationItemResult,
EvaluationMetric,
NodeInfo,
)
from graphon.node_events.base import NodeRunResult
logger = logging.getLogger(__name__)
class BaseEvaluationInstance(ABC):
"""Abstract base class for evaluation framework adapters."""
@abstractmethod
def evaluate_llm(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate LLM outputs using the configured framework."""
...
@abstractmethod
def evaluate_retrieval(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate retrieval quality using the configured framework."""
...
@abstractmethod
def evaluate_agent(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate agent outputs using the configured framework."""
...
@abstractmethod
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
"""Return the list of supported metric names for a given evaluation category."""
...
def evaluate_with_customized_workflow(
self,
node_run_result_mapping_list: list[dict[str, NodeRunResult]],
customized_metrics: CustomizedMetrics,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate using a published workflow as the evaluator.
The evaluator workflow's output variables are treated as metrics:
each output variable name becomes a metric name, and its value
becomes the score.
Args:
node_run_result_mapping_list: One mapping per test-data item,
where each mapping is ``{node_id: NodeRunResult}`` from the
target execution.
customized_metrics: Contains ``evaluation_workflow_id`` (the
published evaluator workflow) and ``input_fields`` (value
sources for the evaluator's input variables).
tenant_id: Tenant scope.
Returns:
A list of ``EvaluationItemResult`` with metrics extracted from
the evaluator workflow's output variables.
"""
from sqlalchemy.orm import Session
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.evaluation.runners import get_service_account_for_app
from models.engine import db
from models.model import App
from services.workflow_service import WorkflowService
workflow_id = customized_metrics.evaluation_workflow_id
if not workflow_id:
raise ValueError("customized_metrics must contain 'evaluation_workflow_id' for customized evaluator")
# Load the evaluator workflow resources using a dedicated session
with Session(db.engine, expire_on_commit=False) as session, session.begin():
app = session.query(App).filter_by(id=workflow_id, tenant_id=tenant_id).first()
if not app:
raise ValueError(f"Evaluation workflow app {workflow_id} not found in tenant {tenant_id}")
service_account = get_service_account_for_app(session, workflow_id)
workflow_service = WorkflowService()
published_workflow = workflow_service.get_published_workflow(app_model=app)
if not published_workflow:
raise ValueError(f"No published workflow found for evaluation app {workflow_id}")
eval_results: list[EvaluationItemResult] = []
for idx, node_run_result_mapping in enumerate(node_run_result_mapping_list):
try:
workflow_inputs = self._build_workflow_inputs(
customized_metrics.input_fields,
node_run_result_mapping,
)
generator = WorkflowAppGenerator()
response: Mapping[str, Any] = generator.generate(
app_model=app,
workflow=published_workflow,
user=service_account,
args={"inputs": workflow_inputs},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
call_depth=0,
)
metrics = self._extract_workflow_metrics(response, workflow_id)
eval_results.append(
EvaluationItemResult(
index=idx,
metrics=metrics,
)
)
except Exception:
logger.exception(
"Customized evaluator failed for item %d with workflow %s",
idx,
workflow_id,
)
eval_results.append(EvaluationItemResult(index=idx))
return eval_results
@staticmethod
def _build_workflow_inputs(
input_fields: dict[str, Any],
node_run_result_mapping: dict[str, NodeRunResult],
) -> dict[str, Any]:
"""Build customized workflow inputs by resolving value sources.
Each entry in ``input_fields`` maps a workflow input variable name
to its value source, which can be:
- **Constant**: a plain string without ``{{#…#}}`` used as-is.
- **Expression**: a string containing one or more
``{{#node_id.output_key#}}`` selectors (same format as
``VariableTemplateParser``) resolved from
``node_run_result_mapping``.
"""
from graphon.nodes.base.variable_template_parser import REGEX as VARIABLE_REGEX
workflow_inputs: dict[str, Any] = {}
for field_name, value_source in input_fields.items():
if not isinstance(value_source, str):
# Non-string values (numbers, bools, dicts) are used directly.
workflow_inputs[field_name] = value_source
continue
# Check if the entire value is a single expression.
full_match = VARIABLE_REGEX.fullmatch(value_source)
if full_match:
workflow_inputs[field_name] = resolve_variable_selector(
full_match.group(1),
node_run_result_mapping,
)
elif VARIABLE_REGEX.search(value_source):
# Mixed template: interpolate all expressions as strings.
workflow_inputs[field_name] = VARIABLE_REGEX.sub(
lambda m: str(resolve_variable_selector(m.group(1), node_run_result_mapping)),
value_source,
)
else:
# Plain constant — no expression markers.
workflow_inputs[field_name] = value_source
return workflow_inputs
@staticmethod
def _extract_workflow_metrics(
response: Mapping[str, object],
evaluation_workflow_id: str,
) -> list[EvaluationMetric]:
"""Extract evaluation metrics from workflow output variables.
Each metric's ``node_info`` is set with *evaluation_workflow_id* as
the ``node_id``, so that judgment conditions can reference customized
metrics via ``variable_selector: [evaluation_workflow_id, metric_name]``.
"""
metrics: list[EvaluationMetric] = []
node_info = NodeInfo(node_id=evaluation_workflow_id, type="customized", title="customized")
data = response.get("data")
if not isinstance(data, Mapping):
logger.warning("Unexpected workflow response format: missing 'data' dict")
return metrics
outputs = data.get("outputs")
if not isinstance(outputs, dict):
logger.warning("Unexpected workflow response format: 'outputs' is not a dict")
return metrics
for key, raw_value in outputs.items():
if not isinstance(key, str):
continue
metrics.append(EvaluationMetric(name=key, value=raw_value, node_info=node_info))
return metrics
def resolve_variable_selector(
selector_raw: str,
node_run_result_mapping: dict[str, NodeRunResult],
) -> object:
"""
Resolve a ``#node_id.output_key#`` selector against node run results.
"""
#
cleaned = selector_raw.strip("#")
parts = cleaned.split(".")
if len(parts) < 2:
logger.warning(
"Selector '%s' must have at least node_id.output_key",
selector_raw,
)
return ""
node_id = parts[0]
output_path = parts[1:]
node_result = node_run_result_mapping.get(node_id)
if not node_result or not node_result.outputs:
logger.warning(
"Selector '%s': node '%s' not found or has no outputs",
selector_raw,
node_id,
)
return ""
# Traverse the output path to support nested keys.
current: object = node_result.outputs
for key in output_path:
if isinstance(current, Mapping):
next_val = current.get(key)
if next_val is None:
logger.warning(
"Selector '%s': key '%s' not found in node '%s' outputs",
selector_raw,
key,
node_id,
)
return ""
current = next_val
else:
logger.warning(
"Selector '%s': cannot traverse into non-dict value at key '%s'",
selector_raw,
key,
)
return ""
return current if current is not None else ""

View File

@ -1,27 +0,0 @@
from enum import StrEnum
from pydantic import BaseModel
class EvaluationFrameworkEnum(StrEnum):
RAGAS = "ragas"
DEEPEVAL = "deepeval"
NONE = "none"
class BaseEvaluationConfig(BaseModel):
"""Base configuration for evaluation frameworks."""
pass
class RagasConfig(BaseEvaluationConfig):
"""RAGAS-specific configuration."""
pass
class DeepEvalConfig(BaseEvaluationConfig):
"""DeepEval-specific configuration."""
pass

View File

@ -1,280 +0,0 @@
import json
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
from core.evaluation.entities.judgment_entity import JudgmentConfig, JudgmentResult
class EvaluationCategory(StrEnum):
LLM = "llm"
RETRIEVAL = "knowledge_retrieval"
AGENT = "agent"
WORKFLOW = "workflow"
SNIPPET = "snippet"
KNOWLEDGE_BASE = "knowledge_base"
class EvaluationMetricName(StrEnum):
"""Canonical metric names shared across all evaluation frameworks.
Each framework maps these names to its own internal implementation.
A framework that does not support a given metric should log a warning
and skip it rather than raising an error.
── LLM / general text-quality metrics ──────────────────────────────────
FAITHFULNESS
Measures whether every claim in the model's response is grounded in
the provided retrieved context. A high score means the answer
contains no hallucinated content — each statement can be traced back
to a passage in the context.
Required fields: user_input, response, retrieved_contexts.
ANSWER_RELEVANCY
Measures how well the model's response addresses the user's question.
A high score means the answer stays on-topic; a low score indicates
irrelevant content or a failure to answer the actual question.
Required fields: user_input, response.
ANSWER_CORRECTNESS
Measures the factual accuracy and completeness of the model's answer
relative to a ground-truth reference. It combines semantic similarity
with key-fact coverage, so both meaning and content matter.
Required fields: user_input, response, reference (expected_output).
SEMANTIC_SIMILARITY
Measures the cosine similarity between the model's response and the
reference answer in an embedding space. It evaluates whether the two
texts convey the same meaning, independent of factual correctness.
Required fields: response, reference (expected_output).
── Retrieval-quality metrics ────────────────────────────────────────────
CONTEXT_PRECISION
Measures the proportion of retrieved context chunks that are actually
relevant to the question (precision). A high score means the retrieval
pipeline returns little noise.
Required fields: user_input, reference, retrieved_contexts.
CONTEXT_RECALL
Measures the proportion of ground-truth information that is covered by
the retrieved context chunks (recall). A high score means the retrieval
pipeline does not miss important supporting evidence.
Required fields: user_input, reference, retrieved_contexts.
CONTEXT_RELEVANCE
Measures how relevant each individual retrieved chunk is to the query.
Similar to CONTEXT_PRECISION but evaluated at the chunk level rather
than against a reference answer.
Required fields: user_input, retrieved_contexts.
── Agent-quality metrics ────────────────────────────────────────────────
TOOL_CORRECTNESS
Measures the correctness of the tool calls made by the agent during
task execution — both the choice of tool and the arguments passed.
A high score means the agent's tool-use strategy matches the expected
behavior.
Required fields: actual tool calls vs. expected tool calls.
TASK_COMPLETION
Measures whether the agent ultimately achieves the user's stated goal.
It evaluates the reasoning chain, intermediate steps, and final output
holistically; a high score means the task was fully accomplished.
Required fields: user_input, actual_output.
"""
# LLM / general text-quality metrics
FAITHFULNESS = "faithfulness"
ANSWER_RELEVANCY = "answer_relevancy"
ANSWER_CORRECTNESS = "answer_correctness"
SEMANTIC_SIMILARITY = "semantic_similarity"
# Retrieval-quality metrics
CONTEXT_PRECISION = "context_precision"
CONTEXT_RECALL = "context_recall"
CONTEXT_RELEVANCE = "context_relevance"
# Agent-quality metrics
TOOL_CORRECTNESS = "tool_correctness"
TASK_COMPLETION = "task_completion"
# Per-category canonical metric lists used by get_supported_metrics().
LLM_METRIC_NAMES: list[EvaluationMetricName] = [
EvaluationMetricName.FAITHFULNESS, # Every claim is grounded in context; no hallucinations
EvaluationMetricName.ANSWER_RELEVANCY, # Response stays on-topic and addresses the question
EvaluationMetricName.ANSWER_CORRECTNESS, # Factual accuracy and completeness vs. reference
EvaluationMetricName.SEMANTIC_SIMILARITY, # Semantic closeness to the reference answer
]
RETRIEVAL_METRIC_NAMES: list[EvaluationMetricName] = [
EvaluationMetricName.CONTEXT_PRECISION, # Fraction of retrieved chunks that are relevant (precision)
EvaluationMetricName.CONTEXT_RECALL, # Fraction of ground-truth info covered by retrieval (recall)
EvaluationMetricName.CONTEXT_RELEVANCE, # Per-chunk relevance to the query
]
AGENT_METRIC_NAMES: list[EvaluationMetricName] = [
EvaluationMetricName.TOOL_CORRECTNESS, # Correct tool selection and arguments
EvaluationMetricName.TASK_COMPLETION, # Whether the agent fully achieves the user's goal
]
WORKFLOW_METRIC_NAMES: list[EvaluationMetricName] = [
EvaluationMetricName.FAITHFULNESS,
EvaluationMetricName.ANSWER_RELEVANCY,
EvaluationMetricName.ANSWER_CORRECTNESS,
]
METRIC_NODE_TYPE_MAPPING: dict[str, str] = {
**{m.value: "llm" for m in LLM_METRIC_NAMES},
**{m.value: "knowledge-retrieval" for m in RETRIEVAL_METRIC_NAMES},
**{m.value: "agent" for m in AGENT_METRIC_NAMES},
}
METRIC_VALUE_TYPE_MAPPING: dict[str, str] = {
EvaluationMetricName.FAITHFULNESS: "number",
EvaluationMetricName.ANSWER_RELEVANCY: "number",
EvaluationMetricName.ANSWER_CORRECTNESS: "number",
EvaluationMetricName.SEMANTIC_SIMILARITY: "number",
EvaluationMetricName.CONTEXT_PRECISION: "number",
EvaluationMetricName.CONTEXT_RECALL: "number",
EvaluationMetricName.CONTEXT_RELEVANCE: "number",
EvaluationMetricName.TOOL_CORRECTNESS: "number",
EvaluationMetricName.TASK_COMPLETION: "number",
}
class NodeInfo(BaseModel):
node_id: str
type: str
title: str
class EvaluationMetric(BaseModel):
name: str
value: Any
details: dict[str, Any] = Field(default_factory=dict)
node_info: NodeInfo | None = None
class EvaluationItemInput(BaseModel):
index: int
inputs: dict[str, Any]
output: str
expected_output: str | None = None
context: list[str] | None = None
class EvaluationDatasetInput(BaseModel):
"""Parsed dataset row used throughout evaluation execution.
``expected_output`` keeps backward compatibility with the original
single-reference template. When users upload node-specific reference
columns such as ``LLM 1 : expected_output``, they are stored in
``expected_outputs`` and resolved by node title at execution time.
"""
index: int
inputs: dict[str, Any]
expected_output: str | None = None
expected_outputs: dict[str, str] = Field(default_factory=dict)
def get_expected_output_for_node(self, node_title: str | None) -> str | None:
"""Return the best matching reference answer for the given node title."""
if node_title:
if node_title in self.expected_outputs:
return self.expected_outputs[node_title]
if self.expected_output is not None:
return self.expected_output
if len(self.expected_outputs) == 1:
return next(iter(self.expected_outputs.values()))
return None
def serialize_expected_output(self) -> str | None:
"""Serialize references for persistence and API responses.
Single-reference datasets stay unchanged, while multi-node references
are stored as JSON so history/detail APIs can still expose the full
uploaded payload without changing the database schema.
"""
if self.expected_output is not None and not self.expected_outputs:
return self.expected_output
if not self.expected_outputs:
return None
serialized_expected_outputs = dict(self.expected_outputs)
if self.expected_output is not None:
serialized_expected_outputs = {"expected_output": self.expected_output, **serialized_expected_outputs}
return json.dumps(serialized_expected_outputs, ensure_ascii=False, sort_keys=True)
def iter_expected_output_columns(self) -> list[tuple[str, str]]:
"""Return uploaded expected-output columns in display order."""
columns: list[tuple[str, str]] = []
if self.expected_output is not None:
columns.append(("expected_output", self.expected_output))
for node_title, value in self.expected_outputs.items():
columns.append((f"{node_title} : expected_output", value))
return columns
class EvaluationItemResult(BaseModel):
index: int
actual_output: str | None = None
metrics: list[EvaluationMetric] = Field(default_factory=list)
metadata: dict[str, Any] = Field(default_factory=dict)
judgment: JudgmentResult = Field(default_factory=JudgmentResult)
error: str | None = None
class DefaultMetric(BaseModel):
metric: str
value_type: str = ""
node_info_list: list[NodeInfo]
class CustomizedMetricOutputField(BaseModel):
variable: str
value_type: str
class CustomizedMetrics(BaseModel):
evaluation_workflow_id: str
input_fields: dict[str, Any]
output_fields: list[CustomizedMetricOutputField]
class EvaluationConfigData(BaseModel):
"""Structured data for saving evaluation configuration."""
evaluation_model: str = ""
evaluation_model_provider: str = ""
default_metrics: list[DefaultMetric] = Field(default_factory=list)
customized_metrics: CustomizedMetrics | None = None
judgment_config: JudgmentConfig | None = None
class EvaluationRunRequest(EvaluationConfigData):
"""Request body for starting an evaluation run."""
file_id: str
class EvaluationRunData(BaseModel):
"""Serializable data for Celery task."""
evaluation_run_id: str
tenant_id: str
target_type: str
target_id: str
evaluation_model_provider: str
evaluation_model: str
default_metrics: list[DefaultMetric] = Field(default_factory=list)
customized_metrics: CustomizedMetrics | None = None
judgment_config: JudgmentConfig | None = None
input_list: list[EvaluationDatasetInput]

View File

@ -1,118 +0,0 @@
"""Judgment condition entities for evaluation metric assessment.
Condition structure mirrors the workflow if-else ``Condition`` model from
``graphon.utils.condition.entities``. The left-hand side uses
``variable_selector`` — a two-element list ``[node_id, metric_name]`` — to
uniquely identify an evaluation metric (different nodes may produce metrics
with the same name).
Operators reuse ``SupportedComparisonOperator`` from the workflow engine so
that type semantics stay consistent across the platform.
Typical usage::
judgment_config = JudgmentConfig(
logical_operator="and",
conditions=[
JudgmentCondition(
variable_selector=["node_abc", "faithfulness"],
comparison_operator=">",
value="0.8",
)
],
)
"""
from collections.abc import Sequence
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from graphon.utils.condition.entities import SupportedComparisonOperator
COMPARISON_OPERATOR_ALIASES: dict[str, str] = {
"==": "=",
"!=": "",
">=": "",
"<=": "",
"is null": "null",
"is not null": "not null",
}
class JudgmentCondition(BaseModel):
"""A single judgment condition that checks one metric value.
Mirrors ``graphon.utils.condition.entities.Condition`` with the left-hand
side being a metric selector instead of a workflow variable selector.
Attributes:
variable_selector: ``[node_id, metric_name]`` identifying the metric.
comparison_operator: Reuses workflow's ``SupportedComparisonOperator``.
value: The comparison target (right side). For unary operators such
as ``empty`` or ``null`` this can be ``None``.
"""
variable_selector: list[str]
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | bool | None = None
@field_validator("comparison_operator", mode="before")
@classmethod
def normalize_comparison_operator(cls, value: Any) -> Any:
"""Accept common ASCII/API aliases for workflow comparison operators."""
if not isinstance(value, str):
return value
normalized_value = value.strip().lower()
alias = COMPARISON_OPERATOR_ALIASES.get(normalized_value)
if alias is not None:
return alias
return value.strip()
class JudgmentConfig(BaseModel):
"""A group of judgment conditions combined with a logical operator.
Attributes:
logical_operator: How to combine condition results — "and" requires
all conditions to pass, "or" requires at least one.
conditions: The list of individual conditions to evaluate.
"""
logical_operator: Literal["and", "or"] = "and"
conditions: list[JudgmentCondition] = Field(default_factory=list)
class JudgmentConditionResult(BaseModel):
"""Result of evaluating a single judgment condition.
Attributes:
variable_selector: ``[node_id, metric_name]`` that was checked.
comparison_operator: The operator that was applied.
expected_value: The resolved comparison value.
actual_value: The actual metric value that was evaluated.
passed: Whether this individual condition passed.
error: Error message if the condition evaluation failed.
"""
variable_selector: list[str]
comparison_operator: str
expected_value: Any = None
actual_value: Any = None
passed: bool = False
error: str | None = None
class JudgmentResult(BaseModel):
"""Overall result of evaluating all judgment conditions for one item.
Attributes:
passed: Whether the overall judgment passed (based on logical_operator).
logical_operator: The logical operator used to combine conditions.
condition_results: Detailed result for each individual condition.
"""
passed: bool = False
logical_operator: Literal["and", "or"] = "and"
condition_results: list[JudgmentConditionResult] = Field(default_factory=list)

View File

@ -1,61 +0,0 @@
import collections
import logging
from typing import Any
from configs import dify_config
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.config_entity import EvaluationFrameworkEnum
from core.evaluation.entities.evaluation_entity import EvaluationCategory
logger = logging.getLogger(__name__)
class EvaluationFrameworkConfigMap(collections.UserDict[str, dict[str, Any]]):
"""Registry mapping framework enum -> {config_class, evaluator_class}."""
def __getitem__(self, framework: str) -> dict[str, Any]:
match framework:
case EvaluationFrameworkEnum.RAGAS:
from core.evaluation.entities.config_entity import RagasConfig
from core.evaluation.frameworks.ragas.ragas_evaluator import RagasEvaluator
return {
"config_class": RagasConfig,
"evaluator_class": RagasEvaluator,
}
case EvaluationFrameworkEnum.DEEPEVAL:
raise NotImplementedError("DeepEval adapter is not yet implemented.")
case _:
raise ValueError(f"Unknown evaluation framework: {framework}")
evaluation_framework_config_map = EvaluationFrameworkConfigMap()
class EvaluationManager:
"""Factory for evaluation instances based on global configuration."""
@staticmethod
def get_evaluation_instance() -> BaseEvaluationInstance | None:
"""Create and return an evaluation instance based on EVALUATION_FRAMEWORK env var."""
framework = dify_config.EVALUATION_FRAMEWORK
if not framework or framework == EvaluationFrameworkEnum.NONE:
return None
try:
config_map = evaluation_framework_config_map[framework]
evaluator_class = config_map["evaluator_class"]
config_class = config_map["config_class"]
config = config_class()
return evaluator_class(config)
except Exception:
logger.exception("Failed to create evaluation instance for framework: %s", framework)
return None
@staticmethod
def get_supported_metrics(category: EvaluationCategory) -> list[str]:
"""Return supported metrics for the current framework and given category."""
instance = EvaluationManager.get_evaluation_instance()
if instance is None:
return []
return instance.get_supported_metrics(category)

View File

@ -1,299 +0,0 @@
import logging
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.config_entity import DeepEvalConfig
from core.evaluation.entities.evaluation_entity import (
AGENT_METRIC_NAMES,
LLM_METRIC_NAMES,
RETRIEVAL_METRIC_NAMES,
WORKFLOW_METRIC_NAMES,
EvaluationCategory,
EvaluationItemInput,
EvaluationItemResult,
EvaluationMetric,
EvaluationMetricName,
)
from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper
logger = logging.getLogger(__name__)
# Maps canonical EvaluationMetricName to the corresponding deepeval metric class name.
# deepeval metric field requirements (LLMTestCase fields):
# - faithfulness: input, actual_output, retrieval_context
# - answer_relevancy: input, actual_output
# - context_precision: input, actual_output, expected_output, retrieval_context
# - context_recall: input, actual_output, expected_output, retrieval_context
# - context_relevance: input, actual_output, retrieval_context
# - tool_correctness: input, actual_output, expected_tools
# - task_completion: input, actual_output
# Metrics not listed here are unsupported by deepeval and will be skipped.
_DEEPEVAL_METRIC_MAP: dict[EvaluationMetricName, str] = {
EvaluationMetricName.FAITHFULNESS: "FaithfulnessMetric",
EvaluationMetricName.ANSWER_RELEVANCY: "AnswerRelevancyMetric",
EvaluationMetricName.CONTEXT_PRECISION: "ContextualPrecisionMetric",
EvaluationMetricName.CONTEXT_RECALL: "ContextualRecallMetric",
EvaluationMetricName.CONTEXT_RELEVANCE: "ContextualRelevancyMetric",
EvaluationMetricName.TOOL_CORRECTNESS: "ToolCorrectnessMetric",
EvaluationMetricName.TASK_COMPLETION: "TaskCompletionMetric",
}
class DeepEvalEvaluator(BaseEvaluationInstance):
"""DeepEval framework adapter for evaluation."""
def __init__(self, config: DeepEvalConfig):
self.config = config
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
match category:
case EvaluationCategory.LLM:
candidates = LLM_METRIC_NAMES
case EvaluationCategory.RETRIEVAL:
candidates = RETRIEVAL_METRIC_NAMES
case EvaluationCategory.AGENT:
candidates = AGENT_METRIC_NAMES
case EvaluationCategory.WORKFLOW | EvaluationCategory.SNIPPET:
candidates = WORKFLOW_METRIC_NAMES
case _:
return []
return [m for m in candidates if m in _DEEPEVAL_METRIC_MAP]
def evaluate_llm(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
def evaluate_retrieval(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
def evaluate_agent(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
def evaluate_workflow(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
def _evaluate(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Core evaluation logic using DeepEval."""
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
requested_metrics = metric_names or self.get_supported_metrics(category)
try:
return self._evaluate_with_deepeval(items, requested_metrics, category)
except ImportError:
logger.warning("DeepEval not installed, falling back to simple evaluation")
return self._evaluate_simple(items, requested_metrics, model_wrapper)
def _evaluate_with_deepeval(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Evaluate using DeepEval library.
Builds LLMTestCase differently per category:
- LLM/Workflow: input=prompt, actual_output=output, retrieval_context=context
- Retrieval: input=query, actual_output=output, expected_output, retrieval_context=context
- Agent: input=query, actual_output=output
"""
metric_pairs = _build_deepeval_metrics(requested_metrics)
if not metric_pairs:
logger.warning("No valid DeepEval metrics found for: %s", requested_metrics)
return [EvaluationItemResult(index=item.index) for item in items]
results: list[EvaluationItemResult] = []
for item in items:
test_case = self._build_test_case(item, category)
metrics: list[EvaluationMetric] = []
for canonical_name, metric in metric_pairs:
try:
metric.measure(test_case)
if metric.score is not None:
metrics.append(EvaluationMetric(name=canonical_name, value=float(metric.score)))
except Exception:
logger.exception(
"Failed to compute metric %s for item %d",
canonical_name,
item.index,
)
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
return results
@staticmethod
def _build_test_case(item: EvaluationItemInput, category: EvaluationCategory) -> Any:
"""Build a deepeval LLMTestCase with the correct fields per category."""
from deepeval.test_case import LLMTestCase
user_input = _format_input(item.inputs, category)
match category:
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
# faithfulness needs: input, actual_output, retrieval_context
# answer_relevancy needs: input, actual_output
return LLMTestCase(
input=user_input,
actual_output=item.output,
expected_output=item.expected_output or None,
retrieval_context=item.context or None,
)
case EvaluationCategory.RETRIEVAL:
# contextual_precision/recall needs: input, actual_output, expected_output, retrieval_context
return LLMTestCase(
input=user_input,
actual_output=item.output or "",
expected_output=item.expected_output or "",
retrieval_context=item.context or [],
)
case _:
return LLMTestCase(
input=user_input,
actual_output=item.output,
)
def _evaluate_simple(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
model_wrapper: DifyModelWrapper,
) -> list[EvaluationItemResult]:
"""Simple LLM-as-judge fallback when DeepEval is not available."""
results: list[EvaluationItemResult] = []
for item in items:
metrics: list[EvaluationMetric] = []
for m_name in requested_metrics:
try:
score = self._judge_with_llm(model_wrapper, m_name, item)
metrics.append(EvaluationMetric(name=m_name, value=score))
except Exception:
logger.exception("Failed to compute metric %s for item %d", m_name, item.index)
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
return results
def _judge_with_llm(
self,
model_wrapper: DifyModelWrapper,
metric_name: str,
item: EvaluationItemInput,
) -> float:
"""Use the LLM to judge a single metric for a single item."""
prompt = self._build_judge_prompt(metric_name, item)
response = model_wrapper.invoke(prompt)
return self._parse_score(response)
@staticmethod
def _build_judge_prompt(metric_name: str, item: EvaluationItemInput) -> str:
"""Build a scoring prompt for the LLM judge."""
parts = [
f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.",
f"\nInput: {item.inputs}",
f"\nOutput: {item.output}",
]
if item.expected_output:
parts.append(f"\nExpected Output: {item.expected_output}")
if item.context:
parts.append(f"\nContext: {'; '.join(item.context)}")
parts.append("\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else.")
return "\n".join(parts)
@staticmethod
def _parse_score(response: str) -> float:
"""Parse a float score from LLM response."""
import re
cleaned = response.strip()
try:
score = float(cleaned)
return max(0.0, min(1.0, score))
except ValueError:
match = re.search(r"(\d+\.?\d*)", cleaned)
if match:
score = float(match.group(1))
return max(0.0, min(1.0, score))
return 0.0
def _format_input(inputs: dict[str, Any], category: EvaluationCategory) -> str:
"""Extract the user-facing input string from the inputs dict."""
match category:
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
return str(inputs.get("prompt", ""))
case EvaluationCategory.RETRIEVAL:
return str(inputs.get("query", ""))
case _:
return str(next(iter(inputs.values()), "")) if inputs else ""
def _build_deepeval_metrics(requested_metrics: list[str]) -> list[tuple[str, Any]]:
"""Build DeepEval metric instances from canonical metric names.
Returns a list of (canonical_name, metric_instance) pairs so that callers
can record the canonical name rather than the framework-internal class name.
"""
try:
from deepeval.metrics import (
AnswerRelevancyMetric,
ContextualPrecisionMetric,
ContextualRecallMetric,
ContextualRelevancyMetric,
FaithfulnessMetric,
TaskCompletionMetric,
ToolCorrectnessMetric,
)
# Maps canonical name → deepeval metric class
deepeval_class_map: dict[str, Any] = {
EvaluationMetricName.FAITHFULNESS: FaithfulnessMetric,
EvaluationMetricName.ANSWER_RELEVANCY: AnswerRelevancyMetric,
EvaluationMetricName.CONTEXT_PRECISION: ContextualPrecisionMetric,
EvaluationMetricName.CONTEXT_RECALL: ContextualRecallMetric,
EvaluationMetricName.CONTEXT_RELEVANCE: ContextualRelevancyMetric,
EvaluationMetricName.TOOL_CORRECTNESS: ToolCorrectnessMetric,
EvaluationMetricName.TASK_COMPLETION: TaskCompletionMetric,
}
pairs: list[tuple[str, Any]] = []
for name in requested_metrics:
metric_class = deepeval_class_map.get(name)
if metric_class:
pairs.append((name, metric_class(threshold=0.5)))
else:
logger.warning("Metric '%s' is not supported by DeepEval, skipping", name)
return pairs
except ImportError:
logger.warning("DeepEval metrics not available")
return []

View File

@ -1,345 +0,0 @@
import logging
from importlib import import_module
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.config_entity import RagasConfig
from core.evaluation.entities.evaluation_entity import (
AGENT_METRIC_NAMES,
LLM_METRIC_NAMES,
RETRIEVAL_METRIC_NAMES,
WORKFLOW_METRIC_NAMES,
EvaluationCategory,
EvaluationItemInput,
EvaluationItemResult,
EvaluationMetric,
EvaluationMetricName,
)
from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper
logger = logging.getLogger(__name__)
# Maps canonical EvaluationMetricName to the corresponding ragas metric class.
# Metrics not listed here are unsupported by ragas and will be skipped.
_RAGAS_METRIC_MAP: dict[EvaluationMetricName, str] = {
EvaluationMetricName.FAITHFULNESS: "Faithfulness",
EvaluationMetricName.ANSWER_RELEVANCY: "AnswerRelevancy",
EvaluationMetricName.ANSWER_CORRECTNESS: "AnswerCorrectness",
EvaluationMetricName.SEMANTIC_SIMILARITY: "SemanticSimilarity",
EvaluationMetricName.CONTEXT_PRECISION: "ContextPrecision",
EvaluationMetricName.CONTEXT_RECALL: "ContextRecall",
EvaluationMetricName.CONTEXT_RELEVANCE: "ContextRelevance",
EvaluationMetricName.TOOL_CORRECTNESS: "ToolCallAccuracy",
}
class RagasEvaluator(BaseEvaluationInstance):
"""RAGAS framework adapter for evaluation."""
def __init__(self, config: RagasConfig):
self.config = config
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
match category:
case EvaluationCategory.LLM:
candidates = LLM_METRIC_NAMES
case EvaluationCategory.RETRIEVAL:
candidates = RETRIEVAL_METRIC_NAMES
case EvaluationCategory.AGENT:
candidates = AGENT_METRIC_NAMES
case EvaluationCategory.WORKFLOW | EvaluationCategory.SNIPPET:
candidates = WORKFLOW_METRIC_NAMES
case _:
return []
return [m for m in candidates if m in _RAGAS_METRIC_MAP]
def evaluate_llm(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
def evaluate_retrieval(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
def evaluate_agent(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
def evaluate_workflow(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
def _evaluate(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Core evaluation logic using RAGAS."""
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
requested_metrics = metric_names or self.get_supported_metrics(category)
try:
return self._evaluate_with_ragas(items, requested_metrics, model_wrapper, category)
except ImportError:
logger.warning("RAGAS not installed, falling back to simple evaluation")
return self._evaluate_simple(items, requested_metrics, model_wrapper)
def _evaluate_with_ragas(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
model_wrapper: DifyModelWrapper,
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Evaluate using RAGAS library.
Builds SingleTurnSample differently per category to match ragas requirements:
- LLM/Workflow: user_input=prompt, response=output, reference=expected_output
- Retrieval: user_input=query, reference=expected_output, retrieved_contexts=context
- Agent: Not supported via EvaluationDataset (requires message-based API)
"""
from ragas import evaluate as ragas_evaluate
from ragas.dataset_schema import EvaluationDataset
samples: list[Any] = []
for item in items:
sample = self._build_sample(item, category)
samples.append(sample)
dataset = EvaluationDataset(samples=samples)
ragas_metrics = self._build_ragas_metrics(requested_metrics)
if not ragas_metrics:
logger.warning("No valid RAGAS metrics found for: %s", requested_metrics)
return [EvaluationItemResult(index=item.index, actual_output=item.output) for item in items]
try:
result = ragas_evaluate(
dataset=dataset,
metrics=ragas_metrics,
llm=model_wrapper,
)
results: list[EvaluationItemResult] = []
result_df = result.to_pandas()
for i, item in enumerate(items):
metrics: list[EvaluationMetric] = []
for m_name in requested_metrics:
if m_name in result_df.columns:
score = result_df.iloc[i][m_name]
if score is not None and not (isinstance(score, float) and score != score):
metrics.append(EvaluationMetric(name=m_name, value=float(score)))
results.append(EvaluationItemResult(index=item.index, metrics=metrics, actual_output=item.output))
return results
except Exception:
logger.exception("RAGAS evaluation failed, falling back to simple evaluation")
return self._evaluate_simple(items, requested_metrics, model_wrapper)
@staticmethod
def _build_sample(item: EvaluationItemInput, category: EvaluationCategory) -> Any:
"""Build a ragas SingleTurnSample with the correct fields per category.
ragas metric field requirements:
- faithfulness: user_input, response, retrieved_contexts
- answer_relevancy: user_input, response
- answer_correctness: user_input, response, reference
- semantic_similarity: user_input, response, reference
- context_precision: user_input, reference, retrieved_contexts
- context_recall: user_input, reference, retrieved_contexts
- context_relevance: user_input, retrieved_contexts
"""
from ragas.dataset_schema import SingleTurnSample
user_input = _format_input(item.inputs, category)
match category:
case EvaluationCategory.LLM:
# response = actual LLM output, reference = expected output
return SingleTurnSample(
user_input=user_input,
response=item.output,
reference=item.expected_output or "",
retrieved_contexts=item.context or [],
)
case EvaluationCategory.RETRIEVAL:
# context_precision/recall only need reference + retrieved_contexts
return SingleTurnSample(
user_input=user_input,
reference=item.expected_output or "",
retrieved_contexts=item.context or [],
)
case _:
return SingleTurnSample(
user_input=user_input,
response=item.output,
)
def _evaluate_simple(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
model_wrapper: DifyModelWrapper,
) -> list[EvaluationItemResult]:
"""Simple LLM-as-judge fallback when RAGAS is not available."""
results: list[EvaluationItemResult] = []
for item in items:
metrics: list[EvaluationMetric] = []
for m_name in requested_metrics:
try:
score = self._judge_with_llm(model_wrapper, m_name, item)
metrics.append(EvaluationMetric(name=m_name, value=score))
except Exception:
logger.exception("Failed to compute metric %s for item %d", m_name, item.index)
results.append(EvaluationItemResult(index=item.index, metrics=metrics, actual_output=item.output))
return results
def _judge_with_llm(
self,
model_wrapper: DifyModelWrapper,
metric_name: str,
item: EvaluationItemInput,
) -> float:
"""Use the LLM to judge a single metric for a single item."""
prompt = self._build_judge_prompt(metric_name, item)
response = model_wrapper.invoke(prompt)
return self._parse_score(response)
@staticmethod
def _build_judge_prompt(metric_name: str, item: EvaluationItemInput) -> str:
"""Build a scoring prompt for the LLM judge."""
parts = [
f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.",
f"\nInput: {item.inputs}",
f"\nOutput: {item.output}",
]
if item.expected_output:
parts.append(f"\nExpected Output: {item.expected_output}")
if item.context:
parts.append(f"\nContext: {'; '.join(item.context)}")
parts.append("\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else.")
return "\n".join(parts)
@staticmethod
def _parse_score(response: str) -> float:
"""Parse a float score from LLM response."""
import re
cleaned = response.strip()
try:
score = float(cleaned)
return max(0.0, min(1.0, score))
except ValueError:
match = re.search(r"(\d+\.?\d*)", cleaned)
if match:
score = float(match.group(1))
return max(0.0, min(1.0, score))
return 0.0
@staticmethod
def _build_ragas_metrics(requested_metrics: list[str]) -> list[Any]:
"""Build RAGAS metric instances from canonical metric names."""
try:
metrics_module = _import_ragas_metrics_module()
# Maps canonical name → ragas metric class
ragas_class_map: dict[str, Any] = {
EvaluationMetricName.FAITHFULNESS: getattr(metrics_module, "Faithfulness"),
EvaluationMetricName.ANSWER_RELEVANCY: getattr(metrics_module, "AnswerRelevancy"),
EvaluationMetricName.ANSWER_CORRECTNESS: getattr(metrics_module, "AnswerCorrectness"),
EvaluationMetricName.SEMANTIC_SIMILARITY: getattr(metrics_module, "SemanticSimilarity"),
EvaluationMetricName.CONTEXT_PRECISION: getattr(metrics_module, "ContextPrecision"),
EvaluationMetricName.CONTEXT_RECALL: getattr(metrics_module, "ContextRecall"),
EvaluationMetricName.CONTEXT_RELEVANCE: getattr(metrics_module, "ContextRelevance"),
EvaluationMetricName.TOOL_CORRECTNESS: getattr(metrics_module, "ToolCallAccuracy"),
}
metrics = []
for name in requested_metrics:
metric_class = ragas_class_map.get(name)
if metric_class:
if name == EvaluationMetricName.ANSWER_CORRECTNESS:
# ragas answer_correctness blends factuality with semantic
# similarity. The latter requires an embeddings backend,
# which is not wired through Dify's evaluation stack yet.
# Keep the metric usable by relying on the factuality
# component only for now.
metrics.append(metric_class(weights=[1.0, 0.0], embeddings=_NoopRagasEmbeddings()))
else:
metrics.append(metric_class())
else:
logger.warning("Metric '%s' is not supported by RAGAS, skipping", name)
return metrics
except ImportError:
logger.warning("RAGAS metrics not available")
return []
def _import_ragas_metrics_module() -> Any:
"""Load ragas metric classes across supported ragas versions.
ragas 0.3.x exposes metric classes from ``ragas.metrics`` while some older
versions used ``ragas.metrics.collections``. Support both so worker
environments do not silently drop all metrics because of a module path
mismatch.
"""
try:
return import_module("ragas.metrics")
except ImportError:
return import_module("ragas.metrics.collections")
class _NoopRagasEmbeddings:
"""Placeholder embeddings for ragas metrics whose embedding branch is disabled.
ragas eagerly injects a default embeddings backend for any metric that
subclasses ``MetricWithEmbeddings``. For answer_correctness we currently
disable the semantic-similarity weight, so no real embedding call should
happen. Supplying this placeholder keeps ragas from constructing its
default OpenAI embeddings client during setup.
"""
async def aembed_query(self, text: str) -> list[float]:
del text
return [0.0]
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return [[0.0] for _ in texts]
def _format_input(inputs: dict[str, Any], category: EvaluationCategory) -> str:
"""Extract the user-facing input string from the inputs dict."""
match category:
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
return str(inputs.get("prompt", ""))
case EvaluationCategory.RETRIEVAL:
return str(inputs.get("query", ""))
case _:
return str(next(iter(inputs.values()), "")) if inputs else ""

View File

@ -1,165 +0,0 @@
import asyncio
import logging
from typing import Any
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.prompt_values import PromptValue
try:
from ragas.llms.base import BaseRagasLLM
except ImportError:
class BaseRagasLLM: # type: ignore[no-redef]
"""Lightweight shim so the module stays importable without ragas installed."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
del args, kwargs
@staticmethod
def get_temperature(n: int) -> float:
return 0.3 if n > 1 else 1e-8
logger = logging.getLogger(__name__)
class DifyModelWrapper(BaseRagasLLM):
"""Bridge Dify model invocation to ragas and fallback LLM-as-judge flows.
ragas can accept a custom ``BaseRagasLLM`` instance. Using one here keeps
evaluation requests on Dify's provider stack instead of falling back to
ragas' default OpenAI factory, which would require standalone environment
credentials and bypass tenant-scoped model configuration.
"""
model_provider: str
model_name: str
tenant_id: str
user_id: str | None
def __init__(self, model_provider: str, model_name: str, tenant_id: str, user_id: str | None = None):
super().__init__()
self.model_provider = model_provider
self.model_name = model_name
self.tenant_id = tenant_id
self.user_id = user_id
def _get_model_instance(self) -> Any:
from core.plugin.impl.model_runtime_factory import create_plugin_model_manager
from graphon.model_runtime.entities.model_entities import ModelType
model_manager = create_plugin_model_manager(tenant_id=self.tenant_id, user_id=self.user_id)
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=self.model_provider,
model_type=ModelType.LLM,
model=self.model_name,
)
return model_instance
def invoke(self, prompt: str) -> str:
"""Invoke the configured Dify model with a plain-text evaluation prompt."""
from graphon.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
model_instance = self._get_model_instance()
result = model_instance.invoke_llm(
prompt_messages=[
SystemPromptMessage(content="You are an evaluation judge. Answer precisely and concisely."),
UserPromptMessage(content=prompt),
],
model_parameters={"temperature": 0.0},
stream=False,
)
return result.message.content
def generate_text(
self,
prompt: PromptValue,
n: int = 1,
temperature: float = 1e-8,
stop: list[str] | None = None,
callbacks: Any = None,
) -> LLMResult:
"""Implement ragas' sync LLM interface on top of Dify's model runtime."""
del callbacks # Dify's invocation path does not currently use LangChain callbacks here.
prompt_messages = _convert_prompt_value(prompt)
model_instance = self._get_model_instance()
generations: list[list[ChatGeneration]] = [[]]
completions = max(1, n)
for _ in range(completions):
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={"temperature": temperature},
stop=stop,
stream=False,
)
text = result.message.content
generations[0].append(
ChatGeneration(
text=text,
message=AIMessage(content=text, response_metadata={"finish_reason": "stop"}),
generation_info={"finish_reason": "stop"},
)
)
return LLMResult(generations=generations)
async def agenerate_text(
self,
prompt: PromptValue,
n: int = 1,
temperature: float | None = None,
stop: list[str] | None = None,
callbacks: Any = None,
) -> LLMResult:
"""Async ragas hook backed by the sync Dify invocation path."""
return await asyncio.to_thread(
self.generate_text,
prompt,
n,
self.get_temperature(n) if temperature is None else temperature,
stop,
callbacks,
)
def _convert_prompt_value(prompt: PromptValue) -> list[Any]:
"""Translate LangChain prompt values into graphon prompt messages."""
from graphon.model_runtime.entities.message_entities import (
AssistantPromptMessage,
SystemPromptMessage,
UserPromptMessage,
)
prompt_messages: list[Any] = []
for message in prompt.to_messages():
content = _message_content_to_text(message)
if isinstance(message, SystemMessage):
prompt_messages.append(SystemPromptMessage(content=content))
elif isinstance(message, AIMessage):
prompt_messages.append(AssistantPromptMessage(content=content))
elif isinstance(message, HumanMessage):
prompt_messages.append(UserPromptMessage(content=content))
else:
prompt_messages.append(UserPromptMessage(content=content))
return prompt_messages
def _message_content_to_text(message: BaseMessage) -> str:
"""Flatten LangChain message content into a plain-text string for Dify."""
if isinstance(message.content, str):
return message.content
if isinstance(message.content, list):
parts: list[str] = []
for block in message.content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, dict):
text = block.get("text")
if text:
parts.append(str(text))
return "\n".join(part for part in parts if part)
return str(message.content or "")

View File

@ -1,160 +0,0 @@
"""Judgment condition processor for evaluation metrics.
Evaluates pass/fail judgment conditions against evaluation metric values.
Each condition uses ``variable_selector`` (``[node_id, metric_name]``) to
look up the metric value, then delegates the actual comparison to the
workflow condition engine (``graphon.utils.condition.processor``).
The processor is intentionally decoupled from evaluation frameworks and
runners. It operates on plain ``dict`` mappings and can be invoked
anywhere that already has per-item metric results.
"""
import logging
from collections.abc import Sequence
from typing import Any, cast
from core.evaluation.entities.judgment_entity import (
JudgmentCondition,
JudgmentConditionResult,
JudgmentConfig,
JudgmentResult,
)
from graphon.utils.condition.entities import SupportedComparisonOperator
from graphon.utils.condition.processor import _evaluate_condition # pyright: ignore[reportPrivateUsage]
logger = logging.getLogger(__name__)
_UNARY_OPERATORS = frozenset({"null", "not null", "empty", "not empty"})
class JudgmentProcessor:
@staticmethod
def evaluate(
metric_values: dict[tuple[str, str], Any],
config: JudgmentConfig,
) -> JudgmentResult:
"""Evaluate all judgment conditions against the given metric values.
Args:
metric_values: Mapping of ``(node_id, metric_name)`` → metric
value (e.g. ``{("node_abc", "faithfulness"): 0.85}``).
config: The judgment configuration with logical_operator and
conditions.
Returns:
JudgmentResult with overall pass/fail and per-condition details.
"""
if not config.conditions:
return JudgmentResult(
passed=True,
logical_operator=config.logical_operator,
condition_results=[],
)
condition_results: list[JudgmentConditionResult] = []
for condition in config.conditions:
result = JudgmentProcessor._evaluate_single_condition(metric_values, condition)
condition_results.append(result)
if config.logical_operator == "and" and not result.passed:
return JudgmentResult(
passed=False,
logical_operator=config.logical_operator,
condition_results=condition_results,
)
if config.logical_operator == "or" and result.passed:
return JudgmentResult(
passed=True,
logical_operator=config.logical_operator,
condition_results=condition_results,
)
if config.logical_operator == "and":
final_passed = all(r.passed for r in condition_results)
else:
final_passed = any(r.passed for r in condition_results)
return JudgmentResult(
passed=final_passed,
logical_operator=config.logical_operator,
condition_results=condition_results,
)
@staticmethod
def _evaluate_single_condition(
metric_values: dict[tuple[str, str], Any],
condition: JudgmentCondition,
) -> JudgmentConditionResult:
"""Evaluate a single judgment condition.
Steps:
1. Extract ``(node_id, metric_name)`` from ``variable_selector``.
2. Look up the metric value from ``metric_values``.
3. Delegate comparison to the workflow condition engine.
"""
selector = condition.variable_selector
if len(selector) < 2:
return JudgmentConditionResult(
variable_selector=selector,
comparison_operator=condition.comparison_operator,
expected_value=condition.value,
actual_value=None,
passed=False,
error=f"variable_selector must have at least 2 elements, got {len(selector)}",
)
node_id, metric_name = selector[0], selector[1]
actual_value = metric_values.get((node_id, metric_name))
if actual_value is None and condition.comparison_operator not in _UNARY_OPERATORS:
return JudgmentConditionResult(
variable_selector=selector,
comparison_operator=condition.comparison_operator,
expected_value=condition.value,
actual_value=None,
passed=False,
error=f"Metric '{metric_name}' on node '{node_id}' not found in evaluation results",
)
try:
expected = condition.value
# Numeric operators need the actual value coerced to int/float
# so that the workflow engine's numeric assertions work correctly.
coerced_actual: object = actual_value
if (
condition.comparison_operator in {"=", "", ">", "<", "", ""}
and actual_value is not None
and not isinstance(actual_value, (int, float, bool))
):
coerced_actual = float(actual_value)
passed = _evaluate_condition(
operator=cast(SupportedComparisonOperator, condition.comparison_operator),
value=coerced_actual,
expected=cast(str | Sequence[str] | bool | Sequence[bool] | None, expected),
)
return JudgmentConditionResult(
variable_selector=selector,
comparison_operator=condition.comparison_operator,
expected_value=expected,
actual_value=actual_value,
passed=passed,
)
except Exception as e:
logger.warning(
"Judgment condition evaluation failed for [%s, %s]: %s",
node_id,
metric_name,
str(e),
)
return JudgmentConditionResult(
variable_selector=selector,
comparison_operator=condition.comparison_operator,
expected_value=condition.value,
actual_value=actual_value,
passed=False,
error=str(e),
)

View File

@ -1,52 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from models import Account, App, CustomizedSnippet, TenantAccountJoin
def get_service_account_for_app(session: Session, app_id: str) -> Account:
"""Get the creator account for an app with tenant context set up.
This follows the same pattern as BaseTraceInstance.get_service_account_with_tenant().
"""
app = session.scalar(select(App).where(App.id == app_id))
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator")
account = session.scalar(select(Account).where(Account.id == app.created_by))
if not account:
raise ValueError(f"Creator account not found for app {app_id}")
current_tenant = session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
if not current_tenant:
raise ValueError(f"Current tenant not found for account {account.id}")
account.set_tenant_id(current_tenant.tenant_id)
return account
def get_service_account_for_snippet(session: Session, snippet_id: str) -> Account:
"""Get the creator account for a snippet with tenant context set up.
Mirrors :func:`get_service_account_for_app` but queries CustomizedSnippet.
"""
snippet = session.scalar(select(CustomizedSnippet).where(CustomizedSnippet.id == snippet_id))
if not snippet:
raise ValueError(f"Snippet with id {snippet_id} not found")
if not snippet.created_by:
raise ValueError(f"Snippet with id {snippet_id} has no creator")
account = session.scalar(select(Account).where(Account.id == snippet.created_by))
if not account:
raise ValueError(f"Creator account not found for snippet {snippet_id}")
current_tenant = session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
if not current_tenant:
raise ValueError(f"Current tenant not found for account {account.id}")
account.set_tenant_id(current_tenant.tenant_id)
return account

View File

@ -1,66 +0,0 @@
import logging
from collections.abc import Mapping
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationDatasetInput,
EvaluationItemInput,
EvaluationItemResult,
NodeInfo,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class AgentEvaluationRunner(BaseEvaluationRunner):
"""Runner for agent evaluation: collects tool calls and final output."""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
super().__init__(evaluation_instance)
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
dataset_items: list[EvaluationDatasetInput] | None = None,
node_info: NodeInfo | None = None,
) -> list[EvaluationItemResult]:
"""Compute agent evaluation metrics."""
if not node_run_result_list:
return []
merged_items = self._merge_results_into_items(node_run_result_list)
return self.evaluation_instance.evaluate_agent(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
"""Create EvaluationItemInput list from NodeRunResult for agent evaluation."""
merged = []
for i, item in enumerate(items):
output = _extract_agent_output(item.outputs)
merged.append(
EvaluationItemInput(
index=i,
inputs=dict(item.inputs),
output=output,
)
)
return merged
def _extract_agent_output(outputs: Mapping[str, Any]) -> str:
"""Extract the primary output text from agent NodeRunResult.outputs."""
if "answer" in outputs:
return str(outputs["answer"])
if "text" in outputs:
return str(outputs["text"])
values = list(outputs.values())
return str(values[0]) if values else ""

View File

@ -1,55 +0,0 @@
"""Base evaluation runner.
Provides the abstract interface for metric computation. Each concrete runner
(LLM, Retrieval, Agent, Workflow, Snippet) implements ``evaluate_metrics``
to compute scores for a specific node type.
Orchestration (merging results from multiple runners, applying judgment, and
persisting to the database) is handled by the evaluation task, not the runner.
"""
import logging
from abc import ABC, abstractmethod
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationDatasetInput,
NodeInfo,
EvaluationItemResult,
)
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class BaseEvaluationRunner(ABC):
"""Abstract base class for evaluation runners.
Runners are stateless metric calculators: they receive node execution
results and a metric specification, then return scored results. They
do **not** touch the database or apply judgment logic.
"""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
self.evaluation_instance = evaluation_instance
@abstractmethod
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
dataset_items: list[EvaluationDatasetInput] | None = None,
node_info: NodeInfo | None = None,
) -> list[EvaluationItemResult]:
"""Compute evaluation metrics on the collected results.
The returned ``EvaluationItemResult.index`` values are positional
(0-based) and correspond to the order of *node_run_result_list*.
The caller is responsible for mapping them back to the original
dataset indices.
"""
...

View File

@ -1,107 +0,0 @@
import logging
import re
from collections.abc import Mapping
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationDatasetInput,
EvaluationItemInput,
EvaluationItemResult,
NodeInfo,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class LLMEvaluationRunner(BaseEvaluationRunner):
"""Runner for LLM evaluation: extracts prompts/outputs then evaluates."""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
super().__init__(evaluation_instance)
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
dataset_items: list[EvaluationDatasetInput] | None = None,
node_info: NodeInfo | None = None,
) -> list[EvaluationItemResult]:
"""Use the evaluation instance to compute LLM metrics."""
if not node_run_result_list:
return []
merged_items = self._merge_results_into_items(node_run_result_list, dataset_items, node_info)
return self.evaluation_instance.evaluate_llm(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _merge_results_into_items(
items: list[NodeRunResult],
dataset_items: list[EvaluationDatasetInput] | None = None,
node_info: NodeInfo | None = None,
) -> list[EvaluationItemInput]:
"""Create new items from NodeRunResult for ragas evaluation.
Extracts prompts from process_data and concatenates them into a single
string with role prefixes (e.g. "system: ...\nuser: ...\nassistant: ...").
The last assistant message in outputs is used as the actual output.
"""
merged = []
for i, item in enumerate(items):
prompts = item.process_data.get("prompts", [])
prompt = _format_prompts(prompts)
output = _extract_llm_output(item.outputs)
dataset_item = dataset_items[i] if dataset_items and i < len(dataset_items) else None
merged.append(
EvaluationItemInput(
index=i,
inputs={"prompt": prompt},
output=output,
expected_output=dataset_item.get_expected_output_for_node(node_info.title) if dataset_item else None,
context=_extract_context_blocks(prompts),
)
)
return merged
def _format_prompts(prompts: list[dict[str, Any]]) -> str:
"""Concatenate a list of prompt messages into a single string for evaluation.
Each message is formatted as "role: text" and joined with newlines.
"""
parts: list[str] = []
for msg in prompts:
role = msg.get("role", "unknown")
text = msg.get("text", "")
parts.append(f"{role}: {text}")
return "\n".join(parts)
def _extract_llm_output(outputs: Mapping[str, Any]) -> str:
"""Extract the LLM output text from NodeRunResult.outputs."""
if "text" in outputs:
return str(outputs["text"])
if "answer" in outputs:
return str(outputs["answer"])
values = list(outputs.values())
return str(values[0]) if values else ""
def _extract_context_blocks(prompts: list[dict[str, Any]]) -> list[str] | None:
"""Extract tagged context blocks from rendered prompts.
Evaluation only treats prompt content wrapped in ``<context>...</context>``
as retrieved evidence. This keeps faithfulness-style metrics opt-in and
avoids guessing which arbitrary prompt text should be considered context.
"""
prompt_text = "\n".join(str(prompt.get("text", "")) for prompt in prompts)
matches = re.findall(r"<context>(.*?)</context>", prompt_text, flags=re.DOTALL)
contexts = [match.strip() for match in matches if match.strip()]
return contexts or None

View File

@ -1,67 +0,0 @@
import logging
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationDatasetInput,
EvaluationItemInput,
EvaluationItemResult,
NodeInfo,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class RetrievalEvaluationRunner(BaseEvaluationRunner):
"""Runner for retrieval evaluation: performs knowledge base retrieval, then evaluates."""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
super().__init__(evaluation_instance)
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
dataset_items: list[EvaluationDatasetInput] | None = None,
node_info: NodeInfo | None = None,
) -> list[EvaluationItemResult]:
"""Compute retrieval evaluation metrics."""
if not node_run_result_list:
return []
merged_items = []
for i, node_result in enumerate(node_run_result_list):
outputs = node_result.outputs
query = self._extract_query(dict(node_result.inputs))
result_list = outputs.get("result", [])
contexts = [item.get("content", "") for item in result_list if item.get("content")]
output = "\n---\n".join(contexts)
dataset_item = dataset_items[i] if dataset_items and i < len(dataset_items) else None
merged_items.append(
EvaluationItemInput(
index=i,
inputs={"query": query},
output=output,
expected_output=dataset_item.get_expected_output_for_node(node_info.title) if dataset_item else None,
context=contexts,
)
)
return self.evaluation_instance.evaluate_retrieval(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _extract_query(inputs: dict[str, Any]) -> str:
for key in ("query", "question", "input", "text"):
if key in inputs:
return str(inputs[key])
values = list(inputs.values())
return str(values[0]) if values else ""

View File

@ -1,72 +0,0 @@
"""Runner for Snippet evaluation.
Snippets are essentially workflows, so we reuse ``evaluate_workflow`` from
the evaluation instance for metric computation.
"""
import logging
from collections.abc import Mapping
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationDatasetInput,
EvaluationItemInput,
EvaluationItemResult,
NodeInfo,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class SnippetEvaluationRunner(BaseEvaluationRunner):
"""Runner for snippet evaluation: evaluates a published Snippet workflow."""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
super().__init__(evaluation_instance)
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
dataset_items: list[EvaluationDatasetInput] | None = None,
node_info: NodeInfo | None = None,
) -> list[EvaluationItemResult]:
"""Compute evaluation metrics for snippet outputs."""
if not node_run_result_list:
return []
merged_items = self._merge_results_into_items(node_run_result_list)
return self.evaluation_instance.evaluate_workflow(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
"""Create EvaluationItemInput list from NodeRunResult for snippet evaluation."""
merged = []
for i, item in enumerate(items):
output = _extract_snippet_output(item.outputs)
merged.append(
EvaluationItemInput(
index=i,
inputs=dict(item.inputs),
output=output,
)
)
return merged
def _extract_snippet_output(outputs: Mapping[str, Any]) -> str:
"""Extract the primary output text from snippet NodeRunResult.outputs."""
if "answer" in outputs:
return str(outputs["answer"])
if "text" in outputs:
return str(outputs["text"])
values = list(outputs.values())
return str(values[0]) if values else ""

View File

@ -1,66 +0,0 @@
import logging
from collections.abc import Mapping
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationDatasetInput,
EvaluationItemInput,
EvaluationItemResult,
NodeInfo,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class WorkflowEvaluationRunner(BaseEvaluationRunner):
"""Runner for workflow evaluation: executes workflow App in non-streaming mode."""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
super().__init__(evaluation_instance)
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
dataset_items: list[EvaluationDatasetInput] | None = None,
node_info: NodeInfo | None = None,
) -> list[EvaluationItemResult]:
"""Compute workflow evaluation metrics (end-to-end)."""
if not node_run_result_list:
return []
merged_items = self._merge_results_into_items(node_run_result_list)
return self.evaluation_instance.evaluate_workflow(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
"""Create EvaluationItemInput list from NodeRunResult for workflow evaluation."""
merged = []
for i, item in enumerate(items):
output = _extract_workflow_output(item.outputs)
merged.append(
EvaluationItemInput(
index=i,
inputs=dict(item.inputs),
output=output,
)
)
return merged
def _extract_workflow_output(outputs: Mapping[str, Any]) -> str:
"""Extract the primary output text from workflow NodeRunResult.outputs."""
if "answer" in outputs:
return str(outputs["answer"])
if "text" in outputs:
return str(outputs["text"])
values = list(outputs.values())
return str(values[0]) if values else ""

View File

@ -13,8 +13,6 @@ from core.llm_generator.output_parser.rule_config_generator import RuleConfigGen
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import (
CONVERSATION_TITLE_PROMPT,
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM,
@ -217,8 +215,8 @@ class LLMGenerator:
else:
# Default-model generation keeps the built-in suggested-questions tuning.
model_parameters = {
"max_tokens": DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
"temperature": DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
"max_tokens": 2560,
"temperature": 0.0,
}
stop = []

View File

@ -104,10 +104,6 @@ DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
'["question1","question2","question3"]\n'
)
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS = 256
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE = 0.0
GENERATOR_QA_PROMPT = (
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
" in the long text. Please think step by step."

View File

@ -569,13 +569,13 @@ class OpsTraceManager:
db.session.commit()
@classmethod
def get_app_tracing_config(cls, app_id: str, session: Session):
def get_app_tracing_config(cls, app_id: str):
"""
Get app tracing config
:param app_id: app id
:return:
"""
app: App | None = session.get(App, app_id)
app: App | None = db.session.get(App, app_id)
if not app:
raise ValueError("App not found")
if not app.tracing:

View File

@ -53,27 +53,24 @@ class PromptMessageUtil:
files = []
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
match content:
case TextPromptMessageContent():
text += content.data
case ImagePromptMessageContent():
files.append(
{
"type": "image",
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
"detail": content.detail.value,
}
)
case AudioPromptMessageContent():
files.append(
{
"type": "audio",
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
"format": content.format,
}
)
case _:
continue
if isinstance(content, TextPromptMessageContent):
text += content.data
elif isinstance(content, ImagePromptMessageContent):
files.append(
{
"type": "image",
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
"detail": content.detail.value,
}
)
elif isinstance(content, AudioPromptMessageContent):
files.append(
{
"type": "audio",
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
"format": content.format,
}
)
else:
text = cast(str, prompt_message.content)

View File

@ -195,6 +195,23 @@ class RetrievalService:
)
return all_documents
@classmethod
def _filter_documents_by_vector_score_threshold(
cls, documents: list[Document], score_threshold: float | None
) -> list[Document]:
"""Keep documents whose stored retrieval score meets the threshold.
Used when hybrid search skips early vector thresholding but no rerank
runner applies a threshold afterward (same rule as ``calculate_vector_score``).
"""
if score_threshold is None:
return documents
return [
document
for document in documents
if document.metadata and document.metadata.get("score", 0) >= score_threshold
]
@classmethod
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
"""Deduplicate documents in O(n) while preserving first-seen order.
@ -294,13 +311,20 @@ class RetrievalService:
vector = Vector(dataset=dataset)
documents = []
# Hybrid search merges keyword / full-text / vector hits and then reranks
# (weighted fusion or reranking model). Applying the user score threshold at
# vector retrieval time uses embedding similarity, which is not comparable to
# reranked or fused scores and incorrectly drops high-quality chunks (#35233).
embedding_score_threshold = (
0.0 if retrieval_method == RetrievalMethod.HYBRID_SEARCH else score_threshold
)
if query_type == QueryType.TEXT_QUERY:
documents.extend(
vector.search_by_vector(
query,
search_type="similarity_score_threshold",
top_k=top_k,
score_threshold=score_threshold,
score_threshold=embedding_score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
@ -312,7 +336,7 @@ class RetrievalService:
vector.search_by_file(
file_id=query,
top_k=top_k,
score_threshold=score_threshold,
score_threshold=embedding_score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
@ -845,6 +869,10 @@ class RetrievalService:
top_n=top_k,
query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY,
)
if not data_post_processor.rerank_runner and score_threshold:
all_documents_item = self._filter_documents_by_vector_score_threshold(
all_documents_item, score_threshold
)
all_documents.extend(all_documents_item)

View File

@ -23,37 +23,36 @@ _TOOL_FILE_URL_PATTERN = re.compile(r"(?:^|/+)files/tools/(?P<tool_file_id>[^/?#
def safe_json_value(v):
match v:
case datetime():
tz_name = "UTC"
if isinstance(current_user, Account) and current_user.timezone is not None:
tz_name = current_user.timezone
return v.astimezone(pytz.timezone(tz_name)).isoformat()
case date():
return v.isoformat()
case UUID():
return str(v)
case Decimal():
return float(v)
case bytes():
try:
return v.decode("utf-8")
except UnicodeDecodeError:
return v.hex()
case memoryview():
return v.tobytes().hex()
case np.integer():
return int(v)
case np.floating():
return float(v)
case np.ndarray():
return v.tolist()
case dict():
return safe_json_dict(v)
case list() | tuple() | set():
return [safe_json_value(i) for i in v]
case _:
return v
if isinstance(v, datetime):
tz_name = "UTC"
if isinstance(current_user, Account) and current_user.timezone is not None:
tz_name = current_user.timezone
return v.astimezone(pytz.timezone(tz_name)).isoformat()
elif isinstance(v, date):
return v.isoformat()
elif isinstance(v, UUID):
return str(v)
elif isinstance(v, Decimal):
return float(v)
elif isinstance(v, bytes):
try:
return v.decode("utf-8")
except UnicodeDecodeError:
return v.hex()
elif isinstance(v, memoryview):
return v.tobytes().hex()
elif isinstance(v, np.integer):
return int(v)
elif isinstance(v, np.floating):
return float(v)
elif isinstance(v, np.ndarray):
return v.tolist()
elif isinstance(v, dict):
return safe_json_dict(v)
elif isinstance(v, list | tuple | set):
return [safe_json_value(i) for i in v]
else:
return v
def safe_json_dict(d: dict[str, Any]):

View File

@ -42,12 +42,7 @@ from graphon.model_runtime.entities.llm_entities import (
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.human_input.entities import (
FileInputConfig,
FileListInputConfig,
FormInputConfig,
HumanInputNodeData,
)
from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.nodes.llm.runtime_protocols import (
PreparedLLMProtocol,
PromptMessageSerializerProtocol,
@ -83,6 +78,7 @@ from .system_variables import SystemVariableKey, get_system_text
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage
from graphon.file import File
from graphon.nodes.llm.file_saver import LLMFileSaver
from graphon.nodes.tool.entities import ToolNodeData
@ -633,7 +629,6 @@ class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol):
self._run_context = resolve_dify_run_context(run_context)
self._workflow_execution_id_getter = workflow_execution_id_getter
self._form_repository = form_repository
self._file_reference_factory = DifyFileReferenceFactory(self._run_context)
def _invoke_source(self) -> str:
invoke_from = self._run_context.invoke_from
@ -687,23 +682,6 @@ class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol):
repo = self.build_form_repository()
return repo.get_form(node_id)
def restore_submitted_data(
self,
*,
node_data: HumanInputNodeData,
submitted_data: Mapping[str, Any],
) -> Mapping[str, Any]:
restored_data: dict[str, Any] = dict(submitted_data)
for input_config in node_data.inputs:
output_variable_name = input_config.output_variable_name
if output_variable_name not in submitted_data:
continue
restored_data[output_variable_name] = self._restore_submitted_value(
input_config=input_config,
value=submitted_data[output_variable_name],
)
return restored_data
def create_form(
self,
*,
@ -724,55 +702,6 @@ class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol):
)
return repo.create_form(params)
def _restore_submitted_value(
self,
*,
input_config: FormInputConfig,
value: Any,
) -> Any:
if isinstance(input_config, FileInputConfig):
return self._restore_submitted_file_value(
output_variable_name=input_config.output_variable_name,
value=value,
)
if isinstance(input_config, FileListInputConfig):
return self._restore_submitted_file_list_value(
output_variable_name=input_config.output_variable_name,
value=value,
)
return value
def _restore_submitted_file_value(
self,
*,
output_variable_name: str,
value: Any,
) -> Any:
if not isinstance(value, Mapping):
msg = (
"HumanInput file submission must be persisted as a mapping, "
f"output_variable_name={output_variable_name}"
)
raise ValueError(msg)
return self._file_reference_factory.build_from_mapping(mapping=value)
def _restore_submitted_file_list_value(
self,
*,
output_variable_name: str,
value: Any,
) -> list[Any]:
if not isinstance(value, list):
msg = (
"HumanInput file-list submission must be persisted as a list, "
f"output_variable_name={output_variable_name}"
)
raise ValueError(msg)
if any(not isinstance(item, Mapping) for item in value):
msg = f"HumanInput file-list submission must contain mappings, output_variable_name={output_variable_name}"
raise ValueError(msg)
return [self._file_reference_factory.build_from_mapping(mapping=item) for item in value]
def build_dify_llm_file_saver(
*,

View File

@ -1,20 +0,0 @@
"""Shared snippet virtual Start-node identifiers and compatibility helpers.
Snippet workflows do not persist a real canvas Start node, so the backend
injects one at runtime. Existing workflow references commonly use the public
selector shape ``#start.<var>#``; keep that contract stable by treating the
runtime-only snippet Start node as compatible with the legacy ``start`` id.
"""
from __future__ import annotations
LEGACY_START_NODE_ID = "start"
SNIPPET_VIRTUAL_START_NODE_ID = "__snippet_virtual_start__"
def get_compatible_start_aliases(*, workflow_kind: str | None, root_node_id: str | None) -> tuple[str, ...]:
"""Return additional selector ids that should mirror snippet Start inputs."""
if workflow_kind == "snippet" and root_node_id == SNIPPET_VIRTUAL_START_NODE_ID:
return (LEGACY_START_NODE_ID,)
return ()

View File

@ -10,19 +10,6 @@ def add_variables_to_pool(variable_pool: VariablePool, variables: Sequence[Varia
variable_pool.add(variable.selector, variable)
def add_node_inputs_to_pool(
variable_pool: VariablePool,
*,
node_id: str,
inputs: Mapping[str, Any],
aliases: Sequence[str] = (),
) -> None:
"""Store node inputs under the primary node id and any compatible aliases."""
node_ids: list[str] = [node_id]
for alias in aliases:
if alias not in node_ids:
node_ids.append(alias)
for current_node_id in node_ids:
for key, value in inputs.items():
variable_pool.add((current_node_id, key), value)
def add_node_inputs_to_pool(variable_pool: VariablePool, *, node_id: str, inputs: Mapping[str, Any]) -> None:
for key, value in inputs.items():
variable_pool.add((node_id, key), value)

View File

@ -1,95 +0,0 @@
"""Generate FastOpenAPI OpenAPI 3.0 specs without booting the full backend."""
from __future__ import annotations
import argparse
import json
import logging
import sys
from dataclasses import dataclass
from pathlib import Path
API_ROOT = Path(__file__).resolve().parents[1]
if str(API_ROOT) not in sys.path:
sys.path.insert(0, str(API_ROOT))
from dev.generate_swagger_specs import apply_runtime_defaults, drop_null_values, sort_openapi_arrays
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class FastOpenApiSpecTarget:
route: str
filename: str
FASTOPENAPI_SPEC_TARGETS: tuple[FastOpenApiSpecTarget, ...] = (
FastOpenApiSpecTarget(route="/fastopenapi/openapi.json", filename="fastopenapi-console-openapi.json"),
)
def create_fastopenapi_spec_app():
"""Build a minimal Flask app that only mounts FastOpenAPI docs routes."""
apply_runtime_defaults()
from app_factory import create_flask_app_with_configs
from extensions import ext_fastopenapi
app = create_flask_app_with_configs()
ext_fastopenapi.init_app(app)
return app
def generate_fastopenapi_specs(output_dir: Path) -> list[Path]:
"""Write FastOpenAPI specs to `output_dir` and return the written paths."""
output_dir.mkdir(parents=True, exist_ok=True)
app = create_fastopenapi_spec_app()
client = app.test_client()
written_paths: list[Path] = []
for target in FASTOPENAPI_SPEC_TARGETS:
response = client.get(target.route)
if response.status_code != 200:
raise RuntimeError(f"failed to fetch {target.route}: {response.status_code}")
payload = response.get_json()
if not isinstance(payload, dict):
raise RuntimeError(f"unexpected response payload for {target.route}")
payload = drop_null_values(payload)
payload = sort_openapi_arrays(payload)
output_path = output_dir / target.filename
output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
written_paths.append(output_path)
return written_paths
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"-o",
"--output-dir",
type=Path,
default=Path("openapi"),
help="Directory where the OpenAPI JSON files will be written.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
written_paths = generate_fastopenapi_specs(args.output_dir)
for path in written_paths:
logger.debug(path)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -1,161 +0,0 @@
"""Generate OpenAPI JSON specs and split Markdown API docs.
The Markdown step uses `swagger-markdown`, the same converter family as the
Swagger Markdown UI, so CI and local regeneration catch converter-incompatible
OpenAPI output early.
"""
from __future__ import annotations
import argparse
import logging
import subprocess
import sys
import tempfile
from pathlib import Path
API_ROOT = Path(__file__).resolve().parents[1]
if str(API_ROOT) not in sys.path:
sys.path.insert(0, str(API_ROOT))
from dev.generate_fastopenapi_specs import FASTOPENAPI_SPEC_TARGETS, generate_fastopenapi_specs
from dev.generate_swagger_specs import SPEC_TARGETS, generate_specs
logger = logging.getLogger(__name__)
SWAGGER_MARKDOWN_PACKAGE = "swagger-markdown@3.0.0"
CONSOLE_SWAGGER_FILENAME = "console-swagger.json"
STALE_COMBINED_MARKDOWN_FILENAME = "api-reference.md"
def _convert_spec_to_markdown(spec_path: Path, markdown_path: Path) -> None:
subprocess.run(
[
"npx",
"--yes",
SWAGGER_MARKDOWN_PACKAGE,
"-i",
str(spec_path),
"-o",
str(markdown_path),
],
check=True,
)
def _demote_markdown_headings(markdown: str, *, levels: int = 1) -> str:
"""Nest generated Markdown under another Markdown section."""
heading_prefix = "#" * levels
lines = []
for line in markdown.splitlines():
if line.startswith("#"):
lines.append(f"{heading_prefix}{line}")
else:
lines.append(line)
return "\n".join(lines).strip()
def _append_fastopenapi_markdown(console_markdown_path: Path, fastopenapi_markdown_path: Path) -> None:
"""Append FastOpenAPI console docs to the existing console API Markdown."""
console_markdown = console_markdown_path.read_text(encoding="utf-8").rstrip()
fastopenapi_markdown = _demote_markdown_headings(
fastopenapi_markdown_path.read_text(encoding="utf-8"),
levels=2,
)
console_markdown_path.write_text(
"\n\n".join(
[
console_markdown,
"## FastOpenAPI Preview (OpenAPI 3.0)",
fastopenapi_markdown,
]
)
+ "\n",
encoding="utf-8",
)
def generate_markdown_docs(
swagger_dir: Path,
markdown_dir: Path,
*,
keep_swagger_json: bool = False,
) -> list[Path]:
"""Generate intermediate specs, convert them to split Markdown API docs, and return Markdown paths."""
swagger_paths = generate_specs(swagger_dir)
fastopenapi_paths = generate_fastopenapi_specs(swagger_dir)
spec_paths = [*swagger_paths, *fastopenapi_paths]
swagger_paths_by_name = {path.name: path for path in swagger_paths}
fastopenapi_paths_by_name = {path.name: path for path in fastopenapi_paths}
markdown_dir.mkdir(parents=True, exist_ok=True)
written_paths: list[Path] = []
try:
with tempfile.TemporaryDirectory(prefix="dify-api-docs-") as temp_dir:
temp_markdown_dir = Path(temp_dir)
for target in SPEC_TARGETS:
swagger_path = swagger_paths_by_name[target.filename]
markdown_path = markdown_dir / f"{swagger_path.stem}.md"
_convert_spec_to_markdown(swagger_path, markdown_path)
written_paths.append(markdown_path)
for target in FASTOPENAPI_SPEC_TARGETS: # type: ignore
fastopenapi_path = fastopenapi_paths_by_name[target.filename]
markdown_path = temp_markdown_dir / f"{fastopenapi_path.stem}.md"
_convert_spec_to_markdown(fastopenapi_path, markdown_path)
console_markdown_path = markdown_dir / f"{Path(CONSOLE_SWAGGER_FILENAME).stem}.md"
_append_fastopenapi_markdown(console_markdown_path, markdown_path)
(markdown_dir / STALE_COMBINED_MARKDOWN_FILENAME).unlink(missing_ok=True)
finally:
if not keep_swagger_json:
for path in spec_paths:
path.unlink(missing_ok=True)
return written_paths
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--swagger-dir",
type=Path,
default=Path("openapi"),
help="Directory where intermediate JSON spec files will be written.",
)
parser.add_argument(
"--markdown-dir",
type=Path,
default=Path("openapi/markdown"),
help="Directory where split Markdown API docs will be written.",
)
parser.add_argument(
"--keep-swagger-json",
action="store_true",
help="Keep intermediate JSON spec files after Markdown generation.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
written_paths = generate_markdown_docs(
args.swagger_dir,
args.markdown_dir,
keep_swagger_json=args.keep_swagger_json,
)
for path in written_paths:
logger.debug(path)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -9,15 +9,12 @@ which is unnecessary when the goal is only to serialize the Flask-RESTX
from __future__ import annotations
import argparse
import hashlib
import json
import logging
import os
import sys
from collections.abc import MutableMapping
from dataclasses import dataclass
from pathlib import Path
from typing import Protocol, TypeGuard
from flask import Flask
from flask_restx.swagger import Swagger
@ -33,110 +30,19 @@ if str(API_ROOT) not in sys.path:
class SpecTarget:
route: str
filename: str
namespace: str
class RestxApi(Protocol):
models: MutableMapping[str, object]
def model(self, name: str, model: dict[object, object]) -> object: ...
SPEC_TARGETS: tuple[SpecTarget, ...] = (
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json", namespace="console"),
SpecTarget(route="/api/swagger.json", filename="web-swagger.json", namespace="web"),
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json", namespace="service"),
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json"),
SpecTarget(route="/api/swagger.json", filename="web-swagger.json"),
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json"),
)
_ORIGINAL_REGISTER_MODEL = Swagger.register_model
_ORIGINAL_REGISTER_FIELD = Swagger.register_field
def _is_inline_field_map(value: object) -> TypeGuard[dict[object, object]]:
"""Return whether a nested field map is an anonymous inline mapping."""
from flask_restx.model import Model, OrderedModel
return isinstance(value, dict) and not isinstance(value, (Model, OrderedModel))
def _jsonable_schema_value(value: object) -> object:
"""Return a deterministic JSON-serializable representation for schema fingerprints."""
if value is None or isinstance(value, str | int | float | bool):
return value
if isinstance(value, list | tuple):
return [_jsonable_schema_value(item) for item in value]
if isinstance(value, dict):
return {str(key): _jsonable_schema_value(item) for key, item in value.items()}
value_type = type(value)
return f"<{value_type.__module__}.{value_type.__qualname__}>"
def _field_signature(field: object) -> object:
"""Build a stable signature for a Flask-RESTX field object."""
from flask_restx import fields
from flask_restx.model import instance
field_instance = instance(field)
signature: dict[str, object] = {
"class": f"{field_instance.__class__.__module__}.{field_instance.__class__.__qualname__}"
}
if isinstance(field_instance, fields.Nested):
nested = getattr(field_instance, "nested", None)
if _is_inline_field_map(nested):
signature["nested"] = _inline_model_signature(nested)
else:
signature["nested"] = getattr(
nested,
"name",
f"<{type(nested).__module__}.{type(nested).__qualname__}>",
)
elif hasattr(field_instance, "container"):
signature["container"] = _field_signature(field_instance.container)
else:
schema = getattr(field_instance, "__schema__", None)
if isinstance(schema, dict):
signature["schema"] = _jsonable_schema_value(schema)
for attr_name in (
"attribute",
"default",
"description",
"example",
"max",
"min",
"nullable",
"readonly",
"required",
"title",
):
if hasattr(field_instance, attr_name):
signature[attr_name] = _jsonable_schema_value(getattr(field_instance, attr_name))
return signature
def _inline_model_signature(nested_fields: dict[object, object]) -> object:
"""Build a stable signature for an anonymous inline model."""
return [
(str(field_name), _field_signature(field))
for field_name, field in sorted(nested_fields.items(), key=lambda item: str(item[0]))
]
def _inline_model_name(nested_fields: dict[object, object]) -> str:
"""Return a stable Swagger model name for an anonymous inline field map."""
signature = json.dumps(_inline_model_signature(nested_fields), sort_keys=True, separators=(",", ":"))
digest = hashlib.sha1(signature.encode("utf-8")).hexdigest()[:12]
return f"_AnonymousInlineModel_{digest}"
def apply_runtime_defaults() -> None:
def _apply_runtime_defaults() -> None:
"""Force the small config surface required for Swagger generation."""
os.environ.setdefault("SECRET_KEY", "spec-export")
@ -168,26 +74,25 @@ def _patch_swagger_for_inline_nested_dicts() -> None:
anonymous_models = getattr(self, "_anonymous_inline_models", None)
if anonymous_models is None:
anonymous_models = {}
self.__dict__["_anonymous_inline_models"] = anonymous_models
self._anonymous_inline_models = anonymous_models
anonymous_name = anonymous_models.get(id(nested_fields))
if anonymous_name is None:
anonymous_name = _inline_model_name(nested_fields)
anonymous_name = f"_AnonymousInlineModel{len(anonymous_models) + 1}"
anonymous_models[id(nested_fields)] = anonymous_name
if anonymous_name not in self.api.models:
self.api.model(anonymous_name, nested_fields)
self.api.model(anonymous_name, nested_fields)
return self.api.models[anonymous_name]
def register_model_with_inline_dict_support(self: Swagger, model: object) -> dict[str, str]:
if _is_inline_field_map(model):
if isinstance(model, dict):
model = get_or_create_inline_model(self, model)
return _ORIGINAL_REGISTER_MODEL(self, model)
def register_field_with_inline_dict_support(self: Swagger, field: object) -> None:
nested = getattr(field, "nested", None)
if _is_inline_field_map(nested):
if isinstance(nested, dict):
field.model = get_or_create_inline_model(self, nested) # type: ignore
_ORIGINAL_REGISTER_FIELD(self, field)
@ -200,169 +105,22 @@ def _patch_swagger_for_inline_nested_dicts() -> None:
def create_spec_app() -> Flask:
"""Build a minimal Flask app that only mounts the Swagger-producing blueprints."""
apply_runtime_defaults()
_apply_runtime_defaults()
_patch_swagger_for_inline_nested_dicts()
app = Flask(__name__)
from controllers.console import bp as console_bp
from controllers.console import console_ns
from controllers.service_api import bp as service_api_bp
from controllers.service_api import service_api_ns
from controllers.web import bp as web_bp
from controllers.web import web_ns
app.register_blueprint(console_bp)
app.register_blueprint(web_bp)
app.register_blueprint(service_api_bp)
for namespace in (console_ns, web_ns, service_api_ns):
for api in namespace.apis:
_materialize_inline_model_definitions(api)
return app
def _registered_models(namespace: str) -> dict[str, object]:
"""Return the Flask-RESTX models registered for a Swagger namespace."""
if namespace == "console":
from controllers.console import console_ns
models = dict(console_ns.models)
for api in console_ns.apis:
models.update(api.models)
return models
if namespace == "web":
from controllers.web import web_ns
models = dict(web_ns.models)
for api in web_ns.apis:
models.update(api.models)
return models
if namespace == "service":
from controllers.service_api import service_api_ns
models = dict(service_api_ns.models)
for api in service_api_ns.apis:
models.update(api.models)
return models
raise ValueError(f"unknown Swagger namespace: {namespace}")
def _materialize_inline_model_definitions(api: RestxApi) -> None:
"""Convert inline `fields.Nested({...})` maps into named API models."""
from flask_restx import fields
from flask_restx.model import Model, OrderedModel, instance
inline_models: dict[int, dict[object, object]] = {}
inline_model_names: dict[int, str] = {}
def collect_field(field: object) -> None:
field_instance = instance(field)
if isinstance(field_instance, fields.Nested):
nested = getattr(field_instance, "nested", None)
if _is_inline_field_map(nested) and id(nested) not in inline_models:
inline_models[id(nested)] = nested
for nested_field in nested.values():
collect_field(nested_field)
container = getattr(field_instance, "container", None)
if container is not None:
collect_field(container)
for model in list(api.models.values()):
if isinstance(model, (Model, OrderedModel)):
for field in model.values():
collect_field(field)
for nested_fields in sorted(inline_models.values(), key=_inline_model_name):
anonymous_name = _inline_model_name(nested_fields)
inline_model_names[id(nested_fields)] = anonymous_name
if anonymous_name not in api.models:
api.model(anonymous_name, nested_fields)
def model_name_for(nested_fields: dict[object, object]) -> str:
anonymous_name = inline_model_names.get(id(nested_fields))
if anonymous_name is None:
anonymous_name = _inline_model_name(nested_fields)
inline_model_names[id(nested_fields)] = anonymous_name
if anonymous_name not in api.models:
api.model(anonymous_name, nested_fields)
return anonymous_name
def materialize_field(field: object) -> None:
field_instance = instance(field)
if isinstance(field_instance, fields.Nested):
nested = getattr(field_instance, "nested", None)
if _is_inline_field_map(nested):
field_instance.model = api.models[model_name_for(nested)] # type: ignore[attr-defined]
container = getattr(field_instance, "container", None)
if container is not None:
materialize_field(container)
index = 0
while index < len(api.models):
model = list(api.models.values())[index]
index += 1
if isinstance(model, (Model, OrderedModel)):
for field in model.values():
materialize_field(field)
def drop_null_values(value: object) -> object:
"""Remove JSON null values that make the Markdown converter crash."""
if isinstance(value, dict):
return {key: drop_null_values(item) for key, item in value.items() if item is not None}
if isinstance(value, list):
return [drop_null_values(item) for item in value]
return value
def sort_openapi_arrays(value: object, *, parent_key: str | None = None) -> object:
"""Sort order-insensitive Swagger arrays so generated Markdown is stable."""
if isinstance(value, dict):
return {key: sort_openapi_arrays(item, parent_key=key) for key, item in value.items()}
if not isinstance(value, list):
return value
sorted_items = [sort_openapi_arrays(item, parent_key=parent_key) for item in value]
if parent_key == "parameters":
return sorted(
sorted_items,
key=lambda item: (
item.get("in", "") if isinstance(item, dict) else "",
item.get("name", "") if isinstance(item, dict) else "",
json.dumps(item, sort_keys=True, default=str),
),
)
if parent_key in {"enum", "required", "schemes", "tags"}:
string_items = [item for item in sorted_items if isinstance(item, str)]
if len(string_items) == len(sorted_items):
return sorted(string_items)
return sorted_items
def _merge_registered_definitions(payload: dict[str, object], namespace: str) -> dict[str, object]:
"""Include registered but route-indirect models in the exported Swagger definitions."""
definitions = payload.setdefault("definitions", {})
if not isinstance(definitions, dict):
raise RuntimeError("unexpected Swagger definitions payload")
for name, model in _registered_models(namespace).items():
schema = getattr(model, "__schema__", None)
if isinstance(schema, dict):
definitions.setdefault(name, schema)
return payload
def generate_specs(output_dir: Path) -> list[Path]:
"""Write all Swagger specs to `output_dir` and return the written paths."""
@ -380,9 +138,6 @@ def generate_specs(output_dir: Path) -> list[Path]:
payload = response.get_json()
if not isinstance(payload, dict):
raise RuntimeError(f"unexpected response payload for {target.route}")
payload = _merge_registered_definitions(payload, target.namespace)
payload = drop_null_values(payload)
payload = sort_openapi_arrays(payload)
output_path = output_dir / target.filename
output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")

View File

@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
if [[ -z "${CELERY_QUEUES}" ]]; then
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_publisher,trigger_refresh_executor,retention,workflow_based_app_execution,evaluation"
DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_publisher,trigger_refresh_executor,retention,workflow_based_app_execution"
else
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_publisher,trigger_refresh_executor,retention,workflow_based_app_execution,evaluation"
DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_publisher,trigger_refresh_executor,retention,workflow_based_app_execution"
fi
else
DEFAULT_QUEUES="${CELERY_QUEUES}"

View File

@ -6,9 +6,9 @@ import click
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from events.document_index_event import document_index_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Document
from models.enums import IndexingStatus
@ -22,24 +22,25 @@ def handle(sender, **kwargs):
document_ids = kwargs.get("document_ids", [])
documents = []
start_at = time.perf_counter()
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
with session_factory.create_session() as session:
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = db.session.scalar(
select(Document).where(
Document.id == document_id,
Document.dataset_id == dataset_id,
document = session.scalar(
select(Document).where(
Document.id == document_id,
Document.dataset_id == dataset_id,
)
)
)
if not document:
raise NotFound("Document not found")
if not document:
raise NotFound("Document not found")
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
documents.append(document)
session.add(document)
session.commit()
with contextlib.suppress(Exception):
try:

View File

@ -1,5 +1,5 @@
from core.db.session_factory import session_factory
from events.app_event import app_was_created
from extensions.ext_database import db
from models.enums import CustomizeTokenStrategy
from models.model import Site
@ -22,6 +22,6 @@ def handle(sender, **kwargs):
created_by=app.created_by,
updated_by=app.updated_by,
)
db.session.add(site)
db.session.commit()
with session_factory.create_session() as session:
session.add(site)
session.commit()

View File

@ -148,7 +148,6 @@ def init_app(app: DifyApp) -> Celery:
"tasks.trigger_processing_tasks", # async trigger processing
"tasks.generate_summary_index_task", # summary index generation
"tasks.regenerate_summary_index_task", # summary index regeneration
"tasks.evaluation_task", # evaluation run execution
]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME

View File

@ -28,6 +28,7 @@ def init_app(app: DifyApp):
reset_encrypt_key_pair,
reset_password,
restore_workflow_runs,
sample_vector_space_usage,
setup_datasource_oauth_client,
setup_system_tool_oauth_client,
setup_system_trigger_oauth_client,
@ -68,6 +69,7 @@ def init_app(app: DifyApp):
clean_workflow_runs,
clean_expired_messages,
export_app_messages,
sample_vector_space_usage,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import mimetypes
import uuid
from collections.abc import Mapping, Sequence
from typing import Any, Literal, NotRequired, TypedDict, assert_never, cast
from typing import Any
from sqlalchemy import select
@ -19,58 +19,10 @@ from .common import resolve_mapping_file_id
from .remote import get_remote_file_info
from .validation import is_file_valid_with_config
type FileTypeValue = FileType | Literal["image", "document", "audio", "video", "custom"]
type _LocalFileTransferMethod = Literal["local_file", FileTransferMethod.LOCAL_FILE]
type _RemoteUrlTransferMethod = Literal["remote_url", FileTransferMethod.REMOTE_URL]
type _ToolFileTransferMethod = Literal["tool_file", FileTransferMethod.TOOL_FILE]
type _DatasourceFileTransferMethod = Literal["datasource_file", FileTransferMethod.DATASOURCE_FILE]
class LocalFileMapping(TypedDict):
transfer_method: _LocalFileTransferMethod
id: NotRequired[str | None] # Read as the graph-layer File.file_id.
type: NotRequired[FileTypeValue | None] # Read for type override and upload config validation.
upload_file_id: NotRequired[str | None] # File id lookup priority 1.
reference: NotRequired[str | None] # File id lookup priority 2; may be an opaque file reference.
related_id: NotRequired[str | None] # File id lookup priority 3; legacy persisted field.
class RemoteUrlMapping(TypedDict):
transfer_method: _RemoteUrlTransferMethod
id: NotRequired[str | None] # Read as the graph-layer File.file_id.
type: NotRequired[FileTypeValue | None] # Read for type override and upload config validation.
upload_file_id: NotRequired[str | None] # Persisted UploadFile lookup priority 1.
reference: NotRequired[str | None] # Persisted UploadFile lookup priority 2; may be an opaque file reference.
related_id: NotRequired[str | None] # Persisted UploadFile lookup priority 3; legacy persisted field.
url: NotRequired[str | None] # External URL lookup priority 1 when no UploadFile id is resolved.
remote_url: NotRequired[str | None] # External URL lookup priority 2 when no UploadFile id is resolved.
class ToolFileMapping(TypedDict):
transfer_method: _ToolFileTransferMethod
id: NotRequired[str | None] # Read as the graph-layer File.file_id.
type: NotRequired[FileTypeValue | None] # Read for type override and upload config validation.
tool_file_id: NotRequired[str | None] # ToolFile lookup priority 1.
reference: NotRequired[str | None] # ToolFile lookup priority 2; may be an opaque file reference.
related_id: NotRequired[str | None] # ToolFile lookup priority 3; legacy persisted field.
class DatasourceFileMapping(TypedDict):
transfer_method: _DatasourceFileTransferMethod
type: NotRequired[FileTypeValue | None] # Read for type override and upload config validation.
datasource_file_id: NotRequired[str | None] # UploadFile lookup priority 1 for datasource-backed files.
reference: NotRequired[str | None] # UploadFile lookup priority 2; may be an opaque file reference.
related_id: NotRequired[str | None] # UploadFile lookup priority 3; legacy persisted field.
type FileMapping = LocalFileMapping | RemoteUrlMapping | ToolFileMapping | DatasourceFileMapping
type FileMappingInput = FileMapping | Mapping[str, Any]
def build_from_mapping(
*,
mapping: FileMappingInput,
mapping: Mapping[str, Any],
tenant_id: str,
config: FileUploadConfig | None = None,
strict_type_validation: bool = False,
@ -80,45 +32,18 @@ def build_from_mapping(
if not transfer_method_value:
raise ValueError("transfer_method is required in file mapping")
transfer_method = FileTransferMethod.value_of(str(transfer_method_value))
match transfer_method:
case FileTransferMethod.LOCAL_FILE:
file = _build_from_local_file(
mapping=cast(LocalFileMapping, mapping),
tenant_id=tenant_id,
transfer_method=transfer_method,
strict_type_validation=strict_type_validation,
access_controller=access_controller,
)
case FileTransferMethod.REMOTE_URL:
file = _build_from_remote_url(
mapping=cast(RemoteUrlMapping, mapping),
tenant_id=tenant_id,
transfer_method=transfer_method,
strict_type_validation=strict_type_validation,
access_controller=access_controller,
)
case FileTransferMethod.TOOL_FILE:
file = _build_from_tool_file(
mapping=cast(ToolFileMapping, mapping),
tenant_id=tenant_id,
transfer_method=transfer_method,
strict_type_validation=strict_type_validation,
access_controller=access_controller,
)
case FileTransferMethod.DATASOURCE_FILE:
file = _build_from_datasource_file(
mapping=cast(DatasourceFileMapping, mapping),
tenant_id=tenant_id,
transfer_method=transfer_method,
strict_type_validation=strict_type_validation,
access_controller=access_controller,
)
case _:
assert_never(transfer_method)
transfer_method = FileTransferMethod.value_of(transfer_method_value)
build_func = _get_build_function(transfer_method)
file = build_func(
mapping=mapping,
tenant_id=tenant_id,
transfer_method=transfer_method,
strict_type_validation=strict_type_validation,
access_controller=access_controller,
)
if config and not is_file_valid_with_config(
input_file_type=mapping.get("type") or FileType.CUSTOM,
input_file_type=mapping.get("type", FileType.CUSTOM),
file_extension=file.extension or "",
file_transfer_method=file.transfer_method,
config=config,
@ -162,6 +87,19 @@ def build_from_mappings(
return files
def _get_build_function(transfer_method: FileTransferMethod):
build_functions = {
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
FileTransferMethod.REMOTE_URL: _build_from_remote_url,
FileTransferMethod.TOOL_FILE: _build_from_tool_file,
FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file,
}
build_func = build_functions.get(transfer_method)
if build_func is None:
raise ValueError(f"Invalid file transfer method: {transfer_method}")
return build_func
def _resolve_file_type(
*,
detected_file_type: FileType,
@ -178,7 +116,7 @@ def _resolve_file_type(
def _build_from_local_file(
*,
mapping: LocalFileMapping,
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
@ -225,7 +163,7 @@ def _build_from_local_file(
def _build_from_remote_url(
*,
mapping: RemoteUrlMapping,
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
@ -297,7 +235,7 @@ def _build_from_remote_url(
def _build_from_tool_file(
*,
mapping: ToolFileMapping,
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
@ -340,7 +278,7 @@ def _build_from_tool_file(
def _build_from_datasource_file(
*,
mapping: DatasourceFileMapping,
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,

View File

@ -80,7 +80,6 @@ app_detail_fields = {
"updated_at": TimestampField,
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_fields)),
"permission_keys": fields.List(fields.String),
}
prompt_config_fields = {
@ -118,7 +117,6 @@ app_partial_fields = {
"create_user_name": fields.String,
"author_name": fields.String,
"has_draft_trigger": fields.Boolean,
"permission_keys": fields.List(fields.String),
}
@ -199,7 +197,6 @@ app_detail_fields_with_site = {
"deleted_tools": fields.List(fields.Nested(deleted_tool_fields)),
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_fields)),
"permission_keys": fields.List(fields.String),
"site": fields.Nested(site_fields),
}

View File

@ -11,7 +11,6 @@ dataset_fields = {
"indexing_technique": fields.String,
"created_by": fields.String,
"created_at": TimestampField,
"permission_keys": fields.List(fields.String),
}
reranking_model_fields = {"reranking_provider_name": fields.String, "reranking_model_name": fields.String}
@ -108,7 +107,6 @@ dataset_detail_fields = {
"total_available_documents": fields.Integer,
"enable_api": fields.Boolean,
"is_multimodal": fields.Boolean,
"permission_keys": fields.List(fields.String),
}
file_info_fields = {

View File

@ -1,47 +0,0 @@
from flask_restx import fields
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
# Snippet list item fields (lightweight for list display)
snippet_list_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"version": fields.Integer,
"use_count": fields.Integer,
"is_published": fields.Boolean,
"icon_info": fields.Raw,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
}
# Full snippet fields (includes creator info and graph data)
snippet_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"version": fields.Integer,
"use_count": fields.Integer,
"is_published": fields.Boolean,
"icon_info": fields.Raw,
"graph": fields.Raw(attribute="graph_dict"),
"input_fields": fields.Raw(attribute="input_fields_list"),
"created_by": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_at": TimestampField,
"updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),
"updated_at": TimestampField,
}
# Pagination response fields
snippet_pagination_fields = {
"data": fields.List(fields.Nested(snippet_list_fields)),
"page": fields.Integer,
"limit": fields.Integer,
"total": fields.Integer,
"has_more": fields.Boolean,
}

View File

@ -4,7 +4,7 @@ from datetime import datetime
from typing import Any
from flask_restx import Namespace, fields
from pydantic import Field, field_validator
from pydantic import field_validator
from fields.base import ResponseModel
from fields.end_user_fields import SimpleEndUser, simple_end_user_fields
@ -23,7 +23,6 @@ workflow_app_log_partial_fields = {
"id": fields.String,
"workflow_run": fields.Nested(workflow_run_for_log_fields, attribute="workflow_run", allow_null=True),
"details": fields.Raw(attribute="details"),
"evaluation": fields.Raw(attribute="evaluation", default=None),
"created_from": fields.String,
"created_by_role": fields.String,
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
@ -103,28 +102,10 @@ def _to_timestamp(value: datetime | int | None) -> int | None:
return value
class WorkflowAppLogEvaluationNodeInfoResponse(ResponseModel):
node_id: str
type: str
title: str
class WorkflowAppLogEvaluationItemResponse(ResponseModel):
name: str
value: Any = None
details: dict[str, Any] | None = None
node_info: WorkflowAppLogEvaluationNodeInfoResponse | None = Field(
default=None,
validation_alias="node_info",
serialization_alias="nodeInfo",
)
class WorkflowAppLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForLogResponse | None = None
details: Any = None
evaluation: list[WorkflowAppLogEvaluationItemResponse] = Field(default_factory=list)
created_from: str | None = None
created_by_role: str | None = None
created_by_account: SimpleAccount | None = None
@ -136,11 +117,6 @@ class WorkflowAppLogPartialResponse(ResponseModel):
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
@field_validator("evaluation", mode="before")
@classmethod
def _normalize_evaluation(cls, value: Any) -> list[dict[str, Any]] | list[WorkflowAppLogEvaluationItemResponse]:
return value or []
class WorkflowArchivedLogPartialResponse(ResponseModel):
id: str

View File

@ -68,7 +68,6 @@ pipeline_variable_fields = {
workflow_fields = {
"id": fields.String,
"kind": fields.String(attribute="kind_or_standard"),
"graph": fields.Raw(attribute="graph_dict"),
"features": fields.Raw(attribute="features_dict"),
"hash": fields.String(attribute="unique_hash"),

View File

@ -1,83 +0,0 @@
"""add_customized_snippets_table
Revision ID: 1c05e80d2380
Revises: 788d3099ae3a
Create Date: 2026-01-29 12:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
import models as models
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = "1c05e80d2380"
down_revision = "788d3099ae3a"
branch_labels = None
depends_on = None
def upgrade():
conn = op.get_bind()
if _is_pg(conn):
op.create_table(
"customized_snippets",
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
sa.Column("icon_info", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("graph", sa.Text(), nullable=True),
sa.Column("input_fields", sa.Text(), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
)
else:
op.create_table(
"customized_snippets",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("description", models.types.LongText(), nullable=True),
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
sa.Column("icon_info", models.types.AdjustedJSON(astext_type=sa.Text()), nullable=True),
sa.Column("graph", models.types.LongText(), nullable=True),
sa.Column("input_fields", models.types.LongText(), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
)
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
batch_op.create_index("customized_snippet_tenant_idx", ["tenant_id"], unique=False)
def downgrade():
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
batch_op.drop_index("customized_snippet_tenant_idx")
op.drop_table("customized_snippets")

View File

@ -1,116 +0,0 @@
"""add_evaluation_tables
Revision ID: a1b2c3d4e5f6
Revises: 1c05e80d2380
Create Date: 2026-03-03 00:01:00.000000
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = "a1b2c3d4e5f6"
down_revision = "1c05e80d2380"
branch_labels = None
depends_on = None
def upgrade():
# evaluation_configurations
op.create_table(
"evaluation_configurations",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("target_type", sa.String(length=20), nullable=False),
sa.Column("target_id", models.types.StringUUID(), nullable=False),
sa.Column("evaluation_model_provider", sa.String(length=255), nullable=True),
sa.Column("evaluation_model", sa.String(length=255), nullable=True),
sa.Column("metrics_config", models.types.LongText(), nullable=True),
sa.Column("judgement_conditions", models.types.LongText(), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=False),
sa.Column("updated_by", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="evaluation_configuration_pkey"),
sa.UniqueConstraint("tenant_id", "target_type", "target_id", name="evaluation_configuration_unique"),
)
with op.batch_alter_table("evaluation_configurations", schema=None) as batch_op:
batch_op.create_index(
"evaluation_configuration_target_idx", ["tenant_id", "target_type", "target_id"], unique=False
)
# evaluation_runs
op.create_table(
"evaluation_runs",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("target_type", sa.String(length=20), nullable=False),
sa.Column("target_id", models.types.StringUUID(), nullable=False),
sa.Column("evaluation_config_id", models.types.StringUUID(), nullable=False),
sa.Column("status", sa.String(length=20), nullable=False, server_default=sa.text("'pending'")),
sa.Column("dataset_file_id", models.types.StringUUID(), nullable=True),
sa.Column("result_file_id", models.types.StringUUID(), nullable=True),
sa.Column("total_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("completed_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("failed_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("metrics_summary", models.types.LongText(), nullable=True),
sa.Column("error", sa.Text(), nullable=True),
sa.Column("celery_task_id", sa.String(length=255), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=False),
sa.Column("started_at", sa.DateTime(), nullable=True),
sa.Column("completed_at", sa.DateTime(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="evaluation_run_pkey"),
)
with op.batch_alter_table("evaluation_runs", schema=None) as batch_op:
batch_op.create_index(
"evaluation_run_target_idx", ["tenant_id", "target_type", "target_id"], unique=False
)
batch_op.create_index("evaluation_run_status_idx", ["tenant_id", "status"], unique=False)
# evaluation_run_items
op.create_table(
"evaluation_run_items",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("evaluation_run_id", models.types.StringUUID(), nullable=False),
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=True),
sa.Column("item_index", sa.Integer(), nullable=False),
sa.Column("inputs", models.types.LongText(), nullable=True),
sa.Column("expected_output", models.types.LongText(), nullable=True),
sa.Column("context", models.types.LongText(), nullable=True),
sa.Column("actual_output", models.types.LongText(), nullable=True),
sa.Column("metrics", models.types.LongText(), nullable=True),
sa.Column("metadata_json", models.types.LongText(), nullable=True),
sa.Column("error", sa.Text(), nullable=True),
sa.Column("overall_score", sa.Float(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="evaluation_run_item_pkey"),
)
with op.batch_alter_table("evaluation_run_items", schema=None) as batch_op:
batch_op.create_index("evaluation_run_item_run_idx", ["evaluation_run_id"], unique=False)
batch_op.create_index(
"evaluation_run_item_index_idx", ["evaluation_run_id", "item_index"], unique=False
)
batch_op.create_index("evaluation_run_item_workflow_run_idx", ["workflow_run_id"], unique=False)
def downgrade():
with op.batch_alter_table("evaluation_run_items", schema=None) as batch_op:
batch_op.drop_index("evaluation_run_item_workflow_run_idx")
batch_op.drop_index("evaluation_run_item_index_idx")
batch_op.drop_index("evaluation_run_item_run_idx")
op.drop_table("evaluation_run_items")
with op.batch_alter_table("evaluation_runs", schema=None) as batch_op:
batch_op.drop_index("evaluation_run_status_idx")
batch_op.drop_index("evaluation_run_target_idx")
op.drop_table("evaluation_runs")
with op.batch_alter_table("evaluation_configurations", schema=None) as batch_op:
batch_op.drop_index("evaluation_configuration_target_idx")
op.drop_table("evaluation_configurations")

View File

@ -1,25 +0,0 @@
"""merge migration heads
Revision ID: 4c60d8d3ee74
Revises: fce013ca180e, a1b2c3d4e5f6
Create Date: 2026-03-17 17:21:12.105536
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '4c60d8d3ee74'
down_revision = ('fce013ca180e', 'a1b2c3d4e5f6')
branch_labels = None
depends_on = None
def upgrade():
pass
def downgrade():
pass

View File

@ -1,64 +0,0 @@
"""Add human input upload token and file association tables
Revision ID: 8d4c2a1b9f03
Revises: 227822d22895
Create Date: 2026-05-06 12:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
import models
# revision identifiers, used by Alembic.
revision = "8d4c2a1b9f03"
down_revision = "227822d22895"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"human_input_form_upload_tokens",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("app_id", models.types.StringUUID(), nullable=False),
sa.Column("form_id", models.types.StringUUID(), nullable=False),
sa.Column("recipient_id", models.types.StringUUID(), nullable=False),
sa.Column("token", sa.String(length=255), nullable=False),
sa.PrimaryKeyConstraint("id", name="human_input_form_upload_tokens_pkey"),
sa.UniqueConstraint("token", name="human_input_form_upload_tokens_token_key"),
)
with op.batch_alter_table("human_input_form_upload_tokens", schema=None) as batch_op:
batch_op.create_index("human_input_form_upload_tokens_form_id_idx", ["form_id"], unique=False)
op.create_table(
"human_input_form_upload_files",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("app_id", models.types.StringUUID(), nullable=False),
sa.Column("form_id", models.types.StringUUID(), nullable=False),
sa.Column("upload_file_id", models.types.StringUUID(), nullable=False),
sa.Column("upload_token_id", models.types.StringUUID(), nullable=False),
sa.PrimaryKeyConstraint("id", name="human_input_form_upload_files_pkey"),
sa.UniqueConstraint("upload_file_id", name="human_input_form_upload_files_upload_file_id_key"),
)
with op.batch_alter_table("human_input_form_upload_files", schema=None) as batch_op:
batch_op.create_index("human_input_form_upload_files_form_id_idx", ["form_id"], unique=False)
batch_op.create_index("human_input_form_upload_files_upload_token_id_idx", ["upload_token_id"], unique=False)
def downgrade():
with op.batch_alter_table("human_input_form_upload_files", schema=None) as batch_op:
batch_op.drop_index("human_input_form_upload_files_upload_token_id_idx")
batch_op.drop_index("human_input_form_upload_files_form_id_idx")
op.drop_table("human_input_form_upload_files")
with op.batch_alter_table("human_input_form_upload_tokens", schema=None) as batch_op:
batch_op.drop_index("human_input_form_upload_tokens_form_id_idx")
op.drop_table("human_input_form_upload_tokens")

Some files were not shown because too many files have changed in this diff Show More