Compare commits

..

52 Commits

Author SHA1 Message Date
e96b18acf2 fix: use vp for contracts generation in autofix workflow
Agent-Logs-Url: https://github.com/langgenius/dify/sessions/96ddc59c-309e-4f87-bb28-f326fbed9778

Co-authored-by: fatelei <961094+fatelei@users.noreply.github.com>
2026-05-18 03:38:25 +00:00
b85cdc37f5 Initial plan 2026-05-18 03:36:33 +00:00
1925d58369 chore: generate contract in ci (#36286) 2026-05-18 03:13:40 +00:00
b79fc5d6b4 fix: add missing phase field to _TokenData TypedDict (#36261) 2026-05-18 02:08:56 +00:00
yyh
6649e4025e feat(dify-ui): add Checkbox/CheckboxGroup primitives (#36271) 2026-05-18 02:01:56 +00:00
b96f372f45 chore(api): upgrade graphon to 0.4.0 (#36124)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
2026-05-18 00:34:17 +00:00
127fbf2c9a refactor: use match cases for workflow stream responses (#36267) 2026-05-17 21:38:20 +00:00
3c70d28064 fix(auth): preserve phase field in _TokenData so reset-password / change-email phase-bound checks don't 400 (#36116) (#36117)
Signed-off-by: vuko <alexander.vukovic@seqis.com>
2026-05-17 19:55:00 +00:00
cd4d6f8a22 fix(web): migrate metadata picker to combobox (#36255)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-17 03:35:20 +00:00
9d0906c684 chore: improve swagger markdown optional fields typing (#36247) 2026-05-16 16:40:20 +00:00
41b6f894c0 fix: fetch memory of LLM node may cause out of flask context (#36253) 2026-05-16 16:38:48 +00:00
e7e6fe8813 refactor: convert isinstance chains to match/case (part 3) (#36242)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-16 08:16:14 +00:00
c0bdd6792f refactor: convert isinstance chains to match/case in easy_ui_based_generate_task_pipeline.py (#36222) 2026-05-15 13:51:49 +00:00
27b084c4d4 fix: reduce db roundtrips of message update (#36213) 2026-05-15 08:39:48 +00:00
3f7a68fc77 fix(api): avoid dify-agent path lookup during Docker build (#36187) 2026-05-15 08:25:58 +00:00
yyh
a252fbddfa feat: initialize user timezone and language from browser (#36170)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-15 08:12:52 +00:00
ff02636a4b fix(web): app icon in webapp (#36206) 2026-05-15 07:44:09 +00:00
63946d829e fix(web): web app description missing (#36209) 2026-05-15 07:43:44 +00:00
cdcfd2ef2c fix: regenerate document summary after update via API (#35950) (#36035)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-15 07:26:29 +00:00
b04a3851cc refactor: enhance layout and scrolling behavior in various modals for improved user experience (#36210)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-15 07:17:23 +00:00
b41338cd08 chore(layout): reintroduce AmplitudeProvider in common layouts for analytics tracking (#36208)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-05-15 06:33:31 +00:00
28153df4d3 chore: enchance copywriting in none education plan warning modal (#36201)
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
2026-05-15 05:08:06 +00:00
3bc3386535 refactor(install): improve layout and scrolling behavior for plugin installation step (#36199)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-05-15 03:12:14 +00:00
7654f14241 fix: replace deprecated testcontainers log waits (#36125) 2026-05-15 01:30:59 +00:00
yyh
194b54bae4 fix: allow tag rename without type payload (#36182)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-15 01:29:17 +00:00
0e16d36edb fix(commands): purge tenant tool credentials on reset-encrypt-key-pair (#35396) (#35843)
Co-authored-by: xr843 <xianren843@protonmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-05-14 16:25:54 +00:00
432a6412a3 fix(security): tenant-scope FilePreviewApi text-extract endpoint (GHSA-2qwc-c2cc-2xwv) (#35797)
Signed-off-by: xr843 <137012659+xr843@users.noreply.github.com>
Co-authored-by: Ido Shani <ido@zafran.io>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-05-14 16:13:04 +00:00
55d05fe52d fix(security): enforce tenant scoping on app trace-config endpoints (GHSA-48xc-wmw8-3jr3) (#35793)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Ido Shani <ido@zafran.io>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-05-14 15:59:31 +00:00
0d500e6965 fix(api): allow LLM nodes to access retrieved knowledge files (#36175) 2026-05-14 13:09:25 +00:00
5798610f27 refactor(api): migrate console.app.workflow_comment to BaseModel (#36180) 2026-05-14 12:13:47 +00:00
a35b28dbef refactor: cleanup duplicate code (#36173) 2026-05-14 10:34:31 +00:00
1a4288c811 fix: action btn is hidden if there are many packages to install (#36176) 2026-05-14 10:21:32 +00:00
9dc32f2318 chore: increase default graph engine min workers (#35650)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
2026-05-14 09:27:45 +00:00
7210f856c9 fix: pipeline template render (#36168)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-14 09:11:18 +00:00
ebcc1200a3 feat(MessageLogModal): refactor modal structure and improve tab handling (#36169)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-14 08:50:21 +00:00
e660d7af38 fix(api): gracefully handle credential fetch failures in rag pipeline (#36165)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-14 08:27:19 +00:00
d9ccfcbc6e fix: fix delete logs failed (#36151)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-14 08:02:31 +00:00
a9bcec013f feat: allow disabling run time cred check (#36031)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-14 07:12:10 +00:00
aeb7687e2c fix: add null check in get_recommend_app_detail before accessing result['id'] (#36153) 2026-05-14 06:42:22 +00:00
9355d36718 chore(deps): bump urllib3 from 2.6.3 to 2.7.0 in /dify-agent (#36160)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-14 06:41:22 +00:00
a03ee828a3 fix: get recommend_app categories should not re-order it (#36161) 2026-05-14 06:36:28 +00:00
7066372892 feat(workflow): enhance workflow run callbacks with additional data tracking (#36149)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-14 06:20:12 +00:00
55f95dbc36 feat(agent): init agent server (#36087)
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-14 06:04:44 +00:00
8b40de3c4e chore: enchance notify link ui (#36155) 2026-05-14 06:03:44 +00:00
af4b9bfa8f chore: Remove pyright in favor of pyrefly (#36154) 2026-05-14 05:49:08 +00:00
b9e3130388 chore: drop unnecessary | None on LLMError and Mail.send (#36147)
Co-authored-by: Brian Wang <20699847+BrianWang1990@users.noreply.github.com>
2026-05-14 03:22:00 +00:00
12d33652b6 fix(errors): clean unnecessary | None in error classes (#36135) 2026-05-14 03:21:41 +00:00
fe8cf2aff4 fix: fix pydantic union type error (#36134) 2026-05-14 02:54:23 +00:00
d1d190374d fix: credit pool access outside flask context (#36143) 2026-05-14 02:45:53 +00:00
e1be4e6aa8 chore(deps): bump langsmith from 0.7.31 to 0.8.0 in /api (#36142)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-14 02:36:02 +00:00
301a470e7a fix: isolate Langfuse v3 SDK TracerProvider to prevent cross-task interference (#36136)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-14 01:46:23 +00:00
91251ad5a5 chore(deps): bump authlib from 1.6.11 to 1.6.12 in /api (#36121)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-13 14:56:20 +00:00
1134 changed files with 31552 additions and 62411 deletions

View File

@ -1,709 +0,0 @@
#!/usr/bin/env bash
set -Eeuo pipefail
SCRIPT_NAME="$(basename "$0")"
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd -P)"
DEFAULT_REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd -P)"
REPO_ROOT="${DIFY_REPO_ROOT:-$DEFAULT_REPO_ROOT}"
YES=false
DRY_RUN=true
SKIP_SMOKE=false
SKIP_MIGRATION=false
TIMEOUT_SECONDS="${DIFY_RESET_TIMEOUT_SECONDS:-300}"
SMOKE_URL="${DIFY_RESET_SMOKE_URL:-}"
LOCK_DIR=""
CURRENT_PHASE="init"
DELETED_PATHS=()
SKIPPED_PATHS=()
DELETED_NAMED_VOLUMES=()
SKIPPED_NAMED_VOLUMES=()
PRESERVED_PATHS=()
HEALTH_RESULTS=()
SMOKE_RESULT="not-run"
START_TIME="$(date +%s)"
RUNTIME_PATHS=(
"volumes/db/data"
"volumes/mysql/data"
"volumes/redis/data"
"volumes/app/storage"
"volumes/plugin_daemon"
"volumes/weaviate"
"volumes/qdrant"
"volumes/pgvector"
"volumes/pgvecto_rs"
"volumes/chroma"
"volumes/milvus"
"volumes/opensearch/data"
)
NAMED_VOLUMES=(
"dify_es01_data"
)
PRESERVE_PATHS=(
".env"
"middleware.env"
"docker-compose.yaml"
"docker-compose.middleware.yaml"
"nginx"
"ssrf_proxy"
"volumes/certbot"
"volumes/opensearch/opensearch_dashboards.yml"
"nginx/ssl"
)
usage() {
cat <<EOF
Usage: $SCRIPT_NAME [options]
Safely reset a Dify test environment in place. The command defaults to dry-run.
Options:
--yes Perform destructive reset. Required to delete data.
--dry-run Print planned actions without changing services or data.
--repo-root PATH Repository root. Defaults to auto-detected Dify root.
--smoke-url URL Public URL to verify after restart.
--skip-smoke Skip public-domain smoke verification.
--skip-migration Skip explicit migration gate.
--timeout SECONDS Health-check timeout. Default: $TIMEOUT_SECONDS.
-h, --help Show this help.
Required for destructive reset:
ALLOW_DIFY_TEST_RESET=true
DIFY_ENV_NAME=test
Optional:
DIFY_RESET_SMOKE_URL=https://test.example.com
RESET_TARGET_DOMAIN=test.example.com
EOF
}
log() {
printf '[%s] %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*"
}
fail() {
local message="$1"
print_report "failure"
printf 'ERROR: %s\n' "$message" >&2
exit 1
}
run_cmd() {
printf '+'
printf ' %q' "$@"
printf '\n'
if [ "$DRY_RUN" = false ]; then
set +e
"$@"
local status=$?
set -e
if [ "$status" -ne 0 ]; then
fail "Command failed with exit code $status: $(command_string "$@")"
fi
fi
}
command_string() {
local arg
local result=""
for arg in "$@"; do
result="$result $(printf '%q' "$arg")"
done
printf '%s' "${result# }"
}
read_env_value() {
local key="$1"
local default_value="$2"
local env_file="$DOCKER_DIR/.env"
local value=""
if [ -f "$env_file" ]; then
value="$(awk -F= -v key="$key" '
$0 !~ /^[[:space:]]*#/ && $1 == key {
sub(/^[^=]*=/, "")
print
}
' "$env_file" | tail -n 1)"
fi
if [ -z "$value" ]; then
printf '%s' "$default_value"
return
fi
value="${value%\"}"
value="${value#\"}"
value="${value%\'}"
value="${value#\'}"
printf '%s' "$value"
}
parse_args() {
while [ "$#" -gt 0 ]; do
case "$1" in
--yes)
YES=true
DRY_RUN=false
;;
--dry-run)
DRY_RUN=true
;;
--repo-root)
[ "$#" -ge 2 ] || fail "--repo-root requires a path"
REPO_ROOT="$2"
shift
;;
--smoke-url)
[ "$#" -ge 2 ] || fail "--smoke-url requires a URL"
SMOKE_URL="$2"
shift
;;
--skip-smoke)
SKIP_SMOKE=true
;;
--skip-migration)
SKIP_MIGRATION=true
;;
--timeout)
[ "$#" -ge 2 ] || fail "--timeout requires seconds"
TIMEOUT_SECONDS="$2"
shift
;;
-h|--help)
usage
exit 0
;;
*)
fail "Unknown option: $1"
;;
esac
shift
done
}
validate_number() {
case "$TIMEOUT_SECONDS" in
''|*[!0-9]*)
fail "--timeout must be a positive integer"
;;
esac
}
require_docker() {
command -v docker >/dev/null 2>&1 || fail "docker command not found"
docker compose version >/dev/null 2>&1 || fail "docker compose is not available"
}
validate_environment() {
CURRENT_PHASE="validate"
REPO_ROOT="$(cd "$REPO_ROOT" && pwd -P)"
DOCKER_DIR="$REPO_ROOT/docker"
[ -d "$DOCKER_DIR" ] || fail "Docker directory not found: $DOCKER_DIR"
[ -f "$DOCKER_DIR/docker-compose.yaml" ] || fail "docker-compose.yaml not found in $DOCKER_DIR"
[ -f "$DOCKER_DIR/.env" ] || fail ".env not found in $DOCKER_DIR"
if [ "$DRY_RUN" = false ]; then
[ "$YES" = true ] || fail "Destructive reset requires --yes"
[ "${ALLOW_DIFY_TEST_RESET:-}" = "true" ] || fail "ALLOW_DIFY_TEST_RESET=true is required"
[ "${DIFY_ENV_NAME:-}" = "test" ] || fail "DIFY_ENV_NAME=test is required"
require_docker
fi
}
acquire_lock() {
CURRENT_PHASE="lock"
local env_name="${DIFY_ENV_NAME:-dry-run}"
LOCK_DIR="${TMPDIR:-/tmp}/dify-test-reset-${env_name}.lock"
if ! mkdir "$LOCK_DIR" 2>/dev/null; then
fail "Reset lock is already held: $LOCK_DIR"
fi
printf '%s\n' "$$" > "$LOCK_DIR/pid"
trap cleanup EXIT
}
cleanup() {
if [ -n "$LOCK_DIR" ] && [ -d "$LOCK_DIR" ]; then
rm -rf "$LOCK_DIR"
fi
}
compose() {
local args=(compose --env-file "$DOCKER_DIR/.env" -f "$DOCKER_DIR/docker-compose.yaml")
if [ -n "${DIFY_COMPOSE_PROJECT:-}" ]; then
args+=(-p "$DIFY_COMPOSE_PROJECT")
fi
docker "${args[@]}" "$@"
}
compose_project_name() {
if [ -n "${DIFY_COMPOSE_PROJECT:-}" ]; then
printf '%s' "$DIFY_COMPOSE_PROJECT"
return
fi
if [ -n "${COMPOSE_PROJECT_NAME:-}" ]; then
printf '%s' "$COMPOSE_PROJECT_NAME"
return
fi
local env_project
env_project="$(read_env_value COMPOSE_PROJECT_NAME "")"
if [ -n "$env_project" ]; then
printf '%s' "$env_project"
return
fi
basename "$DOCKER_DIR"
}
active_db_service() {
local db_type
db_type="$(read_env_value DB_TYPE postgresql)"
case "$db_type" in
postgresql|'')
printf '%s\n' "db_postgres"
;;
mysql)
printf '%s\n' "db_mysql"
;;
oceanbase)
printf '%s\n' "oceanbase"
;;
*)
printf '%s\n' "$db_type"
;;
esac
}
active_vector_service() {
local vector_store
vector_store="$(read_env_value VECTOR_STORE weaviate)"
case "$vector_store" in
''|none|external)
return 0
;;
pgvecto-rs|pgvecto_rs)
printf '%s\n' "pgvecto-rs"
;;
milvus)
printf '%s\n' "milvus-standalone"
;;
elasticsearch|opensearch|weaviate|qdrant|pgvector|chroma|oceanbase|seekdb|couchbase-server|iris)
printf '%s\n' "$vector_store"
;;
couchbase)
printf '%s\n' "couchbase-server"
;;
*)
return 0
;;
esac
}
safe_runtime_path() {
local rel_path="$1"
case "$rel_path" in
""|"/"| "." | ".." | *".."* | /*)
return 1
;;
esac
case "$rel_path" in
volumes/*)
return 0
;;
*)
return 1
;;
esac
}
safe_named_volume() {
local volume="$1"
case "$volume" in
""|*"/"*|*" "*|*$'\t'*|*$'\n'*|*$'\r'*)
return 1
;;
*[!a-zA-Z0-9_.-]*)
return 1
;;
*)
return 0
;;
esac
}
volume_exists() {
docker volume inspect "$1" >/dev/null 2>&1
}
append_unique_volume() {
local candidate="$1"
local existing
[ -n "$candidate" ] || return 0
for existing in "${RESOLVED_VOLUME_NAMES[@]}"; do
if [ "$existing" = "$candidate" ]; then
return
fi
done
RESOLVED_VOLUME_NAMES+=("$candidate")
}
resolve_named_volume_names() {
local logical_name="$1"
local project_name
local candidate
local volume_list
local status
RESOLVED_VOLUME_NAMES=()
project_name="$(compose_project_name)"
set +e
volume_list="$(docker volume ls -q \
--filter "label=com.docker.compose.project=$project_name" \
--filter "label=com.docker.compose.volume=$logical_name" 2>/dev/null)"
status=$?
set -e
if [ "$status" -ne 0 ]; then
fail "Failed to list Docker volumes for Compose project $project_name"
fi
while IFS= read -r candidate; do
append_unique_volume "$candidate"
done <<< "$volume_list"
for candidate in "${project_name}_${logical_name}" "$logical_name"; do
if volume_exists "$candidate"; then
append_unique_volume "$candidate"
fi
done
}
collect_preserved_paths() {
PRESERVED_PATHS=()
local rel_path
for rel_path in "${PRESERVE_PATHS[@]}"; do
if [ -e "$DOCKER_DIR/$rel_path" ]; then
PRESERVED_PATHS+=("$rel_path")
fi
done
}
print_plan() {
CURRENT_PHASE="plan"
local db_service
local vector_service
db_service="$(active_db_service)"
vector_service="$(active_vector_service || true)"
collect_preserved_paths
log "Reset mode: $([ "$DRY_RUN" = true ] && printf dry-run || printf destructive)"
log "Repository root: $REPO_ROOT"
log "Docker directory: $DOCKER_DIR"
log "Compose project: $(compose_project_name)"
log "Database service: $db_service"
log "Vector service: ${vector_service:-<external-or-none>}"
log "Timeout: ${TIMEOUT_SECONDS}s"
printf '\nPlanned runtime path deletions:\n'
local rel_path
for rel_path in "${RUNTIME_PATHS[@]}"; do
printf ' - %s\n' "$rel_path"
done
printf '\nPlanned named volume deletions:\n'
local volume
for volume in "${NAMED_VOLUMES[@]}"; do
printf ' - %s (Compose project: %s)\n' "$volume" "$(compose_project_name)"
done
printf '\nPreserved configuration paths found:\n'
for rel_path in "${PRESERVED_PATHS[@]}"; do
printf ' - %s\n' "$rel_path"
done
printf '\nCommands:\n'
printf ' - docker compose down --remove-orphans\n'
printf ' - delete allowlisted runtime paths and named volumes\n'
printf ' - docker compose up -d %s redis%s\n' "$db_service" "${vector_service:+ $vector_service}"
printf ' - docker compose run --rm -e MIGRATION_ENABLED=true -e MODE=migration api\n'
printf ' - docker compose up -d\n'
printf ' - health checks and smoke check\n\n'
}
delete_runtime_paths() {
CURRENT_PHASE="delete-runtime-data"
local rel_path
local abs_path
for rel_path in "${RUNTIME_PATHS[@]}"; do
safe_runtime_path "$rel_path" || fail "Unsafe runtime path in allowlist: $rel_path"
abs_path="$DOCKER_DIR/$rel_path"
if [ ! -e "$abs_path" ]; then
SKIPPED_PATHS+=("$rel_path (absent)")
continue
fi
DELETED_PATHS+=("$rel_path")
run_cmd rm -rf -- "$abs_path"
done
}
delete_named_volumes() {
CURRENT_PHASE="delete-runtime-volumes"
local logical_name
local actual_name
for logical_name in "${NAMED_VOLUMES[@]}"; do
safe_named_volume "$logical_name" || fail "Unsafe named volume in allowlist: $logical_name"
resolve_named_volume_names "$logical_name"
if [ "${#RESOLVED_VOLUME_NAMES[@]}" -eq 0 ]; then
SKIPPED_NAMED_VOLUMES+=("$logical_name (absent)")
continue
fi
for actual_name in "${RESOLVED_VOLUME_NAMES[@]}"; do
DELETED_NAMED_VOLUMES+=("$actual_name")
run_cmd docker volume rm "$actual_name"
done
done
}
stop_stack() {
CURRENT_PHASE="stop-stack"
run_cmd compose down --remove-orphans
}
start_middleware() {
CURRENT_PHASE="start-middleware"
local db_service
local vector_service
local services=()
db_service="$(active_db_service)"
vector_service="$(active_vector_service || true)"
services+=("$db_service" "redis")
if [ -n "$vector_service" ]; then
services+=("$vector_service")
fi
run_cmd compose up -d "${services[@]}"
if [ "$DRY_RUN" = false ]; then
wait_for_services "${services[@]}"
fi
}
run_migration() {
CURRENT_PHASE="migration"
if [ "$SKIP_MIGRATION" = true ]; then
HEALTH_RESULTS+=("migration:skipped")
return
fi
run_cmd compose run --rm -e MIGRATION_ENABLED=true -e MODE=migration api
HEALTH_RESULTS+=("migration:ok")
}
start_full_stack() {
CURRENT_PHASE="start-full-stack"
run_cmd compose up -d
if [ "$DRY_RUN" = false ]; then
wait_for_services api web worker nginx
wait_if_service_exists plugin_daemon
fi
}
container_status() {
local service="$1"
local container_id
container_id="$(compose ps -q "$service" 2>/dev/null || true)"
[ -n "$container_id" ] || return 1
local health
health="$(docker inspect --format '{{if .State.Health}}{{.State.Health.Status}}{{else}}{{.State.Status}}{{end}}' "$container_id" 2>/dev/null || true)"
printf '%s' "$health"
}
wait_if_service_exists() {
local service="$1"
if [ -n "$(compose ps -q "$service" 2>/dev/null || true)" ]; then
wait_for_services "$service"
fi
}
wait_for_services() {
local service
for service in "$@"; do
wait_for_service "$service"
done
}
wait_for_service() {
local service="$1"
local deadline=$(( $(date +%s) + TIMEOUT_SECONDS ))
local status=""
log "Waiting for service: $service"
while [ "$(date +%s)" -le "$deadline" ]; do
status="$(container_status "$service" || true)"
case "$status" in
healthy|running)
HEALTH_RESULTS+=("$service:$status")
return 0
;;
unhealthy|exited|dead)
HEALTH_RESULTS+=("$service:$status")
fail "Service $service reached failure status: $status"
;;
esac
sleep 3
done
HEALTH_RESULTS+=("$service:timeout")
fail "Timed out waiting for service: $service"
}
default_smoke_url() {
if [ -n "$SMOKE_URL" ]; then
printf '%s' "$SMOKE_URL"
return
fi
if [ -n "${RESET_TARGET_DOMAIN:-}" ]; then
local https_enabled
https_enabled="$(read_env_value NGINX_HTTPS_ENABLED false)"
if [ "$https_enabled" = "true" ]; then
printf 'https://%s' "$RESET_TARGET_DOMAIN"
else
printf 'http://%s' "$RESET_TARGET_DOMAIN"
fi
return
fi
local port
port="$(read_env_value EXPOSE_NGINX_PORT 80)"
printf 'http://localhost:%s' "$port"
}
run_smoke_check() {
CURRENT_PHASE="smoke"
if [ "$SKIP_SMOKE" = true ]; then
SMOKE_RESULT="skipped"
return
fi
local url
url="$(default_smoke_url)"
if [ "$DRY_RUN" = true ]; then
SMOKE_RESULT="planned:$url"
printf '+ curl -fsS --max-time 10 %q\n' "$url"
return
fi
curl -fsS --max-time 10 "$url" >/dev/null || fail "Smoke check failed: $url"
SMOKE_RESULT="ok:$url"
}
print_report() {
local status="${1:-success}"
local end_time
end_time="$(date +%s)"
printf '\nReset report\n'
printf '============\n'
printf 'status: %s\n' "$status"
printf 'environment: %s\n' "${DIFY_ENV_NAME:-<unset>}"
printf 'repo_root: %s\n' "${REPO_ROOT:-<unset>}"
printf 'phase: %s\n' "$CURRENT_PHASE"
printf 'duration_seconds: %s\n' "$(( end_time - START_TIME ))"
printf 'mode: %s\n' "$([ "$DRY_RUN" = true ] && printf dry-run || printf destructive)"
printf '\ndeleted_runtime_paths:\n'
if [ "${#DELETED_PATHS[@]}" -eq 0 ]; then
printf ' - <none>\n'
else
printf ' - %s\n' "${DELETED_PATHS[@]}"
fi
printf '\nskipped_runtime_paths:\n'
if [ "${#SKIPPED_PATHS[@]}" -eq 0 ]; then
printf ' - <none>\n'
else
printf ' - %s\n' "${SKIPPED_PATHS[@]}"
fi
printf '\ndeleted_named_volumes:\n'
if [ "${#DELETED_NAMED_VOLUMES[@]}" -eq 0 ]; then
printf ' - <none>\n'
else
printf ' - %s\n' "${DELETED_NAMED_VOLUMES[@]}"
fi
printf '\nskipped_named_volumes:\n'
if [ "${#SKIPPED_NAMED_VOLUMES[@]}" -eq 0 ]; then
printf ' - <none>\n'
else
printf ' - %s\n' "${SKIPPED_NAMED_VOLUMES[@]}"
fi
printf '\npreserved_paths:\n'
if [ "${#PRESERVED_PATHS[@]}" -eq 0 ]; then
printf ' - <none found>\n'
else
printf ' - %s\n' "${PRESERVED_PATHS[@]}"
fi
printf '\nhealth_results:\n'
if [ "${#HEALTH_RESULTS[@]}" -eq 0 ]; then
printf ' - <not run>\n'
else
printf ' - %s\n' "${HEALTH_RESULTS[@]}"
fi
printf '\nsmoke_result: %s\n' "$SMOKE_RESULT"
}
main() {
parse_args "$@"
validate_number
validate_environment
acquire_lock
print_plan
if [ "$DRY_RUN" = true ]; then
run_smoke_check
print_report "dry-run"
return 0
fi
stop_stack
delete_runtime_paths
delete_named_volumes
start_middleware
run_migration
start_full_stack
run_smoke_check
CURRENT_PHASE="complete"
print_report "success"
}
main "$@"

View File

@ -120,7 +120,11 @@ jobs:
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
run: |
cd api
uv run dev/generate_swagger_markdown_docs.py --swagger-dir openapi --markdown-dir openapi/markdown
uv run dev/generate_swagger_markdown_docs.py --swagger-dir ../packages/contracts/openapi --markdown-dir openapi/markdown --keep-swagger-json
- name: Generate frontend contracts
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
run: cd packages/contracts && vp run gen-api-contract-from-openapi
- name: ESLint autofix
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'

View File

@ -77,6 +77,8 @@ jobs:
with:
files: |
web/**
e2e/**
sdks/nodejs-client/**
packages/**
package.json
pnpm-lock.yaml
@ -94,14 +96,14 @@ jobs:
id: eslint-cache-restore
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: web/.eslintcache
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
path: .eslintcache
key: ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
restore-keys: |
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
working-directory: .
run: vp run lint:ci
- name: Web tsslint
@ -113,7 +115,7 @@ jobs:
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
working-directory: .
run: vp run type-check
- name: Web dead code check
@ -125,7 +127,7 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: web/.eslintcache
path: .eslintcache
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
superlinter:

3
.gitignore vendored
View File

@ -203,6 +203,7 @@ sdks/python-client/dify_client.egg-info
.vscode/*
!.vscode/launch.json.template
!.vscode/settings.example.json
!.vscode/README.md
api/.vscode
# vscode Code History Extension
@ -249,3 +250,5 @@ scripts/stress-test/reports/
# Code Agent Folder
.qoder/*
.context/*
.eslintcache

View File

@ -56,44 +56,9 @@ if $api_modified; then
fi
fi
if $web_modified; then
if $skip_web_checks; then
echo "Git operation in progress, skipping web checks"
exit 0
fi
echo "Running ESLint on web module"
if git diff --cached --quiet -- 'web/**/*.ts' 'web/**/*.tsx'; then
web_ts_modified=false
else
ts_diff_status=$?
if [ $ts_diff_status -eq 1 ]; then
web_ts_modified=true
else
echo "Unable to determine staged TypeScript changes (git exit code: $ts_diff_status)."
exit $ts_diff_status
fi
fi
cd ./web || exit 1
pnpm exec vp staged
if $web_ts_modified; then
echo "Running TypeScript type-check:tsgo"
if ! npm run type-check:tsgo; then
echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors."
exit 1
fi
else
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
fi
echo "Running knip"
if ! npm run knip; then
echo "Knip check failed. Please run 'npm run knip' to fix the errors."
exit 1
fi
cd ../
if $skip_web_checks; then
echo "Git operation in progress, skipping web checks"
exit 0
fi
vp staged

View File

@ -9,6 +9,7 @@ The codebase is split into:
- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
- **Frontend Web** (`/web`): Next.js application using TypeScript and React
- **Docker deployment** (`/docker`): Containerized deployment configurations
- **Dify Agent Backend** (`/dify-agent`): Backend services for managing and executing agent
## Backend Workflow

View File

@ -83,16 +83,15 @@ lint:
@echo "✅ Linting complete"
type-check:
@echo "📝 Running type checks (basedpyright + pyrefly + mypy)..."
@./dev/basedpyright-check $(PATH_TO_CHECK)
@./dev/pyrefly-check-local
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
@echo "📝 Running type checks (pyrefly + mypy)..."
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
@echo "✅ Type checks complete"
type-check-core:
@echo "📝 Running core type checks (basedpyright + mypy)..."
@./dev/basedpyright-check $(PATH_TO_CHECK)
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
@echo "📝 Running core type checks (pyrefly + mypy)..."
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
@echo "✅ Core type checks complete"
test:
@ -153,8 +152,8 @@ help:
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)"
@echo " make type-check-core - Run core type checks (basedpyright, mypy)"
@echo " make type-check - Run type checks (pyrefly, mypy)"
@echo " make type-check-core - Run core type checks (pyrefly, mypy)"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
@echo ""
@echo "Docker Build Targets:"

View File

@ -557,7 +557,7 @@ MAX_VARIABLE_SIZE=204800
# GraphEngine Worker Pool Configuration
# Minimum number of workers per GraphEngine instance (default: 1)
GRAPH_ENGINE_MIN_WORKERS=1
GRAPH_ENGINE_MIN_WORKERS=3
# Maximum number of workers per GraphEngine instance (default: 10)
GRAPH_ENGINE_MAX_WORKERS=10
# Queue depth threshold that triggers worker scale up (default: 3)

View File

@ -17,14 +17,15 @@ FROM base AS packages
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
# basic environment
git g++ \
g++ \
# for building gmpy2
libmpfr-dev libmpc-dev
# Install Python dependencies (workspace members under providers/vdb/)
COPY pyproject.toml uv.lock ./
COPY providers ./providers
RUN uv sync --locked --no-dev --group evaluation
# Trust the checked-in lock during image builds; dev-only path sources live outside the api/ context.
RUN uv sync --frozen --no-dev
# production stage
FROM base AS production
@ -77,7 +78,6 @@ RUN \
# Install dependencies
&& apt-get install -y --no-install-recommends \
# basic environment
git \
nodejs=${NODE_PACKAGE_VERSION} \
# for gmpy2 \
libgmp-dev libmpfr-dev libmpc-dev \

View File

@ -99,7 +99,7 @@ The scripts resolve paths relative to their location, so you can run them from a
./dev/reformat # Run all formatters and linters
uv run ruff check --fix ./ # Fix linting issues
uv run ruff format ./ # Format code
uv run basedpyright . # Type checking
uv run pyrefly check # Type checking
```
## Generate TS stub

View File

@ -117,7 +117,7 @@ def create_flask_app_with_configs() -> DifyApp:
logger.warning("Failed to add trace headers to response", exc_info=True)
return response
# Capture the decorator's return value to avoid pyright reportUnusedFunction
# Capture the decorator return values so static checkers do not treat the hooks as unused.
_ = before_request
_ = add_trace_headers

View File

@ -4,7 +4,6 @@ CLI command modules extracted from `commands.py`.
from .account import create_tenant, reset_email, reset_password
from .plugin import (
backfill_plugin_auto_upgrade,
extract_plugins,
extract_unique_plugins,
install_plugins,
@ -15,7 +14,6 @@ from .plugin import (
setup_system_trigger_oauth_client,
transform_datasource_credentials,
)
from .rbac import migrate_member_roles_to_rbac
from .retention import (
archive_workflow_runs,
clean_expired_messages,
@ -39,7 +37,6 @@ from .vector import (
__all__ = [
"add_qdrant_index",
"archive_workflow_runs",
"backfill_plugin_auto_upgrade",
"clean_expired_messages",
"clean_workflow_runs",
"cleanup_orphaned_draft_variables",
@ -58,7 +55,6 @@ __all__ = [
"migrate_annotation_vector_database",
"migrate_data_for_plugin",
"migrate_knowledge_vector_database",
"migrate_member_roles_to_rbac",
"migrate_oss",
"old_metadata_migration",
"remove_orphaned_files_on_storage",

View File

@ -1,11 +1,10 @@
import json
import logging
import time
from typing import Any, cast
import click
from pydantic import TypeAdapter
from sqlalchemy import delete, func, select
from sqlalchemy import delete, select
from sqlalchemy.engine import CursorResult
from configs import dify_config
@ -15,13 +14,11 @@ from core.plugin.impl.plugin import PluginInstaller
from core.tools.utils.system_encryption import encrypt_system_params
from extensions.ext_database import db
from models import Tenant
from models.account import TenantPluginAutoUpgradeStrategy
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.provider_ids import DatasourceProviderID, ToolProviderID
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from models.tools import ToolOAuthSystemClient
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
@ -188,9 +185,9 @@ def transform_datasource_credentials(environment: str):
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
jina_plugin_id = "langgenius/jina_datasource"
if environment == "online":
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id)
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id)
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id)
else:
notion_plugin_unique_identifier = None
firecrawl_plugin_unique_identifier = None
@ -405,110 +402,6 @@ def migrate_data_for_plugin():
click.echo(click.style("Migrate data for plugin completed.", fg="green"))
def _candidate_auto_upgrade_strategy_tenant_ids_stmt(limit: int | None = None):
category_count = len(TenantPluginAutoUpgradeStrategy.PluginCategory)
stmt = (
select(TenantPluginAutoUpgradeStrategy.tenant_id)
.group_by(TenantPluginAutoUpgradeStrategy.tenant_id)
.having(func.count(func.distinct(TenantPluginAutoUpgradeStrategy.category)) < category_count)
.order_by(TenantPluginAutoUpgradeStrategy.tenant_id)
)
if limit is not None:
stmt = stmt.limit(limit)
return stmt
def _count_auto_upgrade_strategy_tenant_ids(limit: int | None) -> int:
candidate_stmt = _candidate_auto_upgrade_strategy_tenant_ids_stmt(limit).subquery()
return db.session.scalar(select(func.count()).select_from(candidate_stmt)) or 0
def _iter_auto_upgrade_strategy_tenant_ids(limit: int | None):
stmt = _candidate_auto_upgrade_strategy_tenant_ids_stmt(limit).execution_options(yield_per=1000)
yield from db.session.scalars(stmt)
@click.command(
"backfill-plugin-auto-upgrade",
help="Backfill category-scoped plugin auto-upgrade strategies and normalize plugin lists.",
)
@click.option("--tenant-id", multiple=True, help="Tenant ID to backfill. Can be passed multiple times.")
@click.option("--limit", type=int, default=None, help="Maximum number of candidate tenants to process.")
@click.option("--batch-size", type=int, default=500, show_default=True, help="Progress reporting batch size.")
@click.option("--dry-run", is_flag=True, help="Only print candidate tenant count.")
def backfill_plugin_auto_upgrade(
tenant_id: tuple[str, ...],
limit: int | None,
batch_size: int,
dry_run: bool,
):
"""
Backfill historical auto-upgrade strategies after the category column exists.
Missing category rows are created from the tenant's tool/default row. Pure default
strategies become latest for model plugins and fix-only for all other categories.
Tenants with include/exclude plugin IDs are split
by installed plugin category using plugin daemon metadata.
"""
start_at = time.perf_counter()
candidate_count = len(tenant_id) if tenant_id else _count_auto_upgrade_strategy_tenant_ids(limit)
click.echo(click.style(f"Found {candidate_count} candidate tenants.", fg="yellow"))
if dry_run:
elapsed = time.perf_counter() - start_at
click.echo(click.style(f"Dry run completed. elapsed={elapsed:.2f}s", fg="green"))
return
tenant_ids = list(tenant_id) if tenant_id else _iter_auto_upgrade_strategy_tenant_ids(limit)
backfilled_count = 0
created_count = 0
normalized_count = 0
skipped_count = 0
failed_count = 0
for index, current_tenant_id in enumerate(tenant_ids, start=1):
try:
result = PluginAutoUpgradeService.backfill_strategy_categories(
current_tenant_id,
)
except Exception as e:
failed_count += 1
click.echo(click.style(f"Failed tenant {current_tenant_id}: {str(e)}", fg="red"))
continue
if result.created_count > 0:
backfilled_count += 1
created_count += result.created_count
elif not result.normalized:
skipped_count += 1
if result.normalized:
normalized_count += 1
if batch_size > 0 and index % batch_size == 0:
click.echo(
click.style(
f"Processed {index}/{candidate_count} tenants. "
f"backfilled={backfilled_count}, created_rows={created_count}, "
f"normalized={normalized_count}, skipped={skipped_count}, failed={failed_count}, "
f"elapsed={time.perf_counter() - start_at:.2f}s",
fg="yellow",
)
)
elapsed = time.perf_counter() - start_at
click.echo(
click.style(
f"Backfill plugin auto-upgrade strategy categories completed. "
f"backfilled={backfilled_count}, created_rows={created_count}, "
f"normalized={normalized_count}, skipped={skipped_count}, failed={failed_count}, "
f"elapsed={elapsed:.2f}s",
fg="green",
)
)
@click.command("extract-plugins", help="Extract plugins.")
@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl")
@click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10)

View File

@ -1,109 +0,0 @@
from __future__ import annotations
import click
from sqlalchemy import select
from core.db.session_factory import session_factory
from models import TenantAccountJoin, TenantAccountRole
from services.enterprise.rbac_service import ListOption, RBACService
def _resolve_builtin_role_id(tenant_id: str, operator_account_id: str, legacy_role: str) -> str:
"""Resolve a legacy workspace role to the current tenant's builtin RBAC role id.
The migration replays the old `TenantAccountJoin.role` values onto the
RBAC member-role binding API. Builtin RBAC roles are tenant-scoped and
identified by runtime ids, so the command must look them up per tenant.
"""
expected_builtin_name = {
TenantAccountRole.OWNER.value: "所有者",
TenantAccountRole.ADMIN.value: "管理者",
TenantAccountRole.EDITOR.value: "编辑者",
TenantAccountRole.NORMAL.value: "普通用户",
TenantAccountRole.DATASET_OPERATOR.value: "知识库操作员",
}.get(legacy_role)
if not expected_builtin_name:
raise ValueError(f"Unsupported legacy workspace role: {legacy_role}")
roles = RBACService.Roles.list(
tenant_id=tenant_id,
account_id=operator_account_id,
options=ListOption(page_number=1, results_per_page=100),
).data
for role in roles:
if role.is_builtin and role.category == "global_system_default" and role.name == expected_builtin_name:
return str(role.id)
raise ValueError(f"Builtin RBAC role not found for tenant={tenant_id}, legacy_role={legacy_role}")
@click.command("rbac-migrate-member-roles", help="Migrate legacy workspace member roles into RBAC member-role bindings.")
@click.option("--tenant-id", help="Only migrate a single workspace.")
@click.option("--dry-run", is_flag=True, default=False, help="Preview the migration without writing RBAC bindings.")
def migrate_member_roles_to_rbac(tenant_id: str | None, dry_run: bool) -> None:
"""Backfill RBAC member-role bindings from legacy `TenantAccountJoin.role` data.
This is an offline migration command for workspaces that already have
members in the legacy role model but need matching records in the RBAC
member-role binding store.
"""
click.echo(click.style("Starting RBAC member-role migration.", fg="green"))
with session_factory.create_session() as session:
stmt = select(TenantAccountJoin).order_by(TenantAccountJoin.tenant_id.asc(), TenantAccountJoin.id.asc())
if tenant_id:
stmt = stmt.where(TenantAccountJoin.tenant_id == tenant_id)
joins = list(session.scalars(stmt).all())
if not joins:
click.echo(click.style("No workspace members found for migration.", fg="yellow"))
return
owner_account_by_tenant: dict[str, str] = {}
resolved_role_ids: dict[tuple[str, str], str] = {}
migrated_count = 0
for join in joins:
workspace_id = str(join.tenant_id)
member_account_id = str(join.account_id)
legacy_role = str(join.role)
if workspace_id not in owner_account_by_tenant:
owner_join = next(
(
item
for item in joins
if str(item.tenant_id) == workspace_id and str(item.role) == TenantAccountRole.OWNER.value
),
None,
)
if not owner_join:
raise ValueError(f"Workspace owner not found for tenant={workspace_id}")
owner_account_by_tenant[workspace_id] = str(owner_join.account_id)
operator_account_id = owner_account_by_tenant[workspace_id]
cache_key = (workspace_id, legacy_role)
if cache_key not in resolved_role_ids:
resolved_role_ids[cache_key] = _resolve_builtin_role_id(workspace_id, operator_account_id, legacy_role)
resolved_role_id = resolved_role_ids[cache_key]
click.echo(
f"tenant={workspace_id} member={member_account_id} legacy_role={legacy_role} -> rbac_role_id={resolved_role_id}"
)
if dry_run:
continue
RBACService.MemberRoles.replace(
tenant_id=workspace_id,
account_id=operator_account_id,
member_account_id=member_account_id,
role_ids=[resolved_role_id],
)
migrated_count += 1
if dry_run:
click.echo(click.style("Dry run completed. No RBAC bindings were written.", fg="yellow"))
else:
click.echo(click.style(f"RBAC member-role migration completed. Migrated {migrated_count} members.", fg="green"))

View File

@ -14,6 +14,7 @@ from libs.rsa import generate_key_pair
from models import Tenant
from models.model import App, AppMode, Conversation
from models.provider import Provider, ProviderModel
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider
logger = logging.getLogger(__name__)
@ -23,13 +24,16 @@ DB_UPGRADE_LOCK_TTL_SECONDS = 60
@click.command(
"reset-encrypt-key-pair",
help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. "
"After the reset, all LLM credentials will become invalid, "
"requiring re-entry."
"After the reset, all LLM credentials and tool provider credentials "
"(builtin / API / MCP) will be purged, requiring re-entry. "
"Only support SELF_HOSTED mode.",
)
@click.confirmation_option(
prompt=click.style(
"Are you sure you want to reset encrypt key pair? This operation cannot be rolled back!", fg="red"
"Are you sure you want to reset encrypt key pair? "
"This will also purge builtin / API / MCP tool provider records for every tenant. "
"This operation cannot be rolled back!",
fg="red",
)
)
def reset_encrypt_key_pair():
@ -53,6 +57,13 @@ def reset_encrypt_key_pair():
session.execute(delete(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id))
session.execute(delete(ProviderModel).where(ProviderModel.tenant_id == tenant.id))
# Purge tool provider records that hold credentials encrypted under the
# tenant key. Leaving them in place causes /console/api/workspaces/current/
# tool-providers to 500 because decryption fails on stale ciphertext (#35396).
session.execute(delete(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant.id))
session.execute(delete(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant.id))
session.execute(delete(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant.id))
click.echo(
click.style(
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",

View File

@ -23,9 +23,10 @@ class EnterpriseFeatureConfig(BaseSettings):
ge=1, description="Maximum timeout in seconds for enterprise requests", default=5
)
RBAC_ENABLED: bool = Field(
description="Enable enterprise RBAC APIs. When disabled, compatibility responses fall back to legacy roles.",
ENTERPRISE_DISABLE_RUNTIME_CREDENTIAL_CHECK: bool = Field(
default=False,
description="If disabled, credential policy check is only performed when saving workflows."
"This helps gain runtime performance by trading off consistency.",
)

View File

@ -761,7 +761,7 @@ class WorkflowConfig(BaseSettings):
# GraphEngine Worker Pool Configuration
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
description="Minimum number of workers per GraphEngine instance",
default=1,
default=3,
)
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
@ -1406,32 +1406,6 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
)
class EvaluationConfig(BaseSettings):
"""
Configuration for evaluation runtime
"""
EVALUATION_FRAMEWORK: str = Field(
description="Evaluation framework to use (ragas/deepeval/none)",
default="none",
)
EVALUATION_MAX_CONCURRENT_RUNS: PositiveInt = Field(
description="Maximum number of concurrent evaluation runs per tenant",
default=3,
)
EVALUATION_MAX_DATASET_ROWS: PositiveInt = Field(
description="Maximum number of rows allowed in an evaluation dataset",
default=500,
)
EVALUATION_TASK_TIMEOUT: PositiveInt = Field(
description="Timeout in seconds for a single evaluation task",
default=3600,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -1445,7 +1419,6 @@ class FeatureConfig(
MarketplaceConfig,
DataSetConfig,
EndpointConfig,
EvaluationConfig,
FileAccessConfig,
FileUploadConfig,
HttpConfig,

View File

@ -1,5 +1,5 @@
import os
from typing import Any, Literal, TypedDict
from typing import Any, Literal, TypedDict, cast
from urllib.parse import parse_qsl, quote_plus
from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
@ -50,28 +50,30 @@ from .vdb.vastbase_vector_config import VastbaseVectorConfig
from .vdb.vikingdb_config import VikingDBConfig
from .vdb.weaviate_config import WeaviateConfig
_VALID_STORAGE_TYPE = Literal[
"opendal",
"s3",
"aliyun-oss",
"azure-blob",
"baidu-obs",
"clickzetta-volume",
"google-storage",
"huawei-obs",
"oci-storage",
"tencent-cos",
"volcengine-tos",
"supabase",
"local",
]
class StorageConfig(BaseSettings):
STORAGE_TYPE: Literal[
"opendal",
"s3",
"aliyun-oss",
"azure-blob",
"baidu-obs",
"clickzetta-volume",
"google-storage",
"huawei-obs",
"oci-storage",
"tencent-cos",
"volcengine-tos",
"supabase",
"local",
] = Field(
STORAGE_TYPE: _VALID_STORAGE_TYPE = Field(
description="Type of storage to use."
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', "
"'clickzetta-volume', 'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', "
"'volcengine-tos', 'supabase'. Default is 'opendal'.",
default="opendal",
default=cast(_VALID_STORAGE_TYPE, "opendal"),
)
STORAGE_LOCAL_PATH: str = Field(

View File

@ -1,36 +1,21 @@
from pydantic import BaseModel, Field, JsonValue
import json
HUMAN_INPUT_FORM_INPUT_EXAMPLE = {
"decision": "approve",
"attachment": {
"transfer_method": "local_file",
"upload_file_id": "4e0d1b87-52f2-49f6-b8c6-95cd9c954b3e",
"type": "document",
},
"attachments": [
{
"transfer_method": "local_file",
"upload_file_id": "1a77f0df-c0e6-461c-987c-e72526f341ee",
"type": "document",
},
{
"transfer_method": "remote_url",
"url": "https://example.com/report.pdf",
"type": "document",
},
],
}
from pydantic import BaseModel, JsonValue
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict[str, JsonValue] = Field(
description=(
"Submitted human input values keyed by output variable name. "
"Use a string for paragraph or select input values, a file mapping for file inputs, "
"and a list of file mappings for file-list inputs. Local file mappings use "
"`transfer_method=local_file` with `upload_file_id`; remote file mappings use "
"`transfer_method=remote_url` with `url` or `remote_url`."
),
examples=[HUMAN_INPUT_FORM_INPUT_EXAMPLE],
)
inputs: dict[str, JsonValue]
action: str
def stringify_form_default_values(values: dict[str, object]) -> dict[str, str]:
"""Serialize default values into strings expected by human-input form clients."""
result: dict[str, str] = {}
for key, value in values.items():
if value is None:
result[key] = ""
elif isinstance(value, (dict, list)):
result[key] = json.dumps(value, ensure_ascii=False)
else:
result[key] = str(value)
return result

View File

@ -39,6 +39,7 @@ QueryParamDoc = TypedDict(
def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None:
"""Register a JSON schema and promote any nested Pydantic `$defs`."""
schema = _swagger_2_compatible_schema(schema)
nested_definitions = schema.get("$defs")
schema_to_register = dict(schema)
if isinstance(nested_definitions, dict):
@ -65,6 +66,35 @@ def _register_schema_model(namespace: Namespace, model: type[BaseModel], *, mode
)
def _swagger_2_compatible_schema(value: Any) -> Any:
if isinstance(value, list):
return [_swagger_2_compatible_schema(item) for item in value]
if not isinstance(value, dict):
return value
converted = {key: _swagger_2_compatible_schema(child) for key, child in value.items()}
any_of = value.get("anyOf")
if not isinstance(any_of, list):
return converted
non_null_candidates = [
candidate for candidate in any_of if isinstance(candidate, Mapping) and candidate.get("type") != "null"
]
has_null_candidate = any(isinstance(candidate, Mapping) and candidate.get("type") == "null" for candidate in any_of)
if not has_null_candidate or len(non_null_candidates) != 1:
return converted
non_null_schema = _swagger_2_compatible_schema(dict(non_null_candidates[0]))
if not isinstance(non_null_schema, dict):
return converted
converted.pop("anyOf", None)
converted.update(non_null_schema)
converted["x-nullable"] = True
return converted
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
"""Register a BaseModel and its nested schema definitions for Swagger documentation."""

View File

@ -33,7 +33,6 @@ for module_name in RESOURCE_MODULES:
# Ensure resource modules are imported so route decorators are evaluated.
# Import other controllers
from . import (
admin,
apikey,
extension,
feature,
@ -108,9 +107,6 @@ from .datasets.rag_pipeline import (
rag_pipeline_workflow,
)
# Import evaluation controllers
from .evaluation import evaluation
# Import explore controllers
from .explore import (
banner,
@ -120,12 +116,8 @@ from .explore import (
saved_message,
trial,
)
from .socketio import workflow as socketio_workflow
# Import snippet controllers
from .snippets import snippet_workflow, snippet_workflow_draft_variable
from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport]
# Import snippet controllers
# Import tag controllers
from .tag import tags
@ -139,8 +131,6 @@ from .workspace import (
model_providers,
models,
plugin,
rbac,
snippets,
tool_providers,
trigger_providers,
workspace,
@ -151,7 +141,6 @@ api.add_namespace(console_ns)
__all__ = [
"account",
"activate",
"admin",
"advanced_prompt_template",
"agent",
"agent_providers",
@ -178,7 +167,6 @@ __all__ = [
"datasource_content_preview",
"email_register",
"endpoint",
"evaluation",
"extension",
"external",
"feature",
@ -209,17 +197,10 @@ __all__ = [
"rag_pipeline_draft_variable",
"rag_pipeline_import",
"rag_pipeline_workflow",
"rbac",
"recommended_app",
"saved_message",
"setup",
"site",
"snippet_workflow",
"snippet_workflow",
"snippet_workflow_draft_variable",
"snippet_workflow_draft_variable",
"snippets",
"snippets",
"socketio_workflow",
"spec",
"statistic",

View File

@ -1,64 +1,11 @@
import csv
import io
from collections.abc import Callable
from functools import wraps
from typing import cast
from uuid import UUID
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from werkzeug.exceptions import Unauthorized
from configs import dify_config
from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import only_edition_cloud
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from services.billing_service import BillingService, LangContentDict
class InsertExploreAppPayload(BaseModel):
app_id: str = Field(...)
desc: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
language: str = Field(...)
category: str = Field(...)
position: int = Field(...)
can_trial: bool = Field(default=False)
trial_limit: int = Field(default=0)
@field_validator("language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
class InsertExploreBannerPayload(BaseModel):
category: str = Field(...)
title: str = Field(...)
description: str = Field(...)
img_src: str = Field(..., alias="img-src")
language: str = Field(default="en-US")
link: str = Field(...)
sort: int = Field(...)
@field_validator("language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
model_config = {"populate_by_name": True}
register_schema_models(console_ns, InsertExploreAppPayload, InsertExploreBannerPayload)
def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@ -76,353 +23,3 @@ def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return view(*args, **kwargs)
return decorated
@console_ns.route("/admin/insert-explore-apps")
class InsertExploreAppListApi(Resource):
@console_ns.doc("insert_explore_app")
@console_ns.doc(description="Insert or update an app in the explore list")
@console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
@console_ns.response(200, "App updated successfully")
@console_ns.response(201, "App inserted successfully")
@console_ns.response(404, "App not found")
@only_edition_cloud
@admin_required
def post(self):
payload = InsertExploreAppPayload.model_validate(console_ns.payload)
app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
if not app:
raise NotFound(f"App '{payload.app_id}' is not found")
site = app.site
if not site:
desc = payload.desc or ""
copy_right = payload.copyright or ""
privacy_policy = payload.privacy_policy or ""
custom_disclaimer = payload.custom_disclaimer or ""
else:
desc = site.description or payload.desc or ""
copy_right = site.copyright or payload.copyright or ""
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
with session_factory.create_session() as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
).scalar_one_or_none()
if not recommended_app:
recommended_app = RecommendedApp(
app_id=app.id,
description=desc,
copyright=copy_right,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
language=payload.language,
category=payload.category,
position=payload.position,
)
db.session.add(recommended_app)
if payload.can_trial:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == payload.app_id)
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=payload.app_id,
tenant_id=app.tenant_id,
trial_limit=payload.trial_limit,
)
)
else:
trial_app.trial_limit = payload.trial_limit
app.is_public = True
db.session.commit()
return {"result": "success"}, 201
else:
recommended_app.description = desc
recommended_app.copyright = copy_right
recommended_app.privacy_policy = privacy_policy
recommended_app.custom_disclaimer = custom_disclaimer
recommended_app.language = payload.language
recommended_app.category = payload.category
recommended_app.position = payload.position
if payload.can_trial:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == payload.app_id)
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=payload.app_id,
tenant_id=app.tenant_id,
trial_limit=payload.trial_limit,
)
)
else:
trial_app.trial_limit = payload.trial_limit
app.is_public = True
db.session.commit()
return {"result": "success"}, 200
@console_ns.route("/admin/insert-explore-apps/<uuid:app_id>")
class InsertExploreAppApi(Resource):
@console_ns.doc("delete_explore_app")
@console_ns.doc(description="Remove an app from the explore list")
@console_ns.doc(params={"app_id": "Application ID to remove"})
@console_ns.response(204, "App removed successfully")
@only_edition_cloud
@admin_required
def delete(self, app_id: UUID):
with session_factory.create_session() as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()
if not recommended_app:
return {"result": "success"}, 204
with session_factory.create_session() as session:
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
if app:
app.is_public = False
with session_factory.create_session() as session:
installed_apps = (
session.execute(
select(InstalledApp).where(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
)
.scalars()
.all()
)
for installed_app in installed_apps:
session.delete(installed_app)
trial_app = session.execute(
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
).scalar_one_or_none()
if trial_app:
session.delete(trial_app)
db.session.delete(recommended_app)
db.session.commit()
return {"result": "success"}, 204
@console_ns.route("/admin/insert-explore-banner")
class InsertExploreBannerApi(Resource):
@console_ns.doc("insert_explore_banner")
@console_ns.doc(description="Insert an explore banner")
@console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__])
@console_ns.response(201, "Banner inserted successfully")
@only_edition_cloud
@admin_required
def post(self):
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
banner = ExporleBanner(
content={
"category": payload.category,
"title": payload.title,
"description": payload.description,
"img-src": payload.img_src,
},
link=payload.link,
sort=payload.sort,
language=payload.language,
)
db.session.add(banner)
db.session.commit()
return {"result": "success"}, 201
@console_ns.route("/admin/delete-explore-banner/<uuid:banner_id>")
class DeleteExploreBannerApi(Resource):
@console_ns.doc("delete_explore_banner")
@console_ns.doc(description="Delete an explore banner")
@console_ns.doc(params={"banner_id": "Banner ID to delete"})
@console_ns.response(204, "Banner deleted successfully")
@only_edition_cloud
@admin_required
def delete(self, banner_id):
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
if not banner:
raise NotFound(f"Banner '{banner_id}' is not found")
db.session.delete(banner)
db.session.commit()
return {"result": "success"}, 204
class LangContentPayload(BaseModel):
lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'")
title: str = Field(...)
subtitle: str | None = Field(default=None)
body: str = Field(...)
title_pic_url: str | None = Field(default=None)
class UpsertNotificationPayload(BaseModel):
notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update")
contents: list[LangContentPayload] = Field(..., min_length=1)
start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z")
end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z")
frequency: str = Field(default="once", description="'once' | 'every_page_load'")
status: str = Field(default="active", description="'active' | 'inactive'")
class BatchAddNotificationAccountsPayload(BaseModel):
notification_id: str = Field(...)
user_email: list[str] = Field(..., description="List of account email addresses")
register_schema_models(console_ns, UpsertNotificationPayload, BatchAddNotificationAccountsPayload)
@console_ns.route("/admin/upsert_notification")
class UpsertNotificationApi(Resource):
@console_ns.doc("upsert_notification")
@console_ns.doc(
description=(
"Create or update an in-product notification. "
"Supply notification_id to update an existing one; omit it to create a new one. "
"Pass at least one language variant in contents (zh / en / jp)."
)
)
@console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__])
@console_ns.response(200, "Notification upserted successfully")
@only_edition_cloud
@admin_required
def post(self):
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
result = BillingService.upsert_notification(
contents=[cast(LangContentDict, c.model_dump()) for c in payload.contents],
frequency=payload.frequency,
status=payload.status,
notification_id=payload.notification_id,
start_time=payload.start_time,
end_time=payload.end_time,
)
return {"result": "success", "notification_id": result.get("notificationId")}, 200
@console_ns.route("/admin/batch_add_notification_accounts")
class BatchAddNotificationAccountsApi(Resource):
@console_ns.doc("batch_add_notification_accounts")
@console_ns.doc(
description=(
"Register target accounts for a notification by email address. "
'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. '
"File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) "
"plus a 'notification_id' field. "
"Emails that do not match any account are silently skipped."
)
)
@console_ns.response(200, "Accounts added successfully")
@only_edition_cloud
@admin_required
def post(self):
from models.account import Account
if "file" in request.files:
notification_id = request.form.get("notification_id", "").strip()
if not notification_id:
raise BadRequest("notification_id is required.")
emails = self._parse_emails_from_file()
else:
payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload)
notification_id = payload.notification_id
emails = payload.user_email
if not emails:
raise BadRequest("No valid email addresses provided.")
# Resolve emails → account IDs in chunks to avoid large IN-clause
account_ids: list[str] = []
chunk_size = 500
for i in range(0, len(emails), chunk_size):
chunk = emails[i : i + chunk_size]
rows = db.session.execute(select(Account.id, Account.email).where(Account.email.in_(chunk))).all()
account_ids.extend(str(row.id) for row in rows)
if not account_ids:
raise BadRequest("None of the provided emails matched an existing account.")
# Send to dify-saas in batches of 1000
total_count = 0
batch_size = 1000
for i in range(0, len(account_ids), batch_size):
batch = account_ids[i : i + batch_size]
result = BillingService.batch_add_notification_accounts(
notification_id=notification_id,
account_ids=batch,
)
total_count += result.get("count", 0)
return {
"result": "success",
"emails_provided": len(emails),
"accounts_matched": len(account_ids),
"count": total_count,
}, 200
@staticmethod
def _parse_emails_from_file() -> list[str]:
"""Parse email addresses from an uploaded CSV or TXT file."""
file = request.files["file"]
if not file.filename:
raise BadRequest("Uploaded file has no filename.")
filename_lower = file.filename.lower()
if not filename_lower.endswith((".csv", ".txt")):
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
try:
content = file.stream.read().decode("utf-8")
except UnicodeDecodeError:
try:
file.stream.seek(0)
content = file.stream.read().decode("gbk")
except UnicodeDecodeError:
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
emails: list[str] = []
if filename_lower.endswith(".csv"):
reader = csv.reader(io.StringIO(content))
for row in reader:
for cell in row:
cell = cell.strip()
if cell:
emails.append(cell)
else:
for line in content.splitlines():
line = line.strip()
if line:
emails.append(line)
# Deduplicate while preserving order
seen: set[str] = set()
unique_emails: list[str] = []
for email in emails:
if email.lower() not in seen:
seen.add(email.lower())
unique_emails.append(email)
return unique_emails

View File

@ -11,6 +11,7 @@ from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import to_timestamp
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.enums import ApiTokenType
@ -21,12 +22,6 @@ from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class ApiKeyItem(ResponseModel):
id: str
type: str
@ -37,7 +32,7 @@ class ApiKeyItem(ResponseModel):
@field_validator("last_used_at", "created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
return to_timestamp(value)
class ApiKeyList(ResponseModel):

View File

@ -3,7 +3,6 @@ import re
import uuid
from datetime import datetime
from typing import Any, Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -31,19 +30,16 @@ from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.trigger.constants import TRIGGER_NODE_TYPES
from configs import dify_config
from extensions.ext_database import db
from fields.base import ResponseModel
from graphon.enums import WorkflowExecutionStatus
from libs.helper import build_icon_url
from libs.helper import build_icon_url, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
from models.workflow import resolve_workflow_kind
from services.app_dsl_service import AppDslService
from services.app_service import AppListParams, AppService, CreateAppParams
from services.enterprise.enterprise_service import EnterpriseService
from services.enterprise import rbac_service as enterprise_rbac_service
from services.entities.dsl_entities import ImportMode, ImportStatus
from services.entities.knowledge_entities.knowledge_entities import (
DataSource,
@ -181,12 +177,6 @@ class AppTracePayload(BaseModel):
type JSONValue = Any
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class Tag(ResponseModel):
id: str
name: str
@ -203,7 +193,7 @@ class WorkflowPartial(ResponseModel):
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
return to_timestamp(value)
class ModelConfigPartial(ResponseModel):
@ -217,7 +207,7 @@ class ModelConfigPartial(ResponseModel):
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
return to_timestamp(value)
class ModelConfig(ResponseModel):
@ -278,7 +268,7 @@ class ModelConfig(ResponseModel):
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
return to_timestamp(value)
class Site(ResponseModel):
@ -321,7 +311,7 @@ class Site(ResponseModel):
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
return to_timestamp(value)
class DeletedTool(ResponseModel):
@ -355,9 +345,6 @@ class AppPartial(ResponseModel):
create_user_name: str | None = None
author_name: str | None = None
has_draft_trigger: bool | None = None
workflow_type: str | None = None
workflow_kind: str | None = None
permission_keys: list[str] = Field(default_factory=list)
@computed_field(return_type=str | None) # type: ignore
@property
@ -367,7 +354,7 @@ class AppPartial(ResponseModel):
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
return to_timestamp(value)
class AppDetail(ResponseModel):
@ -392,15 +379,12 @@ class AppDetail(ResponseModel):
updated_by: str | None = None
updated_at: int | None = None
access_mode: str | None = None
workflow_type: str | None = None
workflow_kind: str | None = None
tags: list[Tag] = Field(default_factory=list)
permission_keys: list[str] = Field(default_factory=list)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
return to_timestamp(value)
class AppDetailWithSite(AppDetail):
@ -428,22 +412,6 @@ class AppExportResponse(ResponseModel):
data: str
def _collect_app_access_permission_keys(access_matrix: enterprise_rbac_service.AppAccessMatrix) -> list[str]:
permission_keys: list[str] = []
seen_permission_keys: set[str] = set()
for item in access_matrix.items:
if not item.policy:
continue
for permission_key in item.policy.permission_keys:
if permission_key in seen_permission_keys:
continue
seen_permission_keys.add(permission_key)
permission_keys.append(permission_key)
return permission_keys
register_enum_models(console_ns, RetrievalMethod, WorkflowExecutionStatus, DatasetPermissionEnum)
register_schema_models(
@ -529,20 +497,6 @@ class AppListApi(Resource):
if str(app.id) in res:
app.access_mode = res[str(app.id)].access_mode
if app_pagination.items:
if dify_config.RBAC_ENABLED:
app_ids = [str(app.id) for app in app_pagination.items]
permission_keys_map = enterprise_rbac_service.RBACService.AppPermissions.batch_get(
str(current_tenant_id),
current_user.id,
app_ids,
)
for app in app_pagination.items:
app.permission_keys = permission_keys_map.get(str(app.id), [])
else:
for app in app_pagination.items:
app.permission_keys = []
workflow_capable_app_ids = [
str(app.id) for app in app_pagination.items if app.mode in {"workflow", "advanced-chat"}
]
@ -574,25 +528,6 @@ class AppListApi(Resource):
for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
workflow_ids = [str(app.workflow_id) for app in app_pagination.items if app.workflow_id]
workflow_info_map: dict[str, tuple[str, str]] = {}
if workflow_ids:
rows = db.session.execute(
select(Workflow.id, Workflow.type, Workflow.kind).where(Workflow.id.in_(workflow_ids))
).all()
workflow_info_map = {
str(row.id): (
row.type.value if hasattr(row.type, "value") else str(row.type),
resolve_workflow_kind(row.kind).value,
)
for row in rows
}
for app in app_pagination.items:
workflow_info = workflow_info_map.get(str(app.workflow_id)) if app.workflow_id else None
app.workflow_type = workflow_info[0] if workflow_info else None
app.workflow_kind = workflow_info[1] if workflow_info else None
pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True)
return pagination_model.model_dump(mode="json"), 200
@ -639,7 +574,6 @@ class AppApi(Resource):
@get_app_model(mode=None)
def get(self, app_model):
"""Get app detail"""
current_user, current_tenant_id = current_account_with_tenant()
app_service = AppService()
app_model = app_service.get_app(app_model)
@ -648,29 +582,6 @@ class AppApi(Resource):
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
app_model.access_mode = app_setting.access_mode
if app_model.workflow_id:
row = db.session.execute(
select(Workflow.type, Workflow.kind).where(Workflow.id == app_model.workflow_id)
).first()
app_model.workflow_type = (
(row.type.value if hasattr(row.type, "value") else str(row.type)) if row else None
)
app_model.workflow_kind = resolve_workflow_kind(row.kind).value if row else None
else:
app_model.workflow_type = None
app_model.workflow_kind = None
if dify_config.RBAC_ENABLED:
app_access_matrix = enterprise_rbac_service.RBACService.AppAccess.matrix(
str(current_tenant_id),
current_user.id,
str(app_model.id),
)
app_model.permission_keys = _collect_app_access_permission_keys(app_access_matrix)
else:
app_model.permission_keys = []
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
@ -938,10 +849,11 @@ class AppTraceApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id: UUID):
@get_app_model
def get(self, app_model):
"""Get app trace"""
with session_factory.create_session() as session:
app_trace_config = OpsTraceManager.get_app_tracing_config(str(app_id), session)
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
return app_trace_config
@ -955,12 +867,13 @@ class AppTraceApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, app_id: UUID):
@get_app_model
def post(self, app_model):
# add app trace
args = AppTracePayload.model_validate(console_ns.payload)
OpsTraceManager.update_app_tracing_config(
app_id=str(app_id),
app_id=app_model.id,
enabled=args.enabled,
tracing_provider=args.tracing_provider,
)

View File

@ -16,6 +16,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from extensions.ext_database import db
from fields._value_type_serializer import serialize_value_type
from fields.base import ResponseModel
from libs.helper import to_timestamp
from libs.login import login_required
from models import ConversationVariable
from models.model import AppMode
@ -25,12 +26,6 @@ class ConversationVariablesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID to filter variables")
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class ConversationVariableResponse(ResponseModel):
id: str
name: str
@ -65,7 +60,7 @@ class ConversationVariableResponse(ResponseModel):
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
return to_timestamp(value)
class PaginatedConversationVariableResponse(ResponseModel):

View File

@ -13,17 +13,12 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import to_timestamp
from libs.login import current_account_with_tenant, login_required
from models.enums import AppMCPServerStatus
from models.model import AppMCPServer
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
@ -41,14 +36,14 @@ class AppMCPServerResponse(ResponseModel):
name: str
server_code: str
description: str
status: str
status: AppMCPServerStatus
parameters: dict[str, Any] | list[Any] | str
created_at: int | None = None
updated_at: int | None = None
@field_validator("parameters", mode="before")
@classmethod
def _parse_json_string(cls, value: Any) -> Any:
def _normalize_parameters(cls, value: Any) -> Any:
if isinstance(value, str):
try:
return json.loads(value)
@ -59,7 +54,7 @@ class AppMCPServerResponse(ResponseModel):
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
return to_timestamp(value)
register_schema_models(console_ns, MCPServerCreatePayload, MCPServerUpdatePayload, AppMCPServerResponse)
@ -70,7 +65,9 @@ class AppMCPServerController(Resource):
@console_ns.doc("get_app_mcp_server")
@console_ns.doc(description="Get MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(
200, "MCP server configuration retrieved successfully", console_ns.models[AppMCPServerResponse.__name__]
)
@login_required
@account_initialization_required
@setup_required
@ -85,7 +82,9 @@ class AppMCPServerController(Resource):
@console_ns.doc(description="Create MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
@console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(
201, "MCP server configuration created successfully", console_ns.models[AppMCPServerResponse.__name__]
)
@console_ns.response(403, "Insufficient permissions")
@account_initialization_required
@get_app_model
@ -111,13 +110,15 @@ class AppMCPServerController(Resource):
)
db.session.add(server)
db.session.commit()
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json"), 201
@console_ns.doc("update_app_mcp_server")
@console_ns.doc(description="Update MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
@console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(
200, "MCP server configuration updated successfully", console_ns.models[AppMCPServerResponse.__name__]
)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@get_app_model
@ -154,7 +155,7 @@ class AppMCPServerRefreshController(Resource):
@console_ns.doc("refresh_app_mcp_server")
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
@console_ns.doc(params={"server_id": "Server ID"})
@console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(200, "MCP server refreshed successfully", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@setup_required

View File

@ -37,10 +37,9 @@ from fields.conversation_fields import (
JSONValue,
MessageFile,
format_files_contained,
to_timestamp,
)
from graphon.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from libs.helper import to_timestamp, uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from models.enums import FeedbackFromSource, FeedbackRating
@ -144,9 +143,7 @@ class MessageDetailResponse(ResponseModel):
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
return to_timestamp(value)
class MessageInfiniteScrollPaginationResponse(ResponseModel):

View File

@ -1,5 +1,4 @@
from typing import Any
from uuid import UUID
from flask import request
from flask_restx import Resource, fields
@ -9,8 +8,10 @@ from werkzeug.exceptions import BadRequest
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from models import App
from services.ops_service import OpsService
@ -43,11 +44,14 @@ class TraceAppConfigApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id: UUID):
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True))
@get_app_model
def get(self, app_model: App):
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
trace_config = OpsService.get_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider)
trace_config = OpsService.get_tracing_app_config(
app_id=app_model.id, tracing_provider=args.tracing_provider
)
if not trace_config:
return {"has_not_configured": True}
return trace_config
@ -65,13 +69,14 @@ class TraceAppConfigApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, app_id: UUID):
@get_app_model
def post(self, app_model: App):
"""Create a new trace app configuration"""
args = TraceConfigPayload.model_validate(console_ns.payload)
try:
result = OpsService.create_tracing_app_config(
app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
app_id=app_model.id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
)
if not result:
raise TracingConfigIsExist()
@ -90,13 +95,14 @@ class TraceAppConfigApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, app_id: UUID):
@get_app_model
def patch(self, app_model: App):
"""Update an existing trace app configuration"""
args = TraceConfigPayload.model_validate(console_ns.payload)
try:
result = OpsService.update_tracing_app_config(
app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
app_id=app_model.id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
)
if not result:
raise TracingConfigNotExist()
@ -113,12 +119,13 @@ class TraceAppConfigApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, app_id: UUID):
@get_app_model
def delete(self, app_model: App):
"""Delete an existing trace app configuration"""
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True))
try:
result = OpsService.delete_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider)
result = OpsService.delete_tracing_app_config(app_id=app_model.id, tracing_provider=args.tracing_provider)
if not result:
raise TracingConfigNotExist()
return {"result": "success"}, 204

View File

@ -1,21 +1,17 @@
import json
import logging
from collections.abc import Sequence
from typing import Any, Literal
from typing import Any
from flask import abort, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, ValidationError, field_validator
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
from controllers.common.schema import (
DEFAULT_REF_TEMPLATE_SWAGGER_2_0,
register_response_schema_model,
register_schema_models,
)
from controllers.common.schema import register_response_schema_model, register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model
@ -53,7 +49,7 @@ from libs.helper import TimestampField, uuid_value
from libs.login import current_account_with_tenant, login_required
from models import App
from models.model import AppMode
from models.workflow import Workflow, WorkflowKind
from models.workflow import Workflow
from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
@ -161,23 +157,6 @@ class ConvertToWorkflowPayload(BaseModel):
icon_background: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
keyword: str | None = Field(default=None, max_length=255)
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class WorkflowTypeConvertQuery(BaseModel):
target_type: Literal["workflow", "evaluation"]
class WorkflowFeaturesPayload(BaseModel):
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
@ -199,26 +178,24 @@ class DraftWorkflowTriggerRunAllPayload(BaseModel):
node_ids: list[str]
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(SyncDraftWorkflowPayload)
reg(AdvancedChatWorkflowRunPayload)
reg(IterationNodeRunPayload)
reg(LoopNodeRunPayload)
reg(DraftWorkflowRunPayload)
reg(DraftWorkflowNodeRunPayload)
reg(PublishWorkflowPayload)
reg(DefaultBlockConfigQuery)
reg(ConvertToWorkflowPayload)
reg(WorkflowListQuery)
reg(WorkflowUpdatePayload)
reg(WorkflowTypeConvertQuery)
reg(WorkflowFeaturesPayload)
reg(WorkflowOnlineUsersPayload)
reg(DraftWorkflowTriggerRunPayload)
reg(DraftWorkflowTriggerRunAllPayload)
register_schema_models(
console_ns,
SyncDraftWorkflowPayload,
AdvancedChatWorkflowRunPayload,
IterationNodeRunPayload,
LoopNodeRunPayload,
DraftWorkflowRunPayload,
DraftWorkflowNodeRunPayload,
PublishWorkflowPayload,
DefaultBlockConfigQuery,
ConvertToWorkflowPayload,
WorkflowListQuery,
WorkflowUpdatePayload,
WorkflowFeaturesPayload,
WorkflowOnlineUsersPayload,
DraftWorkflowTriggerRunPayload,
DraftWorkflowTriggerRunAllPayload,
)
register_response_schema_model(console_ns, WorkflowRunNodeExecutionResponse)
@ -898,54 +875,6 @@ class PublishedWorkflowApi(Resource):
}
@console_ns.route("/apps/<uuid:app_id>/workflows/publish/evaluation")
class EvaluationPublishedWorkflowApi(Resource):
@console_ns.doc("publish_evaluation_workflow")
@console_ns.doc(description="Publish draft workflow as evaluation workflow")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
@console_ns.response(200, "Evaluation workflow published successfully")
@console_ns.response(400, "Invalid workflow or unsupported node type")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
Publish draft workflow as evaluation workflow.
Evaluation workflows cannot include trigger or human-input nodes.
"""
current_user, _ = current_account_with_tenant()
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService()
with Session(db.engine, expire_on_commit=False) as session:
workflow = workflow_service.publish_evaluation_workflow(
session=session,
app_model=app_model,
account=current_user,
marked_name=args.marked_name or "",
marked_comment=args.marked_comment or "",
)
# Keep workflow_id aligned with the latest published workflow.
app_model_in_session = session.get(App, app_model.id)
if app_model_in_session:
app_model_in_session.workflow_id = workflow.id
app_model_in_session.updated_by = current_user.id
app_model_in_session.updated_at = naive_utc_now()
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
return {
"result": "success",
"created_at": workflow_created_at,
}
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
class DefaultBlockConfigsApi(Resource):
@console_ns.doc("get_default_block_configs")
@ -1143,52 +1072,6 @@ class DraftWorkflowRestoreApi(Resource):
}
@console_ns.route("/apps/<uuid:app_id>/workflows/convert-type")
class WorkflowTypeConvertApi(Resource):
@console_ns.doc("convert_published_workflow_type")
@console_ns.doc(description="Convert current effective published workflow type in-place")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowTypeConvertQuery.__name__])
@console_ns.response(200, "Workflow type converted successfully")
@console_ns.response(400, "Invalid workflow type or unsupported workflow graph")
@console_ns.response(404, "Workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
args = WorkflowTypeConvertQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
target_type = WorkflowKind.EVALUATION if args.target_type == "evaluation" else WorkflowKind.STANDARD
workflow_service = WorkflowService()
with Session(db.engine, expire_on_commit=False) as session:
try:
workflow = workflow_service.convert_published_workflow_type(
session=session,
app_model=app_model,
target_type=target_type,
account=current_user,
)
except WorkflowNotFoundError as exc:
raise NotFound(str(exc)) from exc
except IsDraftWorkflowError as exc:
raise BadRequest(str(exc)) from exc
except ValueError as exc:
raise BadRequest(str(exc)) from exc
session.commit()
return {
"result": "success",
"workflow_id": workflow.id,
"type": workflow.type.value,
"kind": workflow.kind_or_standard,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
class WorkflowByIdApi(Resource):
@console_ns.doc("update_workflow_by_id")

View File

@ -16,6 +16,7 @@ from fields.base import ResponseModel
from fields.end_user_fields import SimpleEndUser
from fields.member_fields import SimpleAccount
from graphon.enums import WorkflowExecutionStatus
from libs.helper import to_timestamp
from libs.login import login_required
from models import App
from models.model import AppMode
@ -82,9 +83,7 @@ class WorkflowRunForLogResponse(ResponseModel):
@field_validator("created_at", "finished_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
return to_timestamp(value)
class WorkflowRunForArchivedLogResponse(ResponseModel):
@ -104,28 +103,10 @@ class WorkflowRunForArchivedLogResponse(ResponseModel):
return str(getattr(value, "value", value))
class WorkflowAppLogEvaluationNodeInfoResponse(ResponseModel):
node_id: str
type: str
title: str
class WorkflowAppLogEvaluationItemResponse(ResponseModel):
name: str
value: Any = None
details: dict[str, Any] | None = None
node_info: WorkflowAppLogEvaluationNodeInfoResponse | None = Field(
default=None,
validation_alias="node_info",
serialization_alias="nodeInfo",
)
class WorkflowAppLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForLogResponse | None = None
details: Any = None
evaluation: list[WorkflowAppLogEvaluationItemResponse] = Field(default_factory=list)
created_from: str | None = None
created_by_role: str | None = None
created_by_account: SimpleAccount | None = None
@ -135,14 +116,7 @@ class WorkflowAppLogPartialResponse(ResponseModel):
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
@field_validator("evaluation", mode="before")
@classmethod
def _normalize_evaluation(cls, value: Any) -> list[dict[str, Any]] | list[WorkflowAppLogEvaluationItemResponse]:
return value or []
return to_timestamp(value)
class WorkflowArchivedLogPartialResponse(ResponseModel):
@ -156,9 +130,7 @@ class WorkflowArchivedLogPartialResponse(ResponseModel):
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
return to_timestamp(value)
class WorkflowAppLogPaginationResponse(ResponseModel):
@ -182,8 +154,6 @@ register_schema_models(
WorkflowAppLogQuery,
WorkflowRunForLogResponse,
WorkflowRunForArchivedLogResponse,
WorkflowAppLogEvaluationNodeInfoResponse,
WorkflowAppLogEvaluationItemResponse,
WorkflowAppLogPartialResponse,
WorkflowArchivedLogPartialResponse,
WorkflowAppLogPaginationResponse,

View File

@ -1,22 +1,16 @@
import logging
from datetime import datetime
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, TypeAdapter
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, computed_field, field_validator
from controllers.common.schema import register_schema_models
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.base import ResponseModel
from fields.member_fields import AccountWithRole
from fields.workflow_comment_fields import (
workflow_comment_basic_fields,
workflow_comment_create_fields,
workflow_comment_detail_fields,
workflow_comment_reply_create_fields,
workflow_comment_reply_update_fields,
workflow_comment_resolve_fields,
workflow_comment_update_fields,
)
from libs.helper import build_avatar_url, dump_response, to_timestamp
from libs.login import current_user, login_required
from models import App
from services.account_service import TenantService
@ -51,6 +45,138 @@ class WorkflowCommentMentionUsersPayload(BaseModel):
users: list[AccountWithRole]
class WorkflowCommentAccount(ResponseModel):
id: str
name: str
email: str
avatar: str | None = Field(default=None, exclude=True)
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
@property
def avatar_url(self) -> str | None:
return build_avatar_url(self.avatar)
class WorkflowCommentReply(ResponseModel):
id: str
content: str
created_by: str
created_by_account: WorkflowCommentAccount | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class WorkflowCommentMention(ResponseModel):
mentioned_user_id: str
mentioned_user_account: WorkflowCommentAccount | None = None
reply_id: str | None = None
class WorkflowCommentBasic(ResponseModel):
id: str
position_x: float
position_y: float
content: str
created_by: str
created_by_account: WorkflowCommentAccount | None = None
created_at: int | None = None
updated_at: int | None = None
resolved: bool
resolved_at: int | None = None
resolved_by: str | None = None
resolved_by_account: WorkflowCommentAccount | None = None
reply_count: int
mention_count: int
participants: list[WorkflowCommentAccount]
@field_validator("created_at", "updated_at", "resolved_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class WorkflowCommentBasicList(ResponseModel):
data: list[WorkflowCommentBasic]
class WorkflowCommentDetail(ResponseModel):
id: str
position_x: float
position_y: float
content: str
created_by: str
created_by_account: WorkflowCommentAccount | None = None
created_at: int | None = None
updated_at: int | None = None
resolved: bool
resolved_at: int | None = None
resolved_by: str | None = None
resolved_by_account: WorkflowCommentAccount | None = None
replies: list[WorkflowCommentReply]
mentions: list[WorkflowCommentMention]
@field_validator("created_at", "updated_at", "resolved_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class WorkflowCommentCreate(ResponseModel):
id: str
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class WorkflowCommentUpdate(ResponseModel):
id: str
updated_at: int | None = None
@field_validator("updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class WorkflowCommentResolve(ResponseModel):
id: str
resolved: bool
resolved_at: int | None = None
resolved_by: str | None = None
@field_validator("resolved_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class WorkflowCommentReplyCreate(ResponseModel):
id: str
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class WorkflowCommentReplyUpdate(ResponseModel):
id: str
updated_at: int | None = None
@field_validator("updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
register_schema_models(
console_ns,
AccountWithRole,
@ -59,17 +185,19 @@ register_schema_models(
WorkflowCommentUpdatePayload,
WorkflowCommentReplyPayload,
)
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
workflow_comment_reply_create_model = console_ns.model(
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
)
workflow_comment_reply_update_model = console_ns.model(
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
register_response_schema_models(
console_ns,
WorkflowCommentAccount,
WorkflowCommentReply,
WorkflowCommentMention,
WorkflowCommentBasic,
WorkflowCommentBasicList,
WorkflowCommentDetail,
WorkflowCommentCreate,
WorkflowCommentUpdate,
WorkflowCommentResolve,
WorkflowCommentReplyCreate,
WorkflowCommentReplyUpdate,
)
@ -80,28 +208,26 @@ class WorkflowCommentListApi(Resource):
@console_ns.doc("list_workflow_comments")
@console_ns.doc(description="Get all comments for a workflow")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
@console_ns.response(200, "Comments retrieved successfully", console_ns.models[WorkflowCommentBasicList.__name__])
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_basic_model, envelope="data")
def get(self, app_model: App):
"""Get all comments for a workflow."""
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
return comments
return WorkflowCommentBasicList.model_validate({"data": comments}).model_dump(mode="json")
@console_ns.doc("create_workflow_comment")
@console_ns.doc(description="Create a new workflow comment")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
@console_ns.response(201, "Comment created successfully", console_ns.models[WorkflowCommentCreate.__name__])
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_create_model)
@edit_permission_required
def post(self, app_model: App):
"""Create a new workflow comment."""
@ -117,7 +243,7 @@ class WorkflowCommentListApi(Resource):
mentioned_user_ids=payload.mentioned_user_ids,
)
return result, 201
return dump_response(WorkflowCommentCreate, result), 201
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
@ -127,30 +253,28 @@ class WorkflowCommentDetailApi(Resource):
@console_ns.doc("get_workflow_comment")
@console_ns.doc(description="Get a specific workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
@console_ns.response(200, "Comment retrieved successfully", console_ns.models[WorkflowCommentDetail.__name__])
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_detail_model)
def get(self, app_model: App, comment_id: str):
"""Get a specific workflow comment."""
comment = WorkflowCommentService.get_comment(
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
)
return comment
return dump_response(WorkflowCommentDetail, comment)
@console_ns.doc("update_workflow_comment")
@console_ns.doc(description="Update a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
@console_ns.response(200, "Comment updated successfully", console_ns.models[WorkflowCommentUpdate.__name__])
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_update_model)
@edit_permission_required
def put(self, app_model: App, comment_id: str):
"""Update a workflow comment."""
@ -167,7 +291,7 @@ class WorkflowCommentDetailApi(Resource):
mentioned_user_ids=payload.mentioned_user_ids,
)
return result
return dump_response(WorkflowCommentUpdate, result)
@console_ns.doc("delete_workflow_comment")
@console_ns.doc(description="Delete a workflow comment")
@ -197,12 +321,11 @@ class WorkflowCommentResolveApi(Resource):
@console_ns.doc("resolve_workflow_comment")
@console_ns.doc(description="Resolve a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
@console_ns.response(200, "Comment resolved successfully", console_ns.models[WorkflowCommentResolve.__name__])
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_resolve_model)
@edit_permission_required
def post(self, app_model: App, comment_id: str):
"""Resolve a workflow comment."""
@ -213,7 +336,7 @@ class WorkflowCommentResolveApi(Resource):
user_id=current_user.id,
)
return comment
return dump_response(WorkflowCommentResolve, comment)
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
@ -224,12 +347,11 @@ class WorkflowCommentReplyApi(Resource):
@console_ns.doc(description="Add a reply to a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__])
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
@console_ns.response(201, "Reply created successfully", console_ns.models[WorkflowCommentReplyCreate.__name__])
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_reply_create_model)
@edit_permission_required
def post(self, app_model: App, comment_id: str):
"""Add a reply to a workflow comment."""
@ -247,7 +369,7 @@ class WorkflowCommentReplyApi(Resource):
mentioned_user_ids=payload.mentioned_user_ids,
)
return result, 201
return dump_response(WorkflowCommentReplyCreate, result), 201
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
@ -258,12 +380,11 @@ class WorkflowCommentReplyDetailApi(Resource):
@console_ns.doc(description="Update a comment reply")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
@console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__])
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
@console_ns.response(200, "Reply updated successfully", console_ns.models[WorkflowCommentReplyUpdate.__name__])
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_reply_update_model)
@edit_permission_required
def put(self, app_model: App, comment_id: str, reply_id: str):
"""Update a comment reply."""
@ -284,7 +405,7 @@ class WorkflowCommentReplyDetailApi(Resource):
mentioned_user_ids=payload.mentioned_user_ids,
)
return reply
return dump_response(WorkflowCommentReplyUpdate, reply)
@console_ns.doc("delete_workflow_comment_reply")
@console_ns.doc(description="Delete a comment reply")

View File

@ -1,6 +1,4 @@
import base64
import json
from datetime import UTC, datetime, timedelta
from typing import Literal
from flask import request
@ -12,7 +10,6 @@ from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@ -80,39 +77,3 @@ class PartnerTenants(Resource):
raise BadRequest("Invalid partner information")
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
_DEBUG_KEY = "billing:debug"
_DEBUG_TTL = timedelta(days=7)
class DebugDataPayload(BaseModel):
type: str = Field(..., min_length=1, description="Data type key")
data: str = Field(..., min_length=1, description="Data value to append")
@console_ns.route("/billing/debug/data")
class DebugData(Resource):
def post(self):
body = DebugDataPayload.model_validate(request.get_json(force=True))
item = json.dumps({
"type": body.type,
"data": body.data,
"createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
})
redis_client.lpush(_DEBUG_KEY, item)
redis_client.expire(_DEBUG_KEY, _DEBUG_TTL)
return {"result": "ok"}, 201
def get(self):
recent = request.args.get("recent", 10, type=int)
items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1)
return {
"data": [
json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items
]
}
def delete(self):
redis_client.delete(_DEBUG_KEY)
return {"result": "ok"}

View File

@ -1,13 +1,10 @@
import json
from typing import Any, cast
from urllib.parse import quote
from flask import Response, request
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
@ -24,7 +21,6 @@ from controllers.console.wraps import (
setup_required,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
from core.indexing_runner import IndexingRunner
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.datasource.vdb.vector_type import VectorType
@ -33,7 +29,6 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from extensions.ext_storage import storage
from fields.app_fields import app_detail_kernel_fields, related_app_list
from fields.dataset_fields import (
content_fields,
@ -56,22 +51,13 @@ from fields.document_fields import document_status_fields
from graphon.model_runtime.entities.model_entities import ModelType
from libs.login import current_account_with_tenant, login_required
from libs.url_utils import normalize_api_base_url
from models import ApiToken, Dataset, Document, DocumentSegment, EvaluationRun, EvaluationTargetType, UploadFile
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import ApiTokenType, SegmentStatus
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.errors.evaluation import (
EvaluationDatasetInvalidError,
EvaluationFrameworkNotConfiguredError,
EvaluationMaxConcurrentRunsError,
EvaluationNotFoundError,
)
from services.evaluation_service import EvaluationService
from services.enterprise import rbac_service as enterprise_rbac_service
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
@ -141,14 +127,6 @@ def _validate_doc_form(value: str | None) -> str | None:
return value
def _ensure_permission_keys(dataset: Dataset, *, enabled: bool) -> None:
if not enabled:
setattr(dataset, "permission_keys", [])
return
if not isinstance(getattr(dataset, "permission_keys", None), list):
setattr(dataset, "permission_keys", [])
class DatasetCreatePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field("", max_length=400)
@ -351,19 +329,6 @@ class DatasetListApi(Resource):
query.include_all,
)
for dataset in datasets:
_ensure_permission_keys(dataset, enabled=dify_config.RBAC_ENABLED)
if dify_config.RBAC_ENABLED and datasets:
dataset_ids = [str(dataset.id) for dataset in datasets]
permission_keys_map = enterprise_rbac_service.RBACService.DatasetPermissions.batch_get(
str(current_tenant_id),
current_user.id,
dataset_ids,
)
for dataset in datasets:
setattr(dataset, "permission_keys", permission_keys_map.get(str(dataset.id), []))
# check embedding setting
provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id)
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
@ -445,7 +410,6 @@ class DatasetListApi(Resource):
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
_ensure_permission_keys(dataset, enabled=dify_config.RBAC_ENABLED)
return marshal(dataset, dataset_detail_fields), 201
@ -470,7 +434,6 @@ class DatasetApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
_ensure_permission_keys(dataset, enabled=dify_config.RBAC_ENABLED)
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if dataset.embedding_model_provider:
@ -540,7 +503,6 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
_ensure_permission_keys(dataset, enabled=dify_config.RBAC_ENABLED)
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_tenant_id
@ -1023,432 +985,3 @@ class DatasetAutoDisableLogApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
# ---- Knowledge Base Retrieval Evaluation ----
def _serialize_dataset_evaluation_run(run: EvaluationRun) -> dict[str, Any]:
return {
"id": run.id,
"tenant_id": run.tenant_id,
"target_type": run.target_type,
"target_id": run.target_id,
"evaluation_config_id": run.evaluation_config_id,
"status": run.status,
"dataset_file_id": run.dataset_file_id,
"result_file_id": run.result_file_id,
"total_items": run.total_items,
"completed_items": run.completed_items,
"failed_items": run.failed_items,
"progress": run.progress,
"metrics_summary": json.loads(run.metrics_summary) if run.metrics_summary else {},
"error": run.error,
"created_by": run.created_by,
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
}
def _serialize_dataset_evaluation_run_item(item: Any) -> dict[str, Any]:
return {
"id": item.id,
"item_index": item.item_index,
"inputs": item.inputs_dict,
"expected_output": item.expected_output,
"actual_output": item.actual_output,
"metrics": item.metrics_list,
"judgment": item.judgment_dict,
"metadata": item.metadata_dict,
"error": item.error,
"overall_score": item.overall_score,
}
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/template/download")
class DatasetEvaluationTemplateDownloadApi(Resource):
@console_ns.doc("download_dataset_evaluation_template")
@console_ns.response(200, "Template file streamed as XLSX attachment")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
"""Download evaluation dataset template for knowledge base retrieval."""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
xlsx_content, filename = EvaluationService.generate_retrieval_dataset_template()
encoded_filename = quote(filename)
response = Response(
xlsx_content,
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Length"] = str(len(xlsx_content))
return response
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation")
class DatasetEvaluationDetailApi(Resource):
@console_ns.doc("get_dataset_evaluation_config")
@console_ns.response(200, "Evaluation configuration retrieved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
"""Get evaluation configuration for the knowledge base."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.get_evaluation_config(
session, current_tenant_id, "dataset", dataset_id_str
)
if config is None:
return {
"evaluation_model": None,
"evaluation_model_provider": None,
"default_metrics": None,
"customized_metrics": None,
"judgment_config": None,
}
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"default_metrics": config.default_metrics_list,
"customized_metrics": config.customized_metrics_dict,
"judgment_config": config.judgment_config_dict,
}
@console_ns.doc("save_dataset_evaluation_config")
@console_ns.response(200, "Evaluation configuration saved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def put(self, dataset_id):
"""Save evaluation configuration for the knowledge base."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
body = request.get_json(force=True)
try:
config_data = EvaluationConfigData.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.save_evaluation_config(
session=session,
tenant_id=current_tenant_id,
target_type="dataset",
target_id=dataset_id_str,
account_id=str(current_user.id),
data=config_data,
)
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"default_metrics": config.default_metrics_list,
"customized_metrics": config.customized_metrics_dict,
"judgment_config": config.judgment_config_dict,
}
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/run")
class DatasetEvaluationRunApi(Resource):
@console_ns.doc("start_dataset_evaluation_run")
@console_ns.response(200, "Evaluation run started")
@console_ns.response(400, "Invalid request")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
"""Start an evaluation run for the knowledge base retrieval."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
body = request.get_json(force=True)
if not body:
raise BadRequest("Request body is required.")
try:
run_request = EvaluationRunRequest.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
upload_file = (
db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=current_tenant_id).first()
)
if not upload_file:
raise NotFound("Dataset file not found.")
try:
dataset_content = storage.load_once(upload_file.key)
except Exception:
raise BadRequest("Failed to read dataset file.")
if not dataset_content:
raise BadRequest("Dataset file is empty.")
try:
with Session(db.engine, expire_on_commit=False) as session:
evaluation_run = EvaluationService.start_evaluation_run(
session=session,
tenant_id=current_tenant_id,
target_type=EvaluationTargetType.KNOWLEDGE_BASE,
target_id=dataset_id_str,
account_id=str(current_user.id),
dataset_file_content=dataset_content,
run_request=run_request,
)
return _serialize_dataset_evaluation_run(evaluation_run), 200
except EvaluationFrameworkNotConfiguredError as e:
return {"message": str(e.description)}, 400
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except EvaluationMaxConcurrentRunsError as e:
return {"message": str(e.description)}, 429
except EvaluationDatasetInvalidError as e:
return {"message": str(e.description)}, 400
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/logs")
class DatasetEvaluationLogsApi(Resource):
@console_ns.doc("get_dataset_evaluation_logs")
@console_ns.response(200, "Evaluation logs retrieved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
"""Get evaluation run history for the knowledge base."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
with Session(db.engine, expire_on_commit=False) as session:
runs, total = EvaluationService.get_evaluation_runs(
session=session,
tenant_id=current_tenant_id,
target_type="dataset",
target_id=dataset_id_str,
page=page,
page_size=page_size,
)
return {
"data": [_serialize_dataset_evaluation_run(run) for run in runs],
"total": total,
"page": page,
"page_size": page_size,
}
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>")
class DatasetEvaluationRunDetailApi(Resource):
@console_ns.doc("get_dataset_evaluation_run_detail")
@console_ns.response(200, "Evaluation run detail retrieved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset or run not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, run_id):
"""Get evaluation run detail including per-item results."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
run_id_str = str(run_id)
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 50, type=int)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.get_evaluation_run_detail(
session=session,
tenant_id=current_tenant_id,
run_id=run_id_str,
)
items, total_items = EvaluationService.get_evaluation_run_items(
session=session,
run_id=run_id_str,
page=page,
page_size=page_size,
)
return {
"run": _serialize_dataset_evaluation_run(run),
"items": {
"data": [_serialize_dataset_evaluation_run_item(item) for item in items],
"total": total_items,
"page": page,
"page_size": page_size,
},
}
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>/cancel")
class DatasetEvaluationRunCancelApi(Resource):
@console_ns.doc("cancel_dataset_evaluation_run")
@console_ns.response(200, "Evaluation run cancelled")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset or run not found")
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id, run_id):
"""Cancel a running knowledge base evaluation."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
run_id_str = str(run_id)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.cancel_evaluation_run(
session=session,
tenant_id=current_tenant_id,
run_id=run_id_str,
)
return _serialize_dataset_evaluation_run(run)
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except ValueError as e:
return {"message": str(e)}, 400
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/metrics")
class DatasetEvaluationMetricsApi(Resource):
@console_ns.doc("get_dataset_evaluation_metrics")
@console_ns.response(200, "Available retrieval metrics retrieved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
"""Get available evaluation metrics for knowledge base retrieval."""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
return {
"metrics": EvaluationService.get_supported_metrics(EvaluationCategory.RETRIEVAL)
}
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/files/<uuid:file_id>")
class DatasetEvaluationFileDownloadApi(Resource):
@console_ns.doc("download_dataset_evaluation_file")
@console_ns.response(200, "File download URL generated")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset or file not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, file_id):
"""Download evaluation test file or result file for the knowledge base."""
from core.workflow.file import helpers as file_helpers
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
file_id_str = str(file_id)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(UploadFile).where(
UploadFile.id == file_id_str,
UploadFile.tenant_id == current_tenant_id,
)
upload_file = session.execute(stmt).scalar_one_or_none()
if not upload_file:
raise NotFound("File not found.")
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
"download_url": download_url,
}

View File

@ -3,18 +3,19 @@ import logging
from argparse import ArgumentTypeError
from collections.abc import Sequence
from contextlib import ExitStack
from datetime import datetime
from typing import Any, Literal, cast
import sqlalchemy as sa
from flask import request, send_file
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource, marshal
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import asc, desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from core.errors.error import (
LLMBadRequestError,
@ -29,17 +30,16 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
from fields.dataset_fields import dataset_fields
from fields.base import ResponseModel
from fields.document_fields import (
dataset_and_document_fields,
document_fields,
document_metadata_fields,
document_status_fields,
document_with_segments_fields,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
from libs.datetime_utils import naive_utc_now
from libs.helper import to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
@ -72,27 +72,94 @@ from ..wraps import (
logger = logging.getLogger(__name__)
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_model = get_or_create_model("Dataset", dataset_fields)
def _normalize_enum(value: Any) -> Any:
if isinstance(value, str) or value is None:
return value
return getattr(value, "value", value)
document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields)
document_fields_copy = document_fields.copy()
document_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
document_model = get_or_create_model("Document", document_fields_copy)
class DatasetResponse(ResponseModel):
id: str
name: str
description: str | None = None
permission: str | None = None
data_source_type: str | None = None
indexing_technique: str | None = None
created_by: str | None = None
created_at: int | None = None
document_with_segments_fields_copy = document_with_segments_fields.copy()
document_with_segments_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
@field_validator("data_source_type", "indexing_technique", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> Any:
return _normalize_enum(value)
dataset_and_document_fields_copy = dataset_and_document_fields.copy()
dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class DocumentMetadataResponse(ResponseModel):
id: str
name: str
type: str
value: str | None = None
class DocumentResponse(ResponseModel):
id: str
position: int | None = None
data_source_type: str | None = None
data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict")
data_source_detail_dict: Any = None
dataset_process_rule_id: str | None = None
name: str
created_from: str | None = None
created_by: str | None = None
created_at: int | None = None
tokens: int | None = None
indexing_status: str | None = None
error: str | None = None
enabled: bool | None = None
disabled_at: int | None = None
disabled_by: str | None = None
archived: bool | None = None
display_status: str | None = None
word_count: int | None = None
hit_count: int | None = None
doc_form: str | None = None
doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details")
summary_index_status: str | None = None
need_summary: bool | None = None
@field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> Any:
return _normalize_enum(value)
@field_validator("doc_metadata", mode="before")
@classmethod
def _normalize_doc_metadata(cls, value: Any) -> list[Any]:
if value is None:
return []
return value
@field_validator("created_at", "disabled_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class DocumentWithSegmentsResponse(DocumentResponse):
process_rule_dict: Any = None
completed_segments: int | None = None
total_segments: int | None = None
class DatasetAndDocumentResponse(ResponseModel):
dataset: DatasetResponse
documents: list[DocumentResponse]
batch: str
class DocumentRetryPayload(BaseModel):
@ -107,6 +174,11 @@ class GenerateSummaryPayload(BaseModel):
document_list: list[str]
class DocumentMetadataUpdatePayload(BaseModel):
doc_type: str | None = None
doc_metadata: Any = None
class DocumentDatasetListParam(BaseModel):
page: int = Field(1, title="Page", description="Page number.")
limit: int = Field(20, title="Limit", description="Page size.")
@ -124,7 +196,13 @@ register_schema_models(
DocumentRetryPayload,
DocumentRenamePayload,
GenerateSummaryPayload,
DocumentMetadataUpdatePayload,
DocumentBatchDownloadZipPayload,
DatasetResponse,
DocumentMetadataResponse,
DocumentResponse,
DocumentWithSegmentsResponse,
DatasetAndDocumentResponse,
)
@ -360,10 +438,10 @@ class DatasetDocumentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
@ -401,7 +479,9 @@ class DatasetDocumentListApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
return {"dataset": dataset, "documents": documents, "batch": batch}
return DatasetAndDocumentResponse.model_validate(
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
).model_dump(mode="json")
@setup_required
@login_required
@ -429,12 +509,13 @@ class DatasetInitApi(Resource):
@console_ns.doc("init_dataset")
@console_ns.doc(description="Initialize dataset with documents")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
@console_ns.response(
201, "Dataset initialized successfully", console_ns.models[DatasetAndDocumentResponse.__name__]
)
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
@ -482,9 +563,9 @@ class DatasetInitApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
response = {"dataset": dataset, "documents": documents, "batch": batch}
return response
return DatasetAndDocumentResponse.model_validate(
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
).model_dump(mode="json")
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
@ -991,15 +1072,7 @@ class DocumentMetadataApi(DocumentResource):
@console_ns.doc("update_document_metadata")
@console_ns.doc(description="Update document metadata")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.expect(
console_ns.model(
"UpdateDocumentMetadataRequest",
{
"doc_type": fields.String(description="Document type"),
"doc_metadata": fields.Raw(description="Document metadata"),
},
)
)
@console_ns.expect(console_ns.models[DocumentMetadataUpdatePayload.__name__])
@console_ns.response(200, "Document metadata updated successfully")
@console_ns.response(404, "Document not found")
@console_ns.response(403, "Permission denied")
@ -1012,10 +1085,10 @@ class DocumentMetadataApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
req_data = request.get_json()
req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
doc_type = req_data.get("doc_type")
doc_metadata = req_data.get("doc_metadata")
doc_type = req_data.doc_type
doc_metadata = req_data.doc_metadata
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
@ -1197,7 +1270,7 @@ class DocumentRenameApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(document_model)
@console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
@ -1215,7 +1288,7 @@ class DocumentRenameApi(DocumentResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return document
return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json")
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")

View File

@ -8,6 +8,7 @@ from pydantic import Field, field_validator
from controllers.common.schema import register_schema_models
from fields.base import ResponseModel
from libs.helper import to_timestamp
from libs.login import login_required
from .. import console_ns
@ -19,12 +20,6 @@ from ..wraps import (
)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class HitTestingDocument(ResponseModel):
id: str | None = None
data_source_type: str | None = None
@ -61,7 +56,7 @@ class HitTestingSegment(ResponseModel):
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
return to_timestamp(value)
class HitTestingChildChunk(ResponseModel):

View File

@ -39,11 +39,8 @@ class HitTestingPayload(BaseModel):
class DatasetsHitTestingBase:
@staticmethod
def _normalize_hit_testing_query(query: Any) -> str:
"""Return the user-visible query string from legacy and current response shapes."""
if isinstance(query, str):
return query
def _extract_hit_testing_query(query: Any) -> str:
"""Return the query string from the service response shape."""
if isinstance(query, dict):
content = query.get("content")
if isinstance(content, str):
@ -52,15 +49,15 @@ class DatasetsHitTestingBase:
raise ValueError("Invalid hit testing query response")
@staticmethod
def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]:
"""Coerce nullable collection fields into lists before response validation."""
def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]:
"""Ensure collection fields match the API schema before response validation."""
if not isinstance(records, list):
return []
raise ValueError("Invalid hit testing records response")
normalized_records: list[dict[str, Any]] = []
for record in records:
if not isinstance(record, dict):
continue
raise ValueError("Invalid hit testing record response")
normalized_record = dict(record)
segment = normalized_record.get("segment")
@ -118,8 +115,8 @@ class DatasetsHitTestingBase:
limit=10,
)
return {
"query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")),
"records": DatasetsHitTestingBase._normalize_hit_testing_records(
"query": DatasetsHitTestingBase._extract_hit_testing_query(response.get("query")),
"records": DatasetsHitTestingBase._prepare_hit_testing_records(
marshal(response.get("records", []), hit_testing_record_fields)
),
}

View File

@ -1 +0,0 @@
# Evaluation controller module

View File

@ -1,993 +0,0 @@
from __future__ import annotations
import logging
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Union
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource, fields, marshal
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.workflow import WorkflowListQuery
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
from extensions.ext_database import db
from extensions.ext_storage import storage
from fields.member_fields import simple_account_fields
from graphon.file import helpers as file_helpers
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models import App, Dataset
from models.evaluation import EvaluationTargetType
from models.model import UploadFile
from models.snippet import CustomizedSnippet
from services.errors.evaluation import (
EvaluationDatasetInvalidError,
EvaluationFrameworkNotConfiguredError,
EvaluationMaxConcurrentRunsError,
EvaluationNotFoundError,
)
from services.evaluation_service import EvaluationService
from services.workflow_service import WorkflowService
if TYPE_CHECKING:
from models.evaluation import EvaluationRun, EvaluationRunItem
logger = logging.getLogger(__name__)
EVALUATE_TARGET_TYPES = {
EvaluationTargetType.APPS.value,
EvaluationTargetType.SNIPPETS.value,
}
class VersionQuery(BaseModel):
"""Query parameters for version endpoint."""
version: str
register_schema_models(
console_ns,
VersionQuery,
)
# Response field definitions
file_info_fields = {
"id": fields.String,
"name": fields.String,
}
evaluation_log_fields = {
"created_at": TimestampField,
"created_by": fields.String,
"test_file": fields.Nested(
console_ns.model(
"EvaluationTestFile",
file_info_fields,
)
),
"result_file": fields.Nested(
console_ns.model(
"EvaluationResultFile",
file_info_fields,
),
allow_null=True,
),
"version": fields.String,
}
evaluation_log_list_model = console_ns.model(
"EvaluationLogList",
{
"data": fields.List(fields.Nested(console_ns.model("EvaluationLog", evaluation_log_fields))),
},
)
evaluation_default_metric_node_info_fields = {
"node_id": fields.String,
"type": fields.String,
"title": fields.String,
}
evaluation_default_metric_item_fields = {
"metric": fields.String,
"value_type": fields.String,
"node_info_list": fields.List(
fields.Nested(
console_ns.model("EvaluationDefaultMetricNodeInfo", evaluation_default_metric_node_info_fields),
),
),
}
customized_metrics_fields = {
"evaluation_workflow_id": fields.String,
"input_fields": fields.Raw,
"output_fields": fields.Raw,
}
judgment_condition_fields = {
"variable_selector": fields.List(fields.String),
"comparison_operator": fields.String,
"value": fields.String,
}
judgment_config_fields = {
"logical_operator": fields.String,
"conditions": fields.List(fields.Nested(console_ns.model("JudgmentCondition", judgment_condition_fields))),
}
evaluation_detail_fields = {
"evaluation_model": fields.String,
"evaluation_model_provider": fields.String,
"default_metrics": fields.List(
fields.Nested(console_ns.model("EvaluationDefaultMetricItem_Detail", evaluation_default_metric_item_fields)),
allow_null=True,
),
"customized_metrics": fields.Nested(
console_ns.model("EvaluationCustomizedMetrics", customized_metrics_fields),
allow_null=True,
),
"judgment_config": fields.Nested(
console_ns.model("EvaluationJudgmentConfig", judgment_config_fields),
allow_null=True,
),
}
evaluation_detail_model = console_ns.model("EvaluationDetail", evaluation_detail_fields)
available_evaluation_workflow_list_fields = {
"id": fields.String,
"app_id": fields.String,
"app_name": fields.String,
"type": fields.String,
"kind": fields.String,
"version": fields.String,
"marked_name": fields.String,
"marked_comment": fields.String,
"hash": fields.String,
"created_by": fields.Nested(simple_account_fields),
"created_at": TimestampField,
"updated_by": fields.Nested(simple_account_fields, allow_null=True),
"updated_at": TimestampField,
}
available_evaluation_workflow_pagination_fields = {
"items": fields.List(fields.Nested(available_evaluation_workflow_list_fields)),
"page": fields.Integer,
"limit": fields.Integer,
"has_more": fields.Boolean,
}
available_evaluation_workflow_pagination_model = console_ns.model(
"AvailableEvaluationWorkflowPagination",
available_evaluation_workflow_pagination_fields,
)
evaluation_default_metrics_response_model = console_ns.model(
"EvaluationDefaultMetricsResponse",
{
"default_metrics": fields.List(
fields.Nested(console_ns.model("EvaluationDefaultMetricItem", evaluation_default_metric_item_fields)),
),
},
)
evaluation_dataset_columns_response_model = console_ns.model(
"EvaluationDatasetColumnsResponse",
{
"columns": fields.List(
fields.Nested(
console_ns.model(
"EvaluationTemplateColumn",
{
"name": fields.String,
"type": fields.String,
},
)
)
),
},
)
def get_evaluation_target[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
"""
Decorator to resolve polymorphic evaluation target (apps or snippets).
Validates the target_type parameter and fetches the corresponding
model (App or CustomizedSnippet) with tenant isolation.
"""
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
target_type = kwargs.get("evaluate_target_type")
target_id = kwargs.get("evaluate_target_id")
if target_type not in EVALUATE_TARGET_TYPES:
raise NotFound(f"Invalid evaluation target type: {target_type}")
_, current_tenant_id = current_account_with_tenant()
target_id = str(target_id)
# Remove path parameters
del kwargs["evaluate_target_type"]
del kwargs["evaluate_target_id"]
target: Union[App, CustomizedSnippet] | None = None
if target_type == EvaluationTargetType.APPS.value:
target = db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first()
elif target_type == EvaluationTargetType.SNIPPETS.value:
target = (
db.session.query(CustomizedSnippet)
.where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id)
.first()
)
if not target:
raise NotFound(f"{str(target_type)} not found")
kwargs["target"] = target
kwargs["target_type"] = target_type
return view_func(*args, **kwargs)
return decorated_view
def _load_evaluation_run_request_and_dataset(tenant_id: str) -> tuple[EvaluationRunRequest, bytes, str]:
"""Validate the run payload and load the uploaded dataset bytes."""
body = request.get_json(force=True)
if not body:
raise BadRequest("Request body is required.")
try:
run_request = EvaluationRunRequest.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
upload_file = db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=tenant_id).first()
if not upload_file:
raise NotFound("Dataset file not found.")
try:
dataset_content = storage.load_once(upload_file.key)
except Exception:
raise BadRequest("Failed to read dataset file.")
if not dataset_content:
raise BadRequest("Dataset file is empty.")
return run_request, dataset_content, upload_file.name
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/dataset-template/download")
class EvaluationDatasetTemplateDownloadApi(Resource):
@console_ns.doc("download_evaluation_dataset_template")
@console_ns.response(200, "Template file streamed as XLSX attachment")
@console_ns.response(400, "Invalid target type or excluded app mode")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Download evaluation dataset template.
Generates an XLSX template based on the target's input parameters
and streams it directly as a file attachment.
"""
try:
xlsx_content, filename = EvaluationService.generate_dataset_template(
target=target,
target_type=target_type,
)
except ValueError as e:
return {"message": str(e)}, 400
encoded_filename = quote(filename)
response = Response(
xlsx_content,
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Length"] = str(len(xlsx_content))
return response
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation")
class EvaluationDetailApi(Resource):
@console_ns.doc("get_evaluation_detail")
@console_ns.response(200, "Evaluation details retrieved successfully", evaluation_detail_model)
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation configuration for the target.
Returns evaluation configuration including model settings,
metrics config, and judgement conditions.
"""
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.get_evaluation_config(session, current_tenant_id, target_type, str(target.id))
if config is None:
return {
"evaluation_model": None,
"evaluation_model_provider": None,
"default_metrics": None,
"customized_metrics": None,
"judgment_config": None,
}
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"default_metrics": EvaluationService.serialize_console_default_metrics(config.default_metrics_list),
"customized_metrics": config.customized_metrics_dict,
"judgment_config": EvaluationService.serialize_console_judgment_config(config.judgment_config_dict),
}
@console_ns.doc("save_evaluation_detail")
@console_ns.response(200, "Evaluation configuration saved successfully")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def put(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Save evaluation configuration for the target.
"""
current_account, current_tenant_id = current_account_with_tenant()
body = request.get_json(force=True)
try:
config_data = EvaluationConfigData.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.save_evaluation_config(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
account_id=str(current_account.id),
data=config_data,
)
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"default_metrics": EvaluationService.serialize_console_default_metrics(config.default_metrics_list),
"customized_metrics": config.customized_metrics_dict,
"judgment_config": EvaluationService.serialize_console_judgment_config(config.judgment_config_dict),
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/template-columns")
class EvaluationTemplateColumnsApi(Resource):
@console_ns.doc("get_evaluation_template_columns")
@console_ns.response(200, "Evaluation dataset columns resolved", evaluation_dataset_columns_response_model)
@console_ns.response(400, "Invalid request body")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
"""Return the dataset template columns implied by the current evaluation config."""
body = request.get_json(silent=True) or {}
try:
config_data = EvaluationConfigData.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
return {
"columns": EvaluationService.get_dataset_column_names(
target=target,
target_type=target_type,
data=config_data,
)
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/logs")
class EvaluationLogsApi(Resource):
@console_ns.doc("get_evaluation_logs")
@console_ns.response(200, "Evaluation logs retrieved successfully")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation run history for the target.
Returns a paginated list of evaluation runs.
"""
_, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
with Session(db.engine, expire_on_commit=False) as session:
runs, total = EvaluationService.get_evaluation_runs(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
page=page,
page_size=page_size,
)
return {
"data": [_serialize_evaluation_run(run) for run in runs],
"total": total,
"page": page,
"page_size": page_size,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/run1")
class EvaluationRunApi(Resource):
@console_ns.doc("start_evaluation_run")
@console_ns.response(200, "Evaluation run started")
@console_ns.response(400, "Invalid request")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet, Dataset], target_type: str):
"""
Start an evaluation run.
Expects JSON body with:
- file_id: uploaded dataset file ID
- evaluation_model: evaluation model name
- evaluation_model_provider: evaluation model provider
- default_metrics: list of default metric objects
- customized_metrics: customized metrics object (optional)
- judgment_config: judgment conditions config (optional)
"""
current_account, current_tenant_id = current_account_with_tenant()
run_request, dataset_content, dataset_filename = _load_evaluation_run_request_and_dataset(current_tenant_id)
try:
with Session(db.engine, expire_on_commit=False) as session:
if target_type == EvaluationTargetType.APPS.value:
evaluation_run = EvaluationService.start_stub_evaluation_run(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
account_id=str(current_account.id),
dataset_file_content=dataset_content,
dataset_filename=dataset_filename,
run_request=run_request,
)
else:
evaluation_run = EvaluationService.start_evaluation_run(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
account_id=str(current_account.id),
dataset_file_content=dataset_content,
dataset_filename=dataset_filename,
run_request=run_request,
)
return _serialize_evaluation_run(evaluation_run), 200
except EvaluationFrameworkNotConfiguredError as e:
return {"message": str(e.description)}, 400
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except EvaluationMaxConcurrentRunsError as e:
return {"message": str(e.description)}, 429
except EvaluationDatasetInvalidError as e:
return {"message": str(e.description)}, 400
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/run")
class EvaluationRunRealApi(Resource):
@console_ns.doc("start_evaluation_run_real")
@console_ns.response(200, "Evaluation run started")
@console_ns.response(400, "Invalid request")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet, Dataset], target_type: str):
"""Start the real evaluation execution flow on the temporary dev path."""
current_account, current_tenant_id = current_account_with_tenant()
run_request, dataset_content, dataset_filename = _load_evaluation_run_request_and_dataset(current_tenant_id)
try:
with Session(db.engine, expire_on_commit=False) as session:
evaluation_run = EvaluationService.start_evaluation_run(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
account_id=str(current_account.id),
dataset_file_content=dataset_content,
dataset_filename=dataset_filename,
run_request=run_request,
)
return _serialize_evaluation_run(evaluation_run), 200
except EvaluationFrameworkNotConfiguredError as e:
return {"message": str(e.description)}, 400
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except EvaluationMaxConcurrentRunsError as e:
return {"message": str(e.description)}, 429
except EvaluationDatasetInvalidError as e:
return {"message": str(e.description)}, 400
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>")
class EvaluationRunDetailApi(Resource):
@console_ns.doc("get_evaluation_run_detail")
@console_ns.response(200, "Evaluation run detail retrieved")
@console_ns.response(404, "Run not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
"""
Get evaluation run detail including items.
"""
_, current_tenant_id = current_account_with_tenant()
run_id = str(run_id)
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 50, type=int)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.get_evaluation_run_detail(
session=session,
tenant_id=current_tenant_id,
run_id=run_id,
)
items, total_items = EvaluationService.get_evaluation_run_items(
session=session,
run_id=run_id,
page=page,
page_size=page_size,
)
return {
"run": _serialize_evaluation_run(run),
"items": {
"data": [_serialize_evaluation_run_item(item) for item in items],
"total": total_items,
"page": page,
"page_size": page_size,
},
}
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>/cancel")
class EvaluationRunCancelApi(Resource):
@console_ns.doc("cancel_evaluation_run")
@console_ns.response(200, "Evaluation run cancelled")
@console_ns.response(404, "Run not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
"""Cancel a running evaluation."""
_, current_tenant_id = current_account_with_tenant()
run_id = str(run_id)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.cancel_evaluation_run(
session=session,
tenant_id=current_tenant_id,
run_id=run_id,
)
return _serialize_evaluation_run(run)
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except ValueError as e:
return {"message": str(e)}, 400
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/metrics")
class EvaluationMetricsApi(Resource):
@console_ns.doc("get_evaluation_metrics")
@console_ns.response(200, "Available metrics retrieved")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get available evaluation metrics for the current framework.
"""
result = {}
for category in EvaluationCategory:
if category in EvaluationService.CONSOLE_DISABLED_CATEGORIES:
continue
result[category.value] = EvaluationService.get_supported_metrics(category)
return {"metrics": result}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/default-metrics")
class EvaluationDefaultMetricsApi(Resource):
@console_ns.doc(
"get_evaluation_default_metrics_with_nodes",
description=(
"List default metrics supported by the current evaluation framework with matching nodes "
"from the target's published workflow only (draft is ignored)."
),
)
@console_ns.response(
200,
"Default metrics and node candidates for the published workflow",
evaluation_default_metrics_response_model,
)
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
default_metrics = EvaluationService.get_default_metrics_with_nodes_for_published_target(
target=target,
target_type=target_type,
)
return {
"default_metrics": [
m.model_dump() for m in EvaluationService.filter_console_default_metrics(default_metrics)
]
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/node-info")
class EvaluationNodeInfoApi(Resource):
@console_ns.doc("get_evaluation_node_info")
@console_ns.response(200, "Node info grouped by metric")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
"""Return workflow/snippet node info grouped by requested metrics.
Request body (JSON):
- metrics: list[str] | None metric names to query; omit or pass
an empty list to get all nodes under key ``"all"``.
Response:
``{metric_or_all: [{"node_id": ..., "type": ..., "title": ...}, ...]}``
"""
body = request.get_json(silent=True) or {}
metrics: list[str] | None = body.get("metrics") or None
result = EvaluationService.get_nodes_for_metrics(
target=target,
target_type=target_type,
metrics=metrics,
)
if not metrics:
result = {
"all": [
node
for node in result.get("all", [])
if node.get("type") not in EvaluationService.CONSOLE_DISABLED_CATEGORIES
]
}
else:
result = {
metric: nodes
for metric, nodes in result.items()
if metric not in EvaluationService.CONSOLE_DISABLED_METRICS
}
return result
@console_ns.route("/evaluation/available-metrics")
class EvaluationAvailableMetricsApi(Resource):
@console_ns.doc("get_available_evaluation_metrics")
@console_ns.response(200, "Available metrics list")
@setup_required
@login_required
@account_initialization_required
def get(self):
"""Return the centrally-defined list of evaluation metrics."""
return {
"metrics": [
metric
for metric in EvaluationService.get_available_metrics()
if metric not in EvaluationService.CONSOLE_DISABLED_METRICS
]
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/files/<uuid:file_id>")
class EvaluationFileDownloadApi(Resource):
@console_ns.doc("download_evaluation_file")
@console_ns.response(200, "File download URL generated successfully")
@console_ns.response(404, "Target or file not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str, file_id: str):
"""
Download evaluation test file or result file.
Looks up the specified file, verifies it belongs to the same tenant,
and returns file info and download URL.
"""
file_id = str(file_id)
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(UploadFile).where(
UploadFile.id == file_id,
UploadFile.tenant_id == current_tenant_id,
)
upload_file = session.execute(stmt).scalar_one_or_none()
if not upload_file:
raise NotFound("File not found")
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
"download_url": download_url,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/version")
class EvaluationVersionApi(Resource):
@console_ns.doc("get_evaluation_version_detail")
@console_ns.expect(console_ns.models.get(VersionQuery.__name__))
@console_ns.response(200, "Version details retrieved successfully")
@console_ns.response(404, "Target or version not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation target version details.
Returns the workflow graph for the specified version.
"""
version = request.args.get("version")
if not version:
return {"message": "version parameter is required"}, 400
graph = {}
if target_type == EvaluationTargetType.SNIPPETS.value and isinstance(target, CustomizedSnippet):
graph = target.graph_dict
return {
"graph": graph,
}
@console_ns.route("/workspaces/current/available-evaluation-workflows")
class AvailableEvaluationWorkflowsApi(Resource):
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
@console_ns.doc("list_available_evaluation_workflows")
@console_ns.doc(description="List published evaluation workflows in the current workspace (all apps)")
@console_ns.response(
200,
"Available evaluation workflows retrieved",
available_evaluation_workflow_pagination_model,
)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self):
"""List published evaluation-type workflows for the current tenant (cross-app)."""
current_user, current_tenant_id = current_account_with_tenant()
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
page = args.page
limit = args.limit
user_id = args.user_id
named_only = args.named_only
keyword = args.keyword
if user_id and user_id != current_user.id:
raise Forbidden()
workflow_service = WorkflowService()
with Session(db.engine) as session:
workflows, has_more = workflow_service.list_published_evaluation_workflows(
session=session,
tenant_id=current_tenant_id,
page=page,
limit=limit,
user_id=user_id,
named_only=named_only,
keyword=keyword,
)
app_ids = {w.app_id for w in workflows}
if app_ids:
apps = session.scalars(select(App).where(App.id.in_(app_ids))).all()
app_names = {a.id: a.name for a in apps}
else:
app_names = {}
items = []
for wf in workflows:
items.append(
{
"id": wf.id,
"app_id": wf.app_id,
"app_name": app_names.get(wf.app_id, ""),
"type": wf.type.value,
"kind": wf.kind_or_standard,
"version": wf.version,
"marked_name": wf.marked_name,
"marked_comment": wf.marked_comment,
"hash": wf.unique_hash,
"created_by": wf.created_by_account,
"created_at": wf.created_at,
"updated_by": wf.updated_by_account,
"updated_at": wf.updated_at,
}
)
return (
marshal(
{"items": items, "page": page, "limit": limit, "has_more": has_more},
available_evaluation_workflow_pagination_fields,
),
200,
)
@console_ns.route("/workspaces/current/evaluation-workflows/<string:workflow_id>/associated-targets")
class EvaluationWorkflowAssociatedTargetsApi(Resource):
@console_ns.doc("list_evaluation_workflow_associated_targets")
@console_ns.doc(
description="List targets (apps / snippets / knowledge bases) that use the given workflow as customized metrics"
)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, workflow_id: str):
"""Return all evaluation targets that reference this workflow as customized metrics."""
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
configs = EvaluationService.list_targets_by_customized_workflow(
session=session,
tenant_id=current_tenant_id,
customized_workflow_id=workflow_id,
)
target_ids_by_type: dict[str, list[str]] = {}
for cfg in configs:
target_ids_by_type.setdefault(cfg.target_type, []).append(cfg.target_id)
app_names: dict[str, str] = {}
if EvaluationTargetType.APPS.value in target_ids_by_type:
apps = session.scalars(
select(App).where(App.id.in_(target_ids_by_type[EvaluationTargetType.APPS.value]))
).all()
app_names = {a.id: a.name for a in apps}
snippet_names: dict[str, str] = {}
if "snippets" in target_ids_by_type:
snippets = session.scalars(
select(CustomizedSnippet).where(CustomizedSnippet.id.in_(target_ids_by_type["snippets"]))
).all()
snippet_names = {s.id: s.name for s in snippets}
dataset_names: dict[str, str] = {}
if "knowledge_base" in target_ids_by_type:
datasets = session.scalars(
select(Dataset).where(Dataset.id.in_(target_ids_by_type["knowledge_base"]))
).all()
dataset_names = {d.id: d.name for d in datasets}
items = []
for cfg in configs:
name = ""
if cfg.target_type == EvaluationTargetType.APPS.value:
name = app_names.get(cfg.target_id, "")
elif cfg.target_type == EvaluationTargetType.SNIPPETS.value:
name = snippet_names.get(cfg.target_id, "")
elif cfg.target_type == "knowledge_base":
name = dataset_names.get(cfg.target_id, "")
items.append(
{
"target_type": cfg.target_type,
"target_id": cfg.target_id,
"target_name": name,
}
)
return {"items": items}, 200
# ---- Serialization Helpers ----
def _serialize_evaluation_run(run: EvaluationRun) -> dict[str, object]:
return {
"id": run.id,
"tenant_id": run.tenant_id,
"target_type": run.target_type,
"target_id": run.target_id,
"evaluation_config_id": run.evaluation_config_id,
"status": run.status,
"dataset_file_id": run.dataset_file_id,
"result_file_id": run.result_file_id,
"total_items": run.total_items,
"completed_items": run.completed_items,
"failed_items": run.failed_items,
"progress": run.progress,
"metrics_summary": run.metrics_summary_dict,
"error": run.error,
"created_by": run.created_by,
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
}
def _serialize_evaluation_run_item(item: EvaluationRunItem) -> dict[str, object]:
return {
"id": item.id,
"item_index": item.item_index,
"inputs": item.inputs_dict,
"expected_output": item.expected_output,
"actual_output": item.actual_output,
"metrics": item.metrics_list,
"judgment": item.judgment_dict,
"metadata": item.metadata_dict,
"error": item.error,
"overall_score": item.overall_score,
}

View File

@ -16,6 +16,7 @@ from extensions.ext_database import db
from fields.base import ResponseModel
from graphon.file import helpers as file_helpers
from libs.datetime_utils import naive_utc_now
from libs.helper import to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
from models.model import IconType
@ -105,9 +106,7 @@ class InstalledAppResponse(ResponseModel):
@field_validator("last_used_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
return to_timestamp(value)
class InstalledAppListResponse(ResponseModel):

View File

@ -64,28 +64,15 @@ class RecommendedAppListResponse(ResponseModel):
categories: list[str]
class LearnDifyAppListResponse(ResponseModel):
recommended_apps: list[RecommendedAppResponse]
register_schema_models(
console_ns,
RecommendedAppsQuery,
RecommendedAppInfoResponse,
RecommendedAppResponse,
RecommendedAppListResponse,
LearnDifyAppListResponse,
)
def _resolve_language(language: str | None) -> str:
if language and language in languages:
return language
if current_user and current_user.interface_language:
return current_user.interface_language
return languages[0]
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource):
@console_ns.doc(params=query_params_from_model(RecommendedAppsQuery))
@ -95,7 +82,13 @@ class RecommendedAppListApi(Resource):
def get(self):
# language args
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True))
language_prefix = _resolve_language(args.language)
language = args.language
if language and language in languages:
language_prefix = language
elif current_user and current_user.interface_language:
language_prefix = current_user.interface_language
else:
language_prefix = languages[0]
return RecommendedAppListResponse.model_validate(
RecommendedAppService.get_recommended_apps_and_categories(language_prefix),
@ -103,22 +96,6 @@ class RecommendedAppListApi(Resource):
).model_dump(mode="json")
@console_ns.route("/explore/apps/learn-dify")
class LearnDifyAppListApi(Resource):
@console_ns.doc(params=query_params_from_model(RecommendedAppsQuery))
@console_ns.response(200, "Success", console_ns.models[LearnDifyAppListResponse.__name__])
@login_required
@account_initialization_required
def get(self):
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True))
language_prefix = _resolve_language(args.language)
return LearnDifyAppListResponse.model_validate(
RecommendedAppService.get_learn_dify_apps(language_prefix),
from_attributes=True,
).model_dump(mode="json")
@console_ns.route("/explore/apps/<uuid:app_id>")
class RecommendedAppApi(Resource):
@login_required

View File

@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, TypeAdapter, field_validator
from constants import HIDDEN_VALUE
from fields.base import ResponseModel
from libs.helper import to_timestamp
from libs.login import current_account_with_tenant, login_required
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
@ -40,12 +41,6 @@ def _mask_api_key(api_key: str) -> str:
return api_key[:3] + "******" + api_key[-3:]
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class APIBasedExtensionResponse(ResponseModel):
id: str
name: str
@ -61,7 +56,7 @@ class APIBasedExtensionResponse(ResponseModel):
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
return to_timestamp(value)
register_schema_models(console_ns, APIBasedExtensionPayload, CodeBasedExtensionResponse, APIBasedExtensionResponse)

View File

@ -105,7 +105,8 @@ class FilePreviewApi(Resource):
@account_initialization_required
def get(self, file_id):
file_id = str(file_id)
text = FileService(db.engine).get_file_preview(file_id)
_, tenant_id = current_account_with_tenant()
text = FileService(db.engine).get_file_preview(file_id, tenant_id)
return {"content": text}

View File

@ -1,142 +0,0 @@
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
class SnippetListQuery(BaseModel):
"""Query parameters for listing snippets."""
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=20, ge=1, le=100)
keyword: str | None = None
is_published: bool | None = Field(default=None, description="Filter by published status")
creators: list[str] | None = Field(default=None, description="Filter by creator account IDs")
@field_validator("creators", mode="before")
@classmethod
def parse_creators(cls, value: object) -> list[str] | None:
"""Normalize creators filter from query string or list input."""
if value is None:
return None
if isinstance(value, str):
return [creator.strip() for creator in value.split(",") if creator.strip()] or None
if isinstance(value, list):
return [str(creator).strip() for creator in value if str(creator).strip()] or None
return None
class IconInfo(BaseModel):
"""Icon information model."""
icon: str | None = None
icon_type: Literal["emoji", "image"] | None = None
icon_background: str | None = None
icon_url: str | None = None
class InputFieldDefinition(BaseModel):
"""Input field definition for snippet parameters."""
default: str | None = None
hint: bool | None = None
label: str | None = None
max_length: int | None = None
options: list[str] | None = None
placeholder: str | None = None
required: bool | None = None
type: str | None = None # e.g., "text-input"
class CreateSnippetPayload(BaseModel):
"""Payload for creating a new snippet."""
name: str = Field(..., min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=2000)
type: Literal["node", "group"] = "node"
icon_info: IconInfo | None = None
graph: dict[str, Any] | None = None
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
class UpdateSnippetPayload(BaseModel):
"""Payload for updating a snippet."""
name: str | None = Field(default=None, min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=2000)
icon_info: IconInfo | None = None
class SnippetDraftSyncPayload(BaseModel):
"""Payload for syncing snippet draft workflow."""
graph: dict[str, Any]
hash: str | None = None
conversation_variables: list[dict[str, Any]] | None = Field(
default=None,
description="Ignored. Snippet workflows do not persist conversation variables.",
)
input_fields: list[dict[str, Any]] | None = None
class SnippetWorkflowListQuery(BaseModel):
"""Query parameters for listing snippet published workflows."""
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
class WorkflowRunQuery(BaseModel):
"""Query parameters for workflow runs."""
last_id: str | None = None
limit: int = Field(default=20, ge=1, le=100)
class SnippetDraftRunPayload(BaseModel):
"""Payload for running snippet draft workflow."""
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
class SnippetDraftNodeRunPayload(BaseModel):
"""Payload for running a single node in snippet draft workflow."""
inputs: dict[str, Any]
query: str = ""
files: list[dict[str, Any]] | None = None
class SnippetIterationNodeRunPayload(BaseModel):
"""Payload for running an iteration node in snippet draft workflow."""
inputs: dict[str, Any] | None = None
class SnippetLoopNodeRunPayload(BaseModel):
"""Payload for running a loop node in snippet draft workflow."""
inputs: dict[str, Any] | None = None
class PublishWorkflowPayload(BaseModel):
"""Payload for publishing snippet workflow."""
knowledge_base_setting: dict[str, Any] | None = None
class SnippetImportPayload(BaseModel):
"""Payload for importing snippet from DSL."""
mode: str = Field(..., description="Import mode: yaml-content or yaml-url")
yaml_content: str | None = Field(default=None, description="YAML content (required for yaml-content mode)")
yaml_url: str | None = Field(default=None, description="YAML URL (required for yaml-url mode)")
name: str | None = Field(default=None, description="Override snippet name")
description: str | None = Field(default=None, description="Override snippet description")
snippet_id: str | None = Field(default=None, description="Snippet ID to update (optional)")
class IncludeSecretQuery(BaseModel):
"""Query parameter for including secret variables in export."""
include_secret: str = Field(default="false", description="Whether to include secret variables")

View File

@ -1,617 +0,0 @@
# import logging
# from collections.abc import Callable
# from functools import wraps
# from flask import request
# from flask_restx import Resource, fields, marshal, marshal_with
# from sqlalchemy.orm import Session
# from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
# from controllers.common.schema import register_schema_models
# from controllers.console import console_ns
# from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
# from controllers.console.app.workflow import (
# RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE,
# workflow_model,
# workflow_pagination_model,
# )
# from controllers.console.app.workflow_run import (
# workflow_run_detail_model,
# workflow_run_node_execution_list_model,
# workflow_run_node_execution_model,
# workflow_run_pagination_model,
# )
# from controllers.console.snippets.payloads import (
# PublishWorkflowPayload,
# SnippetDraftNodeRunPayload,
# SnippetDraftRunPayload,
# SnippetDraftSyncPayload,
# SnippetIterationNodeRunPayload,
# SnippetLoopNodeRunPayload,
# SnippetWorkflowListQuery,
# WorkflowRunQuery,
# )
# from controllers.console.wraps import (
# account_initialization_required,
# edit_permission_required,
# setup_required,
# )
# from core.app.apps.base_app_queue_manager import AppQueueManager
# from core.app.entities.app_invoke_entities import InvokeFrom
# from extensions.ext_database import db
# from extensions.ext_redis import redis_client
# from graphon.graph_engine.manager import GraphEngineManager
# from libs import helper
# from libs.helper import TimestampField
# from libs.login import current_account_with_tenant, login_required
# from models.snippet import CustomizedSnippet
# from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
# from services.snippet_generate_service import SnippetGenerateService
# from services.snippet_service import SnippetService
# logger = logging.getLogger(__name__)
# # Register Pydantic models with Swagger
# register_schema_models(
# console_ns,
# SnippetDraftSyncPayload,
# SnippetDraftNodeRunPayload,
# SnippetDraftRunPayload,
# SnippetIterationNodeRunPayload,
# SnippetLoopNodeRunPayload,
# SnippetWorkflowListQuery,
# WorkflowRunQuery,
# PublishWorkflowPayload,
# )
# snippet_workflow_model = console_ns.clone("SnippetWorkflow", workflow_model, {
# "input_fields": fields.Raw(default=[]),
# })
# class SnippetNotFoundError(Exception):
# """Snippet not found error."""
# pass
# def get_snippet[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
# """Decorator to fetch and validate snippet access."""
# @wraps(view_func)
# def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
# if not kwargs.get("snippet_id"):
# raise ValueError("missing snippet_id in path parameters")
# _, current_tenant_id = current_account_with_tenant()
# snippet_id = str(kwargs.get("snippet_id"))
# del kwargs["snippet_id"]
# snippet = SnippetService.get_snippet_by_id(
# snippet_id=snippet_id,
# tenant_id=current_tenant_id,
# )
# if not snippet:
# raise NotFound("Snippet not found")
# kwargs["snippet"] = snippet
# return view_func(*args, **kwargs)
# return decorated_view
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
# class SnippetDraftWorkflowApi(Resource):
# @console_ns.doc("get_snippet_draft_workflow")
# @console_ns.response(200, "Draft workflow retrieved successfully", snippet_workflow_model)
# @console_ns.response(404, "Snippet or draft workflow not found")
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @edit_permission_required
# @marshal_with(snippet_workflow_model)
# def get(self, snippet: CustomizedSnippet):
# """Get draft workflow for snippet."""
# snippet_service = SnippetService()
# workflow = snippet_service.get_draft_workflow(snippet=snippet)
# if not workflow:
# raise DraftWorkflowNotExist()
# db.session.expunge(workflow)
# workflow.conversation_variables = []
# workflow.input_fields = snippet.input_fields_list
# return workflow
# @console_ns.doc("sync_snippet_draft_workflow")
# @console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
# @console_ns.response(200, "Draft workflow synced successfully")
# @console_ns.response(400, "Hash mismatch")
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @edit_permission_required
# def post(self, snippet: CustomizedSnippet):
# """Sync draft workflow for snippet."""
# current_user, _ = current_account_with_tenant()
# payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
# try:
# snippet_service = SnippetService()
# workflow = snippet_service.sync_draft_workflow(
# snippet=snippet,
# graph=payload.graph,
# unique_hash=payload.hash,
# account=current_user,
# input_fields=payload.input_fields,
# )
# except WorkflowHashNotEqualError:
# raise DraftWorkflowNotSync()
# except ValueError as e:
# return {"message": str(e)}, 400
# return {
# "result": "success",
# "hash": workflow.unique_hash,
# "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
# }
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
# class SnippetDraftConfigApi(Resource):
# @console_ns.doc("get_snippet_draft_config")
# @console_ns.response(200, "Draft config retrieved successfully")
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @edit_permission_required
# def get(self, snippet: CustomizedSnippet):
# """Get snippet draft workflow configuration limits."""
# return {
# "parallel_depth_limit": 3,
# }
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
# class SnippetPublishedWorkflowApi(Resource):
# @console_ns.doc("get_snippet_published_workflow")
# @console_ns.response(200, "Published workflow retrieved successfully", snippet_workflow_model)
# @console_ns.response(404, "Snippet not found")
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @edit_permission_required
# @marshal_with(snippet_workflow_model)
# def get(self, snippet: CustomizedSnippet):
# """Get published workflow for snippet."""
# if not snippet.is_published:
# return None
# snippet_service = SnippetService()
# workflow = snippet_service.get_published_workflow(snippet=snippet)
# if workflow:
# workflow.input_fields = snippet.input_fields_list
# return workflow
# @console_ns.doc("publish_snippet_workflow")
# @console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
# @console_ns.response(200, "Workflow published successfully")
# @console_ns.response(400, "No draft workflow found")
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @edit_permission_required
# def post(self, snippet: CustomizedSnippet):
# """Publish snippet workflow."""
# current_user, _ = current_account_with_tenant()
# snippet_service = SnippetService()
# with Session(db.engine) as session:
# snippet = session.merge(snippet)
# try:
# workflow = snippet_service.publish_workflow(
# session=session,
# snippet=snippet,
# account=current_user,
# )
# workflow_created_at = TimestampField().format(workflow.created_at)
# session.commit()
# except ValueError as e:
# return {"message": str(e)}, 400
# return {
# "result": "success",
# "created_at": workflow_created_at,
# }
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
# class SnippetDefaultBlockConfigsApi(Resource):
# @console_ns.doc("get_snippet_default_block_configs")
# @console_ns.response(200, "Default block configs retrieved successfully")
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @edit_permission_required
# def get(self, snippet: CustomizedSnippet):
# """Get default block configurations for snippet workflow."""
# snippet_service = SnippetService()
# return snippet_service.get_default_block_configs()
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows")
# class SnippetPublishedAllWorkflowApi(Resource):
# @console_ns.expect(console_ns.models[SnippetWorkflowListQuery.__name__])
# @console_ns.doc("get_all_snippet_published_workflows")
# @console_ns.doc(description="Get all published workflows for a snippet")
# @console_ns.doc(params={"snippet_id": "Snippet ID"})
# @console_ns.response(200, "Published workflows retrieved successfully", workflow_pagination_model)
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @edit_permission_required
# def get(self, snippet: CustomizedSnippet):
# """Get all published workflow versions for snippet."""
# args = SnippetWorkflowListQuery.model_validate(request.args.to_dict(flat=True))
# snippet_service = SnippetService()
# with Session(db.engine) as session:
# workflows, has_more = snippet_service.get_all_published_workflows(
# session=session,
# snippet=snippet,
# page=args.page,
# limit=args.limit,
# )
# serialized_workflows = marshal(workflows, workflow_model)
# return {
# "items": serialized_workflows,
# "page": args.page,
# "limit": args.limit,
# "has_more": has_more,
# }
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/<string:workflow_id>/restore")
# class SnippetDraftWorkflowRestoreApi(Resource):
# @console_ns.doc("restore_snippet_workflow_to_draft")
# @console_ns.doc(description="Restore a published snippet workflow version into the draft workflow")
# @console_ns.doc(params={"snippet_id": "Snippet ID", "workflow_id": "Published workflow ID"})
# @console_ns.response(200, "Workflow restored successfully")
# @console_ns.response(400, "Source workflow must be published")
# @console_ns.response(404, "Workflow not found")
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @edit_permission_required
# def post(self, snippet: CustomizedSnippet, workflow_id: str):
# """Restore a published snippet workflow version into the draft workflow."""
# current_user, _ = current_account_with_tenant()
# snippet_service = SnippetService()
# try:
# workflow = snippet_service.restore_published_workflow_to_draft(
# snippet=snippet,
# workflow_id=workflow_id,
# account=current_user,
# )
# except IsDraftWorkflowError as exc:
# raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
# except WorkflowNotFoundError as exc:
# raise NotFound(str(exc)) from exc
# except ValueError as exc:
# raise BadRequest(str(exc)) from exc
# return {
# "result": "success",
# "hash": workflow.unique_hash,
# "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
# }
# @console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
# class SnippetWorkflowRunsApi(Resource):
# @console_ns.doc("list_snippet_workflow_runs")
# @console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @marshal_with(workflow_run_pagination_model)
# def get(self, snippet: CustomizedSnippet):
# """List workflow runs for snippet."""
# query = WorkflowRunQuery.model_validate(
# {
# "last_id": request.args.get("last_id"),
# "limit": request.args.get("limit", type=int, default=20),
# }
# )
# args = {
# "last_id": query.last_id,
# "limit": query.limit,
# }
# snippet_service = SnippetService()
# result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
# return result
# @console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
# class SnippetWorkflowRunDetailApi(Resource):
# @console_ns.doc("get_snippet_workflow_run_detail")
# @console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
# @console_ns.response(404, "Workflow run not found")
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @marshal_with(workflow_run_detail_model)
# def get(self, snippet: CustomizedSnippet, run_id):
# """Get workflow run detail for snippet."""
# run_id = str(run_id)
# snippet_service = SnippetService()
# workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
# if not workflow_run:
# raise NotFound("Workflow run not found")
# return workflow_run
# @console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
# class SnippetWorkflowRunNodeExecutionsApi(Resource):
# @console_ns.doc("list_snippet_workflow_run_node_executions")
# @console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @marshal_with(workflow_run_node_execution_list_model)
# def get(self, snippet: CustomizedSnippet, run_id):
# """List node executions for a workflow run."""
# run_id = str(run_id)
# snippet_service = SnippetService()
# node_executions = snippet_service.get_snippet_workflow_run_node_executions(
# snippet=snippet,
# run_id=run_id,
# )
# return {"data": node_executions}
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/run")
# class SnippetDraftNodeRunApi(Resource):
# @console_ns.doc("run_snippet_draft_node")
# @console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)")
# @console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
# @console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__))
# @console_ns.response(200, "Node run completed successfully", workflow_run_node_execution_model)
# @console_ns.response(404, "Snippet or draft workflow not found")
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @marshal_with(workflow_run_node_execution_model)
# @edit_permission_required
# def post(self, snippet: CustomizedSnippet, node_id: str):
# """
# Run a single node in snippet draft workflow.
# Executes a specific node with provided inputs for single-step debugging.
# Returns the node execution result including status, outputs, and timing.
# """
# current_user, _ = current_account_with_tenant()
# payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
# user_inputs = payload.inputs
# # Get draft workflow for file parsing
# snippet_service = SnippetService()
# draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
# if not draft_workflow:
# raise NotFound("Draft workflow not found")
# files = SnippetGenerateService.parse_files(draft_workflow, payload.files)
# workflow_node_execution = SnippetGenerateService.run_draft_node(
# snippet=snippet,
# node_id=node_id,
# user_inputs=user_inputs,
# account=current_user,
# query=payload.query,
# files=files,
# )
# return workflow_node_execution
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/last-run")
# class SnippetDraftNodeLastRunApi(Resource):
# @console_ns.doc("get_snippet_draft_node_last_run")
# @console_ns.doc(description="Get last run result for a node in snippet draft workflow")
# @console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
# @console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model)
# @console_ns.response(404, "Snippet, draft workflow, or node last run not found")
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @marshal_with(workflow_run_node_execution_model)
# def get(self, snippet: CustomizedSnippet, node_id: str):
# """
# Get the last run result for a specific node in snippet draft workflow.
# Returns the most recent execution record for the given node,
# including status, inputs, outputs, and timing information.
# """
# snippet_service = SnippetService()
# draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
# if not draft_workflow:
# raise NotFound("Draft workflow not found")
# node_exec = snippet_service.get_snippet_node_last_run(
# snippet=snippet,
# workflow=draft_workflow,
# node_id=node_id,
# )
# if node_exec is None:
# raise NotFound("Node last run not found")
# return node_exec
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
# class SnippetDraftRunIterationNodeApi(Resource):
# @console_ns.doc("run_snippet_draft_iteration_node")
# @console_ns.doc(description="Run draft workflow iteration node for snippet")
# @console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
# @console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__))
# @console_ns.response(200, "Iteration node run started successfully (SSE stream)")
# @console_ns.response(404, "Snippet or draft workflow not found")
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @edit_permission_required
# def post(self, snippet: CustomizedSnippet, node_id: str):
# """
# Run a draft workflow iteration node for snippet.
# Iteration nodes execute their internal sub-graph multiple times over an input list.
# Returns an SSE event stream with iteration progress and results.
# """
# current_user, _ = current_account_with_tenant()
# args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
# try:
# response = SnippetGenerateService.generate_single_iteration(
# snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
# )
# return helper.compact_generate_response(response)
# except ValueError as e:
# raise e
# except Exception:
# logger.exception("internal server error.")
# raise InternalServerError()
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/loop/nodes/<string:node_id>/run")
# class SnippetDraftRunLoopNodeApi(Resource):
# @console_ns.doc("run_snippet_draft_loop_node")
# @console_ns.doc(description="Run draft workflow loop node for snippet")
# @console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
# @console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__))
# @console_ns.response(200, "Loop node run started successfully (SSE stream)")
# @console_ns.response(404, "Snippet or draft workflow not found")
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @edit_permission_required
# def post(self, snippet: CustomizedSnippet, node_id: str):
# """
# Run a draft workflow loop node for snippet.
# Loop nodes execute their internal sub-graph repeatedly until a condition is met.
# Returns an SSE event stream with loop progress and results.
# """
# current_user, _ = current_account_with_tenant()
# args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
# try:
# response = SnippetGenerateService.generate_single_loop(
# snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
# )
# return helper.compact_generate_response(response)
# except ValueError as e:
# raise e
# except Exception:
# logger.exception("internal server error.")
# raise InternalServerError()
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/run")
# class SnippetDraftWorkflowRunApi(Resource):
# @console_ns.doc("run_snippet_draft_workflow")
# @console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__))
# @console_ns.response(200, "Draft workflow run started successfully (SSE stream)")
# @console_ns.response(404, "Snippet or draft workflow not found")
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @edit_permission_required
# def post(self, snippet: CustomizedSnippet):
# """
# Run draft workflow for snippet.
# Executes the snippet's draft workflow with the provided inputs
# and returns an SSE event stream with execution progress and results.
# """
# current_user, _ = current_account_with_tenant()
# payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
# args = payload.model_dump(exclude_none=True)
# try:
# response = SnippetGenerateService.generate(
# snippet=snippet,
# user=current_user,
# args=args,
# invoke_from=InvokeFrom.DEBUGGER,
# streaming=True,
# )
# return helper.compact_generate_response(response)
# except ValueError as e:
# raise e
# except Exception:
# logger.exception("internal server error.")
# raise InternalServerError()
# @console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/tasks/<string:task_id>/stop")
# class SnippetWorkflowTaskStopApi(Resource):
# @console_ns.doc("stop_snippet_workflow_task")
# @console_ns.response(200, "Task stopped successfully")
# @console_ns.response(404, "Snippet not found")
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @edit_permission_required
# def post(self, snippet: CustomizedSnippet, task_id: str):
# """
# Stop a running snippet workflow task.
# Uses both the legacy stop flag mechanism and the graph engine
# command channel for backward compatibility.
# """
# # Stop using both mechanisms for backward compatibility
# # Legacy stop flag mechanism (without user check)
# AppQueueManager.set_stop_flag_no_user_check(task_id)
# # New graph engine command channel mechanism
# GraphEngineManager(redis_client).send_stop_command(task_id)
# return {"result": "success"}

View File

@ -1,316 +0,0 @@
# """
# Snippet draft workflow variable APIs.
# Mirrors console app routes under /apps/.../workflows/draft/variables for snippet scope,
# using CustomizedSnippet.id as WorkflowDraftVariable.app_id (same invariant as snippet execution).
# Snippet workflows do not expose system variables (`node_id == sys`) or conversation variables
# (`node_id == conversation`): paginated list queries exclude those rows; single-variable GET/PATCH/DELETE/reset
# reject them; `GET .../system-variables` and `GET .../conversation-variables` return empty lists for API parity.
# Other routes mirror `workflow_draft_variable` app APIs under `/snippets/...`.
# """
# from collections.abc import Callable
# from functools import wraps
# from typing import Any
# from flask import Response, request
# from flask_restx import Resource, marshal, marshal_with
# from sqlalchemy.orm import Session
# from controllers.console import console_ns
# from controllers.console.app.error import DraftWorkflowNotExist
# from controllers.console.app.workflow_draft_variable import (
# WorkflowDraftVariableListQuery,
# WorkflowDraftVariableUpdatePayload,
# _ensure_variable_access,
# _file_access_controller,
# validate_node_id,
# workflow_draft_variable_list_model,
# workflow_draft_variable_list_without_value_model,
# workflow_draft_variable_model,
# )
# from controllers.console.snippets.snippet_workflow import get_snippet
# from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
# from controllers.web.error import InvalidArgumentError, NotFoundError
# from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
# from extensions.ext_database import db
# from factories.file_factory import build_from_mapping, build_from_mappings
# from factories.variable_factory import build_segment_with_type
# from graphon.variables.types import SegmentType
# from libs.login import current_user, login_required
# from models.snippet import CustomizedSnippet
# from models.workflow import WorkflowDraftVariable
# from services.snippet_service import SnippetService
# from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
# _SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS: frozenset[str] = frozenset(
# {SYSTEM_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID}
# )
# def _ensure_snippet_draft_variable_row_allowed(
# *,
# variable: WorkflowDraftVariable,
# variable_id: str,
# ) -> None:
# """Snippet scope only supports canvas-node draft variables; treat sys/conversation rows as not found."""
# if variable.node_id in _SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS:
# raise NotFoundError(description=f"variable not found, id={variable_id}")
# def _snippet_draft_var_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R]:
# """Setup, auth, snippet resolution, and tenant edit permission (same stack as snippet workflow APIs)."""
# @setup_required
# @login_required
# @account_initialization_required
# @get_snippet
# @edit_permission_required
# @wraps(f)
# def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# return f(*args, **kwargs)
# return wrapper
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables")
# class SnippetWorkflowVariableCollectionApi(Resource):
# @console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
# @console_ns.doc("get_snippet_workflow_variables")
# @console_ns.doc(description="List draft workflow variables without values (paginated, snippet scope)")
# @console_ns.response(
# 200,
# "Workflow variables retrieved successfully",
# workflow_draft_variable_list_without_value_model,
# )
# @_snippet_draft_var_prerequisite
# @marshal_with(workflow_draft_variable_list_without_value_model)
# def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
# args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# snippet_service = SnippetService()
# if snippet_service.get_draft_workflow(snippet=snippet) is None:
# raise DraftWorkflowNotExist()
# with Session(bind=db.engine, expire_on_commit=False) as session:
# draft_var_srv = WorkflowDraftVariableService(session=session)
# workflow_vars = draft_var_srv.list_variables_without_values(
# app_id=snippet.id,
# page=args.page,
# limit=args.limit,
# user_id=current_user.id,
# exclude_node_ids=_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS,
# )
# return workflow_vars
# @console_ns.doc("delete_snippet_workflow_variables")
# @console_ns.doc(description="Delete all draft workflow variables for the current user (snippet scope)")
# @console_ns.response(204, "Workflow variables deleted successfully")
# @_snippet_draft_var_prerequisite
# def delete(self, snippet: CustomizedSnippet) -> Response:
# draft_var_srv = WorkflowDraftVariableService(session=db.session())
# draft_var_srv.delete_user_workflow_variables(snippet.id, user_id=current_user.id)
# db.session.commit()
# return Response("", 204)
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/variables")
# class SnippetNodeVariableCollectionApi(Resource):
# @console_ns.doc("get_snippet_node_variables")
# @console_ns.doc(description="Get variables for a specific node (snippet draft workflow)")
# @console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
# @_snippet_draft_var_prerequisite
# @marshal_with(workflow_draft_variable_list_model)
# def get(self, snippet: CustomizedSnippet, node_id: str) -> WorkflowDraftVariableList:
# validate_node_id(node_id)
# with Session(bind=db.engine, expire_on_commit=False) as session:
# draft_var_srv = WorkflowDraftVariableService(session=session)
# node_vars = draft_var_srv.list_node_variables(snippet.id, node_id, user_id=current_user.id)
# return node_vars
# @console_ns.doc("delete_snippet_node_variables")
# @console_ns.doc(description="Delete all variables for a specific node (snippet draft workflow)")
# @console_ns.response(204, "Node variables deleted successfully")
# @_snippet_draft_var_prerequisite
# def delete(self, snippet: CustomizedSnippet, node_id: str) -> Response:
# validate_node_id(node_id)
# srv = WorkflowDraftVariableService(db.session())
# srv.delete_node_variables(snippet.id, node_id, user_id=current_user.id)
# db.session.commit()
# return Response("", 204)
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>")
# class SnippetVariableApi(Resource):
# @console_ns.doc("get_snippet_workflow_variable")
# @console_ns.doc(description="Get a specific draft workflow variable (snippet scope)")
# @console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model)
# @console_ns.response(404, "Variable not found")
# @_snippet_draft_var_prerequisite
# @marshal_with(workflow_draft_variable_model)
# def get(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
# draft_var_srv = WorkflowDraftVariableService(session=db.session())
# variable = _ensure_variable_access(
# variable=draft_var_srv.get_variable(variable_id=variable_id),
# app_id=snippet.id,
# variable_id=variable_id,
# )
# _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
# return variable
# @console_ns.doc("update_snippet_workflow_variable")
# @console_ns.doc(description="Update a draft workflow variable (snippet scope)")
# @console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
# @console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
# @console_ns.response(404, "Variable not found")
# @_snippet_draft_var_prerequisite
# @marshal_with(workflow_draft_variable_model)
# def patch(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
# draft_var_srv = WorkflowDraftVariableService(session=db.session())
# args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
# variable = _ensure_variable_access(
# variable=draft_var_srv.get_variable(variable_id=variable_id),
# app_id=snippet.id,
# variable_id=variable_id,
# )
# _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
# new_name = args_model.name
# raw_value = args_model.value
# if new_name is None and raw_value is None:
# return variable
# new_value = None
# if raw_value is not None:
# if variable.value_type == SegmentType.FILE:
# if not isinstance(raw_value, dict):
# raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
# raw_value = build_from_mapping(
# mapping=raw_value,
# tenant_id=snippet.tenant_id,
# access_controller=_file_access_controller,
# )
# elif variable.value_type == SegmentType.ARRAY_FILE:
# if not isinstance(raw_value, list):
# raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
# if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
# raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
# raw_value = build_from_mappings(
# mappings=raw_value,
# tenant_id=snippet.tenant_id,
# access_controller=_file_access_controller,
# )
# new_value = build_segment_with_type(variable.value_type, raw_value)
# draft_var_srv.update_variable(variable, name=new_name, value=new_value)
# db.session.commit()
# return variable
# @console_ns.doc("delete_snippet_workflow_variable")
# @console_ns.doc(description="Delete a draft workflow variable (snippet scope)")
# @console_ns.response(204, "Variable deleted successfully")
# @console_ns.response(404, "Variable not found")
# @_snippet_draft_var_prerequisite
# def delete(self, snippet: CustomizedSnippet, variable_id: str) -> Response:
# draft_var_srv = WorkflowDraftVariableService(session=db.session())
# variable = _ensure_variable_access(
# variable=draft_var_srv.get_variable(variable_id=variable_id),
# app_id=snippet.id,
# variable_id=variable_id,
# )
# _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
# draft_var_srv.delete_variable(variable)
# db.session.commit()
# return Response("", 204)
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>/reset")
# class SnippetVariableResetApi(Resource):
# @console_ns.doc("reset_snippet_workflow_variable")
# @console_ns.doc(description="Reset a draft workflow variable to its default value (snippet scope)")
# @console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model)
# @console_ns.response(204, "Variable reset (no content)")
# @console_ns.response(404, "Variable not found")
# @_snippet_draft_var_prerequisite
# def put(self, snippet: CustomizedSnippet, variable_id: str) -> Response | Any:
# draft_var_srv = WorkflowDraftVariableService(session=db.session())
# snippet_service = SnippetService()
# draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
# if draft_workflow is None:
# raise NotFoundError(
# f"Draft workflow not found, snippet_id={snippet.id}",
# )
# variable = _ensure_variable_access(
# variable=draft_var_srv.get_variable(variable_id=variable_id),
# app_id=snippet.id,
# variable_id=variable_id,
# )
# _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
# resetted = draft_var_srv.reset_variable(draft_workflow, variable)
# db.session.commit()
# if resetted is None:
# return Response("", 204)
# return marshal(resetted, workflow_draft_variable_model)
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/conversation-variables")
# class SnippetConversationVariableCollectionApi(Resource):
# @console_ns.doc("get_snippet_conversation_variables")
# @console_ns.doc(
# description="Conversation variables are not used in snippet workflows; returns an empty list for API parity"
# )
# @console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
# @_snippet_draft_var_prerequisite
# @marshal_with(workflow_draft_variable_list_model)
# def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
# return WorkflowDraftVariableList(variables=[])
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/system-variables")
# class SnippetSystemVariableCollectionApi(Resource):
# @console_ns.doc("get_snippet_system_variables")
# @console_ns.doc(
# description="System variables are not used in snippet workflows; returns an empty list for API parity"
# )
# @console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
# @_snippet_draft_var_prerequisite
# @marshal_with(workflow_draft_variable_list_model)
# def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
# return WorkflowDraftVariableList(variables=[])
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/environment-variables")
# class SnippetEnvironmentVariableCollectionApi(Resource):
# @console_ns.doc("get_snippet_environment_variables")
# @console_ns.doc(description="Get environment variables from snippet draft workflow graph")
# @console_ns.response(200, "Environment variables retrieved successfully")
# @console_ns.response(404, "Draft workflow not found")
# @_snippet_draft_var_prerequisite
# def get(self, snippet: CustomizedSnippet) -> dict[str, list[dict[str, Any]]]:
# snippet_service = SnippetService()
# workflow = snippet_service.get_draft_workflow(snippet=snippet)
# if workflow is None:
# raise DraftWorkflowNotExist()
# env_vars_list: list[dict[str, Any]] = []
# for v in workflow.environment_variables:
# env_vars_list.append(
# {
# "id": v.id,
# "type": "env",
# "name": v.name,
# "description": v.description,
# "selector": v.selector,
# "value_type": v.value_type.exposed_type().value,
# "value": v.value,
# "edited": False,
# "visible": True,
# "editable": True,
# }
# )
# return {"items": env_vars_list}

View File

@ -25,6 +25,10 @@ class TagBasePayload(BaseModel):
type: TagType = Field(description="Tag type")
class TagUpdateRequestPayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50)
class TagBindingPayload(BaseModel):
tag_ids: list[str] = Field(description="Tag IDs to bind")
target_id: str = Field(description="Target ID to bind tags to")
@ -68,6 +72,7 @@ class TagResponse(ResponseModel):
register_schema_models(
console_ns,
TagBasePayload,
TagUpdateRequestPayload,
TagBindingPayload,
TagBindingRemovePayload,
TagListQueryParam,
@ -118,7 +123,7 @@ class TagListApi(Resource):
@console_ns.route("/tags/<uuid:tag_id>")
class TagUpdateDeleteApi(Resource):
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
@console_ns.expect(console_ns.models[TagUpdateRequestPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -129,8 +134,8 @@ class TagUpdateDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=payload.type), tag_id)
payload = TagUpdateRequestPayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)

View File

@ -42,7 +42,7 @@ from fields.base import ResponseModel
from fields.member_fields import Account as AccountResponse
from graphon.file import helpers as file_helpers
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, extract_remote_ip, timezone
from libs.helper import EmailStr, extract_remote_ip, timezone, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from models.account import AccountStatus, InvitationCodeStatus
@ -185,12 +185,6 @@ def _serialize_account(account) -> dict[str, Any]:
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class AccountIntegrateResponse(ResponseModel):
provider: str
created_at: int | None = None
@ -200,7 +194,7 @@ class AccountIntegrateResponse(ResponseModel):
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
return to_timestamp(value)
class AccountIntegrateListResponse(ResponseModel):
@ -220,7 +214,7 @@ class EducationStatusResponse(ResponseModel):
@field_validator("expire_at", mode="before")
@classmethod
def _normalize_expire_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
return to_timestamp(value)
class EducationAutocompleteResponse(ResponseModel):

View File

@ -30,14 +30,13 @@ from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.enterprise import rbac_service as enterprise_rbac_service
from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
class MemberInvitePayload(BaseModel):
emails: list[str] = Field(default_factory=list)
role: str
role: TenantAccountRole
language: str | None = None
@ -71,18 +70,6 @@ register_schema_models(
)
def _serialize_member_roles(current_role: str | None, member_roles: list[enterprise_rbac_service.MemberRoleSummary]) -> list[dict[str, str]]:
if member_roles:
return [{"id": role.id, "name": role.name} for role in member_roles]
if current_role:
return [{"id": current_role, "name": current_role}]
return []
def _normalize_enum_value(value: object) -> str:
normalized = getattr(value, "value", value)
return str(normalized) if normalized is not None else ""
def _is_role_enabled(role: TenantAccountRole | str, tenant_id: str) -> bool:
if role != TenantAccountRole.DATASET_OPERATOR:
return True
@ -102,36 +89,7 @@ class MemberListApi(Resource):
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
if dify_config.RBAC_ENABLED:
member_ids = [member.id for member in members]
member_roles = enterprise_rbac_service.RBACService.MemberRoles.batch_get(
str(current_user.current_tenant.id),
current_user.id,
member_ids,
)
roles_map = {item.account_id: item.roles for item in member_roles}
else:
roles_map = {}
serialized_members = []
for member in members:
current_role = _normalize_enum_value(member.current_role)
serialized_members.append(
{
"id": member.id,
"name": member.name,
"email": member.email,
"avatar": member.avatar,
"last_login_at": member.last_login_at,
"last_active_at": member.last_active_at,
"created_at": member.created_at,
"role": current_role,
"roles": _serialize_member_roles(current_role, roles_map.get(member.id, [])),
"status": _normalize_enum_value(member.status),
}
)
member_models = TypeAdapter(list[AccountWithRole]).validate_python(serialized_members)
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = AccountWithRoleList(accounts=member_models)
return response.model_dump(mode="json"), 200
@ -152,9 +110,8 @@ class MemberInviteEmailApi(Resource):
invitee_emails = args.emails
invitee_role = args.role
interface_language = args.language
if not dify_config.RBAC_ENABLED:
if not TenantAccountRole.is_valid_role(invitee_role) or not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
inviter = current_user
if not inviter.current_tenant:

View File

@ -1,15 +1,15 @@
import io
from collections.abc import Mapping
from typing import Any, Literal, TypedDict
from typing import Any, Literal
from flask import request, send_file
from flask_restx import Resource
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, Field
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.common.schema import query_params_from_model, register_enum_models, register_schema_models
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
@ -23,14 +23,6 @@ from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
class AutoUpgradeSettingsResponse(TypedDict):
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting
upgrade_time_of_day: int
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode
exclude_plugins: list[str]
include_plugins: list[str]
class ParserList(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
@ -94,8 +86,8 @@ class ParserUninstall(BaseModel):
class ParserPermissionChange(BaseModel):
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
install_permission: TenantPluginPermission.InstallPermission
debug_permission: TenantPluginPermission.DebugPermission
class ParserDynamicOptions(BaseModel):
@ -131,22 +123,13 @@ class PluginAutoUpgradeSettingsPayload(BaseModel):
include_plugins: list[str] = Field(default_factory=list)
class ParserAutoUpgradeChange(BaseModel):
model_config = ConfigDict(extra="forbid")
category: TenantPluginAutoUpgradeStrategy.PluginCategory
class ParserPreferencesChange(BaseModel):
permission: PluginPermissionSettingsPayload
auto_upgrade: PluginAutoUpgradeSettingsPayload
class ParserAutoUpgradeFetch(BaseModel):
category: TenantPluginAutoUpgradeStrategy.PluginCategory
class ParserExcludePlugin(BaseModel):
model_config = ConfigDict(extra="forbid")
plugin_id: str
category: TenantPluginAutoUpgradeStrategy.PluginCategory
class ParserReadme(BaseModel):
@ -173,8 +156,7 @@ register_schema_models(
ParserPermissionChange,
ParserDynamicOptions,
ParserDynamicOptionsWithCredentials,
ParserAutoUpgradeChange,
ParserAutoUpgradeFetch,
ParserPreferencesChange,
ParserExcludePlugin,
ParserReadme,
)
@ -182,36 +164,12 @@ register_schema_models(
register_enum_models(
console_ns,
TenantPluginPermission.DebugPermission,
TenantPluginAutoUpgradeStrategy.PluginCategory,
TenantPluginAutoUpgradeStrategy.UpgradeMode,
TenantPluginAutoUpgradeStrategy.StrategySetting,
TenantPluginPermission.InstallPermission,
)
def _default_auto_upgrade_settings(
tenant_id: str,
category: TenantPluginAutoUpgradeStrategy.PluginCategory,
) -> AutoUpgradeSettingsResponse:
return {
"strategy_setting": PluginAutoUpgradeService.default_strategy_setting_for_category(category),
"upgrade_time_of_day": PluginAutoUpgradeService.default_upgrade_time_of_day(tenant_id),
"upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
"exclude_plugins": [],
"include_plugins": [],
}
def _auto_upgrade_settings_to_dict(strategy: TenantPluginAutoUpgradeStrategy) -> AutoUpgradeSettingsResponse:
return {
"strategy_setting": strategy.strategy_setting,
"upgrade_time_of_day": strategy.upgrade_time_of_day,
"upgrade_mode": strategy.upgrade_mode,
"exclude_plugins": strategy.exclude_plugins,
"include_plugins": strategy.include_plugins,
}
def _read_upload_content(file: FileStorage, max_size: int) -> bytes:
"""
Read the uploaded file and validate its actual size before delegating to the plugin service.
@ -659,13 +617,11 @@ class PluginChangePermissionApi(Resource):
tenant_id = current_tenant_id
set_permission_result = PluginPermissionService.change_permission(
tenant_id, args.install_permission, args.debug_permission
)
if not set_permission_result:
return jsonable_encoder({"success": False, "message": "Failed to set permission"})
return jsonable_encoder({"success": True})
return {
"success": PluginPermissionService.change_permission(
tenant_id, args.install_permission, args.debug_permission
)
}
@console_ns.route("/workspaces/current/plugin/permission/fetch")
@ -754,9 +710,9 @@ class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource):
return jsonable_encoder({"options": options})
@console_ns.route("/workspaces/current/plugin/auto-upgrade/change")
class PluginChangeAutoUpgradeApi(Resource):
@console_ns.expect(console_ns.models[ParserAutoUpgradeChange.__name__])
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource):
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])
@setup_required
@login_required
@account_initialization_required
@ -765,17 +721,38 @@ class PluginChangeAutoUpgradeApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
args = ParserAutoUpgradeChange.model_validate(console_ns.payload)
args = ParserPreferencesChange.model_validate(console_ns.payload)
permission = args.permission
install_permission = permission.install_permission
debug_permission = permission.debug_permission
auto_upgrade = args.auto_upgrade
strategy_setting = auto_upgrade.strategy_setting
upgrade_time_of_day = auto_upgrade.upgrade_time_of_day
upgrade_mode = auto_upgrade.upgrade_mode
exclude_plugins = auto_upgrade.exclude_plugins
include_plugins = auto_upgrade.include_plugins
# set permission
set_permission_result = PluginPermissionService.change_permission(
tenant_id,
install_permission,
debug_permission,
)
if not set_permission_result:
return jsonable_encoder({"success": False, "message": "Failed to set permission"})
# set auto upgrade strategy
set_auto_upgrade_strategy_result = PluginAutoUpgradeService.change_strategy(
tenant_id,
auto_upgrade.strategy_setting,
auto_upgrade.upgrade_time_of_day,
auto_upgrade.upgrade_mode,
auto_upgrade.exclude_plugins,
auto_upgrade.include_plugins,
category=args.category,
strategy_setting,
upgrade_time_of_day,
upgrade_mode,
exclude_plugins,
include_plugins,
)
if not set_auto_upgrade_strategy_result:
return jsonable_encoder({"success": False, "message": "Failed to set auto upgrade strategy"})
@ -783,32 +760,46 @@ class PluginChangeAutoUpgradeApi(Resource):
return jsonable_encoder({"success": True})
@console_ns.route("/workspaces/current/plugin/auto-upgrade/fetch")
class PluginFetchAutoUpgradeApi(Resource):
@console_ns.doc(params=query_params_from_model(ParserAutoUpgradeFetch))
@console_ns.route("/workspaces/current/plugin/preferences/fetch")
class PluginFetchPreferencesApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
args = ParserAutoUpgradeFetch.model_validate(request.args.to_dict(flat=True))
auto_upgrade = PluginAutoUpgradeService.get_strategy(tenant_id, args.category)
auto_upgrade_dict = (
_auto_upgrade_settings_to_dict(auto_upgrade)
if auto_upgrade
else _default_auto_upgrade_settings(tenant_id, args.category)
)
permission = PluginPermissionService.get_permission(tenant_id)
permission_dict = {
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
"debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
}
return jsonable_encoder(
{
"category": args.category,
"auto_upgrade": auto_upgrade_dict,
if permission:
permission_dict["install_permission"] = permission.install_permission
permission_dict["debug_permission"] = permission.debug_permission
auto_upgrade = PluginAutoUpgradeService.get_strategy(tenant_id)
auto_upgrade_dict = {
"strategy_setting": TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
"upgrade_time_of_day": 0,
"upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
"exclude_plugins": [],
"include_plugins": [],
}
if auto_upgrade:
auto_upgrade_dict = {
"strategy_setting": auto_upgrade.strategy_setting,
"upgrade_time_of_day": auto_upgrade.upgrade_time_of_day,
"upgrade_mode": auto_upgrade.upgrade_mode,
"exclude_plugins": auto_upgrade.exclude_plugins,
"include_plugins": auto_upgrade.include_plugins,
}
)
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
@console_ns.route("/workspaces/current/plugin/auto-upgrade/exclude")
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
class PluginAutoUpgradeExcludePluginApi(Resource):
@console_ns.expect(console_ns.models[ParserExcludePlugin.__name__])
@setup_required
@ -820,9 +811,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
args = ParserExcludePlugin.model_validate(console_ns.payload)
return jsonable_encoder(
{"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id, args.category)}
)
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id)})
@console_ns.route("/workspaces/current/plugin/readme")

View File

@ -1,614 +0,0 @@
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import Any
from flask import request
from flask_restx import Resource
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, ValidationError, field_validator
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import console_ns
from libs.login import current_account_with_tenant, login_required
from services.enterprise import rbac_service as svc
_LEGACY_ROLE_PERMISSION_KEYS: dict[str, list[str]] = {
# This is a compatibility projection from the pre-RBAC workspace roles into
# the 2.0 permission matrix documented in "权限整理2.0". It intentionally
# models the product-facing role surface for the new RBAC UI instead of the
# legacy backend's exact hard-authorization checks.
"owner": [
*svc._LEGACY_WORKSPACE_OWNER_KEYS,
*svc._LEGACY_APP_OWNER_KEYS,
*svc._LEGACY_DATASET_OWNER_KEYS,
],
"admin": [
*svc._LEGACY_WORKSPACE_ADMIN_KEYS,
*svc._LEGACY_APP_ADMIN_KEYS,
*svc._LEGACY_DATASET_ADMIN_KEYS,
],
"editor": [
*svc._LEGACY_WORKSPACE_EDITOR_KEYS,
*svc._LEGACY_APP_EDITOR_KEYS,
*svc._LEGACY_DATASET_EDITOR_KEYS,
],
"normal": [
*svc._LEGACY_WORKSPACE_NORMAL_KEYS,
*svc._LEGACY_APP_NORMAL_KEYS,
],
"dataset_operator": [
*svc._LEGACY_WORKSPACE_DATASET_OPERATOR_KEYS,
*svc._LEGACY_DATASET_DATASET_OPERATOR_KEYS,
],
}
def _current_ids() -> tuple[str, str]:
"""Return ``(tenant_id, account_id)`` for the authenticated user, or
raise a 404 when no tenant is associated with the session.
"""
user, tenant_id = current_account_with_tenant()
if not tenant_id:
raise NotFound("Current workspace not found")
return tenant_id, user.id
def _payload(model: type[BaseModel]) -> Any:
"""Validate the JSON body against ``model`` or raise ``ValidationError``.
``ValidationError`` bubbles up as HTTP 400 thanks to
``controllers/common/helpers.py`` error handling.
"""
try:
return model.model_validate(console_ns.payload or {})
except ValidationError as exc:
# Re-raise as-is so the upstream error handler renders a 400.
raise exc
def _dump(model: BaseModel) -> dict[str, Any]:
return model.model_dump(mode="json")
class _PaginationQuery(BaseModel):
model_config = ConfigDict(extra="ignore")
page_number: int | None = Field(default=None, ge=1, validation_alias=AliasChoices("page", "page_number"))
results_per_page: int | None = Field(
default=None, ge=1, le=100, validation_alias=AliasChoices("limit", "results_per_page")
)
reverse: bool | None = None
def to_inner_options(self) -> svc.ListOption:
return svc.ListOption.model_validate(self.model_dump())
class _RolesListQuery(_PaginationQuery):
include_owner: int = Field(default=0, ge=0, le=1)
def _pagination_options() -> svc.ListOption:
return _PaginationQuery.model_validate(request.args.to_dict(flat=True)).to_inner_options()
def _filter_out_owner(paginated: svc.Paginated[svc.RBACRole]) -> svc.Paginated[svc.RBACRole]:
filtered = [r for r in paginated.data if r.name not in {"所有者", "owner"}]
return svc.Paginated[svc.RBACRole](
data=filtered,
pagination=paginated.pagination,
)
def _legacy_workspace_roles(options: svc.ListOption | None = None) -> svc.Paginated[svc.RBACRole]:
"""Return the built-in legacy workspace roles in the RBAC list shape.
This keeps the new `/rbac/roles` endpoint compatible with the original
Dify role model when enterprise RBAC is disabled.
"""
legacy_roles = [
svc.RBACRole(
id=role_name,
tenant_id="",
type=svc.RBACRoleType.WORKSPACE.value,
category="global_system_default",
name=role_name,
description="",
is_builtin=True,
permission_keys=list(_LEGACY_ROLE_PERMISSION_KEYS[role_name]),
role_tag="owner" if role_name == "owner" else "",
)
for role_name in ("owner", "admin", "editor", "normal", "dataset_operator")
]
page_number = options.page_number if options and options.page_number is not None else 1
results_per_page = options.results_per_page if options and options.results_per_page is not None else len(legacy_roles)
reverse = options.reverse if options and options.reverse is not None else False
ordered_roles = list(reversed(legacy_roles)) if reverse else legacy_roles
start = max(page_number - 1, 0) * results_per_page
end = start + results_per_page
paged_roles = ordered_roles[start:end]
total_count = len(legacy_roles)
total_pages = (total_count + results_per_page - 1) // results_per_page if results_per_page > 0 else 0
return svc.Paginated[svc.RBACRole](
data=paged_roles,
pagination=svc.Pagination(
total_count=total_count,
per_page=results_per_page,
current_page=page_number,
total_pages=total_pages,
),
)
# ---------------------------------------------------------------------------
# Permission catalogs.
# ---------------------------------------------------------------------------
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog")
class RBACWorkspaceCatalogApi(Resource):
@login_required
def get(self):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.Catalog.workspace(tenant_id, account_id))
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog/app")
class RBACAppCatalogApi(Resource):
@login_required
def get(self):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.Catalog.app(tenant_id, account_id))
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog/dataset")
class RBACDatasetCatalogApi(Resource):
@login_required
def get(self):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.Catalog.dataset(tenant_id, account_id))
# ---------------------------------------------------------------------------
# Roles.
# ---------------------------------------------------------------------------
class _RoleUpsertRequest(BaseModel):
"""Accepts the payload sent by the Create/Edit Role dialog."""
name: str
description: str = ""
permission_keys: list[str] = []
def to_mutation(self) -> svc.RoleMutation:
return svc.RoleMutation(
name=self.name,
description=self.description,
permission_keys=list(self.permission_keys),
)
@console_ns.route("/workspaces/current/rbac/roles")
class RBACRolesApi(Resource):
@login_required
def get(self):
tenant_id, account_id = _current_ids()
query = _RolesListQuery.model_validate(request.args.to_dict(flat=True))
options = query.to_inner_options()
if not dify_config.RBAC_ENABLED:
result = _legacy_workspace_roles(options)
else:
result = svc.RBACService.Roles.list(tenant_id, account_id, options=options)
if query.include_owner == 0:
result = _filter_out_owner(result)
data = []
for role in result.data:
if role.name in {"所有者", "owner"}:
role.role_tag = "owner"
else:
role.role_tag = ""
data.append(role)
result.data = data
return _dump(result)
@login_required
def post(self):
tenant_id, account_id = _current_ids()
request = _payload(_RoleUpsertRequest)
role = svc.RBACService.Roles.create(tenant_id, account_id, request.to_mutation())
return _dump(role), 201
@console_ns.route("/workspaces/current/rbac/roles/<uuid:role_id>")
class RBACRoleItemApi(Resource):
@login_required
def get(self, role_id):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.Roles.get(tenant_id, account_id, str(role_id)))
@login_required
def put(self, role_id):
tenant_id, account_id = _current_ids()
request = _payload(_RoleUpsertRequest)
role = svc.RBACService.Roles.update(tenant_id, account_id, str(role_id), request.to_mutation())
return _dump(role)
@login_required
def delete(self, role_id):
tenant_id, account_id = _current_ids()
svc.RBACService.Roles.delete(tenant_id, account_id, str(role_id))
return {"result": "success"}
@console_ns.route("/workspaces/current/rbac/roles/<uuid:role_id>/copy")
class RBACRoleCopyApi(Resource):
@login_required
def post(self, role_id):
tenant_id, account_id = _current_ids()
role = svc.RBACService.Roles.copy(tenant_id, account_id, str(role_id))
return _dump(role), 201
# ---------------------------------------------------------------------------
# Access policies (tenant-level permission sets).
# ---------------------------------------------------------------------------
class _AccessPolicyCreateRequest(BaseModel):
name: str
resource_type: svc.RBACResourceType
description: str = ""
permission_keys: list[str] = []
class _AccessPolicyUpdateRequest(BaseModel):
name: str
description: str = ""
permission_keys: list[str] = []
@console_ns.route("/workspaces/current/rbac/access-policies")
class RBACAccessPoliciesApi(Resource):
@login_required
def get(self):
tenant_id, account_id = _current_ids()
# `resource_type` is exposed as a query argument so the UI can show
# only app-scoped or only dataset-scoped permission sets.
resource_type = request.args.get("resource_type") or None
return _dump(
svc.RBACService.AccessPolicies.list(
tenant_id,
account_id,
resource_type=resource_type,
options=_pagination_options(),
)
)
@login_required
def post(self):
tenant_id, account_id = _current_ids()
request = _payload(_AccessPolicyCreateRequest)
policy = svc.RBACService.AccessPolicies.create(
tenant_id,
account_id,
svc.AccessPolicyCreate(
name=request.name,
resource_type=request.resource_type,
description=request.description,
permission_keys=list(request.permission_keys),
),
)
return _dump(policy), 201
@console_ns.route("/workspaces/current/rbac/access-policies/<uuid:policy_id>")
class RBACAccessPolicyItemApi(Resource):
@login_required
def get(self, policy_id):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.AccessPolicies.get(tenant_id, account_id, str(policy_id)))
@login_required
def put(self, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_AccessPolicyUpdateRequest)
policy = svc.RBACService.AccessPolicies.update(
tenant_id,
account_id,
str(policy_id),
svc.AccessPolicyUpdate(
name=request.name,
description=request.description,
permission_keys=list(request.permission_keys),
),
)
return _dump(policy)
@login_required
def delete(self, policy_id):
tenant_id, account_id = _current_ids()
svc.RBACService.AccessPolicies.delete(tenant_id, account_id, str(policy_id))
return {"result": "success"}
@console_ns.route("/workspaces/current/rbac/access-policies/<uuid:policy_id>/copy")
class RBACAccessPolicyCopyApi(Resource):
@login_required
def post(self, policy_id):
tenant_id, account_id = _current_ids()
policy = svc.RBACService.AccessPolicies.copy(tenant_id, account_id, str(policy_id))
return _dump(policy), 201
# ---------------------------------------------------------------------------
# Per-app access (App Access Config).
# ---------------------------------------------------------------------------
class _ReplaceBindingsRequest(BaseModel):
role_ids: list[str] = []
account_ids: list[str] = []
@field_validator("role_ids", "account_ids", mode="before")
@classmethod
def _coerce_bindings(cls, value: Any) -> list[str]:
if value is None:
return []
return value
@console_ns.route("/workspaces/current/rbac/my-permissions")
class RBACMyPermissionsApi(Resource):
@login_required
def get(self):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.MyPermissions.get(
tenant_id,
account_id,
app_id=request.args.get("app_id") or None,
dataset_id=request.args.get("dataset_id") or None,
)
)
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policy")
class RBACAppMatrixApi(Resource):
@login_required
def get(self, app_id):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.AppAccess.matrix(tenant_id, account_id, str(app_id)))
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/role-bindings")
class RBACAppRoleBindingsApi(Resource):
@login_required
def get(self, app_id, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.AppAccess.list_role_bindings(tenant_id, account_id, str(app_id), str(policy_id))
)
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/member-bindings")
class RBACAppMemberBindingsApi(Resource):
@login_required
def get(self, app_id, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.AppAccess.list_member_bindings(tenant_id, account_id, str(app_id), str(policy_id))
)
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/bindings")
class RBACAppBindingsApi(Resource):
@login_required
def put(self, app_id, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceBindingsRequest)
return _dump(
svc.RBACService.AppAccess.replace_bindings(
tenant_id,
account_id,
str(app_id),
str(policy_id),
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
)
)
# ---------------------------------------------------------------------------
# Per-dataset access (Knowledge Base Access Config).
# ---------------------------------------------------------------------------
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policy")
class RBACDatasetMatrixApi(Resource):
@login_required
def get(self, dataset_id):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.DatasetAccess.matrix(tenant_id, account_id, str(dataset_id)))
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policies/<uuid:policy_id>/role-bindings")
class RBACDatasetRoleBindingsApi(Resource):
@login_required
def get(self, dataset_id, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.DatasetAccess.list_role_bindings(
tenant_id, account_id, str(dataset_id), str(policy_id)
)
)
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policies/<uuid:policy_id>/bindings")
class RBACDatasetBindingsApi(Resource):
@login_required
def put(self, dataset_id, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceBindingsRequest)
return _dump(
svc.RBACService.DatasetAccess.replace_bindings(
tenant_id,
account_id,
str(dataset_id),
str(policy_id),
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
)
)
@console_ns.route(
"/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policies/<uuid:policy_id>/member-bindings"
)
class RBACDatasetMemberBindingsApi(Resource):
@login_required
def get(self, dataset_id, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.DatasetAccess.list_member_bindings(
tenant_id, account_id, str(dataset_id), str(policy_id)
)
)
# ---------------------------------------------------------------------------
# Workspace-level access (Settings > Access Rules).
# ---------------------------------------------------------------------------
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policy")
class RBACWorkspaceAppMatrixApi(Resource):
@login_required
def get(self):
tenant_id, account_id = _current_ids()
options = _pagination_options()
return _dump(svc.RBACService.WorkspaceAccess.app_matrix(tenant_id, account_id, options=options))
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/role-bindings")
class RBACWorkspaceAppRoleBindingsApi(Resource):
@login_required
def get(self, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.WorkspaceAccess.list_app_role_bindings(tenant_id, account_id, str(policy_id))
)
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/bindings")
class RBACWorkspaceAppBindingsApi(Resource):
@login_required
def put(self, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceBindingsRequest)
return _dump(
svc.RBACService.WorkspaceAccess.replace_app_bindings(
tenant_id,
account_id,
str(policy_id),
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
)
)
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/member-bindings")
class RBACWorkspaceAppMemberBindingsApi(Resource):
@login_required
def get(self, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.WorkspaceAccess.list_app_member_bindings(tenant_id, account_id, str(policy_id))
)
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policy")
class RBACWorkspaceDatasetMatrixApi(Resource):
@login_required
def get(self):
tenant_id, account_id = _current_ids()
options = _pagination_options()
return _dump(svc.RBACService.WorkspaceAccess.dataset_matrix(tenant_id, account_id, options=options))
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/role-bindings")
class RBACWorkspaceDatasetRoleBindingsApi(Resource):
@login_required
def get(self, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.WorkspaceAccess.list_dataset_role_bindings(tenant_id, account_id, str(policy_id))
)
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/bindings")
class RBACWorkspaceDatasetBindingsApi(Resource):
@login_required
def put(self, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceBindingsRequest)
return _dump(
svc.RBACService.WorkspaceAccess.replace_dataset_bindings(
tenant_id,
account_id,
str(policy_id),
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
)
)
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/member-bindings")
class RBACWorkspaceDatasetMemberBindingsApi(Resource):
@login_required
def get(self, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.WorkspaceAccess.list_dataset_member_bindings(tenant_id, account_id, str(policy_id))
)
# ---------------------------------------------------------------------------
# Member ↔ role bindings (Settings > Members > Assign roles).
# ---------------------------------------------------------------------------
class _ReplaceMemberRolesRequest(BaseModel):
role_ids: list[str] = []
@field_validator("role_ids", mode="before")
@classmethod
def _coerce_role_ids(cls, value: Any) -> list[str]:
if value is None:
return []
return value
@console_ns.route("/workspaces/current/rbac/members/<uuid:member_id>/rbac-roles")
class RBACMemberRolesApi(Resource):
@login_required
def get(self, member_id):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.MemberRoles.get(tenant_id, account_id, str(member_id)))
@login_required
def put(self, member_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceMemberRolesRequest)
return _dump(
svc.RBACService.MemberRoles.replace(
tenant_id,
account_id,
str(member_id),
role_ids=list(request.role_ids),
)
)

View File

@ -1,380 +0,0 @@
import logging
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource, marshal
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.snippets.payloads import (
CreateSnippetPayload,
IncludeSecretQuery,
SnippetImportPayload,
SnippetListQuery,
UpdateSnippetPayload,
)
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from extensions.ext_database import db
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
from libs.login import current_account_with_tenant, login_required
from models.snippet import SnippetType
from services.app_dsl_service import ImportStatus
from services.snippet_dsl_service import SnippetDslService
from services.snippet_service import SnippetService
logger = logging.getLogger(__name__)
# Register Pydantic models with Swagger
register_schema_models(
console_ns,
SnippetListQuery,
CreateSnippetPayload,
UpdateSnippetPayload,
SnippetImportPayload,
IncludeSecretQuery,
)
# Create namespace models for marshaling
snippet_model = console_ns.model("Snippet", snippet_fields)
snippet_list_model = console_ns.model("SnippetList", snippet_list_fields)
snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields)
@console_ns.route("/workspaces/current/customized-snippets")
class CustomizedSnippetsApi(Resource):
@console_ns.doc("list_customized_snippets")
@console_ns.expect(console_ns.models.get(SnippetListQuery.__name__))
@console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model)
@setup_required
@login_required
@account_initialization_required
def get(self):
"""List customized snippets with pagination and search."""
_, current_tenant_id = current_account_with_tenant()
query_params = request.args.to_dict()
query = SnippetListQuery.model_validate(query_params)
snippets, total, has_more = SnippetService.get_snippets(
tenant_id=current_tenant_id,
page=query.page,
limit=query.limit,
keyword=query.keyword,
is_published=query.is_published,
creators=query.creators,
)
return {
"data": marshal(snippets, snippet_list_fields),
"page": query.page,
"limit": query.limit,
"total": total,
"has_more": has_more,
}, 200
@console_ns.doc("create_customized_snippet")
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
@console_ns.response(201, "Snippet created successfully", snippet_model)
@console_ns.response(400, "Invalid request or name already exists")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
"""Create a new customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
try:
snippet_type = SnippetType(payload.type)
except ValueError:
snippet_type = SnippetType.NODE
try:
snippet = SnippetService.create_snippet(
tenant_id=current_tenant_id,
name=payload.name,
description=payload.description,
snippet_type=snippet_type,
icon_info=payload.icon_info.model_dump() if payload.icon_info else None,
input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None,
account=current_user,
)
except ValueError as e:
return {"message": str(e)}, 400
return marshal(snippet, snippet_fields), 201
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>")
class CustomizedSnippetDetailApi(Resource):
@console_ns.doc("get_customized_snippet")
@console_ns.response(200, "Snippet retrieved successfully", snippet_model)
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
def get(self, snippet_id: str):
"""Get customized snippet details."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
return marshal(snippet, snippet_fields), 200
@console_ns.doc("update_customized_snippet")
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
@console_ns.response(200, "Snippet updated successfully", snippet_model)
@console_ns.response(400, "Invalid request or name already exists")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def patch(self, snippet_id: str):
"""Update customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
payload = UpdateSnippetPayload.model_validate(console_ns.payload or {})
update_data = payload.model_dump(exclude_unset=True)
if "icon_info" in update_data and update_data["icon_info"] is not None:
update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None
if not update_data:
return {"message": "No valid fields to update"}, 400
try:
with Session(db.engine, expire_on_commit=False) as session:
snippet = session.merge(snippet)
snippet = SnippetService.update_snippet(
session=session,
snippet=snippet,
account_id=current_user.id,
data=update_data,
)
session.commit()
except ValueError as e:
return {"message": str(e)}, 400
return marshal(snippet, snippet_fields), 200
@console_ns.doc("delete_customized_snippet")
@console_ns.response(204, "Snippet deleted successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, snippet_id: str):
"""Delete customized snippet."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
snippet = session.merge(snippet)
SnippetService.delete_snippet(
session=session,
snippet=snippet,
)
session.commit()
return "", 204
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/export")
class CustomizedSnippetExportApi(Resource):
@console_ns.doc("export_customized_snippet")
@console_ns.doc(description="Export snippet configuration as DSL")
@console_ns.doc(params={"snippet_id": "Snippet ID to export"})
@console_ns.response(200, "Snippet exported successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, snippet_id: str):
"""Export snippet as DSL."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
# Get include_secret parameter
query = IncludeSecretQuery.model_validate(request.args.to_dict())
with Session(db.engine) as session:
export_service = SnippetDslService(session)
result = export_service.export_snippet_dsl(snippet=snippet, include_secret=query.include_secret == "true")
# Set filename with .snippet extension
filename = f"{snippet.name}.snippet"
encoded_filename = quote(filename)
response = Response(
result,
mimetype="application/x-yaml",
)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/x-yaml"
return response
@console_ns.route("/workspaces/current/customized-snippets/imports")
class CustomizedSnippetImportApi(Resource):
@console_ns.doc("import_customized_snippet")
@console_ns.doc(description="Import snippet from DSL")
@console_ns.expect(console_ns.models.get(SnippetImportPayload.__name__))
@console_ns.response(200, "Snippet imported successfully")
@console_ns.response(202, "Import pending confirmation")
@console_ns.response(400, "Import failed")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
"""Import snippet from DSL."""
current_user, _ = current_account_with_tenant()
payload = SnippetImportPayload.model_validate(console_ns.payload or {})
with Session(db.engine) as session:
import_service = SnippetDslService(session)
result = import_service.import_snippet(
account=current_user,
import_mode=payload.mode,
yaml_content=payload.yaml_content,
yaml_url=payload.yaml_url,
snippet_id=payload.snippet_id,
name=payload.name,
description=payload.description,
)
session.commit()
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/customized-snippets/imports/<string:import_id>/confirm")
class CustomizedSnippetImportConfirmApi(Resource):
@console_ns.doc("confirm_snippet_import")
@console_ns.doc(description="Confirm a pending snippet import")
@console_ns.doc(params={"import_id": "Import ID to confirm"})
@console_ns.response(200, "Import confirmed successfully")
@console_ns.response(400, "Import failed")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, import_id: str):
"""Confirm a pending snippet import."""
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
import_service = SnippetDslService(session)
result = import_service.confirm_import(import_id=import_id, account=current_user)
session.commit()
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/check-dependencies")
class CustomizedSnippetCheckDependenciesApi(Resource):
@console_ns.doc("check_snippet_dependencies")
@console_ns.doc(description="Check dependencies for a snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID"})
@console_ns.response(200, "Dependencies checked successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, snippet_id: str):
"""Check dependencies for a snippet."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
import_service = SnippetDslService(session)
result = import_service.check_dependencies(snippet=snippet)
return result.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/use-count/increment")
class CustomizedSnippetUseCountIncrementApi(Resource):
@console_ns.doc("increment_snippet_use_count")
@console_ns.doc(description="Increment snippet use count by 1")
@console_ns.doc(params={"snippet_id": "Snippet ID"})
@console_ns.response(200, "Use count incremented successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, snippet_id: str):
"""Increment snippet use count when it is inserted into a workflow."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
snippet = session.merge(snippet)
SnippetService.increment_use_count(session=session, snippet=snippet)
session.commit()
session.refresh(snippet)
return {"result": "success", "use_count": snippet.use_count}, 200

View File

@ -29,7 +29,7 @@ from controllers.console.wraps import (
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import TimestampField
from libs.helper import TimestampField, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
from services.account_service import TenantService
@ -86,9 +86,7 @@ class TenantInfoResponse(ResponseModel):
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None):
if isinstance(value, datetime):
return int(value.timestamp())
return value
return to_timestamp(value)
register_schema_models(

View File

@ -20,10 +20,13 @@ class TenantUserPayload(BaseModel):
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
"""
Get current user
Get current user.
NOTE: user_id is not trusted, it could be maliciously set to any value.
As a result, it could only be considered as an end user id.
As a result, it could only be considered as an end user id. Even when a
concrete end-user ID is supplied, lookups must stay tenant-scoped so one
tenant cannot bind another tenant's user record into the plugin request
context.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
@ -42,7 +45,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
.limit(1)
)
else:
user_model = session.get(EndUser, user_id)
user_model = session.scalar(
select(EndUser)
.where(
EndUser.id == user_id,
EndUser.tenant_id == tenant_id,
)
.limit(1)
)
if not user_model:
user_model = EndUser(

View File

@ -22,7 +22,7 @@ from fields.conversation_fields import (
SimpleConversation,
)
from graphon.variables.types import SegmentType
from libs.helper import UUIDStrOrEmpty
from libs.helper import UUIDStrOrEmpty, to_timestamp
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
@ -115,9 +115,7 @@ class ConversationVariableResponse(ResponseModel):
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
return to_timestamp(value)
class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel):

View File

@ -7,18 +7,18 @@ paused human input forms in workflow/chatflow runs.
import json
import logging
from datetime import datetime
from flask import Response
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.helper import to_timestamp
from models.model import App, EndUser
from services.human_input_service import Form, FormNotFoundError, HumanInputService
@ -28,30 +28,14 @@ logger = logging.getLogger(__name__)
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
result: dict[str, str] = {}
for key, value in values.items():
if value is None:
result[key] = ""
elif isinstance(value, (dict, list)):
result[key] = json.dumps(value, ensure_ascii=False)
else:
result[key] = str(value)
return result
def _to_timestamp(value: datetime) -> int:
return int(value.timestamp())
def _jsonify_form_definition(form: Form) -> Response:
definition_payload = form.get_definition().model_dump()
payload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": _to_timestamp(form.expiration_time),
"expiration_time": to_timestamp(form.expiration_time),
}
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Literal
from typing import Literal
from dateutil.parser import isoparse
from flask import request
@ -39,6 +39,7 @@ from graphon.enums import WorkflowExecutionStatus
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import to_timestamp
from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
@ -68,12 +69,6 @@ class WorkflowLogQuery(BaseModel):
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
def _enum_value(value):
return getattr(value, "value", value)
@ -109,7 +104,7 @@ class WorkflowRunResponse(ResponseModel):
@field_validator("created_at", "finished_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
return to_timestamp(value)
class WorkflowRunForLogResponse(ResponseModel):
@ -133,31 +128,13 @@ class WorkflowRunForLogResponse(ResponseModel):
@field_validator("created_at", "finished_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class WorkflowAppLogEvaluationNodeInfoResponse(ResponseModel):
node_id: str
type: str
title: str
class WorkflowAppLogEvaluationItemResponse(ResponseModel):
name: str
value: Any = None
details: dict[str, Any] | None = None
node_info: WorkflowAppLogEvaluationNodeInfoResponse | None = Field(
default=None,
validation_alias="node_info",
serialization_alias="nodeInfo",
)
return to_timestamp(value)
class WorkflowAppLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForLogResponse | None = None
details: dict | list | str | int | float | bool | None = None
evaluation: list[WorkflowAppLogEvaluationItemResponse] = Field(default_factory=list)
created_from: str | None = None
created_by_role: str | None = None
created_by_account: SimpleAccount | None = None
@ -172,12 +149,7 @@ class WorkflowAppLogPartialResponse(ResponseModel):
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
@field_validator("evaluation", mode="before")
@classmethod
def _normalize_evaluation(cls, value: Any) -> list[dict[str, Any]] | list[WorkflowAppLogEvaluationItemResponse]:
return value or []
return to_timestamp(value)
class WorkflowAppLogPaginationResponse(ResponseModel):
@ -192,8 +164,6 @@ register_schema_models(
service_api_ns,
WorkflowRunResponse,
WorkflowRunForLogResponse,
WorkflowAppLogEvaluationNodeInfoResponse,
WorkflowAppLogEvaluationItemResponse,
WorkflowAppLogPartialResponse,
WorkflowAppLogPaginationResponse,
)

View File

@ -31,7 +31,9 @@ from services.tag_service import (
TagBindingCreatePayload,
TagBindingDeletePayload,
TagService,
UpdateTagPayload,
)
from services.tag_service import (
UpdateTagPayload as UpdateTagServicePayload,
)
register_enum_models(service_api_ns, DatasetPermissionEnum)
@ -556,7 +558,7 @@ class DatasetTagsApi(DatasetApiResource):
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
tag_id = payload.tag_id
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=TagType.KNOWLEDGE), tag_id)
tag = TagService.update_tags(UpdateTagServicePayload(name=payload.name), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)

View File

@ -23,7 +23,6 @@ from . import (
feature,
files,
forgot_password,
human_input_file_upload,
human_input_form,
login,
message,
@ -47,7 +46,6 @@ __all__ = [
"feature",
"files",
"forgot_password",
"human_input_file_upload",
"human_input_form",
"login",
"message",

View File

@ -1,181 +0,0 @@
import httpx
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, HttpUrl
import services
from controllers.common import helpers
from controllers.common.errors import (
BlockedFileExtensionError,
FileTooLargeError,
NoFileUploadedError,
RemoteFileUploadError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import FileResponse, FileWithSignedUrl
from graphon.file import helpers as file_helpers
from libs.exception import BaseHTTPException
from services.file_service import FileService
from services.human_input_file_upload_service import (
HITL_UPLOAD_TOKEN_PREFIX,
HumanInputFileUploadService,
InvalidUploadTokenError,
)
class InvalidUploadTokenBadRequestError(BaseHTTPException):
error_code = "invalid_upload_token"
description = "Invalid upload token."
code = 400
class InvalidUploadTokenUnauthorizedError(BaseHTTPException):
error_code = "invalid_upload_token"
description = "Upload token is required."
code = 401
class InvalidUploadTokenForbiddenError(BaseHTTPException):
error_code = "invalid_upload_token"
description = "Upload token is invalid or expired."
code = 403
class HumanInputRemoteFileUploadPayload(BaseModel):
url: HttpUrl = Field(description="Remote file URL")
register_schema_models(web_ns, HumanInputRemoteFileUploadPayload, FileResponse, FileWithSignedUrl)
def _extract_hitl_upload_token() -> str:
"""Read HITL upload token from Authorization without invoking other bearer auth chains."""
authorization = request.headers.get("Authorization")
if authorization is None:
raise InvalidUploadTokenUnauthorizedError()
parts = authorization.split()
if len(parts) != 2:
raise InvalidUploadTokenUnauthorizedError()
scheme, token = parts
if scheme.lower() != "bearer":
raise InvalidUploadTokenBadRequestError()
if not token:
raise InvalidUploadTokenUnauthorizedError()
if not token.startswith(HITL_UPLOAD_TOKEN_PREFIX):
raise InvalidUploadTokenBadRequestError()
return token
def _validate_context(service: HumanInputFileUploadService, token: str):
try:
return service.validate_upload_token(token)
except InvalidUploadTokenError as exc:
raise InvalidUploadTokenForbiddenError() from exc
def _parse_local_upload_file():
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
file = request.files["file"]
if not file.filename:
from controllers.common.errors import FilenameNotExistsError
raise FilenameNotExistsError()
return file
@web_ns.route("/form/human_input/files/upload")
class HumanInputFileUploadApi(Resource):
def post(self):
"""Upload one local file for a HITL human input form."""
token = _extract_hitl_upload_token()
upload_service = HumanInputFileUploadService(db.engine)
context = _validate_context(upload_service, token)
file = _parse_local_upload_file()
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename or "",
content=file.read(),
mimetype=file.mimetype,
user=context.owner,
source=None,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
except services.errors.file.BlockedFileExtensionError as exc:
raise BlockedFileExtensionError() from exc
upload_service.record_upload_file(context=context, file_id=upload_file.id)
response = FileResponse.model_validate(upload_file, from_attributes=True)
return response.model_dump(mode="json"), 201
@web_ns.route("/form/human_input/files/remote-upload")
class HumanInputRemoteFileUploadApi(Resource):
def post(self):
"""Upload one remote URL file for a HITL human input form."""
token = _extract_hitl_upload_token()
upload_service = HumanInputFileUploadService(db.engine)
context = _validate_context(upload_service, token)
payload = HumanInputRemoteFileUploadPayload.model_validate(request.get_json(silent=True) or {})
url = str(payload.url)
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as exc:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(exc)}")
file_info = helpers.guess_file_info_from_response(resp)
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError()
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=context.owner,
source_url=url,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
except services.errors.file.BlockedFileExtensionError as exc:
raise BlockedFileExtensionError() from exc
upload_service.record_upload_file(context=context, file_id=upload_file.id)
payload1 = FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
)
return payload1.model_dump(mode="json"), 201

View File

@ -4,39 +4,27 @@ Web App Human Input Form APIs.
import json
import logging
from datetime import datetime
from typing import Any, NotRequired, TypedDict
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
from controllers.web import web_ns
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
from controllers.web.site import serialize_app_site_payload
from extensions.ext_database import db
from libs.helper import RateLimiter, extract_remote_ip
from libs.helper import RateLimiter, extract_remote_ip, to_timestamp
from models.account import TenantStatus
from models.model import App, Site
from services.human_input_file_upload_service import HumanInputFileUploadService
from services.human_input_service import Form, FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
class HumanInputUploadTokenResponse(BaseModel):
upload_token: str
expires_at: int
register_schema_models(web_ns, HumanInputUploadTokenResponse)
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
prefix="web_form_submit_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
@ -47,27 +35,6 @@ _FORM_ACCESS_RATE_LIMITER = RateLimiter(
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
)
_FORM_UPLOAD_TOKEN_RATE_LIMITER = RateLimiter(
prefix="web_form_upload_token_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
)
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
result: dict[str, str] = {}
for key, value in values.items():
if value is None:
result[key] = ""
elif isinstance(value, (dict, list)):
result[key] = json.dumps(value, ensure_ascii=False)
else:
result[key] = str(value)
return result
def _to_timestamp(value: datetime) -> int:
return int(value.timestamp())
class FormDefinitionPayload(TypedDict):
@ -85,42 +52,15 @@ def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Re
payload: FormDefinitionPayload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": _to_timestamp(form.expiration_time),
"expiration_time": to_timestamp(form.expiration_time),
}
if site_payload is not None:
payload["site"] = site_payload
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
@web_ns.route("/form/human_input/<string:form_token>/upload-token")
class HumanInputFormUploadTokenApi(Resource):
"""API for issuing HITL upload tokens for active human input forms."""
def post(self, form_token: str):
"""
Issue an upload token for a human input form.
POST /api/form/human_input/<form_token>/upload-token
"""
ip_address = extract_remote_ip(request)
if _FORM_UPLOAD_TOKEN_RATE_LIMITER.is_rate_limited(ip_address):
raise WebFormRateLimitExceededError()
_FORM_UPLOAD_TOKEN_RATE_LIMITER.increment_rate_limit(ip_address)
try:
token = HumanInputFileUploadService(db.engine).issue_upload_token(form_token)
except FormNotFoundError:
raise NotFoundError("Form not found")
response = HumanInputUploadTokenResponse(
upload_token=token.upload_token,
expires_at=_to_timestamp(token.expires_at),
)
return response.model_dump(mode="json"), 200
@web_ns.route("/form/human_input/<string:form_token>")
class HumanInputFormApi(Resource):
"""API for getting and submitting human input forms via the web app."""

View File

@ -9,7 +9,7 @@ from datetime import datetime
from threading import Thread
from typing import Any, Union
from sqlalchemy import select
from sqlalchemy import select, update
from sqlalchemy.orm import Session, sessionmaker
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@ -245,49 +245,50 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""
human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, HumanInputRequiredResponse):
human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse):
return AdvancedChatPausedBlockingResponse(
task_id=stream_response.task_id,
data=AdvancedChatPausedBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
workflow_run_id=stream_response.data.workflow_run_id,
answer=self._task_state.answer,
metadata=self._message_end_to_stream_response().metadata,
created_at=self._message_created_at,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
status=stream_response.data.status,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
),
)
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {}
if stream_response.metadata:
extras["metadata"] = stream_response.metadata
match stream_response:
case ErrorStreamResponse():
raise stream_response.err
case HumanInputRequiredResponse():
human_input_responses.append(stream_response)
case WorkflowPauseStreamResponse():
return AdvancedChatPausedBlockingResponse(
task_id=stream_response.task_id,
data=AdvancedChatPausedBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
workflow_run_id=stream_response.data.workflow_run_id,
answer=self._task_state.answer,
metadata=self._message_end_to_stream_response().metadata,
created_at=self._message_created_at,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
status=stream_response.data.status,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
),
)
case MessageEndStreamResponse():
extras = {}
if stream_response.metadata:
extras["metadata"] = stream_response.metadata
return ChatbotAppBlockingResponse(
task_id=stream_response.task_id,
data=ChatbotAppBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
answer=self._task_state.answer,
created_at=self._message_created_at,
**extras,
),
)
else:
continue
return ChatbotAppBlockingResponse(
task_id=stream_response.task_id,
data=ChatbotAppBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
answer=self._task_state.answer,
created_at=self._message_created_at,
**extras,
),
)
case _:
continue
if human_input_responses:
return self._build_paused_blocking_response_from_human_input(human_input_responses)
@ -425,11 +426,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._workflow_run_id = run_id
with self._database_session() as session:
message = self._get_message(session=session)
if not message:
raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = run_id
session.execute(update(Message).where(Message.id == self._message_id).values(workflow_run_id=run_id))
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,

View File

@ -178,7 +178,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
if not isinstance(agent_mode, dict):
raise ValueError("agent_mode must be of object type")
# FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing
# FIXME(-LAN-): Cast needed because static checkers do not narrow this dict value.
agent_mode = cast(dict[str, Any], agent_mode)
if "enabled" not in agent_mode or not agent_mode["enabled"]:

View File

@ -408,19 +408,17 @@ class WorkflowResponseConverter:
self, *, event: QueueHumanInputFormFilledEvent, task_id: str
) -> HumanInputFormFilledResponse:
run_id = self._ensure_workflow_run_id()
data = HumanInputFormFilledResponse.Data(
node_id=event.node_id,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
return HumanInputFormFilledResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputFormFilledResponse.Data(
node_id=event.node_id,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
),
)
if event.submitted_data is not None:
runtime_type_converter = WorkflowRuntimeTypeConverter()
data.submitted_data = runtime_type_converter.value_to_json_encodable_recursive(event.submitted_data)
return HumanInputFormFilledResponse(task_id=task_id, workflow_run_id=run_id, data=data)
def human_input_form_timeout_to_stream_response(
self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str

View File

@ -45,20 +45,24 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[Workflow
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
match sub_stream_response:
case PingStreamResponse():
yield "ping"
continue
case ErrorStreamResponse():
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(cast(dict, data))
case _:
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
response_chunk.update(sub_stream_response.model_dump())
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(cast(dict, data))
else:
response_chunk.update(sub_stream_response.model_dump())
yield response_chunk
@classmethod
@ -74,20 +78,28 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[Workflow
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
match sub_stream_response:
case PingStreamResponse():
yield "ping"
continue
case ErrorStreamResponse():
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(cast(dict, data))
case NodeStartStreamResponse() | NodeFinishStreamResponse():
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
case _:
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
response_chunk.update(sub_stream_response.model_dump())
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(cast(dict, data))
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
else:
response_chunk.update(sub_stream_response.model_dump())
yield response_chunk

View File

@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Literal, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker
import contexts
from configs import dify_config
@ -58,25 +58,6 @@ logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator):
@staticmethod
def _ensure_snippet_start_node_in_worker(*, session: Session, workflow: Workflow) -> Workflow:
"""Re-apply snippet virtual Start injection after worker reloads workflow from DB."""
if workflow.kind_or_standard != "snippet":
return workflow
from models.snippet import CustomizedSnippet
from services.snippet_generate_service import SnippetGenerateService
snippet = session.scalar(
select(CustomizedSnippet).where(
CustomizedSnippet.id == workflow.app_id,
CustomizedSnippet.tenant_id == workflow.tenant_id,
)
)
if snippet is None:
return workflow
return SnippetGenerateService.ensure_start_node_for_worker(workflow, snippet)
@staticmethod
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
@ -594,8 +575,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
if workflow is None:
raise ValueError("Workflow not found")
workflow = self._ensure_snippet_start_node_in_worker(session=session, workflow=workflow)
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,

View File

@ -10,7 +10,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
from core.workflow.node_factory import get_default_root_node_id
from core.workflow.snippet_start import get_compatible_start_aliases
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
from core.workflow.workflow_entry import WorkflowEntry
@ -116,15 +115,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
),
)
root_node_id = self._root_node_id or get_default_root_node_id(self._workflow.graph_dict)
add_node_inputs_to_pool(
variable_pool,
node_id=root_node_id,
inputs=inputs,
aliases=get_compatible_start_aliases(
workflow_kind=getattr(self._workflow, "kind_or_standard", None),
root_node_id=root_node_id,
),
)
add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph = self._init_graph(

View File

@ -52,20 +52,24 @@ class WorkflowAppGenerateResponseConverter(
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
match sub_stream_response:
case PingStreamResponse():
yield "ping"
continue
case ErrorStreamResponse():
response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
case _:
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
response_chunk.update(sub_stream_response.model_dump(mode="json"))
response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -81,20 +85,28 @@ class WorkflowAppGenerateResponseConverter(
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
match sub_stream_response:
case PingStreamResponse():
yield "ping"
continue
case ErrorStreamResponse():
response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
case NodeStartStreamResponse() | NodeFinishStreamResponse():
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
case _:
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
response_chunk.update(sub_stream_response.model_dump(mode="json"))
response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -145,50 +145,51 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""
human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, HumanInputRequiredResponse):
human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse):
return WorkflowAppPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.workflow_run_id,
data=WorkflowAppPausedBlockingResponse.Data(
id=stream_response.data.workflow_run_id,
workflow_id=self._workflow.id,
status=stream_response.data.status,
outputs=stream_response.data.outputs or {},
error=None,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=stream_response.data.created_at,
finished_at=None,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
),
)
elif isinstance(stream_response, WorkflowFinishStreamResponse):
return WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.id,
data=WorkflowAppBlockingResponse.Data(
id=stream_response.data.id,
workflow_id=stream_response.data.workflow_id,
status=stream_response.data.status,
outputs=stream_response.data.outputs,
error=stream_response.data.error,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at) if stream_response.data.finished_at else None,
),
)
else:
continue
match stream_response:
case ErrorStreamResponse():
raise stream_response.err
case HumanInputRequiredResponse():
human_input_responses.append(stream_response)
case WorkflowPauseStreamResponse():
return WorkflowAppPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.workflow_run_id,
data=WorkflowAppPausedBlockingResponse.Data(
id=stream_response.data.workflow_run_id,
workflow_id=self._workflow.id,
status=stream_response.data.status,
outputs=stream_response.data.outputs or {},
error=None,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=stream_response.data.created_at,
finished_at=None,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
),
)
case WorkflowFinishStreamResponse():
return WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.id,
data=WorkflowAppBlockingResponse.Data(
id=stream_response.data.id,
workflow_id=stream_response.data.workflow_id,
status=stream_response.data.status,
outputs=stream_response.data.outputs,
error=stream_response.data.error,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at)
if stream_response.data.finished_at
else None,
),
)
case _:
continue
if human_input_responses:
return self._build_paused_blocking_response_from_human_input(human_input_responses)

View File

@ -399,279 +399,281 @@ class WorkflowBasedAppRunner:
:param workflow_entry: workflow entry
:param event: event
"""
if isinstance(event, GraphRunStartedEvent):
self._publish_event(QueueWorkflowStartedEvent(reason=event.reason))
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
self._publish_event(
QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count)
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, GraphRunAbortedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
elif isinstance(event, GraphRunPausedEvent):
runtime_state = workflow_entry.graph_engine.graph_runtime_state
paused_nodes = runtime_state.get_paused_nodes()
self._enqueue_human_input_notifications(event.reasons)
self._publish_event(
QueueWorkflowPausedEvent(
reasons=event.reasons,
outputs=event.outputs,
paused_nodes=paused_nodes,
match event:
case GraphRunStartedEvent():
self._publish_event(QueueWorkflowStartedEvent(reason=event.reason))
case GraphRunSucceededEvent():
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
case GraphRunPartialSucceededEvent():
self._publish_event(
QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count)
)
)
elif isinstance(event, NodeRunHumanInputFormFilledEvent):
self._publish_event(
QueueHumanInputFormFilledEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
submitted_data=event.submitted_data,
case GraphRunFailedEvent():
self._publish_event(
QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)
)
)
elif isinstance(event, NodeRunHumanInputFormTimeoutEvent):
self._publish_event(
QueueHumanInputFormTimeoutEvent(
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
expiration_time=event.expiration_time,
case GraphRunAbortedEvent():
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
case GraphRunPausedEvent():
runtime_state = workflow_entry.graph_engine.graph_runtime_state
paused_nodes = runtime_state.get_paused_nodes()
self._enqueue_human_input_notifications(event.reasons)
self._publish_event(
QueueWorkflowPausedEvent(
reasons=event.reasons,
outputs=event.outputs,
paused_nodes=paused_nodes,
)
)
)
elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.node_run_result
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = project_node_outputs_for_workflow_run(
node_type=event.node_type,
inputs=inputs,
outputs=node_run_result.outputs,
)
execution_metadata = node_run_result.metadata
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_title=event.node_title,
case NodeRunHumanInputFormFilledEvent():
self._publish_event(
QueueHumanInputFormFilledEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
)
)
case NodeRunHumanInputFormTimeoutEvent():
self._publish_event(
QueueHumanInputFormTimeoutEvent(
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
expiration_time=event.expiration_time,
)
)
case NodeRunRetryEvent():
node_run_result = event.node_run_result
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = project_node_outputs_for_workflow_run(
node_type=event.node_type,
start_at=event.start_at,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
inputs=inputs,
process_data=process_data,
outputs=outputs,
error=event.error,
execution_metadata=execution_metadata,
retry_index=event.retry_index,
provider_type=event.provider_type,
provider_id=event.provider_id,
outputs=node_run_result.outputs,
)
)
elif isinstance(event, NodeRunStartedEvent):
self._publish_event(
QueueNodeStartedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_title=event.node_title,
node_type=event.node_type,
start_at=event.start_at,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
agent_strategy=self._build_agent_strategy_info(event),
provider_type=event.provider_type,
provider_id=event.provider_id,
execution_metadata = node_run_result.metadata
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_title=event.node_title,
node_type=event.node_type,
start_at=event.start_at,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
inputs=inputs,
process_data=process_data,
outputs=outputs,
error=event.error,
execution_metadata=execution_metadata,
retry_index=event.retry_index,
provider_type=event.provider_type,
provider_id=event.provider_id,
)
)
)
elif isinstance(event, NodeRunSucceededEvent):
node_run_result = event.node_run_result
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = project_node_outputs_for_workflow_run(
node_type=event.node_type,
inputs=inputs,
outputs=node_run_result.outputs,
)
execution_metadata = node_run_result.metadata
self._publish_event(
QueueNodeSucceededEvent(
node_execution_id=event.id,
node_id=event.node_id,
case NodeRunStartedEvent():
self._publish_event(
QueueNodeStartedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_title=event.node_title,
node_type=event.node_type,
start_at=event.start_at,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
agent_strategy=self._build_agent_strategy_info(event),
provider_type=event.provider_type,
provider_id=event.provider_id,
)
)
case NodeRunSucceededEvent():
node_run_result = event.node_run_result
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = project_node_outputs_for_workflow_run(
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
outputs=node_run_result.outputs,
)
)
elif isinstance(event, NodeRunFailedEvent):
outputs = project_node_outputs_for_workflow_run(
node_type=event.node_type,
inputs=event.node_run_result.inputs,
outputs=event.node_run_result.outputs,
)
self._publish_event(
QueueNodeFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
execution_metadata = node_run_result.metadata
self._publish_event(
QueueNodeSucceededEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
case NodeRunFailedEvent():
outputs = project_node_outputs_for_workflow_run(
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=outputs,
error=event.node_run_result.error or "Unknown error",
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
outputs=event.node_run_result.outputs,
)
)
elif isinstance(event, NodeRunExceptionEvent):
outputs = project_node_outputs_for_workflow_run(
node_type=event.node_type,
inputs=event.node_run_result.inputs,
outputs=event.node_run_result.outputs,
)
self._publish_event(
QueueNodeExceptionEvent(
node_execution_id=event.id,
node_id=event.node_id,
self._publish_event(
QueueNodeFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=outputs,
error=event.node_run_result.error or "Unknown error",
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
case NodeRunExceptionEvent():
outputs = project_node_outputs_for_workflow_run(
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=outputs,
error=event.node_run_result.error or "Unknown error",
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
outputs=event.node_run_result.outputs,
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
text=event.chunk,
from_variable_selector=list(event.selector),
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
self._publish_event(
QueueNodeExceptionEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=outputs,
error=event.node_run_result.error or "Unknown error",
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
)
elif isinstance(event, NodeRunRetrieverResourceEvent):
self._publish_event(
QueueRetrieverResourcesEvent(
retriever_resources=[
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
],
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
case NodeRunStreamChunkEvent():
self._publish_event(
QueueTextChunkEvent(
text=event.chunk,
from_variable_selector=list(event.selector),
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
)
elif isinstance(event, NodeRunAgentLogEvent):
self._publish_event(
QueueAgentLogEvent(
id=event.message_id,
label=event.label,
node_execution_id=event.node_execution_id,
parent_id=event.parent_id,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
node_id=event.node_id,
case NodeRunRetrieverResourceEvent():
self._publish_event(
QueueRetrieverResourcesEvent(
retriever_resources=[
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
],
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
)
elif isinstance(event, NodeRunIterationStartedEvent):
self._publish_event(
QueueIterationStartEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
metadata=event.metadata,
case NodeRunAgentLogEvent():
self._publish_event(
QueueAgentLogEvent(
id=event.message_id,
label=event.label,
node_execution_id=event.node_execution_id,
parent_id=event.parent_id,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
node_id=event.node_id,
)
)
)
elif isinstance(event, NodeRunIterationNextEvent):
self._publish_event(
QueueIterationNextEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
case NodeRunIterationStartedEvent():
self._publish_event(
QueueIterationStartEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
metadata=event.metadata,
)
)
)
elif isinstance(event, (NodeRunIterationSucceededEvent | NodeRunIterationFailedEvent)):
self._publish_event(
QueueIterationCompletedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
case NodeRunIterationNextEvent():
self._publish_event(
QueueIterationNextEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
)
)
)
elif isinstance(event, NodeRunLoopStartedEvent):
self._publish_event(
QueueLoopStartEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
metadata=event.metadata,
case NodeRunIterationSucceededEvent() | NodeRunIterationFailedEvent():
self._publish_event(
QueueIterationCompletedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
)
)
)
elif isinstance(event, NodeRunLoopNextEvent):
self._publish_event(
QueueLoopNextEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_loop_output,
case NodeRunLoopStartedEvent():
self._publish_event(
QueueLoopStartEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
metadata=event.metadata,
)
)
)
elif isinstance(event, (NodeRunLoopSucceededEvent | NodeRunLoopFailedEvent)):
self._publish_event(
QueueLoopCompletedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
case NodeRunLoopNextEvent():
self._publish_event(
QueueLoopNextEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_loop_output,
)
)
case NodeRunLoopSucceededEvent() | NodeRunLoopFailedEvent():
self._publish_event(
QueueLoopCompletedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
)
)
)
def _enqueue_human_input_notifications(self, reasons: Sequence[object]) -> None:
for reason in reasons:

View File

@ -11,7 +11,6 @@ from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import PauseReason
from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from graphon.variables.segments import Segment
class QueueEvent(StrEnum):
@ -509,10 +508,6 @@ class QueueHumanInputFormFilledEvent(AppQueueEvent):
action_id: str
action_text: str
# Keep the field name aligned with Graphon so the app-layer bridge does not
# need to translate between two equivalent payload names.
submitted_data: Mapping[str, Segment] | None = None
class QueueHumanInputFormTimeoutEvent(AppQueueEvent):
"""

View File

@ -342,8 +342,6 @@ class HumanInputFormFilledResponse(StreamResponse):
action_id: str
action_text: str
submitted_data: Mapping[str, Any] | None = None
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_FILLED
workflow_run_id: str
data: Data

View File

@ -1,6 +1,13 @@
from .controller import DatabaseFileAccessController
from .protocols import FileAccessControllerProtocol
from .scope import FileAccessScope, bind_file_access_scope, get_current_file_access_scope
from .scope import (
FileAccessScope,
bind_file_access_scope,
get_current_file_access_scope,
grant_retriever_segment_access,
grant_upload_file_access,
is_retriever_segment_access_granted,
)
__all__ = [
"DatabaseFileAccessController",
@ -8,4 +15,7 @@ __all__ = [
"FileAccessScope",
"bind_file_access_scope",
"get_current_file_access_scope",
"grant_retriever_segment_access",
"grant_upload_file_access",
"is_retriever_segment_access_granted",
]

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from collections.abc import Callable
from sqlalchemy import select
from sqlalchemy import and_, or_, select
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
@ -18,7 +18,8 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
Tenant scoping remains mandatory. When the current execution belongs to an
end user, the lookup is additionally constrained to that end user's file
ownership markers.
ownership markers, plus upload files explicitly granted by the current
execution context.
"""
_scope_getter: Callable[[], FileAccessScope | None]
@ -47,10 +48,19 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
if not resolved_scope.requires_user_ownership:
return scoped_stmt
return scoped_stmt.where(
user_owned_filter = and_(
UploadFile.created_by_role == CreatorUserRole.END_USER,
UploadFile.created_by == resolved_scope.user_id,
)
if not resolved_scope.granted_upload_file_ids:
return scoped_stmt.where(user_owned_filter)
return scoped_stmt.where(
or_(
user_owned_filter,
UploadFile.id.in_(resolved_scope.granted_upload_file_ids),
)
)
def apply_tool_file_filters(
self,

View File

@ -1,9 +1,9 @@
from __future__ import annotations
from collections.abc import Generator # Changed from Iterator
from collections.abc import Generator, Iterable
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from dataclasses import dataclass, field, replace
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
@ -15,12 +15,23 @@ _current_file_access_scope: ContextVar[FileAccessScope | None] = ContextVar(
@dataclass(frozen=True, slots=True)
class FileAccessScope:
"""Request-scoped ownership context used by workflow-layer file lookups."""
"""Request-scoped ownership context used by workflow-layer file lookups.
``granted_upload_file_ids`` is execution-local: callers may add upload files
that were returned by trusted retrieval paths without changing persistent
ownership markers.
``granted_retriever_segment_ids`` gates lazy attachment loading by segment
ID, so user-provided context cannot make a later LLM node load arbitrary
same-tenant knowledge attachments.
"""
tenant_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
granted_upload_file_ids: frozenset[str] = field(default_factory=frozenset)
granted_retriever_segment_ids: frozenset[str] = field(default_factory=frozenset)
@property
def requires_user_ownership(self) -> bool:
@ -31,8 +42,49 @@ def get_current_file_access_scope() -> FileAccessScope | None:
return _current_file_access_scope.get()
def grant_upload_file_access(upload_file_ids: Iterable[str]) -> None:
scope = _current_file_access_scope.get()
if scope is None:
return
granted_upload_file_ids = frozenset(str(file_id) for file_id in upload_file_ids if file_id)
if not granted_upload_file_ids:
return
_current_file_access_scope.set(
replace(
scope,
granted_upload_file_ids=scope.granted_upload_file_ids | granted_upload_file_ids,
)
)
def grant_retriever_segment_access(segment_ids: Iterable[str]) -> None:
scope = _current_file_access_scope.get()
if scope is None:
return
granted_segment_ids = frozenset(str(segment_id) for segment_id in segment_ids if segment_id)
if not granted_segment_ids:
return
_current_file_access_scope.set(
replace(
scope,
granted_retriever_segment_ids=scope.granted_retriever_segment_ids | granted_segment_ids,
)
)
def is_retriever_segment_access_granted(segment_id: str) -> bool:
scope = _current_file_access_scope.get()
if scope is None or not scope.requires_user_ownership:
return True
return str(segment_id) in scope.granted_retriever_segment_ids
@contextmanager
def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None]
def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]:
token = _current_file_access_scope.set(scope)
try:
yield

View File

@ -140,42 +140,43 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:return:
"""
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.model_dump()
response: ChatbotAppBlockingResponse | CompletionAppBlockingResponse
if self._conversation_mode == AppMode.COMPLETION:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=CompletionAppBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
message_id=self._message_id,
answer=self._task_state.llm_result.message.get_text_content(),
created_at=self._message_created_at,
**extras,
),
)
else:
response = ChatbotAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=ChatbotAppBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
answer=self._task_state.llm_result.message.get_text_content(),
created_at=self._message_created_at,
**extras,
),
)
match stream_response:
case ErrorStreamResponse():
raise stream_response.err
case MessageEndStreamResponse():
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.model_dump()
response: ChatbotAppBlockingResponse | CompletionAppBlockingResponse
if self._conversation_mode == AppMode.COMPLETION:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=CompletionAppBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
message_id=self._message_id,
answer=self._task_state.llm_result.message.get_text_content(),
created_at=self._message_created_at,
**extras,
),
)
else:
response = ChatbotAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=ChatbotAppBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
answer=self._task_state.llm_result.message.get_text_content(),
created_at=self._message_created_at,
**extras,
),
)
return response
else:
continue
return response
case _:
continue
raise RuntimeError("queue listening stopped unexpectedly.")
@ -265,104 +266,107 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
publisher.publish(message)
event = message.event
if isinstance(event, QueueErrorEvent):
with sessionmaker(bind=db.engine).begin() as session:
err = self.handle_error(event=event, session=session, message_id=self._message_id)
yield self.error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
if isinstance(event, QueueMessageEndEvent):
if event.llm_result:
self._task_state.llm_result = event.llm_result
else:
self._handle_stop(event)
match event:
case QueueErrorEvent():
with sessionmaker(bind=db.engine).begin() as session:
err = self.handle_error(event=event, session=session, message_id=self._message_id)
yield self.error_to_stream_response(err)
break
case QueueStopEvent() | QueueMessageEndEvent():
if isinstance(event, QueueMessageEndEvent):
if event.llm_result:
self._task_state.llm_result = event.llm_result
else:
self._handle_stop(event)
# handle output moderation
output_moderation_answer = self.handle_output_moderation_when_task_finished(
self._task_state.llm_result.message.get_text_content()
)
if output_moderation_answer:
self._task_state.llm_result.message.content = output_moderation_answer
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=output_moderation_answer
# handle output moderation
output_moderation_answer = self.handle_output_moderation_when_task_finished(
self._task_state.llm_result.message.get_text_content()
)
with sessionmaker(bind=db.engine).begin() as session:
# Save message
self._save_message(session=session, trace_manager=trace_manager)
message_end_resp = self._message_end_to_stream_response()
yield message_end_resp
elif isinstance(event, QueueRetrieverResourcesEvent):
self._message_cycle_manager.handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent):
annotation = self._message_cycle_manager.handle_annotation_reply(event)
if annotation:
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, QueueAgentThoughtEvent):
agent_thought_response = self._agent_thought_to_stream_response(event)
if agent_thought_response is not None:
yield agent_thought_response
elif isinstance(event, QueueMessageFileEvent):
response = self._message_cycle_manager.message_file_to_stream_response(event)
if response:
yield response
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
chunk = event.chunk
delta_text = chunk.delta.message.content
if delta_text is None:
continue
if isinstance(chunk.delta.message.content, list):
delta_text = ""
for content in chunk.delta.message.content:
logger.debug(
"The content type %s in LLM chunk delta message content.: %r", type(content), content
if output_moderation_answer:
self._task_state.llm_result.message.content = output_moderation_answer
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=output_moderation_answer
)
if isinstance(content, TextPromptMessageContent):
delta_text += content.data
elif isinstance(content, str):
delta_text += content # failback to str
else:
logger.warning(
"Unsupported content type %s in LLM chunk delta message content.: %r",
type(content),
content,
with sessionmaker(bind=db.engine).begin() as session:
# Save message
self._save_message(session=session, trace_manager=trace_manager)
message_end_resp = self._message_end_to_stream_response()
yield message_end_resp
case QueueRetrieverResourcesEvent():
self._message_cycle_manager.handle_retriever_resources(event)
case QueueAnnotationReplyEvent():
annotation = self._message_cycle_manager.handle_annotation_reply(event)
if annotation:
self._task_state.llm_result.message.content = annotation.content
case QueueAgentThoughtEvent():
agent_thought_response = self._agent_thought_to_stream_response(event)
if agent_thought_response is not None:
yield agent_thought_response
case QueueMessageFileEvent():
response = self._message_cycle_manager.message_file_to_stream_response(event)
if response:
yield response
case QueueLLMChunkEvent() | QueueAgentMessageEvent():
chunk = event.chunk
delta_text = chunk.delta.message.content
if delta_text is None:
continue
if isinstance(chunk.delta.message.content, list):
delta_text = ""
for content in chunk.delta.message.content:
logger.debug(
"The content type %s in LLM chunk delta message content.: %r", type(content), content
)
continue
match content:
case TextPromptMessageContent():
delta_text += content.data
case str():
delta_text += content # failback to str
case _:
logger.warning(
"Unsupported content type %s in LLM chunk delta message content.: %r",
type(content),
content,
)
continue
if not self._task_state.llm_result.prompt_messages:
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
if not self._task_state.llm_result.prompt_messages:
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
if should_direct_answer:
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
if should_direct_answer:
continue
current_content = cast(str, self._task_state.llm_result.message.content)
current_content += cast(str, delta_text)
self._task_state.llm_result.message.content = current_content
match event:
case QueueLLMChunkEvent():
# Determine the event type once, on first LLM chunk, and reuse for subsequent chunks
if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None:
self._precomputed_event_type = self._message_cycle_manager.get_message_event_type(
message_id=self._message_id
)
yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
event_type=self._precomputed_event_type,
)
case _:
yield self._agent_message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
)
case QueueMessageReplaceEvent():
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
case QueuePingEvent():
yield self.ping_stream_response()
case _:
continue
current_content = cast(str, self._task_state.llm_result.message.content)
current_content += cast(str, delta_text)
self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
# Determine the event type once, on first LLM chunk, and reuse for subsequent chunks
if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None:
self._precomputed_event_type = self._message_cycle_manager.get_message_event_type(
message_id=self._message_id
)
yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
event_type=self._precomputed_event_type,
)
else:
yield self._agent_message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
)
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
yield self.ping_stream_response()
else:
continue
if publisher:
publisher.publish(None)
if self._conversation_name_generate_thread:

View File

@ -95,13 +95,14 @@ class AppGeneratorTTSPublisher:
message_content = message.event.chunk.delta.message.content
if not message_content:
continue
if isinstance(message_content, str):
self.msg_text += message_content
elif isinstance(message_content, list):
for content in message_content:
if not isinstance(content, TextPromptMessageContent):
continue
self.msg_text += content.data
match message_content:
case str():
self.msg_text += message_content
case list():
for content in message_content:
if not isinstance(content, TextPromptMessageContent):
continue
self.msg_text += content.data
elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text
elif isinstance(message.event, QueueNodeSucceededEvent):

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any, TypeAlias
from pydantic import BaseModel, ConfigDict, Field, JsonValue
from pydantic import BaseModel, ConfigDict, Field
from graphon.nodes.human_input.entities import FormInputConfig, UserActionConfig
from models.execution_extra_content import ExecutionContentType
@ -19,8 +19,6 @@ class HumanInputFormDefinition(BaseModel):
inputs: Sequence[FormInputConfig] = Field(default_factory=list)
actions: Sequence[UserActionConfig] = Field(default_factory=list)
display_in_ui: bool = False
# `form_token` is `None` if the corresponding form has been submitted.
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int
@ -31,31 +29,16 @@ class HumanInputFormSubmissionData(BaseModel):
node_id: str
node_title: str
# deprecate: the rendered_content is deprecated and only for historical reasons.
rendered_content: str
# The identifier of action user has chosen.
action_id: str
# The button text of the action user has chosen.
action_text: str
# submitted_data records the submitted form data.
# Keys correspond to `output_variable_name` of HumanInput inputs.
# Values are serialized JSON forms of runtime values, including file dictionaries.
#
# For form submitted before this field is introduced, this field is populated from
# the stored submission data.
submitted_data: Mapping[str, JsonValue] | None = None
class HumanInputContent(BaseModel):
model_config = ConfigDict(frozen=True)
workflow_run_id: str
submitted: bool
# Both the form_defintion and the form_submission_data are present in
# HumanInputContent. For historical records, the
form_definition: HumanInputFormDefinition | None = None
form_submission_data: HumanInputFormSubmissionData | None = None
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)

View File

@ -171,9 +171,9 @@ class ProviderConfiguration(BaseModel):
current_credential_id = self.custom_configuration.provider.current_credential_id
if current_credential_id:
from core.helper.credential_utils import check_credential_policy_compliance
from core.helper.credential_utils import runtime_check_credential_policy_compliance
check_credential_policy_compliance(
runtime_check_credential_policy_compliance(
credential_id=current_credential_id,
provider=self.provider.provider,
credential_type=PluginCredentialType.MODEL,
@ -182,9 +182,9 @@ class ProviderConfiguration(BaseModel):
# no current credential id, check all available credentials
if self.custom_configuration.provider:
for credential_configuration in self.custom_configuration.provider.available_credentials:
from core.helper.credential_utils import check_credential_policy_compliance
from core.helper.credential_utils import runtime_check_credential_policy_compliance
check_credential_policy_compliance(
runtime_check_credential_policy_compliance(
credential_id=credential_configuration.credential_id,
provider=self.provider.provider,
credential_type=PluginCredentialType.MODEL,

View File

@ -1,9 +1,9 @@
class LLMError(ValueError):
"""Base class for all LLM exceptions."""
description: str | None = None
description: str = ""
def __init__(self, description: str | None = None):
def __init__(self, description: str = ""):
self.description = description

View File

@ -1,279 +0,0 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any
from core.evaluation.entities.evaluation_entity import (
CustomizedMetrics,
EvaluationCategory,
EvaluationItemInput,
EvaluationItemResult,
EvaluationMetric,
NodeInfo,
)
from graphon.node_events.base import NodeRunResult
logger = logging.getLogger(__name__)
class BaseEvaluationInstance(ABC):
"""Abstract base class for evaluation framework adapters."""
@abstractmethod
def evaluate_llm(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate LLM outputs using the configured framework."""
...
@abstractmethod
def evaluate_retrieval(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate retrieval quality using the configured framework."""
...
@abstractmethod
def evaluate_agent(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate agent outputs using the configured framework."""
...
@abstractmethod
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
"""Return the list of supported metric names for a given evaluation category."""
...
def evaluate_with_customized_workflow(
self,
node_run_result_mapping_list: list[dict[str, NodeRunResult]],
customized_metrics: CustomizedMetrics,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate using a published workflow as the evaluator.
The evaluator workflow's output variables are treated as metrics:
each output variable name becomes a metric name, and its value
becomes the score.
Args:
node_run_result_mapping_list: One mapping per test-data item,
where each mapping is ``{node_id: NodeRunResult}`` from the
target execution.
customized_metrics: Contains ``evaluation_workflow_id`` (the
published evaluator workflow) and ``input_fields`` (value
sources for the evaluator's input variables).
tenant_id: Tenant scope.
Returns:
A list of ``EvaluationItemResult`` with metrics extracted from
the evaluator workflow's output variables.
"""
from sqlalchemy.orm import Session
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.evaluation.runners import get_service_account_for_app
from models.engine import db
from models.model import App
from services.workflow_service import WorkflowService
workflow_id = customized_metrics.evaluation_workflow_id
if not workflow_id:
raise ValueError("customized_metrics must contain 'evaluation_workflow_id' for customized evaluator")
# Load the evaluator workflow resources using a dedicated session
with Session(db.engine, expire_on_commit=False) as session, session.begin():
app = session.query(App).filter_by(id=workflow_id, tenant_id=tenant_id).first()
if not app:
raise ValueError(f"Evaluation workflow app {workflow_id} not found in tenant {tenant_id}")
service_account = get_service_account_for_app(session, workflow_id)
workflow_service = WorkflowService()
published_workflow = workflow_service.get_published_workflow(app_model=app)
if not published_workflow:
raise ValueError(f"No published workflow found for evaluation app {workflow_id}")
eval_results: list[EvaluationItemResult] = []
for idx, node_run_result_mapping in enumerate(node_run_result_mapping_list):
try:
workflow_inputs = self._build_workflow_inputs(
customized_metrics.input_fields,
node_run_result_mapping,
)
generator = WorkflowAppGenerator()
response: Mapping[str, Any] = generator.generate(
app_model=app,
workflow=published_workflow,
user=service_account,
args={"inputs": workflow_inputs},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
call_depth=0,
)
metrics = self._extract_workflow_metrics(response, workflow_id)
eval_results.append(
EvaluationItemResult(
index=idx,
metrics=metrics,
)
)
except Exception:
logger.exception(
"Customized evaluator failed for item %d with workflow %s",
idx,
workflow_id,
)
eval_results.append(EvaluationItemResult(index=idx))
return eval_results
@staticmethod
def _build_workflow_inputs(
input_fields: dict[str, Any],
node_run_result_mapping: dict[str, NodeRunResult],
) -> dict[str, Any]:
"""Build customized workflow inputs by resolving value sources.
Each entry in ``input_fields`` maps a workflow input variable name
to its value source, which can be:
- **Constant**: a plain string without ``{{#…#}}`` used as-is.
- **Expression**: a string containing one or more
``{{#node_id.output_key#}}`` selectors (same format as
``VariableTemplateParser``) resolved from
``node_run_result_mapping``.
"""
from graphon.nodes.base.variable_template_parser import REGEX as VARIABLE_REGEX
workflow_inputs: dict[str, Any] = {}
for field_name, value_source in input_fields.items():
if not isinstance(value_source, str):
# Non-string values (numbers, bools, dicts) are used directly.
workflow_inputs[field_name] = value_source
continue
# Check if the entire value is a single expression.
full_match = VARIABLE_REGEX.fullmatch(value_source)
if full_match:
workflow_inputs[field_name] = resolve_variable_selector(
full_match.group(1),
node_run_result_mapping,
)
elif VARIABLE_REGEX.search(value_source):
# Mixed template: interpolate all expressions as strings.
workflow_inputs[field_name] = VARIABLE_REGEX.sub(
lambda m: str(resolve_variable_selector(m.group(1), node_run_result_mapping)),
value_source,
)
else:
# Plain constant — no expression markers.
workflow_inputs[field_name] = value_source
return workflow_inputs
@staticmethod
def _extract_workflow_metrics(
response: Mapping[str, object],
evaluation_workflow_id: str,
) -> list[EvaluationMetric]:
"""Extract evaluation metrics from workflow output variables.
Each metric's ``node_info`` is set with *evaluation_workflow_id* as
the ``node_id``, so that judgment conditions can reference customized
metrics via ``variable_selector: [evaluation_workflow_id, metric_name]``.
"""
metrics: list[EvaluationMetric] = []
node_info = NodeInfo(node_id=evaluation_workflow_id, type="customized", title="customized")
data = response.get("data")
if not isinstance(data, Mapping):
logger.warning("Unexpected workflow response format: missing 'data' dict")
return metrics
outputs = data.get("outputs")
if not isinstance(outputs, dict):
logger.warning("Unexpected workflow response format: 'outputs' is not a dict")
return metrics
for key, raw_value in outputs.items():
if not isinstance(key, str):
continue
metrics.append(EvaluationMetric(name=key, value=raw_value, node_info=node_info))
return metrics
def resolve_variable_selector(
selector_raw: str,
node_run_result_mapping: dict[str, NodeRunResult],
) -> object:
"""
Resolve a ``#node_id.output_key#`` selector against node run results.
"""
#
cleaned = selector_raw.strip("#")
parts = cleaned.split(".")
if len(parts) < 2:
logger.warning(
"Selector '%s' must have at least node_id.output_key",
selector_raw,
)
return ""
node_id = parts[0]
output_path = parts[1:]
node_result = node_run_result_mapping.get(node_id)
if not node_result or not node_result.outputs:
logger.warning(
"Selector '%s': node '%s' not found or has no outputs",
selector_raw,
node_id,
)
return ""
# Traverse the output path to support nested keys.
current: object = node_result.outputs
for key in output_path:
if isinstance(current, Mapping):
next_val = current.get(key)
if next_val is None:
logger.warning(
"Selector '%s': key '%s' not found in node '%s' outputs",
selector_raw,
key,
node_id,
)
return ""
current = next_val
else:
logger.warning(
"Selector '%s': cannot traverse into non-dict value at key '%s'",
selector_raw,
key,
)
return ""
return current if current is not None else ""

View File

@ -1,27 +0,0 @@
from enum import StrEnum
from pydantic import BaseModel
class EvaluationFrameworkEnum(StrEnum):
RAGAS = "ragas"
DEEPEVAL = "deepeval"
NONE = "none"
class BaseEvaluationConfig(BaseModel):
"""Base configuration for evaluation frameworks."""
pass
class RagasConfig(BaseEvaluationConfig):
"""RAGAS-specific configuration."""
pass
class DeepEvalConfig(BaseEvaluationConfig):
"""DeepEval-specific configuration."""
pass

View File

@ -1,280 +0,0 @@
import json
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
from core.evaluation.entities.judgment_entity import JudgmentConfig, JudgmentResult
class EvaluationCategory(StrEnum):
LLM = "llm"
RETRIEVAL = "knowledge_retrieval"
AGENT = "agent"
WORKFLOW = "workflow"
SNIPPET = "snippet"
KNOWLEDGE_BASE = "knowledge_base"
class EvaluationMetricName(StrEnum):
"""Canonical metric names shared across all evaluation frameworks.
Each framework maps these names to its own internal implementation.
A framework that does not support a given metric should log a warning
and skip it rather than raising an error.
── LLM / general text-quality metrics ──────────────────────────────────
FAITHFULNESS
Measures whether every claim in the model's response is grounded in
the provided retrieved context. A high score means the answer
contains no hallucinated content — each statement can be traced back
to a passage in the context.
Required fields: user_input, response, retrieved_contexts.
ANSWER_RELEVANCY
Measures how well the model's response addresses the user's question.
A high score means the answer stays on-topic; a low score indicates
irrelevant content or a failure to answer the actual question.
Required fields: user_input, response.
ANSWER_CORRECTNESS
Measures the factual accuracy and completeness of the model's answer
relative to a ground-truth reference. It combines semantic similarity
with key-fact coverage, so both meaning and content matter.
Required fields: user_input, response, reference (expected_output).
SEMANTIC_SIMILARITY
Measures the cosine similarity between the model's response and the
reference answer in an embedding space. It evaluates whether the two
texts convey the same meaning, independent of factual correctness.
Required fields: response, reference (expected_output).
── Retrieval-quality metrics ────────────────────────────────────────────
CONTEXT_PRECISION
Measures the proportion of retrieved context chunks that are actually
relevant to the question (precision). A high score means the retrieval
pipeline returns little noise.
Required fields: user_input, reference, retrieved_contexts.
CONTEXT_RECALL
Measures the proportion of ground-truth information that is covered by
the retrieved context chunks (recall). A high score means the retrieval
pipeline does not miss important supporting evidence.
Required fields: user_input, reference, retrieved_contexts.
CONTEXT_RELEVANCE
Measures how relevant each individual retrieved chunk is to the query.
Similar to CONTEXT_PRECISION but evaluated at the chunk level rather
than against a reference answer.
Required fields: user_input, retrieved_contexts.
── Agent-quality metrics ────────────────────────────────────────────────
TOOL_CORRECTNESS
Measures the correctness of the tool calls made by the agent during
task execution — both the choice of tool and the arguments passed.
A high score means the agent's tool-use strategy matches the expected
behavior.
Required fields: actual tool calls vs. expected tool calls.
TASK_COMPLETION
Measures whether the agent ultimately achieves the user's stated goal.
It evaluates the reasoning chain, intermediate steps, and final output
holistically; a high score means the task was fully accomplished.
Required fields: user_input, actual_output.
"""
# LLM / general text-quality metrics
FAITHFULNESS = "faithfulness"
ANSWER_RELEVANCY = "answer_relevancy"
ANSWER_CORRECTNESS = "answer_correctness"
SEMANTIC_SIMILARITY = "semantic_similarity"
# Retrieval-quality metrics
CONTEXT_PRECISION = "context_precision"
CONTEXT_RECALL = "context_recall"
CONTEXT_RELEVANCE = "context_relevance"
# Agent-quality metrics
TOOL_CORRECTNESS = "tool_correctness"
TASK_COMPLETION = "task_completion"
# Per-category canonical metric lists used by get_supported_metrics().
LLM_METRIC_NAMES: list[EvaluationMetricName] = [
EvaluationMetricName.FAITHFULNESS, # Every claim is grounded in context; no hallucinations
EvaluationMetricName.ANSWER_RELEVANCY, # Response stays on-topic and addresses the question
EvaluationMetricName.ANSWER_CORRECTNESS, # Factual accuracy and completeness vs. reference
EvaluationMetricName.SEMANTIC_SIMILARITY, # Semantic closeness to the reference answer
]
RETRIEVAL_METRIC_NAMES: list[EvaluationMetricName] = [
EvaluationMetricName.CONTEXT_PRECISION, # Fraction of retrieved chunks that are relevant (precision)
EvaluationMetricName.CONTEXT_RECALL, # Fraction of ground-truth info covered by retrieval (recall)
EvaluationMetricName.CONTEXT_RELEVANCE, # Per-chunk relevance to the query
]
AGENT_METRIC_NAMES: list[EvaluationMetricName] = [
EvaluationMetricName.TOOL_CORRECTNESS, # Correct tool selection and arguments
EvaluationMetricName.TASK_COMPLETION, # Whether the agent fully achieves the user's goal
]
WORKFLOW_METRIC_NAMES: list[EvaluationMetricName] = [
EvaluationMetricName.FAITHFULNESS,
EvaluationMetricName.ANSWER_RELEVANCY,
EvaluationMetricName.ANSWER_CORRECTNESS,
]
METRIC_NODE_TYPE_MAPPING: dict[str, str] = {
**{m.value: "llm" for m in LLM_METRIC_NAMES},
**{m.value: "knowledge-retrieval" for m in RETRIEVAL_METRIC_NAMES},
**{m.value: "agent" for m in AGENT_METRIC_NAMES},
}
METRIC_VALUE_TYPE_MAPPING: dict[str, str] = {
EvaluationMetricName.FAITHFULNESS: "number",
EvaluationMetricName.ANSWER_RELEVANCY: "number",
EvaluationMetricName.ANSWER_CORRECTNESS: "number",
EvaluationMetricName.SEMANTIC_SIMILARITY: "number",
EvaluationMetricName.CONTEXT_PRECISION: "number",
EvaluationMetricName.CONTEXT_RECALL: "number",
EvaluationMetricName.CONTEXT_RELEVANCE: "number",
EvaluationMetricName.TOOL_CORRECTNESS: "number",
EvaluationMetricName.TASK_COMPLETION: "number",
}
class NodeInfo(BaseModel):
node_id: str
type: str
title: str
class EvaluationMetric(BaseModel):
name: str
value: Any
details: dict[str, Any] = Field(default_factory=dict)
node_info: NodeInfo | None = None
class EvaluationItemInput(BaseModel):
index: int
inputs: dict[str, Any]
output: str
expected_output: str | None = None
context: list[str] | None = None
class EvaluationDatasetInput(BaseModel):
"""Parsed dataset row used throughout evaluation execution.
``expected_output`` keeps backward compatibility with the original
single-reference template. When users upload node-specific reference
columns such as ``LLM 1 : expected_output``, they are stored in
``expected_outputs`` and resolved by node title at execution time.
"""
index: int
inputs: dict[str, Any]
expected_output: str | None = None
expected_outputs: dict[str, str] = Field(default_factory=dict)
def get_expected_output_for_node(self, node_title: str | None) -> str | None:
"""Return the best matching reference answer for the given node title."""
if node_title:
if node_title in self.expected_outputs:
return self.expected_outputs[node_title]
if self.expected_output is not None:
return self.expected_output
if len(self.expected_outputs) == 1:
return next(iter(self.expected_outputs.values()))
return None
def serialize_expected_output(self) -> str | None:
"""Serialize references for persistence and API responses.
Single-reference datasets stay unchanged, while multi-node references
are stored as JSON so history/detail APIs can still expose the full
uploaded payload without changing the database schema.
"""
if self.expected_output is not None and not self.expected_outputs:
return self.expected_output
if not self.expected_outputs:
return None
serialized_expected_outputs = dict(self.expected_outputs)
if self.expected_output is not None:
serialized_expected_outputs = {"expected_output": self.expected_output, **serialized_expected_outputs}
return json.dumps(serialized_expected_outputs, ensure_ascii=False, sort_keys=True)
def iter_expected_output_columns(self) -> list[tuple[str, str]]:
"""Return uploaded expected-output columns in display order."""
columns: list[tuple[str, str]] = []
if self.expected_output is not None:
columns.append(("expected_output", self.expected_output))
for node_title, value in self.expected_outputs.items():
columns.append((f"{node_title} : expected_output", value))
return columns
class EvaluationItemResult(BaseModel):
index: int
actual_output: str | None = None
metrics: list[EvaluationMetric] = Field(default_factory=list)
metadata: dict[str, Any] = Field(default_factory=dict)
judgment: JudgmentResult = Field(default_factory=JudgmentResult)
error: str | None = None
class DefaultMetric(BaseModel):
metric: str
value_type: str = ""
node_info_list: list[NodeInfo]
class CustomizedMetricOutputField(BaseModel):
variable: str
value_type: str
class CustomizedMetrics(BaseModel):
evaluation_workflow_id: str
input_fields: dict[str, Any]
output_fields: list[CustomizedMetricOutputField]
class EvaluationConfigData(BaseModel):
"""Structured data for saving evaluation configuration."""
evaluation_model: str = ""
evaluation_model_provider: str = ""
default_metrics: list[DefaultMetric] = Field(default_factory=list)
customized_metrics: CustomizedMetrics | None = None
judgment_config: JudgmentConfig | None = None
class EvaluationRunRequest(EvaluationConfigData):
"""Request body for starting an evaluation run."""
file_id: str
class EvaluationRunData(BaseModel):
"""Serializable data for Celery task."""
evaluation_run_id: str
tenant_id: str
target_type: str
target_id: str
evaluation_model_provider: str
evaluation_model: str
default_metrics: list[DefaultMetric] = Field(default_factory=list)
customized_metrics: CustomizedMetrics | None = None
judgment_config: JudgmentConfig | None = None
input_list: list[EvaluationDatasetInput]

View File

@ -1,118 +0,0 @@
"""Judgment condition entities for evaluation metric assessment.
Condition structure mirrors the workflow if-else ``Condition`` model from
``graphon.utils.condition.entities``. The left-hand side uses
``variable_selector`` — a two-element list ``[node_id, metric_name]`` — to
uniquely identify an evaluation metric (different nodes may produce metrics
with the same name).
Operators reuse ``SupportedComparisonOperator`` from the workflow engine so
that type semantics stay consistent across the platform.
Typical usage::
judgment_config = JudgmentConfig(
logical_operator="and",
conditions=[
JudgmentCondition(
variable_selector=["node_abc", "faithfulness"],
comparison_operator=">",
value="0.8",
)
],
)
"""
from collections.abc import Sequence
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from graphon.utils.condition.entities import SupportedComparisonOperator
COMPARISON_OPERATOR_ALIASES: dict[str, str] = {
"==": "=",
"!=": "",
">=": "",
"<=": "",
"is null": "null",
"is not null": "not null",
}
class JudgmentCondition(BaseModel):
"""A single judgment condition that checks one metric value.
Mirrors ``graphon.utils.condition.entities.Condition`` with the left-hand
side being a metric selector instead of a workflow variable selector.
Attributes:
variable_selector: ``[node_id, metric_name]`` identifying the metric.
comparison_operator: Reuses workflow's ``SupportedComparisonOperator``.
value: The comparison target (right side). For unary operators such
as ``empty`` or ``null`` this can be ``None``.
"""
variable_selector: list[str]
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | bool | None = None
@field_validator("comparison_operator", mode="before")
@classmethod
def normalize_comparison_operator(cls, value: Any) -> Any:
"""Accept common ASCII/API aliases for workflow comparison operators."""
if not isinstance(value, str):
return value
normalized_value = value.strip().lower()
alias = COMPARISON_OPERATOR_ALIASES.get(normalized_value)
if alias is not None:
return alias
return value.strip()
class JudgmentConfig(BaseModel):
"""A group of judgment conditions combined with a logical operator.
Attributes:
logical_operator: How to combine condition results — "and" requires
all conditions to pass, "or" requires at least one.
conditions: The list of individual conditions to evaluate.
"""
logical_operator: Literal["and", "or"] = "and"
conditions: list[JudgmentCondition] = Field(default_factory=list)
class JudgmentConditionResult(BaseModel):
"""Result of evaluating a single judgment condition.
Attributes:
variable_selector: ``[node_id, metric_name]`` that was checked.
comparison_operator: The operator that was applied.
expected_value: The resolved comparison value.
actual_value: The actual metric value that was evaluated.
passed: Whether this individual condition passed.
error: Error message if the condition evaluation failed.
"""
variable_selector: list[str]
comparison_operator: str
expected_value: Any = None
actual_value: Any = None
passed: bool = False
error: str | None = None
class JudgmentResult(BaseModel):
"""Overall result of evaluating all judgment conditions for one item.
Attributes:
passed: Whether the overall judgment passed (based on logical_operator).
logical_operator: The logical operator used to combine conditions.
condition_results: Detailed result for each individual condition.
"""
passed: bool = False
logical_operator: Literal["and", "or"] = "and"
condition_results: list[JudgmentConditionResult] = Field(default_factory=list)

View File

@ -1,61 +0,0 @@
import collections
import logging
from typing import Any
from configs import dify_config
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.config_entity import EvaluationFrameworkEnum
from core.evaluation.entities.evaluation_entity import EvaluationCategory
logger = logging.getLogger(__name__)
class EvaluationFrameworkConfigMap(collections.UserDict[str, dict[str, Any]]):
"""Registry mapping framework enum -> {config_class, evaluator_class}."""
def __getitem__(self, framework: str) -> dict[str, Any]:
match framework:
case EvaluationFrameworkEnum.RAGAS:
from core.evaluation.entities.config_entity import RagasConfig
from core.evaluation.frameworks.ragas.ragas_evaluator import RagasEvaluator
return {
"config_class": RagasConfig,
"evaluator_class": RagasEvaluator,
}
case EvaluationFrameworkEnum.DEEPEVAL:
raise NotImplementedError("DeepEval adapter is not yet implemented.")
case _:
raise ValueError(f"Unknown evaluation framework: {framework}")
evaluation_framework_config_map = EvaluationFrameworkConfigMap()
class EvaluationManager:
"""Factory for evaluation instances based on global configuration."""
@staticmethod
def get_evaluation_instance() -> BaseEvaluationInstance | None:
"""Create and return an evaluation instance based on EVALUATION_FRAMEWORK env var."""
framework = dify_config.EVALUATION_FRAMEWORK
if not framework or framework == EvaluationFrameworkEnum.NONE:
return None
try:
config_map = evaluation_framework_config_map[framework]
evaluator_class = config_map["evaluator_class"]
config_class = config_map["config_class"]
config = config_class()
return evaluator_class(config)
except Exception:
logger.exception("Failed to create evaluation instance for framework: %s", framework)
return None
@staticmethod
def get_supported_metrics(category: EvaluationCategory) -> list[str]:
"""Return supported metrics for the current framework and given category."""
instance = EvaluationManager.get_evaluation_instance()
if instance is None:
return []
return instance.get_supported_metrics(category)

View File

@ -1,299 +0,0 @@
import logging
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.config_entity import DeepEvalConfig
from core.evaluation.entities.evaluation_entity import (
AGENT_METRIC_NAMES,
LLM_METRIC_NAMES,
RETRIEVAL_METRIC_NAMES,
WORKFLOW_METRIC_NAMES,
EvaluationCategory,
EvaluationItemInput,
EvaluationItemResult,
EvaluationMetric,
EvaluationMetricName,
)
from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper
logger = logging.getLogger(__name__)
# Maps canonical EvaluationMetricName to the corresponding deepeval metric class name.
# deepeval metric field requirements (LLMTestCase fields):
# - faithfulness: input, actual_output, retrieval_context
# - answer_relevancy: input, actual_output
# - context_precision: input, actual_output, expected_output, retrieval_context
# - context_recall: input, actual_output, expected_output, retrieval_context
# - context_relevance: input, actual_output, retrieval_context
# - tool_correctness: input, actual_output, expected_tools
# - task_completion: input, actual_output
# Metrics not listed here are unsupported by deepeval and will be skipped.
_DEEPEVAL_METRIC_MAP: dict[EvaluationMetricName, str] = {
EvaluationMetricName.FAITHFULNESS: "FaithfulnessMetric",
EvaluationMetricName.ANSWER_RELEVANCY: "AnswerRelevancyMetric",
EvaluationMetricName.CONTEXT_PRECISION: "ContextualPrecisionMetric",
EvaluationMetricName.CONTEXT_RECALL: "ContextualRecallMetric",
EvaluationMetricName.CONTEXT_RELEVANCE: "ContextualRelevancyMetric",
EvaluationMetricName.TOOL_CORRECTNESS: "ToolCorrectnessMetric",
EvaluationMetricName.TASK_COMPLETION: "TaskCompletionMetric",
}
class DeepEvalEvaluator(BaseEvaluationInstance):
"""DeepEval framework adapter for evaluation."""
def __init__(self, config: DeepEvalConfig):
self.config = config
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
match category:
case EvaluationCategory.LLM:
candidates = LLM_METRIC_NAMES
case EvaluationCategory.RETRIEVAL:
candidates = RETRIEVAL_METRIC_NAMES
case EvaluationCategory.AGENT:
candidates = AGENT_METRIC_NAMES
case EvaluationCategory.WORKFLOW | EvaluationCategory.SNIPPET:
candidates = WORKFLOW_METRIC_NAMES
case _:
return []
return [m for m in candidates if m in _DEEPEVAL_METRIC_MAP]
def evaluate_llm(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
def evaluate_retrieval(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
def evaluate_agent(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
def evaluate_workflow(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
def _evaluate(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Core evaluation logic using DeepEval."""
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
requested_metrics = metric_names or self.get_supported_metrics(category)
try:
return self._evaluate_with_deepeval(items, requested_metrics, category)
except ImportError:
logger.warning("DeepEval not installed, falling back to simple evaluation")
return self._evaluate_simple(items, requested_metrics, model_wrapper)
def _evaluate_with_deepeval(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Evaluate using DeepEval library.
Builds LLMTestCase differently per category:
- LLM/Workflow: input=prompt, actual_output=output, retrieval_context=context
- Retrieval: input=query, actual_output=output, expected_output, retrieval_context=context
- Agent: input=query, actual_output=output
"""
metric_pairs = _build_deepeval_metrics(requested_metrics)
if not metric_pairs:
logger.warning("No valid DeepEval metrics found for: %s", requested_metrics)
return [EvaluationItemResult(index=item.index) for item in items]
results: list[EvaluationItemResult] = []
for item in items:
test_case = self._build_test_case(item, category)
metrics: list[EvaluationMetric] = []
for canonical_name, metric in metric_pairs:
try:
metric.measure(test_case)
if metric.score is not None:
metrics.append(EvaluationMetric(name=canonical_name, value=float(metric.score)))
except Exception:
logger.exception(
"Failed to compute metric %s for item %d",
canonical_name,
item.index,
)
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
return results
@staticmethod
def _build_test_case(item: EvaluationItemInput, category: EvaluationCategory) -> Any:
"""Build a deepeval LLMTestCase with the correct fields per category."""
from deepeval.test_case import LLMTestCase
user_input = _format_input(item.inputs, category)
match category:
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
# faithfulness needs: input, actual_output, retrieval_context
# answer_relevancy needs: input, actual_output
return LLMTestCase(
input=user_input,
actual_output=item.output,
expected_output=item.expected_output or None,
retrieval_context=item.context or None,
)
case EvaluationCategory.RETRIEVAL:
# contextual_precision/recall needs: input, actual_output, expected_output, retrieval_context
return LLMTestCase(
input=user_input,
actual_output=item.output or "",
expected_output=item.expected_output or "",
retrieval_context=item.context or [],
)
case _:
return LLMTestCase(
input=user_input,
actual_output=item.output,
)
def _evaluate_simple(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
model_wrapper: DifyModelWrapper,
) -> list[EvaluationItemResult]:
"""Simple LLM-as-judge fallback when DeepEval is not available."""
results: list[EvaluationItemResult] = []
for item in items:
metrics: list[EvaluationMetric] = []
for m_name in requested_metrics:
try:
score = self._judge_with_llm(model_wrapper, m_name, item)
metrics.append(EvaluationMetric(name=m_name, value=score))
except Exception:
logger.exception("Failed to compute metric %s for item %d", m_name, item.index)
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
return results
def _judge_with_llm(
self,
model_wrapper: DifyModelWrapper,
metric_name: str,
item: EvaluationItemInput,
) -> float:
"""Use the LLM to judge a single metric for a single item."""
prompt = self._build_judge_prompt(metric_name, item)
response = model_wrapper.invoke(prompt)
return self._parse_score(response)
@staticmethod
def _build_judge_prompt(metric_name: str, item: EvaluationItemInput) -> str:
"""Build a scoring prompt for the LLM judge."""
parts = [
f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.",
f"\nInput: {item.inputs}",
f"\nOutput: {item.output}",
]
if item.expected_output:
parts.append(f"\nExpected Output: {item.expected_output}")
if item.context:
parts.append(f"\nContext: {'; '.join(item.context)}")
parts.append("\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else.")
return "\n".join(parts)
@staticmethod
def _parse_score(response: str) -> float:
"""Parse a float score from LLM response."""
import re
cleaned = response.strip()
try:
score = float(cleaned)
return max(0.0, min(1.0, score))
except ValueError:
match = re.search(r"(\d+\.?\d*)", cleaned)
if match:
score = float(match.group(1))
return max(0.0, min(1.0, score))
return 0.0
def _format_input(inputs: dict[str, Any], category: EvaluationCategory) -> str:
"""Extract the user-facing input string from the inputs dict."""
match category:
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
return str(inputs.get("prompt", ""))
case EvaluationCategory.RETRIEVAL:
return str(inputs.get("query", ""))
case _:
return str(next(iter(inputs.values()), "")) if inputs else ""
def _build_deepeval_metrics(requested_metrics: list[str]) -> list[tuple[str, Any]]:
"""Build DeepEval metric instances from canonical metric names.
Returns a list of (canonical_name, metric_instance) pairs so that callers
can record the canonical name rather than the framework-internal class name.
"""
try:
from deepeval.metrics import (
AnswerRelevancyMetric,
ContextualPrecisionMetric,
ContextualRecallMetric,
ContextualRelevancyMetric,
FaithfulnessMetric,
TaskCompletionMetric,
ToolCorrectnessMetric,
)
# Maps canonical name → deepeval metric class
deepeval_class_map: dict[str, Any] = {
EvaluationMetricName.FAITHFULNESS: FaithfulnessMetric,
EvaluationMetricName.ANSWER_RELEVANCY: AnswerRelevancyMetric,
EvaluationMetricName.CONTEXT_PRECISION: ContextualPrecisionMetric,
EvaluationMetricName.CONTEXT_RECALL: ContextualRecallMetric,
EvaluationMetricName.CONTEXT_RELEVANCE: ContextualRelevancyMetric,
EvaluationMetricName.TOOL_CORRECTNESS: ToolCorrectnessMetric,
EvaluationMetricName.TASK_COMPLETION: TaskCompletionMetric,
}
pairs: list[tuple[str, Any]] = []
for name in requested_metrics:
metric_class = deepeval_class_map.get(name)
if metric_class:
pairs.append((name, metric_class(threshold=0.5)))
else:
logger.warning("Metric '%s' is not supported by DeepEval, skipping", name)
return pairs
except ImportError:
logger.warning("DeepEval metrics not available")
return []

View File

@ -1,345 +0,0 @@
import logging
from importlib import import_module
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.config_entity import RagasConfig
from core.evaluation.entities.evaluation_entity import (
AGENT_METRIC_NAMES,
LLM_METRIC_NAMES,
RETRIEVAL_METRIC_NAMES,
WORKFLOW_METRIC_NAMES,
EvaluationCategory,
EvaluationItemInput,
EvaluationItemResult,
EvaluationMetric,
EvaluationMetricName,
)
from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper
logger = logging.getLogger(__name__)
# Maps canonical EvaluationMetricName to the corresponding ragas metric class.
# Metrics not listed here are unsupported by ragas and will be skipped.
_RAGAS_METRIC_MAP: dict[EvaluationMetricName, str] = {
EvaluationMetricName.FAITHFULNESS: "Faithfulness",
EvaluationMetricName.ANSWER_RELEVANCY: "AnswerRelevancy",
EvaluationMetricName.ANSWER_CORRECTNESS: "AnswerCorrectness",
EvaluationMetricName.SEMANTIC_SIMILARITY: "SemanticSimilarity",
EvaluationMetricName.CONTEXT_PRECISION: "ContextPrecision",
EvaluationMetricName.CONTEXT_RECALL: "ContextRecall",
EvaluationMetricName.CONTEXT_RELEVANCE: "ContextRelevance",
EvaluationMetricName.TOOL_CORRECTNESS: "ToolCallAccuracy",
}
class RagasEvaluator(BaseEvaluationInstance):
"""RAGAS framework adapter for evaluation."""
def __init__(self, config: RagasConfig):
self.config = config
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
match category:
case EvaluationCategory.LLM:
candidates = LLM_METRIC_NAMES
case EvaluationCategory.RETRIEVAL:
candidates = RETRIEVAL_METRIC_NAMES
case EvaluationCategory.AGENT:
candidates = AGENT_METRIC_NAMES
case EvaluationCategory.WORKFLOW | EvaluationCategory.SNIPPET:
candidates = WORKFLOW_METRIC_NAMES
case _:
return []
return [m for m in candidates if m in _RAGAS_METRIC_MAP]
def evaluate_llm(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
def evaluate_retrieval(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
def evaluate_agent(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
def evaluate_workflow(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
def _evaluate(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Core evaluation logic using RAGAS."""
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
requested_metrics = metric_names or self.get_supported_metrics(category)
try:
return self._evaluate_with_ragas(items, requested_metrics, model_wrapper, category)
except ImportError:
logger.warning("RAGAS not installed, falling back to simple evaluation")
return self._evaluate_simple(items, requested_metrics, model_wrapper)
def _evaluate_with_ragas(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
model_wrapper: DifyModelWrapper,
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Evaluate using RAGAS library.
Builds SingleTurnSample differently per category to match ragas requirements:
- LLM/Workflow: user_input=prompt, response=output, reference=expected_output
- Retrieval: user_input=query, reference=expected_output, retrieved_contexts=context
- Agent: Not supported via EvaluationDataset (requires message-based API)
"""
from ragas import evaluate as ragas_evaluate
from ragas.dataset_schema import EvaluationDataset
samples: list[Any] = []
for item in items:
sample = self._build_sample(item, category)
samples.append(sample)
dataset = EvaluationDataset(samples=samples)
ragas_metrics = self._build_ragas_metrics(requested_metrics)
if not ragas_metrics:
logger.warning("No valid RAGAS metrics found for: %s", requested_metrics)
return [EvaluationItemResult(index=item.index, actual_output=item.output) for item in items]
try:
result = ragas_evaluate(
dataset=dataset,
metrics=ragas_metrics,
llm=model_wrapper,
)
results: list[EvaluationItemResult] = []
result_df = result.to_pandas()
for i, item in enumerate(items):
metrics: list[EvaluationMetric] = []
for m_name in requested_metrics:
if m_name in result_df.columns:
score = result_df.iloc[i][m_name]
if score is not None and not (isinstance(score, float) and score != score):
metrics.append(EvaluationMetric(name=m_name, value=float(score)))
results.append(EvaluationItemResult(index=item.index, metrics=metrics, actual_output=item.output))
return results
except Exception:
logger.exception("RAGAS evaluation failed, falling back to simple evaluation")
return self._evaluate_simple(items, requested_metrics, model_wrapper)
@staticmethod
def _build_sample(item: EvaluationItemInput, category: EvaluationCategory) -> Any:
"""Build a ragas SingleTurnSample with the correct fields per category.
ragas metric field requirements:
- faithfulness: user_input, response, retrieved_contexts
- answer_relevancy: user_input, response
- answer_correctness: user_input, response, reference
- semantic_similarity: user_input, response, reference
- context_precision: user_input, reference, retrieved_contexts
- context_recall: user_input, reference, retrieved_contexts
- context_relevance: user_input, retrieved_contexts
"""
from ragas.dataset_schema import SingleTurnSample
user_input = _format_input(item.inputs, category)
match category:
case EvaluationCategory.LLM:
# response = actual LLM output, reference = expected output
return SingleTurnSample(
user_input=user_input,
response=item.output,
reference=item.expected_output or "",
retrieved_contexts=item.context or [],
)
case EvaluationCategory.RETRIEVAL:
# context_precision/recall only need reference + retrieved_contexts
return SingleTurnSample(
user_input=user_input,
reference=item.expected_output or "",
retrieved_contexts=item.context or [],
)
case _:
return SingleTurnSample(
user_input=user_input,
response=item.output,
)
def _evaluate_simple(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
model_wrapper: DifyModelWrapper,
) -> list[EvaluationItemResult]:
"""Simple LLM-as-judge fallback when RAGAS is not available."""
results: list[EvaluationItemResult] = []
for item in items:
metrics: list[EvaluationMetric] = []
for m_name in requested_metrics:
try:
score = self._judge_with_llm(model_wrapper, m_name, item)
metrics.append(EvaluationMetric(name=m_name, value=score))
except Exception:
logger.exception("Failed to compute metric %s for item %d", m_name, item.index)
results.append(EvaluationItemResult(index=item.index, metrics=metrics, actual_output=item.output))
return results
def _judge_with_llm(
self,
model_wrapper: DifyModelWrapper,
metric_name: str,
item: EvaluationItemInput,
) -> float:
"""Use the LLM to judge a single metric for a single item."""
prompt = self._build_judge_prompt(metric_name, item)
response = model_wrapper.invoke(prompt)
return self._parse_score(response)
@staticmethod
def _build_judge_prompt(metric_name: str, item: EvaluationItemInput) -> str:
"""Build a scoring prompt for the LLM judge."""
parts = [
f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.",
f"\nInput: {item.inputs}",
f"\nOutput: {item.output}",
]
if item.expected_output:
parts.append(f"\nExpected Output: {item.expected_output}")
if item.context:
parts.append(f"\nContext: {'; '.join(item.context)}")
parts.append("\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else.")
return "\n".join(parts)
@staticmethod
def _parse_score(response: str) -> float:
"""Parse a float score from LLM response."""
import re
cleaned = response.strip()
try:
score = float(cleaned)
return max(0.0, min(1.0, score))
except ValueError:
match = re.search(r"(\d+\.?\d*)", cleaned)
if match:
score = float(match.group(1))
return max(0.0, min(1.0, score))
return 0.0
@staticmethod
def _build_ragas_metrics(requested_metrics: list[str]) -> list[Any]:
"""Build RAGAS metric instances from canonical metric names."""
try:
metrics_module = _import_ragas_metrics_module()
# Maps canonical name → ragas metric class
ragas_class_map: dict[str, Any] = {
EvaluationMetricName.FAITHFULNESS: getattr(metrics_module, "Faithfulness"),
EvaluationMetricName.ANSWER_RELEVANCY: getattr(metrics_module, "AnswerRelevancy"),
EvaluationMetricName.ANSWER_CORRECTNESS: getattr(metrics_module, "AnswerCorrectness"),
EvaluationMetricName.SEMANTIC_SIMILARITY: getattr(metrics_module, "SemanticSimilarity"),
EvaluationMetricName.CONTEXT_PRECISION: getattr(metrics_module, "ContextPrecision"),
EvaluationMetricName.CONTEXT_RECALL: getattr(metrics_module, "ContextRecall"),
EvaluationMetricName.CONTEXT_RELEVANCE: getattr(metrics_module, "ContextRelevance"),
EvaluationMetricName.TOOL_CORRECTNESS: getattr(metrics_module, "ToolCallAccuracy"),
}
metrics = []
for name in requested_metrics:
metric_class = ragas_class_map.get(name)
if metric_class:
if name == EvaluationMetricName.ANSWER_CORRECTNESS:
# ragas answer_correctness blends factuality with semantic
# similarity. The latter requires an embeddings backend,
# which is not wired through Dify's evaluation stack yet.
# Keep the metric usable by relying on the factuality
# component only for now.
metrics.append(metric_class(weights=[1.0, 0.0], embeddings=_NoopRagasEmbeddings()))
else:
metrics.append(metric_class())
else:
logger.warning("Metric '%s' is not supported by RAGAS, skipping", name)
return metrics
except ImportError:
logger.warning("RAGAS metrics not available")
return []
def _import_ragas_metrics_module() -> Any:
"""Load ragas metric classes across supported ragas versions.
ragas 0.3.x exposes metric classes from ``ragas.metrics`` while some older
versions used ``ragas.metrics.collections``. Support both so worker
environments do not silently drop all metrics because of a module path
mismatch.
"""
try:
return import_module("ragas.metrics")
except ImportError:
return import_module("ragas.metrics.collections")
class _NoopRagasEmbeddings:
"""Placeholder embeddings for ragas metrics whose embedding branch is disabled.
ragas eagerly injects a default embeddings backend for any metric that
subclasses ``MetricWithEmbeddings``. For answer_correctness we currently
disable the semantic-similarity weight, so no real embedding call should
happen. Supplying this placeholder keeps ragas from constructing its
default OpenAI embeddings client during setup.
"""
async def aembed_query(self, text: str) -> list[float]:
del text
return [0.0]
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return [[0.0] for _ in texts]
def _format_input(inputs: dict[str, Any], category: EvaluationCategory) -> str:
"""Extract the user-facing input string from the inputs dict."""
match category:
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
return str(inputs.get("prompt", ""))
case EvaluationCategory.RETRIEVAL:
return str(inputs.get("query", ""))
case _:
return str(next(iter(inputs.values()), "")) if inputs else ""

View File

@ -1,165 +0,0 @@
import asyncio
import logging
from typing import Any
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.prompt_values import PromptValue
try:
from ragas.llms.base import BaseRagasLLM
except ImportError:
class BaseRagasLLM: # type: ignore[no-redef]
"""Lightweight shim so the module stays importable without ragas installed."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
del args, kwargs
@staticmethod
def get_temperature(n: int) -> float:
return 0.3 if n > 1 else 1e-8
logger = logging.getLogger(__name__)
class DifyModelWrapper(BaseRagasLLM):
"""Bridge Dify model invocation to ragas and fallback LLM-as-judge flows.
ragas can accept a custom ``BaseRagasLLM`` instance. Using one here keeps
evaluation requests on Dify's provider stack instead of falling back to
ragas' default OpenAI factory, which would require standalone environment
credentials and bypass tenant-scoped model configuration.
"""
model_provider: str
model_name: str
tenant_id: str
user_id: str | None
def __init__(self, model_provider: str, model_name: str, tenant_id: str, user_id: str | None = None):
super().__init__()
self.model_provider = model_provider
self.model_name = model_name
self.tenant_id = tenant_id
self.user_id = user_id
def _get_model_instance(self) -> Any:
from core.plugin.impl.model_runtime_factory import create_plugin_model_manager
from graphon.model_runtime.entities.model_entities import ModelType
model_manager = create_plugin_model_manager(tenant_id=self.tenant_id, user_id=self.user_id)
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=self.model_provider,
model_type=ModelType.LLM,
model=self.model_name,
)
return model_instance
def invoke(self, prompt: str) -> str:
"""Invoke the configured Dify model with a plain-text evaluation prompt."""
from graphon.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
model_instance = self._get_model_instance()
result = model_instance.invoke_llm(
prompt_messages=[
SystemPromptMessage(content="You are an evaluation judge. Answer precisely and concisely."),
UserPromptMessage(content=prompt),
],
model_parameters={"temperature": 0.0},
stream=False,
)
return result.message.content
def generate_text(
self,
prompt: PromptValue,
n: int = 1,
temperature: float = 1e-8,
stop: list[str] | None = None,
callbacks: Any = None,
) -> LLMResult:
"""Implement ragas' sync LLM interface on top of Dify's model runtime."""
del callbacks # Dify's invocation path does not currently use LangChain callbacks here.
prompt_messages = _convert_prompt_value(prompt)
model_instance = self._get_model_instance()
generations: list[list[ChatGeneration]] = [[]]
completions = max(1, n)
for _ in range(completions):
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={"temperature": temperature},
stop=stop,
stream=False,
)
text = result.message.content
generations[0].append(
ChatGeneration(
text=text,
message=AIMessage(content=text, response_metadata={"finish_reason": "stop"}),
generation_info={"finish_reason": "stop"},
)
)
return LLMResult(generations=generations)
async def agenerate_text(
self,
prompt: PromptValue,
n: int = 1,
temperature: float | None = None,
stop: list[str] | None = None,
callbacks: Any = None,
) -> LLMResult:
"""Async ragas hook backed by the sync Dify invocation path."""
return await asyncio.to_thread(
self.generate_text,
prompt,
n,
self.get_temperature(n) if temperature is None else temperature,
stop,
callbacks,
)
def _convert_prompt_value(prompt: PromptValue) -> list[Any]:
"""Translate LangChain prompt values into graphon prompt messages."""
from graphon.model_runtime.entities.message_entities import (
AssistantPromptMessage,
SystemPromptMessage,
UserPromptMessage,
)
prompt_messages: list[Any] = []
for message in prompt.to_messages():
content = _message_content_to_text(message)
if isinstance(message, SystemMessage):
prompt_messages.append(SystemPromptMessage(content=content))
elif isinstance(message, AIMessage):
prompt_messages.append(AssistantPromptMessage(content=content))
elif isinstance(message, HumanMessage):
prompt_messages.append(UserPromptMessage(content=content))
else:
prompt_messages.append(UserPromptMessage(content=content))
return prompt_messages
def _message_content_to_text(message: BaseMessage) -> str:
"""Flatten LangChain message content into a plain-text string for Dify."""
if isinstance(message.content, str):
return message.content
if isinstance(message.content, list):
parts: list[str] = []
for block in message.content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, dict):
text = block.get("text")
if text:
parts.append(str(text))
return "\n".join(part for part in parts if part)
return str(message.content or "")

View File

@ -1,160 +0,0 @@
"""Judgment condition processor for evaluation metrics.
Evaluates pass/fail judgment conditions against evaluation metric values.
Each condition uses ``variable_selector`` (``[node_id, metric_name]``) to
look up the metric value, then delegates the actual comparison to the
workflow condition engine (``graphon.utils.condition.processor``).
The processor is intentionally decoupled from evaluation frameworks and
runners. It operates on plain ``dict`` mappings and can be invoked
anywhere that already has per-item metric results.
"""
import logging
from collections.abc import Sequence
from typing import Any, cast
from core.evaluation.entities.judgment_entity import (
JudgmentCondition,
JudgmentConditionResult,
JudgmentConfig,
JudgmentResult,
)
from graphon.utils.condition.entities import SupportedComparisonOperator
from graphon.utils.condition.processor import _evaluate_condition # pyright: ignore[reportPrivateUsage]
logger = logging.getLogger(__name__)
_UNARY_OPERATORS = frozenset({"null", "not null", "empty", "not empty"})
class JudgmentProcessor:
@staticmethod
def evaluate(
metric_values: dict[tuple[str, str], Any],
config: JudgmentConfig,
) -> JudgmentResult:
"""Evaluate all judgment conditions against the given metric values.
Args:
metric_values: Mapping of ``(node_id, metric_name)`` → metric
value (e.g. ``{("node_abc", "faithfulness"): 0.85}``).
config: The judgment configuration with logical_operator and
conditions.
Returns:
JudgmentResult with overall pass/fail and per-condition details.
"""
if not config.conditions:
return JudgmentResult(
passed=True,
logical_operator=config.logical_operator,
condition_results=[],
)
condition_results: list[JudgmentConditionResult] = []
for condition in config.conditions:
result = JudgmentProcessor._evaluate_single_condition(metric_values, condition)
condition_results.append(result)
if config.logical_operator == "and" and not result.passed:
return JudgmentResult(
passed=False,
logical_operator=config.logical_operator,
condition_results=condition_results,
)
if config.logical_operator == "or" and result.passed:
return JudgmentResult(
passed=True,
logical_operator=config.logical_operator,
condition_results=condition_results,
)
if config.logical_operator == "and":
final_passed = all(r.passed for r in condition_results)
else:
final_passed = any(r.passed for r in condition_results)
return JudgmentResult(
passed=final_passed,
logical_operator=config.logical_operator,
condition_results=condition_results,
)
@staticmethod
def _evaluate_single_condition(
metric_values: dict[tuple[str, str], Any],
condition: JudgmentCondition,
) -> JudgmentConditionResult:
"""Evaluate a single judgment condition.
Steps:
1. Extract ``(node_id, metric_name)`` from ``variable_selector``.
2. Look up the metric value from ``metric_values``.
3. Delegate comparison to the workflow condition engine.
"""
selector = condition.variable_selector
if len(selector) < 2:
return JudgmentConditionResult(
variable_selector=selector,
comparison_operator=condition.comparison_operator,
expected_value=condition.value,
actual_value=None,
passed=False,
error=f"variable_selector must have at least 2 elements, got {len(selector)}",
)
node_id, metric_name = selector[0], selector[1]
actual_value = metric_values.get((node_id, metric_name))
if actual_value is None and condition.comparison_operator not in _UNARY_OPERATORS:
return JudgmentConditionResult(
variable_selector=selector,
comparison_operator=condition.comparison_operator,
expected_value=condition.value,
actual_value=None,
passed=False,
error=f"Metric '{metric_name}' on node '{node_id}' not found in evaluation results",
)
try:
expected = condition.value
# Numeric operators need the actual value coerced to int/float
# so that the workflow engine's numeric assertions work correctly.
coerced_actual: object = actual_value
if (
condition.comparison_operator in {"=", "", ">", "<", "", ""}
and actual_value is not None
and not isinstance(actual_value, (int, float, bool))
):
coerced_actual = float(actual_value)
passed = _evaluate_condition(
operator=cast(SupportedComparisonOperator, condition.comparison_operator),
value=coerced_actual,
expected=cast(str | Sequence[str] | bool | Sequence[bool] | None, expected),
)
return JudgmentConditionResult(
variable_selector=selector,
comparison_operator=condition.comparison_operator,
expected_value=expected,
actual_value=actual_value,
passed=passed,
)
except Exception as e:
logger.warning(
"Judgment condition evaluation failed for [%s, %s]: %s",
node_id,
metric_name,
str(e),
)
return JudgmentConditionResult(
variable_selector=selector,
comparison_operator=condition.comparison_operator,
expected_value=condition.value,
actual_value=actual_value,
passed=False,
error=str(e),
)

View File

@ -1,52 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from models import Account, App, CustomizedSnippet, TenantAccountJoin
def get_service_account_for_app(session: Session, app_id: str) -> Account:
"""Get the creator account for an app with tenant context set up.
This follows the same pattern as BaseTraceInstance.get_service_account_with_tenant().
"""
app = session.scalar(select(App).where(App.id == app_id))
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator")
account = session.scalar(select(Account).where(Account.id == app.created_by))
if not account:
raise ValueError(f"Creator account not found for app {app_id}")
current_tenant = session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
if not current_tenant:
raise ValueError(f"Current tenant not found for account {account.id}")
account.set_tenant_id(current_tenant.tenant_id)
return account
def get_service_account_for_snippet(session: Session, snippet_id: str) -> Account:
"""Get the creator account for a snippet with tenant context set up.
Mirrors :func:`get_service_account_for_app` but queries CustomizedSnippet.
"""
snippet = session.scalar(select(CustomizedSnippet).where(CustomizedSnippet.id == snippet_id))
if not snippet:
raise ValueError(f"Snippet with id {snippet_id} not found")
if not snippet.created_by:
raise ValueError(f"Snippet with id {snippet_id} has no creator")
account = session.scalar(select(Account).where(Account.id == snippet.created_by))
if not account:
raise ValueError(f"Creator account not found for snippet {snippet_id}")
current_tenant = session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
if not current_tenant:
raise ValueError(f"Current tenant not found for account {account.id}")
account.set_tenant_id(current_tenant.tenant_id)
return account

View File

@ -1,66 +0,0 @@
import logging
from collections.abc import Mapping
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationDatasetInput,
EvaluationItemInput,
EvaluationItemResult,
NodeInfo,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class AgentEvaluationRunner(BaseEvaluationRunner):
"""Runner for agent evaluation: collects tool calls and final output."""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
super().__init__(evaluation_instance)
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
dataset_items: list[EvaluationDatasetInput] | None = None,
node_info: NodeInfo | None = None,
) -> list[EvaluationItemResult]:
"""Compute agent evaluation metrics."""
if not node_run_result_list:
return []
merged_items = self._merge_results_into_items(node_run_result_list)
return self.evaluation_instance.evaluate_agent(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
"""Create EvaluationItemInput list from NodeRunResult for agent evaluation."""
merged = []
for i, item in enumerate(items):
output = _extract_agent_output(item.outputs)
merged.append(
EvaluationItemInput(
index=i,
inputs=dict(item.inputs),
output=output,
)
)
return merged
def _extract_agent_output(outputs: Mapping[str, Any]) -> str:
"""Extract the primary output text from agent NodeRunResult.outputs."""
if "answer" in outputs:
return str(outputs["answer"])
if "text" in outputs:
return str(outputs["text"])
values = list(outputs.values())
return str(values[0]) if values else ""

View File

@ -1,55 +0,0 @@
"""Base evaluation runner.
Provides the abstract interface for metric computation. Each concrete runner
(LLM, Retrieval, Agent, Workflow, Snippet) implements ``evaluate_metrics``
to compute scores for a specific node type.
Orchestration (merging results from multiple runners, applying judgment, and
persisting to the database) is handled by the evaluation task, not the runner.
"""
import logging
from abc import ABC, abstractmethod
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationDatasetInput,
NodeInfo,
EvaluationItemResult,
)
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class BaseEvaluationRunner(ABC):
"""Abstract base class for evaluation runners.
Runners are stateless metric calculators: they receive node execution
results and a metric specification, then return scored results. They
do **not** touch the database or apply judgment logic.
"""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
self.evaluation_instance = evaluation_instance
@abstractmethod
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
dataset_items: list[EvaluationDatasetInput] | None = None,
node_info: NodeInfo | None = None,
) -> list[EvaluationItemResult]:
"""Compute evaluation metrics on the collected results.
The returned ``EvaluationItemResult.index`` values are positional
(0-based) and correspond to the order of *node_run_result_list*.
The caller is responsible for mapping them back to the original
dataset indices.
"""
...

View File

@ -1,107 +0,0 @@
import logging
import re
from collections.abc import Mapping
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationDatasetInput,
EvaluationItemInput,
EvaluationItemResult,
NodeInfo,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class LLMEvaluationRunner(BaseEvaluationRunner):
"""Runner for LLM evaluation: extracts prompts/outputs then evaluates."""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
super().__init__(evaluation_instance)
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
dataset_items: list[EvaluationDatasetInput] | None = None,
node_info: NodeInfo | None = None,
) -> list[EvaluationItemResult]:
"""Use the evaluation instance to compute LLM metrics."""
if not node_run_result_list:
return []
merged_items = self._merge_results_into_items(node_run_result_list, dataset_items, node_info)
return self.evaluation_instance.evaluate_llm(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _merge_results_into_items(
items: list[NodeRunResult],
dataset_items: list[EvaluationDatasetInput] | None = None,
node_info: NodeInfo | None = None,
) -> list[EvaluationItemInput]:
"""Create new items from NodeRunResult for ragas evaluation.
Extracts prompts from process_data and concatenates them into a single
string with role prefixes (e.g. "system: ...\nuser: ...\nassistant: ...").
The last assistant message in outputs is used as the actual output.
"""
merged = []
for i, item in enumerate(items):
prompts = item.process_data.get("prompts", [])
prompt = _format_prompts(prompts)
output = _extract_llm_output(item.outputs)
dataset_item = dataset_items[i] if dataset_items and i < len(dataset_items) else None
merged.append(
EvaluationItemInput(
index=i,
inputs={"prompt": prompt},
output=output,
expected_output=dataset_item.get_expected_output_for_node(node_info.title) if dataset_item else None,
context=_extract_context_blocks(prompts),
)
)
return merged
def _format_prompts(prompts: list[dict[str, Any]]) -> str:
"""Concatenate a list of prompt messages into a single string for evaluation.
Each message is formatted as "role: text" and joined with newlines.
"""
parts: list[str] = []
for msg in prompts:
role = msg.get("role", "unknown")
text = msg.get("text", "")
parts.append(f"{role}: {text}")
return "\n".join(parts)
def _extract_llm_output(outputs: Mapping[str, Any]) -> str:
"""Extract the LLM output text from NodeRunResult.outputs."""
if "text" in outputs:
return str(outputs["text"])
if "answer" in outputs:
return str(outputs["answer"])
values = list(outputs.values())
return str(values[0]) if values else ""
def _extract_context_blocks(prompts: list[dict[str, Any]]) -> list[str] | None:
"""Extract tagged context blocks from rendered prompts.
Evaluation only treats prompt content wrapped in ``<context>...</context>``
as retrieved evidence. This keeps faithfulness-style metrics opt-in and
avoids guessing which arbitrary prompt text should be considered context.
"""
prompt_text = "\n".join(str(prompt.get("text", "")) for prompt in prompts)
matches = re.findall(r"<context>(.*?)</context>", prompt_text, flags=re.DOTALL)
contexts = [match.strip() for match in matches if match.strip()]
return contexts or None

View File

@ -1,67 +0,0 @@
import logging
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationDatasetInput,
EvaluationItemInput,
EvaluationItemResult,
NodeInfo,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class RetrievalEvaluationRunner(BaseEvaluationRunner):
"""Runner for retrieval evaluation: performs knowledge base retrieval, then evaluates."""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
super().__init__(evaluation_instance)
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
dataset_items: list[EvaluationDatasetInput] | None = None,
node_info: NodeInfo | None = None,
) -> list[EvaluationItemResult]:
"""Compute retrieval evaluation metrics."""
if not node_run_result_list:
return []
merged_items = []
for i, node_result in enumerate(node_run_result_list):
outputs = node_result.outputs
query = self._extract_query(dict(node_result.inputs))
result_list = outputs.get("result", [])
contexts = [item.get("content", "") for item in result_list if item.get("content")]
output = "\n---\n".join(contexts)
dataset_item = dataset_items[i] if dataset_items and i < len(dataset_items) else None
merged_items.append(
EvaluationItemInput(
index=i,
inputs={"query": query},
output=output,
expected_output=dataset_item.get_expected_output_for_node(node_info.title) if dataset_item else None,
context=contexts,
)
)
return self.evaluation_instance.evaluate_retrieval(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _extract_query(inputs: dict[str, Any]) -> str:
for key in ("query", "question", "input", "text"):
if key in inputs:
return str(inputs[key])
values = list(inputs.values())
return str(values[0]) if values else ""

View File

@ -1,72 +0,0 @@
"""Runner for Snippet evaluation.
Snippets are essentially workflows, so we reuse ``evaluate_workflow`` from
the evaluation instance for metric computation.
"""
import logging
from collections.abc import Mapping
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationDatasetInput,
EvaluationItemInput,
EvaluationItemResult,
NodeInfo,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class SnippetEvaluationRunner(BaseEvaluationRunner):
"""Runner for snippet evaluation: evaluates a published Snippet workflow."""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
super().__init__(evaluation_instance)
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
dataset_items: list[EvaluationDatasetInput] | None = None,
node_info: NodeInfo | None = None,
) -> list[EvaluationItemResult]:
"""Compute evaluation metrics for snippet outputs."""
if not node_run_result_list:
return []
merged_items = self._merge_results_into_items(node_run_result_list)
return self.evaluation_instance.evaluate_workflow(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
"""Create EvaluationItemInput list from NodeRunResult for snippet evaluation."""
merged = []
for i, item in enumerate(items):
output = _extract_snippet_output(item.outputs)
merged.append(
EvaluationItemInput(
index=i,
inputs=dict(item.inputs),
output=output,
)
)
return merged
def _extract_snippet_output(outputs: Mapping[str, Any]) -> str:
"""Extract the primary output text from snippet NodeRunResult.outputs."""
if "answer" in outputs:
return str(outputs["answer"])
if "text" in outputs:
return str(outputs["text"])
values = list(outputs.values())
return str(values[0]) if values else ""

View File

@ -1,66 +0,0 @@
import logging
from collections.abc import Mapping
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationDatasetInput,
EvaluationItemInput,
EvaluationItemResult,
NodeInfo,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class WorkflowEvaluationRunner(BaseEvaluationRunner):
"""Runner for workflow evaluation: executes workflow App in non-streaming mode."""
def __init__(self, evaluation_instance: BaseEvaluationInstance):
super().__init__(evaluation_instance)
def evaluate_metrics(
self,
node_run_result_list: list[NodeRunResult],
default_metric: DefaultMetric,
model_provider: str,
model_name: str,
tenant_id: str,
dataset_items: list[EvaluationDatasetInput] | None = None,
node_info: NodeInfo | None = None,
) -> list[EvaluationItemResult]:
"""Compute workflow evaluation metrics (end-to-end)."""
if not node_run_result_list:
return []
merged_items = self._merge_results_into_items(node_run_result_list)
return self.evaluation_instance.evaluate_workflow(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
"""Create EvaluationItemInput list from NodeRunResult for workflow evaluation."""
merged = []
for i, item in enumerate(items):
output = _extract_workflow_output(item.outputs)
merged.append(
EvaluationItemInput(
index=i,
inputs=dict(item.inputs),
output=output,
)
)
return merged
def _extract_workflow_output(outputs: Mapping[str, Any]) -> str:
"""Extract the primary output text from workflow NodeRunResult.outputs."""
if "answer" in outputs:
return str(outputs["answer"])
if "text" in outputs:
return str(outputs["text"])
values = list(outputs.values())
return str(values[0]) if values else ""

View File

@ -2,6 +2,7 @@
Credential utility functions for checking credential existence and policy compliance.
"""
from configs import dify_config
from core.entities import PluginCredentialType
@ -39,6 +40,16 @@ def is_credential_exists(credential_id: str, credential_type: "PluginCredentialT
return False
def runtime_check_credential_policy_compliance(
credential_id: str, provider: str, credential_type: "PluginCredentialType", check_existence: bool = True
):
if dify_config.ENTERPRISE_DISABLE_RUNTIME_CREDENTIAL_CHECK:
return
check_credential_policy_compliance(
credential_id=credential_id, provider=provider, credential_type=credential_type, check_existence=check_existence
)
def check_credential_policy_compliance(
credential_id: str, provider: str, credential_type: "PluginCredentialType", check_existence: bool = True
) -> None:

View File

@ -13,8 +13,6 @@ from core.llm_generator.output_parser.rule_config_generator import RuleConfigGen
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import (
CONVERSATION_TITLE_PROMPT,
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM,
@ -217,8 +215,8 @@ class LLMGenerator:
else:
# Default-model generation keeps the built-in suggested-questions tuning.
model_parameters = {
"max_tokens": DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
"temperature": DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
"max_tokens": 2560,
"temperature": 0.0,
}
stop = []
@ -437,7 +435,7 @@ class LLMGenerator:
stream=False,
)
# Runtime type check since pyright has issues with the overload
# Runtime type check for overload narrowing.
if not isinstance(result, LLMResult):
raise TypeError("Expected LLMResult when stream=False")
response = result

View File

@ -104,10 +104,6 @@ DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
'["question1","question2","question3"]\n'
)
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS = 256
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE = 0.0
GENERATOR_QA_PROMPT = (
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
" in the long text. Please think step by step."

View File

@ -2,7 +2,7 @@ from collections.abc import Callable
from dataclasses import dataclass
from typing import Annotated, Any, Literal
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel, model_validator
from pydantic.networks import AnyUrl, UrlConstraints
"""
@ -173,7 +173,21 @@ class JSONRPCError(BaseModel):
class JSONRPCMessage(RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]):
pass
@model_validator(mode="before")
@classmethod
def _select_message_type(
cls, value: Any
) -> JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError | Any:
if isinstance(value, dict):
if "result" in value:
return JSONRPCResponse.model_validate(value)
if "error" in value:
return JSONRPCError.model_validate(value)
if "method" in value:
if "id" in value:
return JSONRPCRequest.model_validate(value)
return JSONRPCNotification.model_validate(value)
return value
class EmptyResult(Result):

Some files were not shown because too many files have changed in this diff Show More